Rawal Khirodkar commited on
Commit
ba23d94
·
1 Parent(s): aab83b1

Initial sapiens2-normal Space (HF download at startup, all 4 sizes)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +5 -0
  3. README.md +16 -5
  4. app.py +167 -0
  5. assets/configs/sapiens2_0.4b_normal_metasim_render_people-1024x768.py +304 -0
  6. assets/configs/sapiens2_0.8b_normal_metasim_render_people-1024x768.py +304 -0
  7. assets/configs/sapiens2_1b_normal_metasim_render_people-1024x768.py +306 -0
  8. assets/configs/sapiens2_5b_normal_metasim_render_people-1024x768.py +312 -0
  9. assets/images/68204.png +3 -0
  10. assets/images/68210.png +3 -0
  11. assets/images/68658.png +3 -0
  12. assets/images/68666.png +3 -0
  13. assets/images/68691.png +3 -0
  14. assets/images/68956.png +3 -0
  15. assets/images/pexels-amresh444-17315601.png +3 -0
  16. assets/images/pexels-gabby-k-6311686.png +3 -0
  17. assets/images/pexels-julia-m-cameron-4145040.png +3 -0
  18. assets/images/pexels-marcus-aurelius-6787357.png +3 -0
  19. assets/images/pexels-mo-saeed-3616599-5409085.png +3 -0
  20. assets/images/pexels-riedelmax-27355495.png +3 -0
  21. assets/images/pexels-sergeymakashin-5368660.png +3 -0
  22. assets/images/pexels-vinicius-wiesehofer-289347-4219918.png +3 -0
  23. requirements.txt +21 -0
  24. sapiens/__init__.py +14 -0
  25. sapiens/backbones/__init__.py +10 -0
  26. sapiens/backbones/sapiens.py +611 -0
  27. sapiens/backbones/sapiens2.py +916 -0
  28. sapiens/backbones/standalone/sapiens.py +648 -0
  29. sapiens/backbones/standalone/sapiens2.py +908 -0
  30. sapiens/dense/__init__.py +21 -0
  31. sapiens/dense/configs/albedo/render_people/sapiens2_0.4b_albedo_render_people-1024x768.py +274 -0
  32. sapiens/dense/configs/albedo/render_people/sapiens2_0.8b_albedo_render_people-1024x768.py +275 -0
  33. sapiens/dense/configs/albedo/render_people/sapiens2_1b_albedo_render_people-1024x768.py +274 -0
  34. sapiens/dense/configs/albedo/render_people/sapiens2_5b_albedo_render_people-1024x768.py +280 -0
  35. sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.4b_normal_metasim_render_people-1024x768.py +304 -0
  36. sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.8b_normal_metasim_render_people-1024x768.py +304 -0
  37. sapiens/dense/configs/normal/metasim_render_people/sapiens2_1b_normal_metasim_render_people-1024x768.py +306 -0
  38. sapiens/dense/configs/normal/metasim_render_people/sapiens2_5b_normal_metasim_render_people-1024x768.py +312 -0
  39. sapiens/dense/configs/pointmap/render_people/sapiens2_0.4b_pointmap_render_people-1024x768.py +322 -0
  40. sapiens/dense/configs/pointmap/render_people/sapiens2_0.8b_pointmap_render_people-1024x768.py +325 -0
  41. sapiens/dense/configs/pointmap/render_people/sapiens2_1b_pointmap_render_people-1024x768.py +319 -0
  42. sapiens/dense/configs/pointmap/render_people/sapiens2_5b_pointmap_render_people-1024x768.py +329 -0
  43. sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.4b_seg_shutterstock_goliath-1024x768.py +364 -0
  44. sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.8b_seg_shutterstock_goliath-1024x768.py +368 -0
  45. sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_1b_seg_shutterstock_goliath-1024x768.py +366 -0
  46. sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_5b_seg_shutterstock_goliath-1024x768.py +365 -0
  47. sapiens/dense/scripts/albedo/train/sapiens2_0.4b/node.sh +58 -0
  48. sapiens/dense/scripts/albedo/train/sapiens2_0.8b/node.sh +59 -0
  49. sapiens/dense/scripts/albedo/train/sapiens2_1b/node.sh +59 -0
  50. sapiens/dense/scripts/albedo/train/sapiens2_5b/node.sh +60 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ default.profraw
4
+ .DS_Store
5
+ *.log
README.md CHANGED
@@ -1,12 +1,23 @@
1
  ---
2
  title: Sapiens2 Normal
3
- emoji: 🏢
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.13.0
8
  app_file: app.py
 
9
  pinned: false
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
1
  ---
2
  title: Sapiens2 Normal
3
+ emoji: 🧊
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.42.0
8
  app_file: app.py
9
+ python_version: "3.12"
10
  pinned: false
11
+ license: other
12
+ license_name: sapiens2-license
13
+ license_link: https://github.com/facebookresearch/sapiens2/blob/main/LICENSE.md
14
  ---
15
 
16
+ # Sapiens2: Surface Normal Estimation
17
+ ### ICLR 2026
18
+
19
+ Per-pixel surface-normal estimation (3-channel unit vectors in camera frame).
20
+
21
+ - **Code:** [github.com/facebookresearch/sapiens2](https://github.com/facebookresearch/sapiens2)
22
+ - **Models:** [Sapiens2 collection](https://huggingface.co/facebook/sapiens2)
23
+ - **Paper:** https://openreview.net/pdf?id=IVAlYCqdvW
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sapiens2 surface-normal Gradio Space.
2
+
3
+ Image → per-pixel surface normals. Visualized by RGB-encoding the unit-length
4
+ (x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
10
+
11
+ import tempfile
12
+
13
+ import cv2
14
+ import gradio as gr
15
+ import numpy as np
16
+ import spaces
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from PIL import Image
20
+
21
+ from huggingface_hub import hf_hub_download
22
+ from sapiens.dense.models import NormalEstimator, init_model # NormalEstimator triggers registry
23
+ _ = NormalEstimator
24
+
25
+
26
+ # -----------------------------------------------------------------------------
27
+ # Config
28
+
29
+ ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
30
+ CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs")
31
+
32
+ NORMAL_MODELS = {
33
+ "0.4B": {
34
+ "repo": "facebook/sapiens2-normal-0.4b",
35
+ "filename": "sapiens2_0.4b_normal.safetensors",
36
+ "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_normal_metasim_render_people-1024x768.py"),
37
+ },
38
+ "0.8B": {
39
+ "repo": "facebook/sapiens2-normal-0.8b",
40
+ "filename": "sapiens2_0.8b_normal.safetensors",
41
+ "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_normal_metasim_render_people-1024x768.py"),
42
+ },
43
+ "1B": {
44
+ "repo": "facebook/sapiens2-normal-1b",
45
+ "filename": "sapiens2_1b_normal.safetensors",
46
+ "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_normal_metasim_render_people-1024x768.py"),
47
+ },
48
+ "5B": {
49
+ "repo": "facebook/sapiens2-normal-5b",
50
+ "filename": "sapiens2_5b_normal.safetensors",
51
+ "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_normal_metasim_render_people-1024x768.py"),
52
+ },
53
+ }
54
+ DEFAULT_SIZE = "1B"
55
+
56
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+
59
+ # -----------------------------------------------------------------------------
60
+ # Model cache
61
+
62
+ _normal_model_cache: dict = {}
63
+
64
+
65
+ def _get_normal_model(size: str):
66
+ if size not in _normal_model_cache:
67
+ spec = NORMAL_MODELS[size]
68
+ ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"])
69
+ model = init_model(spec["config"], ckpt, device=DEVICE)
70
+ _normal_model_cache[size] = model
71
+ return _normal_model_cache[size]
72
+
73
+
74
+ print("[startup] pre-loading all normal sizes ...")
75
+ for _size in NORMAL_MODELS:
76
+ _get_normal_model(_size)
77
+ print("[startup] ready.")
78
+
79
+
80
+ # -----------------------------------------------------------------------------
81
+ # Inference
82
+
83
+ def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray:
84
+ h0, w0 = image_bgr.shape[:2]
85
+ data = model.pipeline(dict(img=image_bgr))
86
+ data = model.data_preprocessor(data)
87
+ inputs = data["inputs"]
88
+ if inputs.ndim == 3:
89
+ inputs = inputs.unsqueeze(0)
90
+
91
+ with torch.no_grad():
92
+ normals = model(inputs) # (1, 3, H, W)
93
+
94
+ # Unit-length normalization, interpolate to original size, cast to numpy
95
+ normals = normals / normals.norm(dim=1, keepdim=True).clamp_min(1e-6)
96
+ normals = F.interpolate(normals, size=(h0, w0), mode="bilinear", align_corners=False)
97
+ normals = normals[0].cpu().float().numpy() # (3, H, W) in [-1, 1]
98
+ return normals.transpose(1, 2, 0) # (H, W, 3)
99
+
100
+
101
+ def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray:
102
+ rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
103
+ return rgb[:, :, ::-1] # match training viz channel order
104
+
105
+
106
+ # -----------------------------------------------------------------------------
107
+ # Gradio handler
108
+
109
+ @spaces.GPU(duration=120)
110
+ def predict(image: Image.Image, size: str):
111
+ if image is None:
112
+ return None, None
113
+
114
+ image_rgb = np.array(image.convert("RGB"))
115
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
116
+
117
+ model = _get_normal_model(size)
118
+ normals = _estimate_normal(image_bgr, model) # (H, W, 3) in [-1, 1]
119
+ rgb = _normal_to_rgb(normals)
120
+
121
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as f:
122
+ np.save(f.name, normals.astype(np.float32))
123
+ npy_path = f.name
124
+
125
+ return Image.fromarray(rgb), npy_path
126
+
127
+
128
+ # -----------------------------------------------------------------------------
129
+ # UI
130
+
131
+ EXAMPLES = sorted(
132
+ os.path.join(ASSETS_DIR, "images", n)
133
+ for n in os.listdir(os.path.join(ASSETS_DIR, "images"))
134
+ if n.lower().endswith((".jpg", ".jpeg", ".png"))
135
+ )
136
+
137
+ with gr.Blocks(title="Sapiens2 Normal", theme=gr.themes.Default()) as demo:
138
+ gr.Markdown(
139
+ "# Sapiens2: Surface Normal Estimation\n"
140
+ "### ICLR 2026\n"
141
+ "Per-pixel surface-normal estimation. Output is RGB-encoded (x, y, z → R, G, B).\n\n"
142
+ "[Code](https://github.com/facebookresearch/sapiens2) · "
143
+ "[Models](https://huggingface.co/facebook/sapiens2) · "
144
+ "[Paper](https://openreview.net/pdf?id=IVAlYCqdvW)"
145
+ )
146
+ with gr.Row():
147
+ with gr.Column():
148
+ inp = gr.Image(label="Input", type="pil")
149
+ size = gr.Radio(
150
+ choices=list(NORMAL_MODELS.keys()),
151
+ value=DEFAULT_SIZE,
152
+ label="Model size",
153
+ )
154
+ run = gr.Button("Run", variant="primary")
155
+ gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
156
+ with gr.Column():
157
+ out_img = gr.Image(label="Surface normal (RGB-encoded)", type="pil")
158
+ out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1])")
159
+
160
+ run.click(predict, inputs=[inp, size], outputs=[out_img, out_npy])
161
+
162
+
163
+ if __name__ == "__main__":
164
+ if torch.cuda.is_available():
165
+ torch.backends.cuda.matmul.allow_tf32 = True
166
+ torch.backends.cudnn.allow_tf32 = True
167
+ demo.launch(share=False)
assets/configs/sapiens2_0.4b_normal_metasim_render_people-1024x768.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 2e4
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 2
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_0.4b"
34
+ embed_dim = 1024
35
+ num_layers = 24
36
+ num_heads = 16
37
+
38
+ layer_decay_rate = 0.8
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors"
40
+
41
+ ##-----------------------------------------------------------------
42
+ image_size = (1024, 768) ## height x width
43
+ patch_size = 16
44
+
45
+ # ------------------------------------------------------------------
46
+ use_fsdp = True
47
+ # use_fsdp = False
48
+
49
+ use_compile = True
50
+ # use_compile = False
51
+
52
+ ## DDP config
53
+ if use_fsdp is False:
54
+ accelerator_cfg = dict(
55
+ type="DDP",
56
+ log_with="tensorboard",
57
+ # find_unused_parameters=True,
58
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
59
+ max_interval=num_iters,
60
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
61
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
62
+ )
63
+
64
+ else:
65
+ accelerator_cfg = dict(
66
+ type="FSDP",
67
+ log_with="tensorboard",
68
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
69
+ max_interval=num_iters,
70
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
71
+ step_scheduler_with_optimizer=False,
72
+ fsdp_cfg=dict(
73
+ fsdp_version=2, # DTensor-based engine
74
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
75
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
76
+ mixed_precision=dict(
77
+ param_dtype="bf16",
78
+ reduce_dtype="bf16",
79
+ ),
80
+ cpu_ram_efficient_loading=False,
81
+ ),
82
+ )
83
+
84
+ if use_compile:
85
+ accelerator_cfg["compile_cfg"] = dict(
86
+ backend="inductor",
87
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
88
+ fullgraph=False,
89
+ dynamic=False,
90
+ )
91
+
92
+ # ------------------------------------------------------------------
93
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
94
+ logger = dict(
95
+ type="Logger",
96
+ log_interval=log_every_iters,
97
+ )
98
+ checkpoint = dict(
99
+ type="Checkpointer",
100
+ save_interval=save_every_iters,
101
+ )
102
+
103
+ visualizer = dict(
104
+ type="NormalVisualizer",
105
+ vis_interval=vis_every_iters,
106
+ vis_max_samples=8,
107
+ vis_image_width=384,
108
+ vis_image_height=512,
109
+ )
110
+
111
+
112
+ ##-----------------------------------------------------------------
113
+ train_pipeline = [
114
+ dict(type="PhotoMetricDistortion"),
115
+ dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2),
116
+ dict(
117
+ type="NormalRandomScale",
118
+ scale_min=0.5,
119
+ scale_max=2.0,
120
+ prob=0.3,
121
+ ),
122
+ dict(
123
+ type="NormalRandomCropContinuous",
124
+ ar_range=(0.5, 2.0),
125
+ area_range=(0.4, 1.0),
126
+ num_attempts=8,
127
+ prob=0.3,
128
+ ),
129
+ dict(
130
+ type="NormalRandomFlip",
131
+ prob=0.3,
132
+ ),
133
+ dict(type="NormalResize", height=1024, width=768),
134
+ dict(
135
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
136
+ ),
137
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
138
+ dict(type="RandomSolarize", prob=0.3, threshold=128),
139
+ dict(type="NormalGenerateTarget"),
140
+ dict(
141
+ type="NormalPackInputs",
142
+ meta_keys=(
143
+ "img_path",
144
+ "ori_shape",
145
+ ),
146
+ ),
147
+ ]
148
+
149
+ val_pipeline = [
150
+ dict(type="NormalResize", height=1024, width=768, test_mode=True),
151
+ dict(
152
+ type="NormalPackInputs",
153
+ test_mode=True,
154
+ meta_keys=(
155
+ "img_path",
156
+ "orig_img_height",
157
+ "orig_img_width",
158
+ "img_shape",
159
+ "pad_shape",
160
+ ),
161
+ ),
162
+ ]
163
+
164
+ test_pipeline = [
165
+ dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0),
166
+ dict(
167
+ type="NormalPackInputs",
168
+ meta_keys=(
169
+ "img_path",
170
+ "orig_img_height",
171
+ "orig_img_width",
172
+ "padding_size",
173
+ ),
174
+ ),
175
+ ]
176
+
177
+ metasim_dataset = dict(
178
+ type="NormalMetaSimDataset",
179
+ airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data",
180
+ json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json",
181
+ )
182
+
183
+ render_people_dataset = dict(
184
+ type="NormalRenderPeopleBodyDataset", ## body only
185
+ data_root=f"{_DATA_ROOT}/synthetic",
186
+ seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg",
187
+ )
188
+
189
+ multihuman_render_people_dataset = dict(
190
+ type="NormalRenderPeopleMultihumanDataset",
191
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human",
192
+ normal_extension=".npz",
193
+ seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman
194
+ )
195
+
196
+ # train_datasets = 2 * [metasim_dataset] + [
197
+ # render_people_dataset,
198
+ # multihuman_render_people_dataset,
199
+ # ]
200
+
201
+ # train_datasets = [render_people_dataset]
202
+ # train_datasets = [multihuman_render_people_dataset]
203
+ train_datasets = [metasim_dataset]
204
+
205
+ train_dataloader = dict(
206
+ batch_size=1,
207
+ num_workers=4,
208
+ persistent_workers=True,
209
+ shuffle=True,
210
+ dataset=dict(
211
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
212
+ ),
213
+ )
214
+
215
+ val_dataloader = dict(
216
+ batch_size=4,
217
+ num_workers=4,
218
+ persistent_workers=True,
219
+ multiprocessing_context="spawn",
220
+ # num_workers=0, # debug
221
+ # persistent_workers=False, # debug
222
+ shuffle=False,
223
+ dataset=dict(
224
+ type="NormalRenderPeopleBodyDataset", ## body only
225
+ # num_samples=100, ## debug: only use N samples for validation
226
+ test_mode=True,
227
+ data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation",
228
+ pipeline=val_pipeline,
229
+ ),
230
+ )
231
+
232
+ val_cfg = dict(
233
+ val_interval=val_every_iters,
234
+ evaluator=dict(
235
+ type="NormalEvaluator",
236
+ ),
237
+ )
238
+
239
+ data_preprocessor = dict(
240
+ type="ImagePreprocessor",
241
+ mean=[123.675, 116.28, 103.53],
242
+ std=[58.395, 57.12, 57.375],
243
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
244
+ )
245
+
246
+ ##-----------------------------------------------------------------
247
+ model = dict(
248
+ type="NormalEstimator",
249
+ backbone=dict(
250
+ type="Sapiens2",
251
+ arch=model_name,
252
+ img_size=image_size,
253
+ patch_size=patch_size,
254
+ final_norm=True,
255
+ use_tokenizer=False,
256
+ with_cls_token=True,
257
+ out_type="featmap",
258
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
259
+ ),
260
+ decode_head=dict(
261
+ type="NormalHead",
262
+ in_channels=embed_dim,
263
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
264
+ conv_out_channels=[64, 32, 16],
265
+ conv_kernel_sizes=[3, 3, 3],
266
+ loss_decode=[
267
+ dict(
268
+ type="NormalCosineSimilarityLoss",
269
+ loss_weight=10.0,
270
+ ),
271
+ dict(type="L1Loss", loss_weight=1.0),
272
+ dict(type="NormalGradL1Loss", loss_weight=10.0),
273
+ ],
274
+ ),
275
+ )
276
+
277
+
278
+ ##-----------------------------------------------------------------
279
+ optimizer = dict(
280
+ type="AdamW",
281
+ lr=5e-4,
282
+ betas=(0.9, 0.999),
283
+ weight_decay=0.1,
284
+ paramwise_cfg=dict(
285
+ num_layers=num_layers,
286
+ layer_decay_rate=layer_decay_rate,
287
+ ),
288
+ fused=True,
289
+ )
290
+
291
+ scheduler = dict(
292
+ type="SequentialLR",
293
+ milestones=[warmup_iters],
294
+ schedulers=[
295
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
296
+ dict(
297
+ type="PolynomialLR",
298
+ total_iters=num_iters - warmup_iters,
299
+ power=1.0,
300
+ ),
301
+ ],
302
+ )
303
+
304
+ clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0)
assets/configs/sapiens2_0.8b_normal_metasim_render_people-1024x768.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 1e4
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 2
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_0.8b"
34
+ embed_dim = 1280
35
+ num_layers = 32
36
+ num_heads = 16
37
+
38
+ layer_decay_rate = 0.85
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors"
40
+
41
+ ##-----------------------------------------------------------------
42
+ image_size = (1024, 768) ## height x width
43
+ patch_size = 16
44
+
45
+ # ------------------------------------------------------------------
46
+ use_fsdp = True
47
+ # use_fsdp = False
48
+
49
+ use_compile = True
50
+ # use_compile = False
51
+
52
+ ## DDP config
53
+ if use_fsdp is False:
54
+ accelerator_cfg = dict(
55
+ type="DDP",
56
+ log_with="tensorboard",
57
+ # find_unused_parameters=True,
58
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
59
+ max_interval=num_iters,
60
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
61
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
62
+ )
63
+
64
+ else:
65
+ accelerator_cfg = dict(
66
+ type="FSDP",
67
+ log_with="tensorboard",
68
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
69
+ max_interval=num_iters,
70
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
71
+ step_scheduler_with_optimizer=False,
72
+ fsdp_cfg=dict(
73
+ fsdp_version=2, # DTensor-based engine
74
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
75
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
76
+ mixed_precision=dict(
77
+ param_dtype="bf16",
78
+ reduce_dtype="bf16",
79
+ ),
80
+ cpu_ram_efficient_loading=False,
81
+ ),
82
+ )
83
+
84
+ if use_compile:
85
+ accelerator_cfg["compile_cfg"] = dict(
86
+ backend="inductor",
87
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
88
+ fullgraph=False,
89
+ dynamic=False,
90
+ )
91
+
92
+ # ------------------------------------------------------------------
93
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
94
+ logger = dict(
95
+ type="Logger",
96
+ log_interval=log_every_iters,
97
+ )
98
+ checkpoint = dict(
99
+ type="Checkpointer",
100
+ save_interval=save_every_iters,
101
+ )
102
+
103
+ visualizer = dict(
104
+ type="NormalVisualizer",
105
+ vis_interval=vis_every_iters,
106
+ vis_max_samples=8,
107
+ vis_image_width=384,
108
+ vis_image_height=512,
109
+ )
110
+
111
+
112
+ ##-----------------------------------------------------------------
113
+ train_pipeline = [
114
+ dict(type="PhotoMetricDistortion"),
115
+ dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2),
116
+ dict(
117
+ type="NormalRandomScale",
118
+ scale_min=0.5,
119
+ scale_max=2.0,
120
+ prob=0.3,
121
+ ),
122
+ dict(
123
+ type="NormalRandomCropContinuous",
124
+ ar_range=(0.5, 2.0),
125
+ area_range=(0.4, 1.0),
126
+ num_attempts=8,
127
+ prob=0.3,
128
+ ),
129
+ dict(
130
+ type="NormalRandomFlip",
131
+ prob=0.3,
132
+ ),
133
+ dict(type="NormalResize", height=1024, width=768),
134
+ dict(
135
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
136
+ ),
137
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
138
+ dict(type="RandomSolarize", prob=0.3, threshold=128),
139
+ dict(type="NormalGenerateTarget"),
140
+ dict(
141
+ type="NormalPackInputs",
142
+ meta_keys=(
143
+ "img_path",
144
+ "ori_shape",
145
+ ),
146
+ ),
147
+ ]
148
+
149
+ val_pipeline = [
150
+ dict(type="NormalResize", height=1024, width=768, test_mode=True),
151
+ dict(
152
+ type="NormalPackInputs",
153
+ test_mode=True,
154
+ meta_keys=(
155
+ "img_path",
156
+ "orig_img_height",
157
+ "orig_img_width",
158
+ "img_shape",
159
+ "pad_shape",
160
+ ),
161
+ ),
162
+ ]
163
+
164
+ test_pipeline = [
165
+ dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0),
166
+ dict(
167
+ type="NormalPackInputs",
168
+ meta_keys=(
169
+ "img_path",
170
+ "orig_img_height",
171
+ "orig_img_width",
172
+ "padding_size",
173
+ ),
174
+ ),
175
+ ]
176
+
177
+ metasim_dataset = dict(
178
+ type="NormalMetaSimDataset",
179
+ airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data",
180
+ json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json",
181
+ )
182
+
183
+ render_people_dataset = dict(
184
+ type="NormalRenderPeopleBodyDataset", ## body only
185
+ data_root=f"{_DATA_ROOT}/synthetic",
186
+ seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg",
187
+ )
188
+
189
+ multihuman_render_people_dataset = dict(
190
+ type="NormalRenderPeopleMultihumanDataset",
191
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human",
192
+ normal_extension=".npz",
193
+ seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman
194
+ )
195
+
196
+ # train_datasets = 2 * [metasim_dataset] + [
197
+ # render_people_dataset,
198
+ # multihuman_render_people_dataset,
199
+ # ]
200
+
201
+ # train_datasets = [render_people_dataset]
202
+ # train_datasets = [multihuman_render_people_dataset]
203
+ train_datasets = [metasim_dataset]
204
+
205
+ train_dataloader = dict(
206
+ batch_size=1,
207
+ num_workers=4,
208
+ persistent_workers=True,
209
+ shuffle=True,
210
+ dataset=dict(
211
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
212
+ ),
213
+ )
214
+
215
+ val_dataloader = dict(
216
+ batch_size=4,
217
+ num_workers=4,
218
+ persistent_workers=True,
219
+ multiprocessing_context="spawn",
220
+ # num_workers=0, # debug
221
+ # persistent_workers=False, # debug
222
+ shuffle=False,
223
+ dataset=dict(
224
+ type="NormalRenderPeopleBodyDataset", ## body only
225
+ # num_samples=100, ## debug: only use N samples for validation
226
+ test_mode=True,
227
+ data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation",
228
+ pipeline=val_pipeline,
229
+ ),
230
+ )
231
+
232
+ val_cfg = dict(
233
+ val_interval=val_every_iters,
234
+ evaluator=dict(
235
+ type="NormalEvaluator",
236
+ ),
237
+ )
238
+
239
+ data_preprocessor = dict(
240
+ type="ImagePreprocessor",
241
+ mean=[123.675, 116.28, 103.53],
242
+ std=[58.395, 57.12, 57.375],
243
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
244
+ )
245
+
246
+ ##-----------------------------------------------------------------
247
+ model = dict(
248
+ type="NormalEstimator",
249
+ backbone=dict(
250
+ type="Sapiens2",
251
+ arch=model_name,
252
+ img_size=image_size,
253
+ patch_size=patch_size,
254
+ final_norm=True,
255
+ use_tokenizer=False,
256
+ with_cls_token=True,
257
+ out_type="featmap",
258
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
259
+ ),
260
+ decode_head=dict(
261
+ type="NormalHead",
262
+ in_channels=embed_dim,
263
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
264
+ conv_out_channels=[64, 32, 16],
265
+ conv_kernel_sizes=[3, 3, 3],
266
+ loss_decode=[
267
+ dict(
268
+ type="NormalCosineSimilarityLoss",
269
+ loss_weight=10.0,
270
+ ),
271
+ dict(type="L1Loss", loss_weight=1.0),
272
+ dict(type="NormalGradL1Loss", loss_weight=10.0),
273
+ ],
274
+ ),
275
+ )
276
+
277
+
278
+ ##-----------------------------------------------------------------
279
+ optimizer = dict(
280
+ type="AdamW",
281
+ lr=5e-4,
282
+ betas=(0.9, 0.999),
283
+ weight_decay=0.1,
284
+ paramwise_cfg=dict(
285
+ num_layers=num_layers,
286
+ layer_decay_rate=layer_decay_rate,
287
+ ),
288
+ fused=True,
289
+ )
290
+
291
+ scheduler = dict(
292
+ type="SequentialLR",
293
+ milestones=[warmup_iters],
294
+ schedulers=[
295
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
296
+ dict(
297
+ type="PolynomialLR",
298
+ total_iters=num_iters - warmup_iters,
299
+ power=1.0,
300
+ ),
301
+ ],
302
+ )
303
+
304
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
assets/configs/sapiens2_1b_normal_metasim_render_people-1024x768.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+ # num_iters = 1e4 ## light finetune
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_1b"
35
+ embed_dim = 1536
36
+ num_layers = 40
37
+ num_heads = 24
38
+ layer_decay_rate = 0.9
39
+
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ # ------------------------------------------------------------------
47
+ use_fsdp = True
48
+ # use_fsdp = False
49
+
50
+ use_compile = True
51
+ # use_compile = False
52
+
53
+ ## DDP config
54
+ if use_fsdp is False:
55
+ accelerator_cfg = dict(
56
+ type="DDP",
57
+ log_with="tensorboard",
58
+ # find_unused_parameters=True,
59
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
60
+ max_interval=num_iters,
61
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
62
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
63
+ )
64
+
65
+ else:
66
+ accelerator_cfg = dict(
67
+ type="FSDP",
68
+ log_with="tensorboard",
69
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
70
+ max_interval=num_iters,
71
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
72
+ step_scheduler_with_optimizer=False,
73
+ fsdp_cfg=dict(
74
+ fsdp_version=2, # DTensor-based engine
75
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
76
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
77
+ mixed_precision=dict(
78
+ param_dtype="bf16",
79
+ reduce_dtype="bf16",
80
+ ),
81
+ cpu_ram_efficient_loading=False,
82
+ ),
83
+ )
84
+
85
+ if use_compile:
86
+ accelerator_cfg["compile_cfg"] = dict(
87
+ backend="inductor",
88
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
89
+ fullgraph=False,
90
+ dynamic=False,
91
+ )
92
+
93
+ # ------------------------------------------------------------------
94
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
95
+ logger = dict(
96
+ type="Logger",
97
+ log_interval=log_every_iters,
98
+ )
99
+ checkpoint = dict(
100
+ type="Checkpointer",
101
+ save_interval=save_every_iters,
102
+ )
103
+
104
+ visualizer = dict(
105
+ type="NormalVisualizer",
106
+ vis_interval=vis_every_iters,
107
+ vis_max_samples=4,
108
+ vis_image_width=384,
109
+ vis_image_height=512,
110
+ )
111
+
112
+
113
+ ##-----------------------------------------------------------------
114
+ train_pipeline = [
115
+ dict(type="PhotoMetricDistortion"),
116
+ dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2),
117
+ dict(
118
+ type="NormalRandomScale",
119
+ scale_min=0.5,
120
+ scale_max=2.0,
121
+ prob=0.3,
122
+ ),
123
+ dict(
124
+ type="NormalRandomCropContinuous",
125
+ ar_range=(0.5, 2.0),
126
+ area_range=(0.4, 1.0),
127
+ num_attempts=8,
128
+ prob=0.3,
129
+ ),
130
+ dict(
131
+ type="NormalRandomFlip",
132
+ prob=0.3,
133
+ ),
134
+ dict(type="NormalResize", height=1024, width=768),
135
+ dict(
136
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
137
+ ),
138
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
139
+ dict(type="RandomSolarize", prob=0.3, threshold=128),
140
+ dict(type="NormalGenerateTarget"),
141
+ dict(
142
+ type="NormalPackInputs",
143
+ meta_keys=(
144
+ "img_path",
145
+ "ori_shape",
146
+ ),
147
+ ),
148
+ ]
149
+
150
+ val_pipeline = [
151
+ dict(type="NormalResize", height=1024, width=768, test_mode=True),
152
+ dict(
153
+ type="NormalPackInputs",
154
+ test_mode=True,
155
+ meta_keys=(
156
+ "img_path",
157
+ "orig_img_height",
158
+ "orig_img_width",
159
+ "img_shape",
160
+ "pad_shape",
161
+ ),
162
+ ),
163
+ ]
164
+
165
+ test_pipeline = [
166
+ dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0),
167
+ dict(
168
+ type="NormalPackInputs",
169
+ meta_keys=(
170
+ "img_path",
171
+ "orig_img_height",
172
+ "orig_img_width",
173
+ "padding_size",
174
+ ),
175
+ ),
176
+ ]
177
+
178
+ metasim_dataset = dict(
179
+ type="NormalMetaSimDataset",
180
+ airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data",
181
+ json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json",
182
+ )
183
+
184
+ render_people_dataset = dict(
185
+ type="NormalRenderPeopleBodyDataset", ## body only
186
+ data_root=f"{_DATA_ROOT}/synthetic",
187
+ seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg",
188
+ )
189
+
190
+ multihuman_render_people_dataset = dict(
191
+ type="NormalRenderPeopleMultihumanDataset",
192
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human",
193
+ normal_extension=".npz",
194
+ seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman
195
+ )
196
+
197
+ # train_datasets = 2 * [metasim_dataset] + [
198
+ # render_people_dataset,
199
+ # multihuman_render_people_dataset,
200
+ # ]
201
+
202
+ # train_datasets = [render_people_dataset]
203
+ # train_datasets = [multihuman_render_people_dataset]
204
+ train_datasets = [metasim_dataset]
205
+
206
+ train_dataloader = dict(
207
+ batch_size=1,
208
+ num_workers=4,
209
+ persistent_workers=True,
210
+ shuffle=True,
211
+ dataset=dict(
212
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
213
+ ),
214
+ )
215
+
216
+ val_dataloader = dict(
217
+ batch_size=4,
218
+ num_workers=4,
219
+ persistent_workers=True,
220
+ multiprocessing_context="spawn",
221
+ # num_workers=0, # debug
222
+ # persistent_workers=False, # debug
223
+ shuffle=False,
224
+ dataset=dict(
225
+ type="NormalRenderPeopleBodyDataset", ## body only
226
+ # num_samples=100, ## debug: only use N samples for validation
227
+ test_mode=True,
228
+ data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation",
229
+ pipeline=val_pipeline,
230
+ ),
231
+ )
232
+
233
+ val_cfg = dict(
234
+ val_interval=val_every_iters,
235
+ evaluator=dict(
236
+ type="NormalEvaluator",
237
+ ),
238
+ )
239
+
240
+ data_preprocessor = dict(
241
+ type="ImagePreprocessor",
242
+ mean=[123.675, 116.28, 103.53],
243
+ std=[58.395, 57.12, 57.375],
244
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
245
+ )
246
+
247
+ ##-----------------------------------------------------------------
248
+ model = dict(
249
+ type="NormalEstimator",
250
+ backbone=dict(
251
+ type="Sapiens2",
252
+ arch=model_name,
253
+ img_size=image_size,
254
+ patch_size=patch_size,
255
+ final_norm=True,
256
+ use_tokenizer=False,
257
+ # with_cls_token=False,
258
+ with_cls_token=True,
259
+ out_type="featmap",
260
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
261
+ ),
262
+ decode_head=dict(
263
+ type="NormalHead",
264
+ in_channels=embed_dim,
265
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
266
+ conv_out_channels=[64, 32, 16],
267
+ conv_kernel_sizes=[3, 3, 3],
268
+ loss_decode=[
269
+ dict(
270
+ type="NormalCosineSimilarityLoss",
271
+ loss_weight=10.0,
272
+ ),
273
+ dict(type="L1Loss", loss_weight=1.0),
274
+ dict(type="NormalGradL1Loss", loss_weight=10.0),
275
+ ],
276
+ ),
277
+ )
278
+
279
+
280
+ ##-----------------------------------------------------------------
281
+ optimizer = dict(
282
+ type="AdamW",
283
+ lr=5e-4,
284
+ betas=(0.9, 0.999),
285
+ weight_decay=0.1,
286
+ paramwise_cfg=dict(
287
+ num_layers=num_layers,
288
+ layer_decay_rate=layer_decay_rate,
289
+ ),
290
+ fused=True,
291
+ )
292
+
293
+ scheduler = dict(
294
+ type="SequentialLR",
295
+ milestones=[warmup_iters],
296
+ schedulers=[
297
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
298
+ dict(
299
+ type="PolynomialLR",
300
+ total_iters=num_iters - warmup_iters,
301
+ power=1.0,
302
+ ),
303
+ ],
304
+ )
305
+
306
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
assets/configs/sapiens2_5b_normal_metasim_render_people-1024x768.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+ # num_iters = 1e4 ## light finetune
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_5b"
35
+ embed_dim = 2432
36
+ num_layers = 56
37
+ num_heads = 32
38
+ layer_decay_rate = 0.94
39
+
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ # ------------------------------------------------------------------
47
+ use_fsdp = True
48
+ # use_fsdp = False
49
+
50
+ use_compile = True
51
+ # use_compile = False
52
+
53
+ ## DDP config
54
+ if use_fsdp is False:
55
+ accelerator_cfg = dict(
56
+ type="DDP",
57
+ log_with="tensorboard",
58
+ # find_unused_parameters=True,
59
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
60
+ max_interval=num_iters,
61
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
62
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
63
+ )
64
+
65
+ else:
66
+ accelerator_cfg = dict(
67
+ type="FSDP",
68
+ log_with="tensorboard",
69
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
70
+ max_interval=num_iters,
71
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
72
+ step_scheduler_with_optimizer=False,
73
+ fsdp_cfg=dict(
74
+ fsdp_version=2, # DTensor-based engine
75
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
76
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
77
+ mixed_precision=dict(
78
+ param_dtype="bf16",
79
+ reduce_dtype="bf16",
80
+ ),
81
+ cpu_ram_efficient_loading=False,
82
+ ),
83
+ # parallelism_cfg=dict(
84
+ # dp_shard_size=2, # Fully Sharded Data Parallel degree
85
+ # dp_replicate_size=1, # Data Parallel degree
86
+ # tp_size=1, # Tensor Parallel degree
87
+ # cp_size=4, # Context Parallel degree
88
+ # ),
89
+ )
90
+
91
+ if use_compile:
92
+ accelerator_cfg["compile_cfg"] = dict(
93
+ backend="inductor",
94
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
95
+ fullgraph=False,
96
+ dynamic=False,
97
+ )
98
+
99
+ # ------------------------------------------------------------------
100
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
101
+ logger = dict(
102
+ type="Logger",
103
+ log_interval=log_every_iters,
104
+ )
105
+ checkpoint = dict(
106
+ type="Checkpointer",
107
+ save_interval=save_every_iters,
108
+ )
109
+
110
+ visualizer = dict(
111
+ type="NormalVisualizer",
112
+ vis_interval=vis_every_iters,
113
+ vis_max_samples=4,
114
+ vis_image_width=384,
115
+ vis_image_height=512,
116
+ )
117
+
118
+
119
+ ##-----------------------------------------------------------------
120
+ train_pipeline = [
121
+ dict(type="PhotoMetricDistortion"),
122
+ dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2),
123
+ dict(
124
+ type="NormalRandomScale",
125
+ scale_min=0.5,
126
+ scale_max=2.0,
127
+ prob=0.3,
128
+ ),
129
+ dict(
130
+ type="NormalRandomCropContinuous",
131
+ ar_range=(0.5, 2.0),
132
+ area_range=(0.4, 1.0),
133
+ num_attempts=8,
134
+ prob=0.3,
135
+ ),
136
+ dict(
137
+ type="NormalRandomFlip",
138
+ prob=0.3,
139
+ ),
140
+ dict(type="NormalResize", height=1024, width=768),
141
+ dict(
142
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
143
+ ),
144
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
145
+ dict(type="RandomSolarize", prob=0.3, threshold=128),
146
+ dict(type="NormalGenerateTarget"),
147
+ dict(
148
+ type="NormalPackInputs",
149
+ meta_keys=(
150
+ "img_path",
151
+ "ori_shape",
152
+ ),
153
+ ),
154
+ ]
155
+
156
+ val_pipeline = [
157
+ dict(type="NormalResize", height=1024, width=768, test_mode=True),
158
+ dict(
159
+ type="NormalPackInputs",
160
+ test_mode=True,
161
+ meta_keys=(
162
+ "img_path",
163
+ "orig_img_height",
164
+ "orig_img_width",
165
+ "img_shape",
166
+ "pad_shape",
167
+ ),
168
+ ),
169
+ ]
170
+
171
+ test_pipeline = [
172
+ dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0),
173
+ dict(
174
+ type="NormalPackInputs",
175
+ meta_keys=(
176
+ "img_path",
177
+ "orig_img_height",
178
+ "orig_img_width",
179
+ "img_shape",
180
+ ),
181
+ ),
182
+ ]
183
+
184
+ metasim_dataset = dict(
185
+ type="NormalMetaSimDataset",
186
+ airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data",
187
+ json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json",
188
+ )
189
+
190
+ render_people_dataset = dict(
191
+ type="NormalRenderPeopleBodyDataset", ## body only
192
+ data_root=f"{_DATA_ROOT}/synthetic",
193
+ seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg",
194
+ )
195
+
196
+ multihuman_render_people_dataset = dict(
197
+ type="NormalRenderPeopleMultihumanDataset",
198
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human",
199
+ normal_extension=".npz",
200
+ seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman
201
+ )
202
+
203
+ # train_datasets = 2 * [metasim_dataset] + [
204
+ # render_people_dataset,
205
+ # multihuman_render_people_dataset,
206
+ # ]
207
+
208
+ # train_datasets = [render_people_dataset]
209
+ # train_datasets = [multihuman_render_people_dataset]
210
+ train_datasets = [metasim_dataset]
211
+
212
+ train_dataloader = dict(
213
+ batch_size=1,
214
+ num_workers=4,
215
+ persistent_workers=True,
216
+ shuffle=True,
217
+ dataset=dict(
218
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
219
+ ),
220
+ )
221
+
222
+ val_dataloader = dict(
223
+ batch_size=4,
224
+ num_workers=4,
225
+ persistent_workers=True,
226
+ multiprocessing_context="spawn",
227
+ # num_workers=0, # debug
228
+ # persistent_workers=False, # debug
229
+ shuffle=False,
230
+ dataset=dict(
231
+ type="NormalRenderPeopleBodyDataset", ## body only
232
+ # num_samples=100, ## debug: only use N samples for validation
233
+ test_mode=True,
234
+ data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation",
235
+ pipeline=val_pipeline,
236
+ ),
237
+ )
238
+
239
+ val_cfg = dict(
240
+ val_interval=val_every_iters,
241
+ evaluator=dict(
242
+ type="NormalEvaluator",
243
+ ),
244
+ )
245
+
246
+ data_preprocessor = dict(
247
+ type="ImagePreprocessor",
248
+ mean=[123.675, 116.28, 103.53],
249
+ std=[58.395, 57.12, 57.375],
250
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
251
+ )
252
+
253
+ ##-----------------------------------------------------------------
254
+ model = dict(
255
+ type="NormalEstimator",
256
+ backbone=dict(
257
+ type="Sapiens2",
258
+ arch=model_name,
259
+ img_size=image_size,
260
+ patch_size=patch_size,
261
+ final_norm=True,
262
+ use_tokenizer=False,
263
+ with_cls_token=True,
264
+ out_type="featmap",
265
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
266
+ ),
267
+ decode_head=dict(
268
+ type="NormalHead",
269
+ in_channels=embed_dim,
270
+ upsample_channels=[1536, 768, 512, 256], ## 1K resolution
271
+ conv_out_channels=[128, 64, 32],
272
+ conv_kernel_sizes=[3, 3, 3],
273
+ loss_decode=[
274
+ dict(
275
+ type="NormalCosineSimilarityLoss",
276
+ loss_weight=10.0,
277
+ ),
278
+ dict(type="L1Loss", loss_weight=1.0),
279
+ dict(type="NormalGradL1Loss", loss_weight=10.0),
280
+ ],
281
+ ),
282
+ )
283
+
284
+
285
+ ##-----------------------------------------------------------------
286
+ optimizer = dict(
287
+ type="AdamW",
288
+ # lr=5e-4,
289
+ lr=1e-4,
290
+ betas=(0.9, 0.999),
291
+ weight_decay=0.1,
292
+ paramwise_cfg=dict(
293
+ num_layers=num_layers,
294
+ layer_decay_rate=layer_decay_rate,
295
+ ),
296
+ fused=True,
297
+ )
298
+
299
+ scheduler = dict(
300
+ type="SequentialLR",
301
+ milestones=[warmup_iters],
302
+ schedulers=[
303
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
304
+ dict(
305
+ type="PolynomialLR",
306
+ total_iters=num_iters - warmup_iters,
307
+ power=1.0,
308
+ ),
309
+ ],
310
+ )
311
+
312
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
assets/images/68204.png ADDED

Git LFS Details

  • SHA256: 9b0268cb801ed164864a4b5f6d131e0ac5cc2fbd149a6467d5d0c97da47122c2
  • Pointer size: 132 Bytes
  • Size of remote file: 4.29 MB
assets/images/68210.png ADDED

Git LFS Details

  • SHA256: dbe5f80498af4ebd1ff09ae4184f37c20ba981e53bd554c3cc78d39ae0ee7fd7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.93 MB
assets/images/68658.png ADDED

Git LFS Details

  • SHA256: 61a68b619bd17235e683324f2826ce0693322e45ab8c86f1c057851ecb333ac7
  • Pointer size: 132 Bytes
  • Size of remote file: 5.1 MB
assets/images/68666.png ADDED

Git LFS Details

  • SHA256: ea3047e6c2ccb485fdb3966aa2325e803cbf49c27c0bff00287b44bc16f18914
  • Pointer size: 132 Bytes
  • Size of remote file: 4.56 MB
assets/images/68691.png ADDED

Git LFS Details

  • SHA256: fae39e4055c1b297af7068cdddfeeba8d685363281b839d8c5afac1980204b57
  • Pointer size: 132 Bytes
  • Size of remote file: 3.74 MB
assets/images/68956.png ADDED

Git LFS Details

  • SHA256: eee1f27082b10999d0fa848121ecb06cda3386b1a864b9aa0f59ae78261f8908
  • Pointer size: 132 Bytes
  • Size of remote file: 4.15 MB
assets/images/pexels-amresh444-17315601.png ADDED

Git LFS Details

  • SHA256: 4e17ee1b229147e4b52e8348a6ef426bc9e9a2f90738e776e15b26b325abb9b3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.5 MB
assets/images/pexels-gabby-k-6311686.png ADDED

Git LFS Details

  • SHA256: 3f10eded3fb05ab04b963f7b9fd2e183d8d4e81b20569b1c6b0653549639421f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.65 MB
assets/images/pexels-julia-m-cameron-4145040.png ADDED

Git LFS Details

  • SHA256: 459cf0280667b028ffbca16aa11188780d7a0205c0defec02916ff3cbaeecb72
  • Pointer size: 132 Bytes
  • Size of remote file: 2.92 MB
assets/images/pexels-marcus-aurelius-6787357.png ADDED

Git LFS Details

  • SHA256: 7d35452f76492125eaf7d5783aa9fd6b0d5990ebe0579fe9dfd58a9d634f4955
  • Pointer size: 132 Bytes
  • Size of remote file: 3.3 MB
assets/images/pexels-mo-saeed-3616599-5409085.png ADDED

Git LFS Details

  • SHA256: 7c1ca7afd6c2a654e94ef59d5fb56fca4f3cde5fb5216f6b218c34a7b8c143dc
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
assets/images/pexels-riedelmax-27355495.png ADDED

Git LFS Details

  • SHA256: 4141d2f5f718f162ea1f6710c06b28b5cb51fd69598fde35948f8f3491228164
  • Pointer size: 132 Bytes
  • Size of remote file: 3.73 MB
assets/images/pexels-sergeymakashin-5368660.png ADDED

Git LFS Details

  • SHA256: af8f5a8f26dd102d87d94c1be36ec903791fe8e6d951c68ebb9ebcfc6d7397bb
  • Pointer size: 132 Bytes
  • Size of remote file: 4.08 MB
assets/images/pexels-vinicius-wiesehofer-289347-4219918.png ADDED

Git LFS Details

  • SHA256: a6eef5eee15b81fe65ea95627e9a46040b9889466689b3c1ca6ed273e02fe84f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.63 MB
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.42.0
2
+ spaces
3
+
4
+ torch==2.7.1
5
+ torchvision==0.22.1
6
+
7
+ numpy
8
+ opencv-python
9
+ pillow
10
+ matplotlib
11
+ safetensors
12
+ huggingface_hub
13
+
14
+ # Sapiens2 deps (sapiens2 source is vendored under ./sapiens/, not pip-installed).
15
+ tqdm
16
+ scipy
17
+ iopath
18
+ prettytable
19
+ termcolor
20
+ accelerate
21
+ rich
sapiens/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .version import __version__
8
+ from .engine import *
9
+ from .backbones import *
10
+ from .dense import *
11
+ from .pose import *
12
+ from .registry import *
13
+
14
+ __all__ = ["__version__"]
sapiens/backbones/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sapiens import Sapiens
8
+ from .sapiens2 import Sapiens2
9
+
10
+ __all__ = ["Sapiens", "Sapiens2"]
sapiens/backbones/sapiens.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Sequence
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from sapiens.engine.models.base_model import BaseModel
14
+ from sapiens.registry import MODELS
15
+ from torch.nn import Linear, Sequential
16
+
17
+
18
+ # ----------------------------------------------------------------------------
19
+ def to_2tuple(x):
20
+ if isinstance(x, (str, bytes)):
21
+ return (x, x)
22
+ if isinstance(x, Sequence):
23
+ x = tuple(x)
24
+ if len(x) == 2:
25
+ return x
26
+ raise ValueError("Expected scalar or length-2 iterable")
27
+ return (x, x)
28
+
29
+
30
+ def resize_pos_embed(
31
+ pos_embed, src_shape, dst_shape, mode="bicubic", num_extra_tokens=1
32
+ ):
33
+ if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
34
+ return pos_embed
35
+ assert pos_embed.ndim == 3, "shape of pos_embed must be [1, L, C]"
36
+ _, L, C = pos_embed.shape
37
+ src_h, src_w = src_shape
38
+ assert L == src_h * src_w + num_extra_tokens, (
39
+ f"The length of `pos_embed` ({L}) doesn't match the expected "
40
+ f"shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the"
41
+ "`img_size` argument."
42
+ )
43
+ extra_tokens = pos_embed[:, :num_extra_tokens]
44
+
45
+ src_weight = pos_embed[:, num_extra_tokens:]
46
+ src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
47
+
48
+ # The cubic interpolate algorithm only accepts float32
49
+ dst_weight = F.interpolate(
50
+ src_weight.float(), size=dst_shape, align_corners=False, mode=mode
51
+ )
52
+ dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
53
+ dst_weight = dst_weight.to(src_weight.dtype)
54
+
55
+ return torch.cat((extra_tokens, dst_weight), dim=1)
56
+
57
+
58
+ # ----------------------------------------------------------------------------
59
+ class PatchEmbed(nn.Module):
60
+ def __init__(
61
+ self,
62
+ in_channels=3,
63
+ embed_dims=768,
64
+ kernel_size=16,
65
+ stride=16,
66
+ padding="corner",
67
+ dilation=1,
68
+ bias=True,
69
+ input_size=None,
70
+ ):
71
+ super().__init__()
72
+
73
+ self.embed_dims = embed_dims
74
+ if stride is None:
75
+ stride = kernel_size
76
+
77
+ kernel_size = to_2tuple(kernel_size)
78
+ stride = to_2tuple(stride)
79
+ dilation = to_2tuple(dilation)
80
+ padding = 0
81
+ padding = to_2tuple(padding)
82
+
83
+ self.projection = nn.Conv2d(
84
+ in_channels=in_channels,
85
+ out_channels=embed_dims,
86
+ kernel_size=kernel_size,
87
+ stride=stride,
88
+ padding=padding,
89
+ dilation=dilation,
90
+ bias=bias,
91
+ )
92
+
93
+ if input_size:
94
+ input_size = to_2tuple(input_size)
95
+ self.init_input_size = input_size
96
+
97
+ h_out = (
98
+ input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
99
+ ) // stride[0] + 1
100
+ w_out = (
101
+ input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
102
+ ) // stride[1] + 1
103
+ self.init_out_size = (h_out, w_out)
104
+ else:
105
+ self.init_input_size = None
106
+ self.init_out_size = None
107
+
108
+ def forward(self, x):
109
+ x = self.projection(x)
110
+ out_size = (x.shape[2], x.shape[3])
111
+ x = x.flatten(2).transpose(1, 2)
112
+ return x, out_size
113
+
114
+
115
+ # ----------------------------------------------------------------------------
116
+ class LayerScale(nn.Module):
117
+ def __init__(
118
+ self,
119
+ dim: int,
120
+ inplace: bool = False,
121
+ data_format: str = "channels_last",
122
+ scale: float = 1e-5,
123
+ ):
124
+ super().__init__()
125
+ assert data_format in (
126
+ "channels_last",
127
+ "channels_first",
128
+ ), "'data_format' could only be channels_last or channels_first."
129
+ self.inplace = inplace
130
+ self.data_format = data_format
131
+ self.weight = nn.Parameter(torch.ones(dim) * scale)
132
+
133
+ def forward(self, x) -> torch.Tensor:
134
+ if self.data_format == "channels_first":
135
+ shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
136
+ else:
137
+ shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
138
+ if self.inplace:
139
+ return x.mul_(self.weight.view(*shape))
140
+ else:
141
+ return x * self.weight.view(*shape)
142
+
143
+
144
+ # ----------------------------------------------------------------------------
145
+ class FFN(nn.Module):
146
+ def __init__(
147
+ self,
148
+ embed_dims=256,
149
+ feedforward_channels=1024,
150
+ num_fcs=2,
151
+ ffn_drop=0.0,
152
+ add_identity=True,
153
+ layer_scale_init_value=0.0,
154
+ ):
155
+ super().__init__()
156
+ assert num_fcs >= 2, f"num_fcs should be no less than 2. got {num_fcs}."
157
+ self.embed_dims = embed_dims
158
+ self.feedforward_channels = feedforward_channels
159
+ self.num_fcs = num_fcs
160
+
161
+ layers = []
162
+ in_channels = embed_dims
163
+ for _ in range(num_fcs - 1):
164
+ layers.append(
165
+ Sequential(
166
+ Linear(in_channels, feedforward_channels),
167
+ nn.GELU(),
168
+ nn.Dropout(ffn_drop),
169
+ )
170
+ )
171
+ in_channels = feedforward_channels
172
+ layers.append(Linear(feedforward_channels, embed_dims))
173
+ layers.append(nn.Dropout(ffn_drop))
174
+ self.layers = Sequential(*layers)
175
+ self.dropout_layer = nn.Identity()
176
+ self.add_identity = add_identity
177
+
178
+ if layer_scale_init_value > 0:
179
+ self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value)
180
+ else:
181
+ self.gamma2 = nn.Identity()
182
+
183
+ def forward(self, x, identity=None):
184
+ out = self.layers(x)
185
+ out = self.gamma2(out)
186
+ if not self.add_identity:
187
+ return out
188
+ if identity is None:
189
+ identity = x
190
+ return identity + out
191
+
192
+
193
+ # ----------------------------------------------------------------------------
194
+ class MultiheadAttention(nn.Module):
195
+ def __init__(
196
+ self,
197
+ embed_dims,
198
+ num_heads,
199
+ input_dims=None,
200
+ attn_drop=0.0,
201
+ proj_drop=0.0,
202
+ qkv_bias=True,
203
+ proj_bias=True,
204
+ v_shortcut=False,
205
+ ):
206
+ super(MultiheadAttention, self).__init__()
207
+
208
+ self.input_dims = input_dims or embed_dims
209
+ self.embed_dims = embed_dims
210
+ self.num_heads = num_heads
211
+ self.v_shortcut = v_shortcut
212
+
213
+ self.head_dims = embed_dims // num_heads
214
+ self.scaled_dot_product_attention = F.scaled_dot_product_attention
215
+
216
+ self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
217
+ self.attn_drop = attn_drop
218
+ self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
219
+ self.proj_drop = nn.Dropout(proj_drop)
220
+ self.gamma1 = nn.Identity()
221
+
222
+ def forward(self, x):
223
+ B, N, _ = x.shape
224
+ qkv = (
225
+ self.qkv(x)
226
+ .reshape(B, N, 3, self.num_heads, self.head_dims)
227
+ .permute(2, 0, 3, 1, 4)
228
+ )
229
+ q, k, v = qkv[0], qkv[1], qkv[2]
230
+
231
+ attn_drop = self.attn_drop if self.training else 0.0
232
+ x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
233
+ x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
234
+
235
+ x = self.proj(x)
236
+ x = self.gamma1(self.proj_drop(x))
237
+
238
+ if self.v_shortcut:
239
+ x = v.squeeze(1) + x
240
+ return x
241
+
242
+
243
+ # ----------------------------------------------------------------------------
244
+ class TransformerEncoderLayer(nn.Module):
245
+ def __init__(
246
+ self,
247
+ embed_dims,
248
+ num_heads,
249
+ feedforward_channels,
250
+ drop_rate=0.0,
251
+ attn_drop_rate=0.0,
252
+ num_fcs=2,
253
+ qkv_bias=True,
254
+ ):
255
+ super(TransformerEncoderLayer, self).__init__()
256
+
257
+ self.embed_dims = embed_dims
258
+ self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
259
+ self.attn = MultiheadAttention(
260
+ embed_dims=embed_dims,
261
+ num_heads=num_heads,
262
+ attn_drop=attn_drop_rate,
263
+ proj_drop=drop_rate,
264
+ qkv_bias=qkv_bias,
265
+ )
266
+
267
+ self.ln2 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
268
+ self.ffn = FFN(
269
+ embed_dims=embed_dims,
270
+ feedforward_channels=feedforward_channels,
271
+ num_fcs=num_fcs,
272
+ ffn_drop=drop_rate,
273
+ add_identity=True,
274
+ )
275
+
276
+ @property
277
+ def norm1(self):
278
+ return self.ln1
279
+
280
+ @property
281
+ def norm2(self):
282
+ return self.ln2
283
+
284
+ def forward(self, x):
285
+ x = x + self.attn(self.ln1(x))
286
+ x = self.ffn(self.ln2(x), identity=x)
287
+ return x
288
+
289
+
290
+ # ----------------------------------------------------------------------------
291
+ @MODELS.register_module()
292
+ class Sapiens(BaseModel):
293
+ arch_zoo = {
294
+ **dict.fromkeys( ## this is vit-large
295
+ ["0.3b", "sapiens_0.3b"],
296
+ {
297
+ "embed_dims": 1024,
298
+ "num_layers": 24,
299
+ "num_heads": 16,
300
+ "feedforward_channels": 1024 * 4,
301
+ },
302
+ ),
303
+ **dict.fromkeys( ## this is vit-huge
304
+ ["0.6b", "sapiens_0.6b"],
305
+ {
306
+ "embed_dims": 1280,
307
+ "num_layers": 32,
308
+ "num_heads": 16,
309
+ "feedforward_channels": 1280 * 4,
310
+ },
311
+ ),
312
+ **dict.fromkeys( ## this is vit-g
313
+ ["1b", "sapiens_1b"],
314
+ {
315
+ "embed_dims": 1536,
316
+ "num_layers": 40,
317
+ "num_heads": 24,
318
+ "feedforward_channels": 1536 * 4,
319
+ },
320
+ ),
321
+ **dict.fromkeys(
322
+ ["2b", "sapiens_2b"],
323
+ {
324
+ "embed_dims": 1920,
325
+ "num_layers": 48,
326
+ "num_heads": 32,
327
+ "feedforward_channels": 1920 * 4,
328
+ },
329
+ ),
330
+ }
331
+ num_extra_tokens = 1 # class token
332
+ OUT_TYPES = {"raw", "cls_token", "featmap"}
333
+
334
+ def __init__(
335
+ self,
336
+ arch="base",
337
+ img_size=1024,
338
+ patch_size=16,
339
+ in_channels=3,
340
+ out_indices=-1,
341
+ drop_rate=0.0,
342
+ qkv_bias=True,
343
+ final_norm=True,
344
+ out_type="cls_token",
345
+ with_cls_token=True,
346
+ frozen_stages=-1,
347
+ interpolate_mode="bicubic",
348
+ patch_cfg=dict(),
349
+ layer_cfgs=dict(),
350
+ init_cfg=None,
351
+ ):
352
+ super(Sapiens, self).__init__(init_cfg=init_cfg)
353
+
354
+ arch = arch.lower()
355
+ assert arch in set(self.arch_zoo), (
356
+ f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
357
+ )
358
+ self.arch_settings = self.arch_zoo[arch]
359
+
360
+ self.embed_dims = self.arch_settings["embed_dims"]
361
+ self.num_layers = self.arch_settings["num_layers"]
362
+ self.img_size = to_2tuple(img_size)
363
+ self.patch_size = patch_size
364
+
365
+ # Set patch embedding
366
+ _patch_cfg = dict(
367
+ in_channels=in_channels,
368
+ input_size=img_size,
369
+ embed_dims=self.embed_dims,
370
+ kernel_size=patch_size,
371
+ stride=patch_size,
372
+ bias=True,
373
+ )
374
+ _patch_cfg.update(patch_cfg)
375
+ self.patch_embed = PatchEmbed(**_patch_cfg)
376
+ self.patch_resolution = self.patch_embed.init_out_size
377
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
378
+
379
+ # Set out type
380
+ if out_type not in self.OUT_TYPES:
381
+ raise ValueError(
382
+ f"Unsupported `out_type` {out_type}, please "
383
+ f"choose from {self.OUT_TYPES}"
384
+ )
385
+ self.out_type = out_type
386
+
387
+ # Set cls token
388
+ self.with_cls_token = with_cls_token
389
+ if with_cls_token:
390
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
391
+ elif out_type != "cls_token":
392
+ self.cls_token = None
393
+ self.num_extra_tokens = 0
394
+ else:
395
+ raise ValueError('with_cls_token must be True when `out_type="cls_token"`.')
396
+
397
+ # Set position embedding
398
+ self.interpolate_mode = interpolate_mode
399
+ self.pos_embed = nn.Parameter(
400
+ torch.zeros(1, num_patches + self.num_extra_tokens, self.embed_dims)
401
+ )
402
+ self.drop_after_pos = nn.Dropout(p=drop_rate)
403
+
404
+ if isinstance(out_indices, int):
405
+ out_indices = [out_indices]
406
+ assert isinstance(out_indices, Sequence), (
407
+ f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.'
408
+ )
409
+ for i, index in enumerate(out_indices):
410
+ if index < 0:
411
+ out_indices[i] = self.num_layers + index
412
+ assert 0 <= out_indices[i] <= self.num_layers, (
413
+ f"Invalid out_indices {index}"
414
+ )
415
+ self.out_indices = out_indices
416
+
417
+ self.layers = nn.Sequential()
418
+ if isinstance(layer_cfgs, dict):
419
+ layer_cfgs = [layer_cfgs] * self.num_layers
420
+ for i in range(self.num_layers):
421
+ _layer_cfg = dict(
422
+ embed_dims=self.embed_dims,
423
+ num_heads=self.arch_settings["num_heads"],
424
+ feedforward_channels=self.arch_settings["feedforward_channels"],
425
+ drop_rate=drop_rate,
426
+ qkv_bias=qkv_bias,
427
+ )
428
+ _layer_cfg.update(layer_cfgs[i])
429
+ self.layers.append(TransformerEncoderLayer(**_layer_cfg))
430
+
431
+ self.frozen_stages = frozen_stages
432
+ self.pre_norm = nn.Identity()
433
+
434
+ self.final_norm = final_norm
435
+ if final_norm:
436
+ self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
437
+
438
+ # freeze stages only when self.frozen_stages > 0
439
+ if self.frozen_stages > 0:
440
+ self._freeze_stages()
441
+
442
+ self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
443
+
444
+ self.init_weights()
445
+
446
+ return
447
+
448
+ def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
449
+ name = prefix + "pos_embed"
450
+ if name not in state_dict.keys():
451
+ return
452
+
453
+ ckpt_pos_embed_shape = state_dict[name].shape
454
+
455
+ from sapiens.engine.logger import Logger
456
+
457
+ logger = Logger.get_current_instance()
458
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
459
+
460
+ # Handle class token removal if needed
461
+ if not self.with_cls_token:
462
+ if ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1:
463
+ # Remove cls token from state dict if it's not used
464
+ state_dict[name] = state_dict[name][:, 1:]
465
+ ckpt_pos_embed_shape = state_dict[name].shape
466
+ elif ckpt_pos_embed_shape[1] % 2 == 1:
467
+ # Remove class token when interpolation is required
468
+ if rank == 0:
469
+ logger.info(
470
+ "Note: removing the class token from pretrained weights"
471
+ )
472
+ state_dict[name] = state_dict[name][:, 1:]
473
+ ckpt_pos_embed_shape = state_dict[name].shape
474
+
475
+ # Skip if shapes already match
476
+ if self.pos_embed.shape == ckpt_pos_embed_shape:
477
+ return
478
+
479
+ if rank == 0:
480
+ logger.info(
481
+ f"Resize the pos_embed shape from {ckpt_pos_embed_shape} "
482
+ f"to {self.pos_embed.shape}."
483
+ )
484
+
485
+ # Calculate grid dimensions
486
+ pos_h, pos_w = self.patch_embed.init_out_size
487
+ assert pos_h >= pos_w # for vertical aspect ratio or square
488
+
489
+ # Number of non-extra tokens in checkpoint
490
+ num_vis = ckpt_pos_embed_shape[1] - self.num_extra_tokens
491
+
492
+ # Determine original grid shape
493
+ side = int(math.sqrt(num_vis))
494
+ factor = int(math.sqrt((num_vis * self.patch_size * self.patch_size) // 12))
495
+
496
+ # Set old grid based on aspect ratio detection
497
+ if side * side == num_vis:
498
+ old_grid = (side, side) # square grid
499
+ elif 4 * factor * 3 * factor == num_vis * self.patch_size * self.patch_size:
500
+ old_grid = (
501
+ (factor * 4) // self.patch_size,
502
+ (factor * 3) // self.patch_size,
503
+ ) # 4:3 ratio
504
+ else:
505
+ if rank == 0:
506
+ logger.warning(
507
+ f"Original pos_embed tokens ({num_vis}) not square or 4:3 does not match current size"
508
+ )
509
+ state_dict[name] = self.pos_embed
510
+ return
511
+
512
+ # Resize position embedding
513
+ new_grid = (pos_h, pos_w)
514
+ state_dict[name] = resize_pos_embed(
515
+ state_dict[name],
516
+ old_grid,
517
+ new_grid,
518
+ mode=self.interpolate_mode,
519
+ num_extra_tokens=self.num_extra_tokens,
520
+ )
521
+
522
+ @property
523
+ def norm1(self):
524
+ return self.ln1
525
+
526
+ @property
527
+ def norm2(self):
528
+ return self.ln2
529
+
530
+ @staticmethod
531
+ def resize_pos_embed(*args, **kwargs):
532
+ """Interface for backward-compatibility."""
533
+ return resize_pos_embed(*args, **kwargs)
534
+
535
+ def _freeze_stages(self):
536
+ # freeze position embedding
537
+ if self.pos_embed is not None:
538
+ self.pos_embed.requires_grad = False
539
+
540
+ # set dropout to eval model
541
+ self.drop_after_pos.eval()
542
+ # freeze patch embedding
543
+ self.patch_embed.eval()
544
+ for param in self.patch_embed.parameters():
545
+ param.requires_grad = False
546
+ # freeze pre-norm
547
+ for param in self.pre_norm.parameters():
548
+ param.requires_grad = False
549
+ # freeze cls_token
550
+ if self.cls_token is not None:
551
+ self.cls_token.requires_grad = False
552
+ # freeze layers
553
+ for i in range(1, self.frozen_stages + 1):
554
+ m = self.layers[i - 1]
555
+ m.eval()
556
+ for param in m.parameters():
557
+ param.requires_grad = False
558
+ # freeze the last layer norm
559
+ if self.frozen_stages == len(self.layers):
560
+ if self.final_norm:
561
+ self.ln1.eval()
562
+ for param in self.ln1.parameters():
563
+ param.requires_grad = False
564
+
565
+ if self.out_type == "avg_featmap":
566
+ self.ln2.eval()
567
+ for param in self.ln2.parameters():
568
+ param.requires_grad = False
569
+
570
+ def forward(self, x):
571
+ B = x.shape[0]
572
+ x, patch_resolution = self.patch_embed(x)
573
+
574
+ if self.cls_token is not None:
575
+ cls_token = self.cls_token.expand(B, -1, -1)
576
+ x = torch.cat((cls_token, x), dim=1)
577
+
578
+ x = x + resize_pos_embed(
579
+ self.pos_embed,
580
+ self.patch_resolution,
581
+ patch_resolution,
582
+ mode=self.interpolate_mode,
583
+ num_extra_tokens=self.num_extra_tokens,
584
+ )
585
+ x = self.drop_after_pos(x)
586
+
587
+ x = self.pre_norm(x) ## B x (num tokens) x embed_dim
588
+
589
+ outs = []
590
+ for i, layer in enumerate(self.layers):
591
+ x = layer(x)
592
+
593
+ if i == len(self.layers) - 1 and self.final_norm:
594
+ x = self.ln1(x)
595
+
596
+ if i in self.out_indices:
597
+ outs.append(self._format_output(x, patch_resolution))
598
+
599
+ return tuple(outs)
600
+
601
+ def _format_output(self, x, hw):
602
+ if self.out_type == "raw":
603
+ return x
604
+ if self.out_type == "cls_token":
605
+ return x[:, 0]
606
+
607
+ patch_token = x[:, self.num_extra_tokens :]
608
+ if self.out_type == "featmap":
609
+ B = x.size(0)
610
+ # (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
611
+ return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
sapiens/backbones/sapiens2.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from sapiens.engine.models.base_model import BaseModel
14
+ from sapiens.registry import MODELS
15
+ from torch import Tensor
16
+ from torch.nn.init import trunc_normal_
17
+ from torch.utils.checkpoint import checkpoint
18
+
19
+
20
+ # ----------------------------------------------------------------------------
21
+ def to_2tuple(x):
22
+ if isinstance(x, (str, bytes)):
23
+ return (x, x)
24
+ if isinstance(x, Sequence):
25
+ x = tuple(x)
26
+ if len(x) == 2:
27
+ return x
28
+ raise ValueError("Expected scalar or length-2 iterable")
29
+ return (x, x)
30
+
31
+
32
+ class RopePositionEmbedding(nn.Module):
33
+ def __init__(
34
+ self,
35
+ embed_dim: int,
36
+ *,
37
+ num_heads: int,
38
+ base: float | None = 100.0,
39
+ min_period: float | None = None,
40
+ max_period: float | None = None,
41
+ normalize_coords: Literal["min", "max", "separate"] = "separate",
42
+ shift_coords: float | None = None,
43
+ jitter_coords: float | None = None,
44
+ rescale_coords: float | None = None,
45
+ dtype: torch.dtype | None = None,
46
+ device: torch.device | None = None,
47
+ ):
48
+ super().__init__()
49
+ assert embed_dim % (4 * num_heads) == 0
50
+ both_periods = min_period is not None and max_period is not None
51
+ if (base is None and not both_periods) or (base is not None and both_periods):
52
+ raise ValueError(
53
+ "Either `base` or `min_period`+`max_period` must be provided."
54
+ )
55
+
56
+ D_head = embed_dim // num_heads
57
+ self.base = base
58
+ self.min_period = min_period
59
+ self.max_period = max_period
60
+ self.D_head = D_head
61
+ self.normalize_coords = normalize_coords
62
+ self.shift_coords = shift_coords
63
+ self.jitter_coords = jitter_coords
64
+ self.rescale_coords = rescale_coords
65
+
66
+ # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
67
+ self.dtype = dtype or torch.float32 # Don't rely on self.periods.dtype
68
+ self.register_buffer(
69
+ "periods",
70
+ torch.empty(D_head // 4, device=device, dtype=self.dtype),
71
+ persistent=True,
72
+ )
73
+ self._init_weights()
74
+
75
+ def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
76
+ device = self.periods.device
77
+ dtype = self.dtype
78
+ dd = {"device": device, "dtype": dtype}
79
+ # Prepare coords in range [-1, +1]
80
+ if self.normalize_coords == "max":
81
+ max_HW = max(H, W)
82
+ coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
83
+ coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
84
+ elif self.normalize_coords == "min":
85
+ min_HW = min(H, W)
86
+ coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
87
+ coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
88
+ elif self.normalize_coords == "separate":
89
+ coords_h = torch.arange(0.5, H, **dd) / H # [H]
90
+ coords_w = torch.arange(0.5, W, **dd) / W # [W]
91
+ else:
92
+ raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
93
+ coords = torch.stack(
94
+ torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1
95
+ ) # [H, W, 2]
96
+ coords = coords.flatten(0, 1) # [HW, 2]
97
+ coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
98
+
99
+ # Shift coords by adding a uniform value in [-shift, shift]
100
+ if self.training and self.shift_coords is not None:
101
+ shift_hw = torch.empty(2, **dd).uniform_(
102
+ -self.shift_coords, self.shift_coords
103
+ )
104
+ coords += shift_hw[None, :]
105
+
106
+ # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
107
+ if self.training and self.jitter_coords is not None:
108
+ jitter_max = np.log(self.jitter_coords)
109
+ jitter_min = -jitter_max
110
+ jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
111
+ coords *= jitter_hw[None, :]
112
+
113
+ # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
114
+ if self.training and self.rescale_coords is not None:
115
+ rescale_max = np.log(self.rescale_coords)
116
+ rescale_min = -rescale_max
117
+ rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
118
+ coords *= rescale_hw
119
+
120
+ # Prepare angles and sin/cos
121
+ angles = (
122
+ 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
123
+ ) # [HW, 2, D//4]
124
+ angles = angles.flatten(1, 2) # [HW, D//2]
125
+ angles = angles.tile(2) # [HW, D]
126
+ cos = torch.cos(angles) # [HW, D]
127
+ sin = torch.sin(angles) # [HW, D]
128
+
129
+ return (sin, cos) # 2 * [HW, D]
130
+
131
+ def _init_weights(self):
132
+ device = self.periods.device
133
+ dtype = self.dtype
134
+ if self.base is not None:
135
+ periods = self.base ** (
136
+ 2
137
+ * torch.arange(self.D_head // 4, device=device, dtype=dtype)
138
+ / (self.D_head // 2)
139
+ ) # [D//4]
140
+ else:
141
+ base = self.max_period / self.min_period
142
+ exponents = torch.linspace(
143
+ 0, 1, self.D_head // 4, device=device, dtype=dtype
144
+ ) # [D//4] range [0, 1]
145
+ periods = base**exponents # range [1, max_period / min_period]
146
+ periods = periods / base # range [min_period / max_period, 1]
147
+ periods = periods * self.max_period # range [min_period, max_period]
148
+ self.periods.data = periods
149
+
150
+
151
+ # -------------------------------------------------------------------------------
152
+ class Tokenizer(nn.Module):
153
+ """Stacked window self‑attention that emits one token per window
154
+ by re‑using TransformerEncoderLayer blocks."""
155
+
156
+ def __init__(
157
+ self,
158
+ embed_dims: int,
159
+ window_size: int = 4,
160
+ num_heads: int = 4,
161
+ num_tokenizer_layers: int = 1,
162
+ qkv_bias: bool = True,
163
+ use_qk_norm: bool = False,
164
+ chunk_size: int = 1024, # max windows per chunk
165
+ ):
166
+ super().__init__()
167
+ self.ws = window_size
168
+ self.chunk_size = chunk_size
169
+
170
+ # local absolute positional embeddings for [CLS] + patch tokens
171
+ self.local_pos_embed = nn.Parameter(
172
+ torch.zeros(1, 1 + window_size * window_size, embed_dims)
173
+ )
174
+ trunc_normal_(self.local_pos_embed, std=0.02)
175
+
176
+ # build N identical TransformerEncoderLayer blocks
177
+ self.blocks = nn.ModuleList(
178
+ [
179
+ TransformerEncoderLayer2(
180
+ embed_dims=embed_dims,
181
+ num_heads=num_heads,
182
+ feedforward_channels=embed_dims * 4, # standard FFN size
183
+ qkv_bias=qkv_bias,
184
+ use_qk_norm=use_qk_norm,
185
+ )
186
+ for _ in range(num_tokenizer_layers)
187
+ ]
188
+ )
189
+
190
+ # shared CLS token for pooling
191
+ self.w_cls = nn.Parameter(torch.zeros(1, 1, embed_dims))
192
+ trunc_normal_(self.w_cls, std=0.02)
193
+
194
+ def forward(
195
+ self,
196
+ x: torch.Tensor,
197
+ hw: Tuple[int, int],
198
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
199
+ """Args:
200
+ x : B, N, C (N = H*W)
201
+ hw : (H, W) before reduction
202
+ Returns:
203
+ x_ : B, (H/ws)*(W/ws), C
204
+ hw_: (H/ws, W/ws)
205
+ """
206
+ B, N, C = x.shape
207
+ H, W = hw
208
+ ws = self.ws
209
+ assert H % ws == 0 and W % ws == 0, (
210
+ f"Image size {H}×{W} must be divisible by window {ws}."
211
+ )
212
+
213
+ # reshape tokens → non‑overlapping windows
214
+ x = x.view(B, H, W, C)
215
+
216
+ ph, pw = H // ws, W // ws ## ints in eager mode
217
+ ph, pw = int(ph), int(pw) ## ints in scripting mode
218
+ x = x.view(B, ph, ws, pw, ws, C) # B, H/ws, ws, W/ws, ws, C
219
+ x = x.permute(0, 1, 3, 2, 4, 5) # B, H/ws, W/ws, ws, ws, C
220
+ x = x.contiguous().view(B * ph * pw, ws * ws, C) # (B*H/ws*W/ws), ws², C))
221
+
222
+ total_windows = x.size(0)
223
+ chunk_size = int(min(self.chunk_size, total_windows))
224
+ token_out = x.new_empty(total_windows, C)
225
+
226
+ use_ckpt = self.training and torch.is_grad_enabled()
227
+
228
+ def _run_blocks(t: torch.Tensor) -> torch.Tensor:
229
+ for blk in self.blocks:
230
+ t = blk(t)
231
+ return t
232
+
233
+ for i in range(0, total_windows, chunk_size):
234
+ chunk = x[i : i + chunk_size] # (m, ws², C)
235
+ m = chunk.size(0)
236
+ cls = self.w_cls.expand(m, -1, -1) # (m, 1, C)
237
+ chunk = torch.cat([cls, chunk], dim=1) # (m, 1+ws², C)
238
+ chunk = chunk + self.local_pos_embed # add local PE
239
+
240
+ if use_ckpt:
241
+ chunk = checkpoint(_run_blocks, chunk, use_reentrant=False)
242
+ else:
243
+ chunk = _run_blocks(chunk)
244
+
245
+ token_out[i : i + m] = chunk[:, 0] # take CLS out
246
+
247
+ token = token_out.view(B, ph * pw, C) # (B, (H/ws)*(W
248
+ return token, (ph, pw)
249
+
250
+
251
+ # -------------------------------------------------------------------------------
252
+ class GroupedQueryAttention(nn.Module):
253
+ def __init__(
254
+ self,
255
+ embed_dims,
256
+ num_heads,
257
+ num_kv_heads=None,
258
+ input_dims=None,
259
+ attn_drop=0.0,
260
+ proj_drop=0.0,
261
+ qkv_bias=True,
262
+ qk_scale=None,
263
+ proj_bias=True,
264
+ use_qk_norm=True,
265
+ v_shortcut=False,
266
+ layer_scale_init_value=0.0,
267
+ ):
268
+ super().__init__()
269
+ # Core dims
270
+ self.embed_dims = embed_dims
271
+ self.num_heads = num_heads
272
+ self.num_kv_heads = num_kv_heads or num_heads
273
+ assert self.num_heads % self.num_kv_heads == 0, (
274
+ "num_kv_heads must divide num_heads"
275
+ )
276
+ self.head_dim = embed_dims // num_heads
277
+ self.input_dims = input_dims or embed_dims
278
+ # Features
279
+ self.attn_drop = attn_drop
280
+ self.v_shortcut = v_shortcut
281
+ self.use_qk_norm = use_qk_norm
282
+
283
+ # Attention operation selection
284
+ if qk_scale is not None:
285
+ scale = qk_scale
286
+ else:
287
+ scale = self.head_dim**-0.5
288
+
289
+ assert qk_scale is None, "qk_scale is not supported"
290
+ self.attn_op = F.scaled_dot_product_attention
291
+
292
+ # Q/K/V projections
293
+ self.wq = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias)
294
+ self.wk = nn.Linear(
295
+ self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias
296
+ )
297
+ self.wv = nn.Linear(
298
+ self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias
299
+ )
300
+
301
+ if self.use_qk_norm:
302
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=1e-6)
303
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=1e-6)
304
+
305
+ # Output projection + dropout
306
+ self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
307
+ self.proj_drop = nn.Dropout(proj_drop)
308
+
309
+ # Optional LayerScale
310
+ if layer_scale_init_value > 0:
311
+ self.gamma = LayerScale(embed_dims, scale=layer_scale_init_value)
312
+ else:
313
+ self.gamma = nn.Identity()
314
+
315
+ def apply_rope(
316
+ self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]
317
+ ) -> Tuple[Tensor, Tensor]:
318
+ # All operations will use the dtype of rope, the output is cast back to the dtype of q and k
319
+ q_dtype = q.dtype
320
+ k_dtype = k.dtype
321
+ sin, cos = rope
322
+ rope_dtype = sin.dtype
323
+ q = q.to(dtype=rope_dtype)
324
+ k = k.to(dtype=rope_dtype)
325
+ N = q.shape[-2]
326
+ prefix = N - sin.shape[-2] ## extra tokens
327
+ assert prefix >= 0
328
+ q_prefix = q[:, :, :prefix, :]
329
+ q = self._rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
330
+ q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head]
331
+ k_prefix = k[:, :, :prefix, :]
332
+ k = self._rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
333
+ k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head]
334
+ q = q.to(dtype=q_dtype)
335
+ k = k.to(dtype=k_dtype)
336
+ return q, k
337
+
338
+ def _rope_rotate_half(self, x: Tensor) -> Tensor:
339
+ # x: [ x0 x1 x2 x3 x4 x5]
340
+ # out: [-x3 -x4 -x5 x0 x1 x2]
341
+ x1, x2 = x.chunk(2, dim=-1)
342
+ return torch.cat([-x2, x1], dim=-1)
343
+
344
+ def _rope_apply(self, x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
345
+ # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
346
+ # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
347
+ # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2]
348
+ return (x * cos) + (self._rope_rotate_half(x) * sin)
349
+
350
+ def forward(self, x, rope=None):
351
+ B, N, _ = x.shape
352
+ # Q: (B, N, num_heads, head_dim)
353
+ q = self.wq(x).view(B, N, self.num_heads, self.head_dim)
354
+ # K/V: (B, N, num_kv_heads, head_dim)
355
+ k = self.wk(x).view(B, N, self.num_kv_heads, self.head_dim)
356
+ v = self.wv(x).view(B, N, self.num_kv_heads, self.head_dim)
357
+
358
+ # (B, heads, N, head_dim)
359
+ q = q.permute(0, 2, 1, 3)
360
+ k = k.permute(0, 2, 1, 3)
361
+ v = v.permute(0, 2, 1, 3)
362
+
363
+ if self.use_qk_norm:
364
+ q = self.q_norm(q)
365
+ k = self.k_norm(k)
366
+
367
+ # Repeat KV heads if group ratio >1
368
+ if self.num_kv_heads != self.num_heads:
369
+ factor = self.num_heads // self.num_kv_heads
370
+ k = k.repeat_interleave(factor, dim=1)
371
+ v = v.repeat_interleave(factor, dim=1)
372
+
373
+ if rope is not None:
374
+ q, k = self.apply_rope(q, k, rope)
375
+
376
+ # Scaled dot-product attention
377
+ attn_out = self.attn_op(
378
+ q, k, v, dropout_p=self.attn_drop if self.training else 0.0
379
+ ) # (B, num_heads, N, head_dim)
380
+
381
+ # Merge heads -> (B, N, embed_dims)
382
+ out = attn_out.permute(0, 2, 1, 3).reshape(B, N, self.embed_dims)
383
+
384
+ # Output projection + drop + layer scale
385
+ out = self.proj(out)
386
+ out = self.gamma(self.proj_drop(out))
387
+
388
+ # Optional V-shortcut (only when MQA)
389
+ if self.v_shortcut and self.num_kv_heads == 1:
390
+ raise NotImplementedError
391
+ return out
392
+
393
+
394
+ # -------------------------------------------------------------------------------
395
+ class TransformerEncoderLayer2(nn.Module):
396
+ def __init__(
397
+ self,
398
+ embed_dims,
399
+ num_heads,
400
+ num_kv_heads=None,
401
+ feedforward_channels=None,
402
+ drop_rate=0.0,
403
+ attn_drop_rate=0.0,
404
+ layer_scale_init_value=0.0,
405
+ use_qk_norm=True,
406
+ qkv_bias=True,
407
+ ):
408
+ super(TransformerEncoderLayer2, self).__init__()
409
+
410
+ self.embed_dims = embed_dims
411
+ self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6)
412
+ self.attn = GroupedQueryAttention(
413
+ embed_dims=embed_dims,
414
+ num_heads=num_heads,
415
+ num_kv_heads=num_kv_heads,
416
+ attn_drop=attn_drop_rate,
417
+ proj_drop=drop_rate,
418
+ qkv_bias=qkv_bias,
419
+ layer_scale_init_value=layer_scale_init_value,
420
+ use_qk_norm=use_qk_norm,
421
+ )
422
+
423
+ self.ln2 = nn.RMSNorm(self.embed_dims, eps=1e-6)
424
+ self.ffn = SwiGLUFFN(
425
+ embed_dims=embed_dims,
426
+ feedforward_channels=feedforward_channels,
427
+ )
428
+
429
+ @property
430
+ def norm1(self):
431
+ return self.ln1
432
+
433
+ @property
434
+ def norm2(self):
435
+ return self.ln2
436
+
437
+ def forward(self, x, rope=None):
438
+ x = x + self.attn(self.ln1(x), rope=rope)
439
+ x = self.ffn(self.ln2(x), identity=x)
440
+ return x
441
+
442
+
443
+ ##-----------------------------------
444
+ @MODELS.register_module()
445
+ class Sapiens2(BaseModel):
446
+ arch_zoo = {
447
+ **dict.fromkeys(
448
+ ["sapiens2_0.1b"],
449
+ {
450
+ "embed_dims": 768,
451
+ "num_layers": 12,
452
+ "num_heads": 12,
453
+ "feedforward_channels": 768 * 4,
454
+ "num_tokenizer_layers": 2,
455
+ },
456
+ ),
457
+ **dict.fromkeys(
458
+ ["sapiens2_0.4b"],
459
+ {
460
+ "embed_dims": 1024,
461
+ "num_layers": 24,
462
+ "num_heads": 16,
463
+ "feedforward_channels": 1024 * 4,
464
+ "num_tokenizer_layers": 2,
465
+ },
466
+ ),
467
+ **dict.fromkeys(
468
+ ["sapiens2_0.8b"],
469
+ {
470
+ "embed_dims": 1280,
471
+ "num_layers": 32,
472
+ "num_heads": 16,
473
+ "feedforward_channels": 1280 * 4,
474
+ "num_tokenizer_layers": 3,
475
+ },
476
+ ),
477
+ **dict.fromkeys(
478
+ ["sapiens2_1b"],
479
+ {
480
+ "embed_dims": 1536,
481
+ "num_layers": 40,
482
+ "num_heads": 24,
483
+ "feedforward_channels": 1536 * 4,
484
+ "num_tokenizer_layers": 4,
485
+ },
486
+ ),
487
+ **dict.fromkeys(
488
+ ["sapiens2_5b"],
489
+ {
490
+ "embed_dims": 2432,
491
+ "num_layers": 56,
492
+ "num_heads": 32,
493
+ "feedforward_channels": 2432 * 4,
494
+ "num_tokenizer_layers": 6,
495
+ },
496
+ ),
497
+ }
498
+
499
+ num_extra_tokens = 1 # class token
500
+ OUT_TYPES = {"raw", "cls_token", "featmap"}
501
+
502
+ def __init__(
503
+ self,
504
+ arch="sapiens2_1b",
505
+ img_size=(1024, 768),
506
+ patch_size=16,
507
+ in_channels=3,
508
+ out_indices=-1,
509
+ drop_rate=0.0,
510
+ window_size=4,
511
+ use_tokenizer=False, ## 4k resolution
512
+ use_qk_norm=True,
513
+ qkv_bias=True,
514
+ final_norm=True,
515
+ out_type="raw",
516
+ with_cls_token=True,
517
+ layer_scale_init_value=1e-4, ## non zero init to activate layerscale
518
+ frozen_stages=-1,
519
+ patch_cfg=dict(),
520
+ layer_cfgs=dict(),
521
+ pos_embed_rope_base: float = 100.0,
522
+ pos_embed_rope_min_period: float | None = None,
523
+ pos_embed_rope_max_period: float | None = None,
524
+ pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
525
+ pos_embed_rope_shift_coords: float | None = None,
526
+ pos_embed_rope_jitter_coords: float | None = None,
527
+ pos_embed_rope_rescale_coords: float | None = None,
528
+ pos_embed_rope_dtype: str = "bf16",
529
+ n_storage_tokens: int = 8,
530
+ init_cfg=None,
531
+ ):
532
+ super(Sapiens2, self).__init__(init_cfg=init_cfg)
533
+
534
+ arch = arch.lower()
535
+ assert arch in set(self.arch_zoo), (
536
+ f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
537
+ )
538
+ self.arch_settings = self.arch_zoo[arch]
539
+
540
+ self.embed_dims = self.arch_settings["embed_dims"]
541
+ self.num_layers = self.arch_settings["num_layers"]
542
+ self.patch_size = patch_size
543
+
544
+ self.window_size = window_size
545
+ img_size = to_2tuple(img_size)
546
+ encoder_img_size = (
547
+ (img_size[0] // window_size, img_size[1] // window_size)
548
+ if use_tokenizer
549
+ else img_size
550
+ )
551
+ self.img_size = to_2tuple(encoder_img_size)
552
+
553
+ # Set patch embedding
554
+ _patch_cfg = dict(
555
+ in_channels=in_channels,
556
+ input_size=self.img_size,
557
+ embed_dims=self.embed_dims,
558
+ kernel_size=patch_size,
559
+ stride=patch_size,
560
+ bias=True,
561
+ )
562
+ _patch_cfg.update(patch_cfg)
563
+ self.patch_embed = PatchEmbed(**_patch_cfg)
564
+ self.patch_resolution = self.patch_embed.init_out_size
565
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
566
+
567
+ self.rope_embed = RopePositionEmbedding(
568
+ embed_dim=self.embed_dims,
569
+ num_heads=self.arch_settings["num_heads"],
570
+ base=pos_embed_rope_base,
571
+ min_period=pos_embed_rope_min_period,
572
+ max_period=pos_embed_rope_max_period,
573
+ normalize_coords=pos_embed_rope_normalize_coords,
574
+ shift_coords=pos_embed_rope_shift_coords,
575
+ jitter_coords=pos_embed_rope_jitter_coords,
576
+ rescale_coords=pos_embed_rope_rescale_coords,
577
+ dtype=torch.bfloat16 if pos_embed_rope_dtype == "bf16" else torch.float32,
578
+ )
579
+
580
+ # Set out type
581
+ if out_type not in self.OUT_TYPES:
582
+ raise ValueError(
583
+ f"Unsupported `out_type` {out_type}, please "
584
+ f"choose from {self.OUT_TYPES}"
585
+ )
586
+ self.out_type = out_type
587
+
588
+ if use_tokenizer == True:
589
+ self.tokenizer = Tokenizer(
590
+ embed_dims=self.embed_dims,
591
+ window_size=self.window_size,
592
+ num_heads=self.arch_settings["num_heads"],
593
+ num_tokenizer_layers=self.arch_settings["num_tokenizer_layers"],
594
+ qkv_bias=True,
595
+ use_qk_norm=False,
596
+ )
597
+ else:
598
+ self.tokenizer = None
599
+
600
+ # Set cls + storage tokens
601
+ self.with_cls_token = with_cls_token
602
+ if with_cls_token:
603
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
604
+ elif out_type != "cls_token":
605
+ self.cls_token = None
606
+ self.num_extra_tokens = 0
607
+ else:
608
+ raise ValueError('with_cls_token must be True when `out_type="cls_token"`.')
609
+
610
+ ## registers
611
+ self.n_storage_tokens = int(n_storage_tokens)
612
+ self.storage_tokens = (
613
+ nn.Parameter(torch.zeros(1, self.n_storage_tokens, self.embed_dims))
614
+ if self.n_storage_tokens > 0
615
+ else None
616
+ )
617
+ # how many non-patch tokens are at the front
618
+ self.num_extra_tokens = (
619
+ 1 if self.cls_token is not None else 0
620
+ ) + self.n_storage_tokens
621
+
622
+ if isinstance(out_indices, int):
623
+ out_indices = [out_indices]
624
+ assert isinstance(out_indices, Sequence), (
625
+ f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.'
626
+ )
627
+ for i, index in enumerate(out_indices):
628
+ if index < 0:
629
+ out_indices[i] = self.num_layers + index
630
+ assert 0 <= out_indices[i] <= self.num_layers, (
631
+ f"Invalid out_indices {index}"
632
+ )
633
+ self.out_indices = out_indices
634
+
635
+ self.blocks = nn.Sequential()
636
+ if isinstance(layer_cfgs, dict):
637
+ layer_cfgs = [layer_cfgs] * self.num_layers
638
+
639
+ mhsa_early, mhsa_late = 8, 8
640
+ for i in range(self.num_layers):
641
+ if i < mhsa_early or i >= self.num_layers - mhsa_late:
642
+ num_kv_heads = None ## use MHSA
643
+ else:
644
+ num_kv_heads = self.arch_settings["num_heads"] // 2 # Use GQA
645
+
646
+ _layer_cfg = dict(
647
+ embed_dims=self.embed_dims,
648
+ num_heads=self.arch_settings["num_heads"],
649
+ num_kv_heads=num_kv_heads,
650
+ feedforward_channels=self.arch_settings["feedforward_channels"],
651
+ use_qk_norm=use_qk_norm,
652
+ layer_scale_init_value=layer_scale_init_value,
653
+ drop_rate=drop_rate,
654
+ qkv_bias=qkv_bias,
655
+ )
656
+ _layer_cfg.update(layer_cfgs[i])
657
+ self.blocks.append(TransformerEncoderLayer2(**_layer_cfg))
658
+
659
+ self.frozen_stages = frozen_stages
660
+
661
+ self.final_norm = final_norm
662
+ if final_norm:
663
+ self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6)
664
+
665
+ # freeze stages only when self.frozen_stages > 0
666
+ if self.frozen_stages > 0:
667
+ self._freeze_stages()
668
+
669
+ ## load init weights
670
+ self.init_weights()
671
+
672
+ return
673
+
674
+ def init_weights(self):
675
+ if self.init_cfg is not None:
676
+ super(Sapiens2, self).init_weights()
677
+ return
678
+
679
+ # Initialize class token and storagr token embeddings
680
+ if self.with_cls_token:
681
+ trunc_normal_(self.cls_token, std=0.02)
682
+
683
+ if self.storage_tokens is not None:
684
+ trunc_normal_(self.storage_tokens, std=0.02)
685
+
686
+ # Apply custom initialization to all submodules
687
+ self.apply(self._init_weights)
688
+
689
+ def _init_weights(self, m):
690
+ if isinstance(m, nn.Linear):
691
+ # Use a truncated normal distribution for linear layer weights
692
+ trunc_normal_(m.weight, std=0.02)
693
+ if m.bias is not None:
694
+ nn.init.constant_(m.bias, 0)
695
+
696
+ elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)):
697
+ # Initialize normalization layers to act as an identity function
698
+ if hasattr(m, "bias") and m.bias is not None:
699
+ nn.init.constant_(m.bias, 0)
700
+ if hasattr(m, "weight") and m.weight is not None:
701
+ nn.init.constant_(m.weight, 1.0)
702
+
703
+ elif isinstance(m, nn.Conv2d):
704
+ # Initialize conv layer weights like linear layers
705
+ trunc_normal_(m.weight, std=0.02)
706
+ if m.bias is not None:
707
+ nn.init.constant_(m.bias, 0)
708
+
709
+ def _freeze_stages(self):
710
+ ## freeze tokenizer
711
+ if self.frozen_stages >= 1 and self.tokenizer is not None:
712
+ self.tokenizer.eval()
713
+ for param in self.tokenizer.parameters():
714
+ param.requires_grad = False
715
+
716
+ # freeze patch embedding
717
+ self.patch_embed.eval()
718
+ for param in self.patch_embed.parameters():
719
+ param.requires_grad = False
720
+ # freeze cls_token
721
+ if self.cls_token is not None:
722
+ self.cls_token.requires_grad = False
723
+ if self.storage_tokens is not None:
724
+ self.storage_tokens.requires_grad = False
725
+ # freeze layers
726
+ for i in range(1, self.frozen_stages + 1):
727
+ m = self.blocks[i - 1]
728
+ m.eval()
729
+ for param in m.parameters():
730
+ param.requires_grad = False
731
+
732
+ # freeze the last layer norm
733
+ if self.frozen_stages == len(self.blocks):
734
+ if self.final_norm:
735
+ self.ln1.eval()
736
+ for param in self.ln1.parameters():
737
+ param.requires_grad = False
738
+
739
+ def forward(self, x):
740
+ B = x.shape[0]
741
+
742
+ x, patch_resolution = self.patch_embed(x) # (B, 256*256, C)
743
+ if self.tokenizer is not None:
744
+ x, patch_resolution = self.tokenizer(x, patch_resolution)
745
+
746
+ # prepend [CLS] and storage tokens
747
+ prepend = []
748
+ if self.cls_token is not None:
749
+ prepend.append(self.cls_token.expand(B, -1, -1))
750
+ if self.storage_tokens is not None:
751
+ prepend.append(self.storage_tokens.expand(B, -1, -1))
752
+ if len(prepend) > 0:
753
+ x = torch.cat(prepend + [x], dim=1)
754
+
755
+ rope_sincos = self.rope_embed(H=patch_resolution[0], W=patch_resolution[1])
756
+ outs = []
757
+ for i, layer in enumerate(self.blocks):
758
+ x = layer(x, rope=rope_sincos)
759
+
760
+ if i == len(self.blocks) - 1 and self.final_norm:
761
+ x = self.ln1(x)
762
+
763
+ if i in self.out_indices:
764
+ outs.append(self._format_output(x, patch_resolution))
765
+
766
+ return tuple(outs)
767
+
768
+ def _format_output(self, x, hw):
769
+ if self.out_type == "raw":
770
+ return x
771
+ if self.out_type == "cls_token":
772
+ return x[:, 0]
773
+
774
+ patch_token = x[:, self.num_extra_tokens :]
775
+ if self.out_type == "featmap":
776
+ B = x.size(0)
777
+ # (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
778
+ return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
779
+
780
+ @property
781
+ def norm1(self):
782
+ return self.ln1
783
+
784
+
785
+ # ----------------------------------------------------------------------------
786
+ class LayerScale(nn.Module):
787
+ def __init__(
788
+ self,
789
+ dim: int,
790
+ inplace: bool = False,
791
+ data_format: str = "channels_last",
792
+ scale: float = 1e-5,
793
+ ):
794
+ super().__init__()
795
+ assert data_format in (
796
+ "channels_last",
797
+ "channels_first",
798
+ ), "'data_format' could only be channels_last or channels_first."
799
+ self.inplace = inplace
800
+ self.data_format = data_format
801
+ self.weight = nn.Parameter(torch.ones(dim) * scale)
802
+
803
+ def forward(self, x) -> torch.Tensor:
804
+ if self.data_format == "channels_first":
805
+ shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
806
+ else:
807
+ shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
808
+ if self.inplace:
809
+ return x.mul_(self.weight.view(*shape))
810
+ else:
811
+ return x * self.weight.view(*shape)
812
+
813
+
814
+ # ----------------------------------------------------------------------------
815
+ class PatchEmbed(nn.Module):
816
+ def __init__(
817
+ self,
818
+ in_channels=3,
819
+ embed_dims=768,
820
+ kernel_size=16,
821
+ stride=16,
822
+ padding="corner",
823
+ dilation=1,
824
+ bias=True,
825
+ input_size=None,
826
+ ):
827
+ super().__init__()
828
+
829
+ self.embed_dims = embed_dims
830
+ if stride is None:
831
+ stride = kernel_size
832
+
833
+ kernel_size = to_2tuple(kernel_size)
834
+ stride = to_2tuple(stride)
835
+ dilation = to_2tuple(dilation)
836
+ padding = 0
837
+ padding = to_2tuple(padding)
838
+
839
+ self.projection = nn.Conv2d(
840
+ in_channels=in_channels,
841
+ out_channels=embed_dims,
842
+ kernel_size=kernel_size,
843
+ stride=stride,
844
+ padding=padding,
845
+ dilation=dilation,
846
+ bias=bias,
847
+ )
848
+
849
+ if input_size:
850
+ input_size = to_2tuple(input_size)
851
+ self.init_input_size = input_size
852
+ h_out = (
853
+ input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
854
+ ) // stride[0] + 1
855
+ w_out = (
856
+ input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
857
+ ) // stride[1] + 1
858
+ self.init_out_size = (h_out, w_out)
859
+ else:
860
+ self.init_input_size = None
861
+ self.init_out_size = None
862
+
863
+ def forward(self, x):
864
+ x = self.projection(x)
865
+ out_size = (x.shape[2], x.shape[3])
866
+ x = x.flatten(2).transpose(1, 2)
867
+ return x, out_size
868
+
869
+
870
+ # ----------------------------------------------------------------------------
871
+ class SwiGLUFFN(nn.Module):
872
+ """SwiGLU FFN layer.
873
+ https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
874
+ """ # noqa
875
+
876
+ def __init__(
877
+ self,
878
+ embed_dims: int,
879
+ feedforward_channels: Optional[int] = None,
880
+ out_dims: Optional[int] = None,
881
+ layer_scale_init_value: float = 0.0,
882
+ bias: bool = True,
883
+ add_identity: bool = True,
884
+ ) -> None:
885
+ super().__init__()
886
+ self.embed_dims = embed_dims
887
+ self.out_dims = out_dims or embed_dims
888
+ hidden_dims = feedforward_channels or embed_dims
889
+
890
+ self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)
891
+ self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias)
892
+
893
+ if layer_scale_init_value > 0:
894
+ self.gamma2 = LayerScale(dim=embed_dims, scale=layer_scale_init_value)
895
+ else:
896
+ self.gamma2 = nn.Identity()
897
+
898
+ self.add_identity = add_identity
899
+
900
+ def forward(
901
+ self, x: torch.Tensor, identity: Optional[torch.Tensor] = None
902
+ ) -> torch.Tensor:
903
+ x12 = self.w12(x)
904
+ x1, x2 = x12.chunk(2, dim=-1)
905
+ hidden = F.silu(x1) * x2
906
+ out = self.w3(hidden)
907
+ out = self.gamma2(out)
908
+
909
+ if self.out_dims != self.embed_dims or not self.add_identity:
910
+ # due to the dimension inconsistence or user setting
911
+ # not to apply residual operation
912
+ return out
913
+
914
+ if identity is None:
915
+ identity = x
916
+ return identity + out
sapiens/backbones/standalone/sapiens.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Sequence
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn import Linear, Sequential
14
+
15
+
16
+ # ----------------------------------------------------------------------------
17
+ def to_2tuple(x):
18
+ if isinstance(x, (str, bytes)):
19
+ return (x, x)
20
+ if isinstance(x, Sequence):
21
+ x = tuple(x)
22
+ if len(x) == 2:
23
+ return x
24
+ raise ValueError("Expected scalar or length-2 iterable")
25
+ return (x, x)
26
+
27
+
28
+ def resize_pos_embed(
29
+ pos_embed, src_shape, dst_shape, mode="bicubic", num_extra_tokens=1
30
+ ):
31
+ if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
32
+ return pos_embed
33
+ assert pos_embed.ndim == 3, "shape of pos_embed must be [1, L, C]"
34
+ _, L, C = pos_embed.shape
35
+ src_h, src_w = src_shape
36
+ assert L == src_h * src_w + num_extra_tokens, (
37
+ f"The length of `pos_embed` ({L}) doesn't match the expected "
38
+ f"shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the"
39
+ "`img_size` argument."
40
+ )
41
+ extra_tokens = pos_embed[:, :num_extra_tokens]
42
+
43
+ src_weight = pos_embed[:, num_extra_tokens:]
44
+ src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
45
+
46
+ # The cubic interpolate algorithm only accepts float32
47
+ dst_weight = F.interpolate(
48
+ src_weight.float(), size=dst_shape, align_corners=False, mode=mode
49
+ )
50
+ dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
51
+ dst_weight = dst_weight.to(src_weight.dtype)
52
+
53
+ return torch.cat((extra_tokens, dst_weight), dim=1)
54
+
55
+
56
+ # ----------------------------------------------------------------------------
57
+ class AdaptivePadding(nn.Module):
58
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"):
59
+ super().__init__()
60
+ assert padding in ("same", "corner")
61
+
62
+ kernel_size = to_2tuple(kernel_size)
63
+ stride = to_2tuple(stride)
64
+ dilation = to_2tuple(dilation)
65
+
66
+ self.padding = padding
67
+ self.kernel_size = kernel_size
68
+ self.stride = stride
69
+ self.dilation = dilation
70
+
71
+ def get_pad_shape(self, input_shape):
72
+ input_h, input_w = input_shape
73
+ kernel_h, kernel_w = self.kernel_size
74
+ stride_h, stride_w = self.stride
75
+ output_h = math.ceil(input_h / stride_h)
76
+ output_w = math.ceil(input_w / stride_w)
77
+ pad_h = max(
78
+ (output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h,
79
+ 0,
80
+ )
81
+ pad_w = max(
82
+ (output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w,
83
+ 0,
84
+ )
85
+ return pad_h, pad_w
86
+
87
+ def forward(self, x):
88
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
89
+ if pad_h > 0 or pad_w > 0:
90
+ if self.padding == "corner":
91
+ x = F.pad(x, [0, pad_w, 0, pad_h])
92
+ elif self.padding == "same":
93
+ x = F.pad(
94
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
95
+ )
96
+ return x
97
+
98
+
99
+ # ----------------------------------------------------------------------------
100
+ class PatchEmbed(nn.Module):
101
+ def __init__(
102
+ self,
103
+ in_channels=3,
104
+ embed_dims=768,
105
+ kernel_size=16,
106
+ stride=16,
107
+ padding="corner",
108
+ dilation=1,
109
+ bias=True,
110
+ input_size=None,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.embed_dims = embed_dims
115
+ if stride is None:
116
+ stride = kernel_size
117
+
118
+ kernel_size = to_2tuple(kernel_size)
119
+ stride = to_2tuple(stride)
120
+ dilation = to_2tuple(dilation)
121
+
122
+ if isinstance(padding, str):
123
+ self.adaptive_padding = AdaptivePadding(
124
+ kernel_size=kernel_size,
125
+ stride=stride,
126
+ dilation=dilation,
127
+ padding=padding,
128
+ )
129
+ padding = 0
130
+ else:
131
+ self.adaptive_padding = None
132
+ padding = to_2tuple(padding)
133
+
134
+ self.projection = nn.Conv2d(
135
+ in_channels=in_channels,
136
+ out_channels=embed_dims,
137
+ kernel_size=kernel_size,
138
+ stride=stride,
139
+ padding=padding,
140
+ dilation=dilation,
141
+ bias=bias,
142
+ )
143
+
144
+ if input_size:
145
+ input_size = to_2tuple(input_size)
146
+ self.init_input_size = input_size
147
+ if self.adaptive_padding:
148
+ pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size)
149
+ input_h, input_w = input_size
150
+ input_h = input_h + pad_h
151
+ input_w = input_w + pad_w
152
+ input_size = (input_h, input_w)
153
+
154
+ h_out = (
155
+ input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
156
+ ) // stride[0] + 1
157
+ w_out = (
158
+ input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
159
+ ) // stride[1] + 1
160
+ self.init_out_size = (h_out, w_out)
161
+ else:
162
+ self.init_input_size = None
163
+ self.init_out_size = None
164
+
165
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
166
+ if self.adaptive_padding:
167
+ x = self.adaptive_padding(x)
168
+
169
+ x = self.projection(x)
170
+ out_size = (x.shape[2], x.shape[3])
171
+ x = x.flatten(2).transpose(1, 2)
172
+ return x, out_size
173
+
174
+
175
+ # ----------------------------------------------------------------------------
176
+ class LayerScale(nn.Module):
177
+ def __init__(
178
+ self,
179
+ dim: int,
180
+ inplace: bool = False,
181
+ data_format: str = "channels_last",
182
+ scale: float = 1e-5,
183
+ ):
184
+ super().__init__()
185
+ assert data_format in (
186
+ "channels_last",
187
+ "channels_first",
188
+ ), "'data_format' could only be channels_last or channels_first."
189
+ self.inplace = inplace
190
+ self.data_format = data_format
191
+ self.weight = nn.Parameter(torch.ones(dim) * scale)
192
+
193
+ def forward(self, x) -> torch.Tensor:
194
+ if self.data_format == "channels_first":
195
+ shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
196
+ else:
197
+ shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
198
+ if self.inplace:
199
+ return x.mul_(self.weight.view(*shape))
200
+ else:
201
+ return x * self.weight.view(*shape)
202
+
203
+
204
+ # ----------------------------------------------------------------------------
205
+ class FFN(nn.Module):
206
+ def __init__(
207
+ self,
208
+ embed_dims=256,
209
+ feedforward_channels=1024,
210
+ num_fcs=2,
211
+ ffn_drop=0.0,
212
+ add_identity=True,
213
+ layer_scale_init_value=0.0,
214
+ ):
215
+ super().__init__()
216
+ assert num_fcs >= 2, f"num_fcs should be no less than 2. got {num_fcs}."
217
+ self.embed_dims = embed_dims
218
+ self.feedforward_channels = feedforward_channels
219
+ self.num_fcs = num_fcs
220
+
221
+ layers = []
222
+ in_channels = embed_dims
223
+ for _ in range(num_fcs - 1):
224
+ layers.append(
225
+ Sequential(
226
+ Linear(in_channels, feedforward_channels),
227
+ nn.GELU(),
228
+ nn.Dropout(ffn_drop),
229
+ )
230
+ )
231
+ in_channels = feedforward_channels
232
+ layers.append(Linear(feedforward_channels, embed_dims))
233
+ layers.append(nn.Dropout(ffn_drop))
234
+ self.layers = Sequential(*layers)
235
+ self.dropout_layer = nn.Identity()
236
+ self.add_identity = add_identity
237
+
238
+ if layer_scale_init_value > 0:
239
+ self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value)
240
+ else:
241
+ self.gamma2 = nn.Identity()
242
+
243
+ def forward(self, x, identity=None):
244
+ out = self.layers(x)
245
+ out = self.gamma2(out)
246
+ if not self.add_identity:
247
+ return out
248
+ if identity is None:
249
+ identity = x
250
+ return identity + out
251
+
252
+
253
+ # ----------------------------------------------------------------------------
254
+ class MultiheadAttention(nn.Module):
255
+ def __init__(
256
+ self,
257
+ embed_dims,
258
+ num_heads,
259
+ input_dims=None,
260
+ attn_drop=0.0,
261
+ proj_drop=0.0,
262
+ qkv_bias=True,
263
+ proj_bias=True,
264
+ v_shortcut=False,
265
+ ):
266
+ super(MultiheadAttention, self).__init__()
267
+
268
+ self.input_dims = input_dims or embed_dims
269
+ self.embed_dims = embed_dims
270
+ self.num_heads = num_heads
271
+ self.v_shortcut = v_shortcut
272
+
273
+ self.head_dims = embed_dims // num_heads
274
+ self.scaled_dot_product_attention = F.scaled_dot_product_attention
275
+
276
+ self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
277
+ self.attn_drop = attn_drop
278
+ self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
279
+ self.proj_drop = nn.Dropout(proj_drop)
280
+ self.gamma1 = nn.Identity()
281
+
282
+ def forward(self, x):
283
+ B, N, _ = x.shape
284
+ qkv = (
285
+ self.qkv(x)
286
+ .reshape(B, N, 3, self.num_heads, self.head_dims)
287
+ .permute(2, 0, 3, 1, 4)
288
+ )
289
+ q, k, v = qkv[0], qkv[1], qkv[2]
290
+
291
+ attn_drop = self.attn_drop if self.training else 0.0
292
+ x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
293
+ x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
294
+
295
+ x = self.proj(x)
296
+ x = self.gamma1(self.proj_drop(x))
297
+
298
+ if self.v_shortcut:
299
+ x = v.squeeze(1) + x
300
+ return x
301
+
302
+
303
+ # ----------------------------------------------------------------------------
304
+ class TransformerEncoderLayer(nn.Module):
305
+ def __init__(
306
+ self,
307
+ embed_dims,
308
+ num_heads,
309
+ feedforward_channels,
310
+ drop_rate=0.0,
311
+ attn_drop_rate=0.0,
312
+ num_fcs=2,
313
+ qkv_bias=True,
314
+ ):
315
+ super(TransformerEncoderLayer, self).__init__()
316
+
317
+ self.embed_dims = embed_dims
318
+ self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
319
+ self.attn = MultiheadAttention(
320
+ embed_dims=embed_dims,
321
+ num_heads=num_heads,
322
+ attn_drop=attn_drop_rate,
323
+ proj_drop=drop_rate,
324
+ qkv_bias=qkv_bias,
325
+ )
326
+
327
+ self.ln2 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
328
+ self.ffn = FFN(
329
+ embed_dims=embed_dims,
330
+ feedforward_channels=feedforward_channels,
331
+ num_fcs=num_fcs,
332
+ ffn_drop=drop_rate,
333
+ add_identity=True,
334
+ )
335
+
336
+ @property
337
+ def norm1(self):
338
+ return self.ln1
339
+
340
+ @property
341
+ def norm2(self):
342
+ return self.ln2
343
+
344
+ def forward(self, x):
345
+ x = x + self.attn(self.ln1(x))
346
+ x = self.ffn(self.ln2(x), identity=x)
347
+ return x
348
+
349
+
350
+ # ----------------------------------------------------------------------------
351
+ class Sapiens(nn.Module):
352
+ arch_zoo = {
353
+ **dict.fromkeys( ## this is vit-large
354
+ ["0.3b", "sapiens_0.3b"],
355
+ {
356
+ "embed_dims": 1024,
357
+ "num_layers": 24,
358
+ "num_heads": 16,
359
+ "feedforward_channels": 1024 * 4,
360
+ },
361
+ ),
362
+ **dict.fromkeys( ## this is vit-huge
363
+ ["0.6b", "sapiens_0.6b"],
364
+ {
365
+ "embed_dims": 1280,
366
+ "num_layers": 32,
367
+ "num_heads": 16,
368
+ "feedforward_channels": 1280 * 4,
369
+ },
370
+ ),
371
+ **dict.fromkeys( ## this is vit-g
372
+ ["1b", "sapiens_1b"],
373
+ {
374
+ "embed_dims": 1536,
375
+ "num_layers": 40,
376
+ "num_heads": 24,
377
+ "feedforward_channels": 1536 * 4,
378
+ },
379
+ ),
380
+ **dict.fromkeys(
381
+ ["2b", "sapiens_2b"],
382
+ {
383
+ "embed_dims": 1920,
384
+ "num_layers": 48,
385
+ "num_heads": 32,
386
+ "feedforward_channels": 1920 * 4,
387
+ },
388
+ ),
389
+ }
390
+ num_extra_tokens = 1 # class token
391
+ OUT_TYPES = {"raw", "cls_token", "featmap", "avg_featmap"}
392
+
393
+ def __init__(
394
+ self,
395
+ arch="base",
396
+ img_size=224,
397
+ patch_size=16,
398
+ in_channels=3,
399
+ out_indices=-1,
400
+ drop_rate=0.0,
401
+ qkv_bias=True,
402
+ final_norm=True,
403
+ out_type="cls_token",
404
+ with_cls_token=True,
405
+ frozen_stages=-1,
406
+ interpolate_mode="bicubic",
407
+ patch_cfg=dict(),
408
+ layer_cfgs=dict(),
409
+ ):
410
+ super(Sapiens, self).__init__()
411
+
412
+ arch = arch.lower()
413
+ assert arch in set(self.arch_zoo), (
414
+ f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
415
+ )
416
+ self.arch_settings = self.arch_zoo[arch]
417
+
418
+ self.embed_dims = self.arch_settings["embed_dims"]
419
+ self.num_layers = self.arch_settings["num_layers"]
420
+ self.img_size = to_2tuple(img_size)
421
+ self.patch_size = patch_size
422
+
423
+ # Set patch embedding
424
+ _patch_cfg = dict(
425
+ in_channels=in_channels,
426
+ input_size=img_size,
427
+ embed_dims=self.embed_dims,
428
+ kernel_size=patch_size,
429
+ stride=patch_size,
430
+ bias=True,
431
+ )
432
+ _patch_cfg.update(patch_cfg)
433
+ self.patch_embed = PatchEmbed(**_patch_cfg)
434
+ self.patch_resolution = self.patch_embed.init_out_size
435
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
436
+
437
+ # Set out type
438
+ if out_type not in self.OUT_TYPES:
439
+ raise ValueError(
440
+ f"Unsupported `out_type` {out_type}, please "
441
+ f"choose from {self.OUT_TYPES}"
442
+ )
443
+ self.out_type = out_type
444
+
445
+ # Set cls token
446
+ self.with_cls_token = with_cls_token
447
+ if with_cls_token:
448
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
449
+ elif out_type != "cls_token":
450
+ self.cls_token = None
451
+ self.num_extra_tokens = 0
452
+ else:
453
+ raise ValueError('with_cls_token must be True when `out_type="cls_token"`.')
454
+
455
+ # Set position embedding
456
+ self.interpolate_mode = interpolate_mode
457
+ self.pos_embed = nn.Parameter(
458
+ torch.zeros(1, num_patches + self.num_extra_tokens, self.embed_dims)
459
+ )
460
+ self.drop_after_pos = nn.Dropout(p=drop_rate)
461
+
462
+ if isinstance(out_indices, int):
463
+ out_indices = [out_indices]
464
+ assert isinstance(out_indices, Sequence), (
465
+ f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.'
466
+ )
467
+ for i, index in enumerate(out_indices):
468
+ if index < 0:
469
+ out_indices[i] = self.num_layers + index
470
+ assert 0 <= out_indices[i] <= self.num_layers, (
471
+ f"Invalid out_indices {index}"
472
+ )
473
+ self.out_indices = out_indices
474
+
475
+ self.layers = nn.Sequential()
476
+ if isinstance(layer_cfgs, dict):
477
+ layer_cfgs = [layer_cfgs] * self.num_layers
478
+ for i in range(self.num_layers):
479
+ _layer_cfg = dict(
480
+ embed_dims=self.embed_dims,
481
+ num_heads=self.arch_settings["num_heads"],
482
+ feedforward_channels=self.arch_settings["feedforward_channels"],
483
+ drop_rate=drop_rate,
484
+ qkv_bias=qkv_bias,
485
+ )
486
+ _layer_cfg.update(layer_cfgs[i])
487
+ self.layers.append(TransformerEncoderLayer(**_layer_cfg))
488
+
489
+ self.frozen_stages = frozen_stages
490
+ self.pre_norm = nn.Identity()
491
+
492
+ self.final_norm = final_norm
493
+ if final_norm:
494
+ self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
495
+
496
+ # freeze stages only when self.frozen_stages > 0
497
+ if self.frozen_stages > 0:
498
+ self._freeze_stages()
499
+
500
+ self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
501
+
502
+ return
503
+
504
+ def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
505
+ name = prefix + "pos_embed"
506
+ if name not in state_dict.keys():
507
+ return
508
+
509
+ ckpt_pos_embed_shape = state_dict[name].shape
510
+
511
+ # Handle class token removal if needed
512
+ if not self.with_cls_token:
513
+ if ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1:
514
+ # Remove cls token from state dict if it's not used
515
+ state_dict[name] = state_dict[name][:, 1:]
516
+ ckpt_pos_embed_shape = state_dict[name].shape
517
+ elif ckpt_pos_embed_shape[1] % 2 == 1:
518
+ # Remove class token when interpolation is required
519
+ state_dict[name] = state_dict[name][:, 1:]
520
+ ckpt_pos_embed_shape = state_dict[name].shape
521
+
522
+ # Skip if shapes already match
523
+ if self.pos_embed.shape == ckpt_pos_embed_shape:
524
+ return
525
+
526
+ # Calculate grid dimensions
527
+ pos_h, pos_w = self.patch_embed.init_out_size
528
+ assert pos_h >= pos_w # for vertical aspect ratio or square
529
+
530
+ # Number of non-extra tokens in checkpoint
531
+ num_vis = ckpt_pos_embed_shape[1] - self.num_extra_tokens
532
+
533
+ # Determine original grid shape
534
+ side = int(math.sqrt(num_vis))
535
+ factor = int(math.sqrt((num_vis * self.patch_size * self.patch_size) // 12))
536
+
537
+ # Set old grid based on aspect ratio detection
538
+ if side * side == num_vis:
539
+ old_grid = (side, side) # square grid
540
+ elif 4 * factor * 3 * factor == num_vis * self.patch_size * self.patch_size:
541
+ old_grid = (
542
+ (factor * 4) // self.patch_size,
543
+ (factor * 3) // self.patch_size,
544
+ ) # 4:3 ratio
545
+ else:
546
+ state_dict[name] = self.pos_embed
547
+ return
548
+
549
+ # Resize position embedding
550
+ new_grid = (pos_h, pos_w)
551
+ state_dict[name] = resize_pos_embed(
552
+ state_dict[name],
553
+ old_grid,
554
+ new_grid,
555
+ mode=self.interpolate_mode,
556
+ num_extra_tokens=self.num_extra_tokens,
557
+ )
558
+
559
+ @property
560
+ def norm1(self):
561
+ return self.ln1
562
+
563
+ @property
564
+ def norm2(self):
565
+ return self.ln2
566
+
567
+ @staticmethod
568
+ def resize_pos_embed(*args, **kwargs):
569
+ """Interface for backward-compatibility."""
570
+ return resize_pos_embed(*args, **kwargs)
571
+
572
+ def _freeze_stages(self):
573
+ # freeze position embedding
574
+ if self.pos_embed is not None:
575
+ self.pos_embed.requires_grad = False
576
+
577
+ # set dropout to eval model
578
+ self.drop_after_pos.eval()
579
+ # freeze patch embedding
580
+ self.patch_embed.eval()
581
+ for param in self.patch_embed.parameters():
582
+ param.requires_grad = False
583
+ # freeze pre-norm
584
+ for param in self.pre_norm.parameters():
585
+ param.requires_grad = False
586
+ # freeze cls_token
587
+ if self.cls_token is not None:
588
+ self.cls_token.requires_grad = False
589
+ # freeze layers
590
+ for i in range(1, self.frozen_stages + 1):
591
+ m = self.layers[i - 1]
592
+ m.eval()
593
+ for param in m.parameters():
594
+ param.requires_grad = False
595
+ # freeze the last layer norm
596
+ if self.frozen_stages == len(self.layers):
597
+ if self.final_norm:
598
+ self.ln1.eval()
599
+ for param in self.ln1.parameters():
600
+ param.requires_grad = False
601
+
602
+ if self.out_type == "avg_featmap":
603
+ self.ln2.eval()
604
+ for param in self.ln2.parameters():
605
+ param.requires_grad = False
606
+
607
+ def forward(self, x):
608
+ B = x.shape[0]
609
+ x, patch_resolution = self.patch_embed(x)
610
+
611
+ if self.cls_token is not None:
612
+ cls_token = self.cls_token.expand(B, -1, -1)
613
+ x = torch.cat((cls_token, x), dim=1)
614
+
615
+ x = x + resize_pos_embed(
616
+ self.pos_embed,
617
+ self.patch_resolution,
618
+ patch_resolution,
619
+ mode=self.interpolate_mode,
620
+ num_extra_tokens=self.num_extra_tokens,
621
+ )
622
+ x = self.drop_after_pos(x)
623
+
624
+ x = self.pre_norm(x) ## B x (num tokens) x embed_dim
625
+
626
+ outs = []
627
+ for i, layer in enumerate(self.layers):
628
+ x = layer(x)
629
+
630
+ if i == len(self.layers) - 1 and self.final_norm:
631
+ x = self.ln1(x)
632
+
633
+ if i in self.out_indices:
634
+ outs.append(self._format_output(x, patch_resolution))
635
+
636
+ return tuple(outs)
637
+
638
+ def _format_output(self, x, hw):
639
+ if self.out_type == "raw":
640
+ return x
641
+ if self.out_type == "cls_token":
642
+ return x[:, 0]
643
+
644
+ patch_token = x[:, self.num_extra_tokens :]
645
+ if self.out_type == "featmap":
646
+ B = x.size(0)
647
+ # (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
648
+ return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
sapiens/backbones/standalone/sapiens2.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from torch.nn.init import trunc_normal_
15
+ from torch.utils.checkpoint import checkpoint
16
+
17
+
18
+ # ----------------------------------------------------------------------------
19
+ def to_2tuple(x):
20
+ if isinstance(x, (str, bytes)):
21
+ return (x, x)
22
+ if isinstance(x, Sequence):
23
+ x = tuple(x)
24
+ if len(x) == 2:
25
+ return x
26
+ raise ValueError("Expected scalar or length-2 iterable")
27
+ return (x, x)
28
+
29
+
30
+ class RopePositionEmbedding(nn.Module):
31
+ def __init__(
32
+ self,
33
+ embed_dim: int,
34
+ *,
35
+ num_heads: int,
36
+ base: float | None = 100.0,
37
+ min_period: float | None = None,
38
+ max_period: float | None = None,
39
+ normalize_coords: Literal["min", "max", "separate"] = "separate",
40
+ shift_coords: float | None = None,
41
+ jitter_coords: float | None = None,
42
+ rescale_coords: float | None = None,
43
+ dtype: torch.dtype | None = None,
44
+ device: torch.device | None = None,
45
+ ):
46
+ super().__init__()
47
+ assert embed_dim % (4 * num_heads) == 0
48
+ both_periods = min_period is not None and max_period is not None
49
+ if (base is None and not both_periods) or (base is not None and both_periods):
50
+ raise ValueError(
51
+ "Either `base` or `min_period`+`max_period` must be provided."
52
+ )
53
+
54
+ D_head = embed_dim // num_heads
55
+ self.base = base
56
+ self.min_period = min_period
57
+ self.max_period = max_period
58
+ self.D_head = D_head
59
+ self.normalize_coords = normalize_coords
60
+ self.shift_coords = shift_coords
61
+ self.jitter_coords = jitter_coords
62
+ self.rescale_coords = rescale_coords
63
+
64
+ # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
65
+ self.dtype = dtype or torch.float32 # Don't rely on self.periods.dtype
66
+ self.register_buffer(
67
+ "periods",
68
+ torch.empty(D_head // 4, device=device, dtype=self.dtype),
69
+ persistent=True,
70
+ )
71
+ self._init_weights()
72
+
73
+ def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
74
+ device = self.periods.device
75
+ dtype = self.dtype
76
+ dd = {"device": device, "dtype": dtype}
77
+ # Prepare coords in range [-1, +1]
78
+ if self.normalize_coords == "max":
79
+ max_HW = max(H, W)
80
+ coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
81
+ coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
82
+ elif self.normalize_coords == "min":
83
+ min_HW = min(H, W)
84
+ coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
85
+ coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
86
+ elif self.normalize_coords == "separate":
87
+ coords_h = torch.arange(0.5, H, **dd) / H # [H]
88
+ coords_w = torch.arange(0.5, W, **dd) / W # [W]
89
+ else:
90
+ raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
91
+ coords = torch.stack(
92
+ torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1
93
+ ) # [H, W, 2]
94
+ coords = coords.flatten(0, 1) # [HW, 2]
95
+ coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
96
+
97
+ # Shift coords by adding a uniform value in [-shift, shift]
98
+ if self.training and self.shift_coords is not None:
99
+ shift_hw = torch.empty(2, **dd).uniform_(
100
+ -self.shift_coords, self.shift_coords
101
+ )
102
+ coords += shift_hw[None, :]
103
+
104
+ # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
105
+ if self.training and self.jitter_coords is not None:
106
+ jitter_max = np.log(self.jitter_coords)
107
+ jitter_min = -jitter_max
108
+ jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
109
+ coords *= jitter_hw[None, :]
110
+
111
+ # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
112
+ if self.training and self.rescale_coords is not None:
113
+ rescale_max = np.log(self.rescale_coords)
114
+ rescale_min = -rescale_max
115
+ rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
116
+ coords *= rescale_hw
117
+
118
+ # Prepare angles and sin/cos
119
+ angles = (
120
+ 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
121
+ ) # [HW, 2, D//4]
122
+ angles = angles.flatten(1, 2) # [HW, D//2]
123
+ angles = angles.tile(2) # [HW, D]
124
+ cos = torch.cos(angles) # [HW, D]
125
+ sin = torch.sin(angles) # [HW, D]
126
+
127
+ return (sin, cos) # 2 * [HW, D]
128
+
129
+ def _init_weights(self):
130
+ device = self.periods.device
131
+ dtype = self.dtype
132
+ if self.base is not None:
133
+ periods = self.base ** (
134
+ 2
135
+ * torch.arange(self.D_head // 4, device=device, dtype=dtype)
136
+ / (self.D_head // 2)
137
+ ) # [D//4]
138
+ else:
139
+ base = self.max_period / self.min_period
140
+ exponents = torch.linspace(
141
+ 0, 1, self.D_head // 4, device=device, dtype=dtype
142
+ ) # [D//4] range [0, 1]
143
+ periods = base**exponents # range [1, max_period / min_period]
144
+ periods = periods / base # range [min_period / max_period, 1]
145
+ periods = periods * self.max_period # range [min_period, max_period]
146
+ self.periods.data = periods
147
+
148
+
149
+ # -------------------------------------------------------------------------------
150
+ class Tokenizer(nn.Module):
151
+ """Stacked window self‑attention that emits one token per window
152
+ by re‑using TransformerEncoderLayer blocks."""
153
+
154
+ def __init__(
155
+ self,
156
+ embed_dims: int,
157
+ window_size: int = 4,
158
+ num_heads: int = 4,
159
+ num_tokenizer_layers: int = 1,
160
+ qkv_bias: bool = True,
161
+ use_qk_norm: bool = False,
162
+ chunk_size: int = 1024, # max windows per chunk
163
+ ):
164
+ super().__init__()
165
+ self.ws = window_size
166
+ self.chunk_size = chunk_size
167
+
168
+ # local absolute positional embeddings for [CLS] + patch tokens
169
+ self.local_pos_embed = nn.Parameter(
170
+ torch.zeros(1, 1 + window_size * window_size, embed_dims)
171
+ )
172
+ trunc_normal_(self.local_pos_embed, std=0.02)
173
+
174
+ # build N identical TransformerEncoderLayer blocks
175
+ self.blocks = nn.ModuleList(
176
+ [
177
+ TransformerEncoderLayer2(
178
+ embed_dims=embed_dims,
179
+ num_heads=num_heads,
180
+ feedforward_channels=embed_dims * 4, # standard FFN size
181
+ qkv_bias=qkv_bias,
182
+ use_qk_norm=use_qk_norm,
183
+ )
184
+ for _ in range(num_tokenizer_layers)
185
+ ]
186
+ )
187
+
188
+ # shared CLS token for pooling
189
+ self.w_cls = nn.Parameter(torch.zeros(1, 1, embed_dims))
190
+ trunc_normal_(self.w_cls, std=0.02)
191
+
192
+ def forward(
193
+ self,
194
+ x: torch.Tensor,
195
+ hw: Tuple[int, int],
196
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
197
+ """Args:
198
+ x : B, N, C (N = H*W)
199
+ hw : (H, W) before reduction
200
+ Returns:
201
+ x_ : B, (H/ws)*(W/ws), C
202
+ hw_: (H/ws, W/ws)
203
+ """
204
+ B, N, C = x.shape
205
+ H, W = hw
206
+ ws = self.ws
207
+ assert H % ws == 0 and W % ws == 0, (
208
+ f"Image size {H}×{W} must be divisible by window {ws}."
209
+ )
210
+
211
+ # reshape tokens → non‑overlapping windows
212
+ x = x.view(B, H, W, C)
213
+
214
+ ph, pw = H // ws, W // ws ## ints in eager mode
215
+ ph, pw = int(ph), int(pw) ## ints in scripting mode
216
+ x = x.view(B, ph, ws, pw, ws, C) # B, H/ws, ws, W/ws, ws, C
217
+ x = x.permute(0, 1, 3, 2, 4, 5) # B, H/ws, W/ws, ws, ws, C
218
+ x = x.contiguous().view(B * ph * pw, ws * ws, C) # (B*H/ws*W/ws), ws², C))
219
+
220
+ total_windows = x.size(0)
221
+ chunk_size = int(min(self.chunk_size, total_windows))
222
+ token_out = x.new_empty(total_windows, C)
223
+
224
+ use_ckpt = self.training and torch.is_grad_enabled()
225
+
226
+ def _run_blocks(t: torch.Tensor) -> torch.Tensor:
227
+ for blk in self.blocks:
228
+ t = blk(t)
229
+ return t
230
+
231
+ for i in range(0, total_windows, chunk_size):
232
+ chunk = x[i : i + chunk_size] # (m, ws², C)
233
+ m = chunk.size(0)
234
+ cls = self.w_cls.expand(m, -1, -1) # (m, 1, C)
235
+ chunk = torch.cat([cls, chunk], dim=1) # (m, 1+ws², C)
236
+ chunk = chunk + self.local_pos_embed # add local PE
237
+
238
+ if use_ckpt:
239
+ chunk = checkpoint(_run_blocks, chunk, use_reentrant=False)
240
+ else:
241
+ chunk = _run_blocks(chunk)
242
+
243
+ token_out[i : i + m] = chunk[:, 0] # take CLS out
244
+
245
+ token = token_out.view(B, ph * pw, C) # (B, (H/ws)*(W
246
+ return token, (ph, pw)
247
+
248
+
249
+ # -------------------------------------------------------------------------------
250
+ class GroupedQueryAttention(nn.Module):
251
+ def __init__(
252
+ self,
253
+ embed_dims,
254
+ num_heads,
255
+ num_kv_heads=None,
256
+ input_dims=None,
257
+ attn_drop=0.0,
258
+ proj_drop=0.0,
259
+ qkv_bias=True,
260
+ qk_scale=None,
261
+ proj_bias=True,
262
+ use_qk_norm=True,
263
+ v_shortcut=False,
264
+ layer_scale_init_value=0.0,
265
+ ):
266
+ super().__init__()
267
+ # Core dims
268
+ self.embed_dims = embed_dims
269
+ self.num_heads = num_heads
270
+ self.num_kv_heads = num_kv_heads or num_heads
271
+ assert self.num_heads % self.num_kv_heads == 0, (
272
+ "num_kv_heads must divide num_heads"
273
+ )
274
+ self.head_dim = embed_dims // num_heads
275
+ self.input_dims = input_dims or embed_dims
276
+ # Features
277
+ self.attn_drop = attn_drop
278
+ self.v_shortcut = v_shortcut
279
+ self.use_qk_norm = use_qk_norm
280
+
281
+ # Attention operation selection
282
+ if qk_scale is not None:
283
+ scale = qk_scale
284
+ else:
285
+ scale = self.head_dim**-0.5
286
+
287
+ assert qk_scale is None, "qk_scale is not supported"
288
+ self.attn_op = F.scaled_dot_product_attention
289
+
290
+ # Q/K/V projections
291
+ self.wq = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias)
292
+ self.wk = nn.Linear(
293
+ self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias
294
+ )
295
+ self.wv = nn.Linear(
296
+ self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias
297
+ )
298
+
299
+ if self.use_qk_norm:
300
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=1e-6)
301
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=1e-6)
302
+
303
+ # Output projection + dropout
304
+ self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
305
+ self.proj_drop = nn.Dropout(proj_drop)
306
+
307
+ # Optional LayerScale
308
+ if layer_scale_init_value > 0:
309
+ self.gamma = LayerScale(embed_dims, scale=layer_scale_init_value)
310
+ else:
311
+ self.gamma = nn.Identity()
312
+
313
+ def apply_rope(
314
+ self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]
315
+ ) -> Tuple[Tensor, Tensor]:
316
+ # All operations will use the dtype of rope, the output is cast back to the dtype of q and k
317
+ q_dtype = q.dtype
318
+ k_dtype = k.dtype
319
+ sin, cos = rope
320
+ rope_dtype = sin.dtype
321
+ q = q.to(dtype=rope_dtype)
322
+ k = k.to(dtype=rope_dtype)
323
+ N = q.shape[-2]
324
+ prefix = N - sin.shape[-2] ## extra tokens
325
+ assert prefix >= 0
326
+ q_prefix = q[:, :, :prefix, :]
327
+ q = self._rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
328
+ q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head]
329
+ k_prefix = k[:, :, :prefix, :]
330
+ k = self._rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
331
+ k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head]
332
+ q = q.to(dtype=q_dtype)
333
+ k = k.to(dtype=k_dtype)
334
+ return q, k
335
+
336
+ def _rope_rotate_half(self, x: Tensor) -> Tensor:
337
+ # x: [ x0 x1 x2 x3 x4 x5]
338
+ # out: [-x3 -x4 -x5 x0 x1 x2]
339
+ x1, x2 = x.chunk(2, dim=-1)
340
+ return torch.cat([-x2, x1], dim=-1)
341
+
342
+ def _rope_apply(self, x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
343
+ # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
344
+ # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
345
+ # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2]
346
+ return (x * cos) + (self._rope_rotate_half(x) * sin)
347
+
348
+ def forward(self, x, rope=None):
349
+ B, N, _ = x.shape
350
+ # Q: (B, N, num_heads, head_dim)
351
+ q = self.wq(x).view(B, N, self.num_heads, self.head_dim)
352
+ # K/V: (B, N, num_kv_heads, head_dim)
353
+ k = self.wk(x).view(B, N, self.num_kv_heads, self.head_dim)
354
+ v = self.wv(x).view(B, N, self.num_kv_heads, self.head_dim)
355
+
356
+ # (B, heads, N, head_dim)
357
+ q = q.permute(0, 2, 1, 3)
358
+ k = k.permute(0, 2, 1, 3)
359
+ v = v.permute(0, 2, 1, 3)
360
+
361
+ if self.use_qk_norm:
362
+ q = self.q_norm(q)
363
+ k = self.k_norm(k)
364
+
365
+ # Repeat KV heads if group ratio >1
366
+ if self.num_kv_heads != self.num_heads:
367
+ factor = self.num_heads // self.num_kv_heads
368
+ k = k.repeat_interleave(factor, dim=1)
369
+ v = v.repeat_interleave(factor, dim=1)
370
+
371
+ if rope is not None:
372
+ q, k = self.apply_rope(q, k, rope)
373
+
374
+ # Scaled dot-product attention
375
+ attn_out = self.attn_op(
376
+ q, k, v, dropout_p=self.attn_drop if self.training else 0.0
377
+ ) # (B, num_heads, N, head_dim)
378
+
379
+ # Merge heads -> (B, N, embed_dims)
380
+ out = attn_out.permute(0, 2, 1, 3).reshape(B, N, self.embed_dims)
381
+
382
+ # Output projection + drop + layer scale
383
+ out = self.proj(out)
384
+ out = self.gamma(self.proj_drop(out))
385
+
386
+ # Optional V-shortcut (only when MQA)
387
+ if self.v_shortcut and self.num_kv_heads == 1:
388
+ raise NotImplementedError
389
+ return out
390
+
391
+
392
+ # -------------------------------------------------------------------------------
393
+ class TransformerEncoderLayer2(nn.Module):
394
+ def __init__(
395
+ self,
396
+ embed_dims,
397
+ num_heads,
398
+ num_kv_heads=None,
399
+ feedforward_channels=None,
400
+ drop_rate=0.0,
401
+ attn_drop_rate=0.0,
402
+ layer_scale_init_value=0.0,
403
+ use_qk_norm=True,
404
+ qkv_bias=True,
405
+ ):
406
+ super(TransformerEncoderLayer2, self).__init__()
407
+
408
+ self.embed_dims = embed_dims
409
+ self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6)
410
+ self.attn = GroupedQueryAttention(
411
+ embed_dims=embed_dims,
412
+ num_heads=num_heads,
413
+ num_kv_heads=num_kv_heads,
414
+ attn_drop=attn_drop_rate,
415
+ proj_drop=drop_rate,
416
+ qkv_bias=qkv_bias,
417
+ layer_scale_init_value=layer_scale_init_value,
418
+ use_qk_norm=use_qk_norm,
419
+ )
420
+
421
+ self.ln2 = nn.RMSNorm(self.embed_dims, eps=1e-6)
422
+ self.ffn = SwiGLUFFN(
423
+ embed_dims=embed_dims,
424
+ feedforward_channels=feedforward_channels,
425
+ )
426
+
427
+ @property
428
+ def norm1(self):
429
+ return self.ln1
430
+
431
+ @property
432
+ def norm2(self):
433
+ return self.ln2
434
+
435
+ def forward(self, x, rope=None):
436
+ x = x + self.attn(self.ln1(x), rope=rope)
437
+ x = self.ffn(self.ln2(x), identity=x)
438
+ return x
439
+
440
+
441
+ ##-----------------------------------
442
+ class Sapiens2(nn.Module):
443
+ arch_zoo = {
444
+ **dict.fromkeys(
445
+ ["sapiens2_0.1b"],
446
+ {
447
+ "embed_dims": 768,
448
+ "num_layers": 12,
449
+ "num_heads": 12,
450
+ "feedforward_channels": 768 * 4,
451
+ "num_tokenizer_layers": 2,
452
+ },
453
+ ),
454
+ **dict.fromkeys(
455
+ ["sapiens2_0.4b"],
456
+ {
457
+ "embed_dims": 1024,
458
+ "num_layers": 24,
459
+ "num_heads": 16,
460
+ "feedforward_channels": 1024 * 4,
461
+ "num_tokenizer_layers": 2,
462
+ },
463
+ ),
464
+ **dict.fromkeys(
465
+ ["sapiens2_0.8b"],
466
+ {
467
+ "embed_dims": 1280,
468
+ "num_layers": 32,
469
+ "num_heads": 16,
470
+ "feedforward_channels": 1280 * 4,
471
+ "num_tokenizer_layers": 3,
472
+ },
473
+ ),
474
+ **dict.fromkeys(
475
+ ["sapiens2_1b"],
476
+ {
477
+ "embed_dims": 1536,
478
+ "num_layers": 40,
479
+ "num_heads": 24,
480
+ "feedforward_channels": 1536 * 4,
481
+ "num_tokenizer_layers": 4,
482
+ },
483
+ ),
484
+ **dict.fromkeys(
485
+ ["sapiens2_5b"],
486
+ {
487
+ "embed_dims": 2432,
488
+ "num_layers": 56,
489
+ "num_heads": 32,
490
+ "feedforward_channels": 2432 * 4,
491
+ "num_tokenizer_layers": 6,
492
+ },
493
+ ),
494
+ }
495
+
496
+ num_extra_tokens = 1 # class token
497
+ OUT_TYPES = {"raw", "cls_token", "featmap"}
498
+
499
+ def __init__(
500
+ self,
501
+ arch="sapiens2_1b",
502
+ img_size=(1024, 768),
503
+ patch_size=16,
504
+ in_channels=3,
505
+ out_indices=-1,
506
+ drop_rate=0.0,
507
+ window_size=4,
508
+ use_tokenizer=False, ## 4k resolution
509
+ use_qk_norm=True,
510
+ qkv_bias=True,
511
+ final_norm=True,
512
+ out_type="raw",
513
+ with_cls_token=True,
514
+ layer_scale_init_value=1e-4, ## non zero init to activate layerscale
515
+ frozen_stages=-1,
516
+ patch_cfg=dict(),
517
+ layer_cfgs=dict(),
518
+ pos_embed_rope_base: float = 100.0,
519
+ pos_embed_rope_min_period: float | None = None,
520
+ pos_embed_rope_max_period: float | None = None,
521
+ pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
522
+ pos_embed_rope_shift_coords: float | None = None,
523
+ pos_embed_rope_jitter_coords: float | None = None,
524
+ pos_embed_rope_rescale_coords: float | None = None,
525
+ pos_embed_rope_dtype: str = "bf16",
526
+ n_storage_tokens: int = 8,
527
+ ):
528
+ super().__init__()
529
+
530
+ arch = arch.lower()
531
+ assert arch in set(self.arch_zoo), (
532
+ f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
533
+ )
534
+ self.arch_settings = self.arch_zoo[arch]
535
+
536
+ self.embed_dims = self.arch_settings["embed_dims"]
537
+ self.num_layers = self.arch_settings["num_layers"]
538
+ self.patch_size = patch_size
539
+
540
+ self.window_size = window_size
541
+ img_size = to_2tuple(img_size)
542
+ encoder_img_size = (
543
+ (img_size[0] // window_size, img_size[1] // window_size)
544
+ if use_tokenizer
545
+ else img_size
546
+ )
547
+ self.img_size = to_2tuple(encoder_img_size)
548
+
549
+ # Set patch embedding
550
+ _patch_cfg = dict(
551
+ in_channels=in_channels,
552
+ input_size=self.img_size,
553
+ embed_dims=self.embed_dims,
554
+ kernel_size=patch_size,
555
+ stride=patch_size,
556
+ bias=True,
557
+ )
558
+ _patch_cfg.update(patch_cfg)
559
+ self.patch_embed = PatchEmbed(**_patch_cfg)
560
+ self.patch_resolution = self.patch_embed.init_out_size
561
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
562
+
563
+ self.rope_embed = RopePositionEmbedding(
564
+ embed_dim=self.embed_dims,
565
+ num_heads=self.arch_settings["num_heads"],
566
+ base=pos_embed_rope_base,
567
+ min_period=pos_embed_rope_min_period,
568
+ max_period=pos_embed_rope_max_period,
569
+ normalize_coords=pos_embed_rope_normalize_coords,
570
+ shift_coords=pos_embed_rope_shift_coords,
571
+ jitter_coords=pos_embed_rope_jitter_coords,
572
+ rescale_coords=pos_embed_rope_rescale_coords,
573
+ dtype=torch.bfloat16 if pos_embed_rope_dtype == "bf16" else torch.float32,
574
+ )
575
+
576
+ # Set out type
577
+ if out_type not in self.OUT_TYPES:
578
+ raise ValueError(
579
+ f"Unsupported `out_type` {out_type}, please "
580
+ f"choose from {self.OUT_TYPES}"
581
+ )
582
+ self.out_type = out_type
583
+
584
+ if use_tokenizer == True:
585
+ self.tokenizer = Tokenizer(
586
+ embed_dims=self.embed_dims,
587
+ window_size=self.window_size,
588
+ num_heads=self.arch_settings["num_heads"],
589
+ num_tokenizer_layers=self.arch_settings["num_tokenizer_layers"],
590
+ qkv_bias=True,
591
+ use_qk_norm=False,
592
+ )
593
+ else:
594
+ self.tokenizer = None
595
+
596
+ # Set cls + storage tokens
597
+ self.with_cls_token = with_cls_token
598
+ if with_cls_token:
599
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
600
+ elif out_type != "cls_token":
601
+ self.cls_token = None
602
+ self.num_extra_tokens = 0
603
+ else:
604
+ raise ValueError('with_cls_token must be True when `out_type="cls_token"`.')
605
+
606
+ ## registers
607
+ self.n_storage_tokens = int(n_storage_tokens)
608
+ self.storage_tokens = (
609
+ nn.Parameter(torch.zeros(1, self.n_storage_tokens, self.embed_dims))
610
+ if self.n_storage_tokens > 0
611
+ else None
612
+ )
613
+ # how many non-patch tokens are at the front
614
+ self.num_extra_tokens = (
615
+ 1 if self.cls_token is not None else 0
616
+ ) + self.n_storage_tokens
617
+
618
+ if isinstance(out_indices, int):
619
+ out_indices = [out_indices]
620
+ assert isinstance(out_indices, Sequence), (
621
+ f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.'
622
+ )
623
+ for i, index in enumerate(out_indices):
624
+ if index < 0:
625
+ out_indices[i] = self.num_layers + index
626
+ assert 0 <= out_indices[i] <= self.num_layers, (
627
+ f"Invalid out_indices {index}"
628
+ )
629
+ self.out_indices = out_indices
630
+
631
+ self.blocks = nn.Sequential()
632
+ if isinstance(layer_cfgs, dict):
633
+ layer_cfgs = [layer_cfgs] * self.num_layers
634
+
635
+ mhsa_early, mhsa_late = 8, 8
636
+ for i in range(self.num_layers):
637
+ if i < mhsa_early or i >= self.num_layers - mhsa_late:
638
+ num_kv_heads = None ## use MHSA
639
+ else:
640
+ num_kv_heads = self.arch_settings["num_heads"] // 2 # Use GQA
641
+
642
+ _layer_cfg = dict(
643
+ embed_dims=self.embed_dims,
644
+ num_heads=self.arch_settings["num_heads"],
645
+ num_kv_heads=num_kv_heads,
646
+ feedforward_channels=self.arch_settings["feedforward_channels"],
647
+ use_qk_norm=use_qk_norm,
648
+ layer_scale_init_value=layer_scale_init_value,
649
+ drop_rate=drop_rate,
650
+ qkv_bias=qkv_bias,
651
+ )
652
+ _layer_cfg.update(layer_cfgs[i])
653
+ self.blocks.append(TransformerEncoderLayer2(**_layer_cfg))
654
+
655
+ self.frozen_stages = frozen_stages
656
+
657
+ self.final_norm = final_norm
658
+ if final_norm:
659
+ self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6)
660
+
661
+ # freeze stages only when self.frozen_stages > 0
662
+ if self.frozen_stages > 0:
663
+ self._freeze_stages()
664
+
665
+ ## load init weights
666
+ self.init_weights()
667
+
668
+ return
669
+
670
+ def init_weights(self):
671
+ # Initialize class token and storagr token embeddings
672
+ if self.with_cls_token:
673
+ trunc_normal_(self.cls_token, std=0.02)
674
+
675
+ if self.storage_tokens is not None:
676
+ trunc_normal_(self.storage_tokens, std=0.02)
677
+
678
+ # Apply custom initialization to all submodules
679
+ self.apply(self._init_weights)
680
+
681
+ def _init_weights(self, m):
682
+ if isinstance(m, nn.Linear):
683
+ # Use a truncated normal distribution for linear layer weights
684
+ trunc_normal_(m.weight, std=0.02)
685
+ if m.bias is not None:
686
+ nn.init.constant_(m.bias, 0)
687
+
688
+ elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)):
689
+ # Initialize normalization layers to act as an identity function
690
+ if hasattr(m, "bias") and m.bias is not None:
691
+ nn.init.constant_(m.bias, 0)
692
+ if hasattr(m, "weight") and m.weight is not None:
693
+ nn.init.constant_(m.weight, 1.0)
694
+
695
+ elif isinstance(m, nn.Conv2d):
696
+ # Initialize conv layer weights like linear layers
697
+ trunc_normal_(m.weight, std=0.02)
698
+ if m.bias is not None:
699
+ nn.init.constant_(m.bias, 0)
700
+
701
+ def _freeze_stages(self):
702
+ ## freeze tokenizer
703
+ if self.frozen_stages >= 1 and self.tokenizer is not None:
704
+ self.tokenizer.eval()
705
+ for param in self.tokenizer.parameters():
706
+ param.requires_grad = False
707
+
708
+ # freeze patch embedding
709
+ self.patch_embed.eval()
710
+ for param in self.patch_embed.parameters():
711
+ param.requires_grad = False
712
+ # freeze cls_token
713
+ if self.cls_token is not None:
714
+ self.cls_token.requires_grad = False
715
+ if self.storage_tokens is not None:
716
+ self.storage_tokens.requires_grad = False
717
+ # freeze layers
718
+ for i in range(1, self.frozen_stages + 1):
719
+ m = self.blocks[i - 1]
720
+ m.eval()
721
+ for param in m.parameters():
722
+ param.requires_grad = False
723
+
724
+ # freeze the last layer norm
725
+ if self.frozen_stages == len(self.blocks):
726
+ if self.final_norm:
727
+ self.ln1.eval()
728
+ for param in self.ln1.parameters():
729
+ param.requires_grad = False
730
+
731
+ def forward(self, x):
732
+ B = x.shape[0]
733
+
734
+ x, patch_resolution = self.patch_embed(x) # (B, 256*256, C)
735
+ if self.tokenizer is not None:
736
+ x, patch_resolution = self.tokenizer(x, patch_resolution)
737
+
738
+ # prepend [CLS] and storage tokens
739
+ prepend = []
740
+ if self.cls_token is not None:
741
+ prepend.append(self.cls_token.expand(B, -1, -1))
742
+ if self.storage_tokens is not None:
743
+ prepend.append(self.storage_tokens.expand(B, -1, -1))
744
+ if len(prepend) > 0:
745
+ x = torch.cat(prepend + [x], dim=1)
746
+
747
+ rope_sincos = self.rope_embed(H=patch_resolution[0], W=patch_resolution[1])
748
+ outs = []
749
+ for i, layer in enumerate(self.blocks):
750
+ x = layer(x, rope=rope_sincos)
751
+
752
+ if i == len(self.blocks) - 1 and self.final_norm:
753
+ x = self.ln1(x)
754
+
755
+ if i in self.out_indices:
756
+ outs.append(self._format_output(x, patch_resolution))
757
+
758
+ return tuple(outs)
759
+
760
+ def _format_output(self, x, hw):
761
+ if self.out_type == "raw":
762
+ return x
763
+ if self.out_type == "cls_token":
764
+ return x[:, 0]
765
+
766
+ patch_token = x[:, self.num_extra_tokens :]
767
+ if self.out_type == "featmap":
768
+ B = x.size(0)
769
+ # (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
770
+ return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
771
+
772
+ @property
773
+ def norm1(self):
774
+ return self.ln1
775
+
776
+
777
+ # ----------------------------------------------------------------------------
778
+ class LayerScale(nn.Module):
779
+ def __init__(
780
+ self,
781
+ dim: int,
782
+ inplace: bool = False,
783
+ data_format: str = "channels_last",
784
+ scale: float = 1e-5,
785
+ ):
786
+ super().__init__()
787
+ assert data_format in (
788
+ "channels_last",
789
+ "channels_first",
790
+ ), "'data_format' could only be channels_last or channels_first."
791
+ self.inplace = inplace
792
+ self.data_format = data_format
793
+ self.weight = nn.Parameter(torch.ones(dim) * scale)
794
+
795
+ def forward(self, x) -> torch.Tensor:
796
+ if self.data_format == "channels_first":
797
+ shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
798
+ else:
799
+ shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
800
+ if self.inplace:
801
+ return x.mul_(self.weight.view(*shape))
802
+ else:
803
+ return x * self.weight.view(*shape)
804
+
805
+
806
+ # ----------------------------------------------------------------------------
807
+ class PatchEmbed(nn.Module):
808
+ def __init__(
809
+ self,
810
+ in_channels=3,
811
+ embed_dims=768,
812
+ kernel_size=16,
813
+ stride=16,
814
+ padding="corner",
815
+ dilation=1,
816
+ bias=True,
817
+ input_size=None,
818
+ ):
819
+ super().__init__()
820
+
821
+ self.embed_dims = embed_dims
822
+ if stride is None:
823
+ stride = kernel_size
824
+
825
+ kernel_size = to_2tuple(kernel_size)
826
+ stride = to_2tuple(stride)
827
+ dilation = to_2tuple(dilation)
828
+ padding = 0
829
+ padding = to_2tuple(padding)
830
+
831
+ self.projection = nn.Conv2d(
832
+ in_channels=in_channels,
833
+ out_channels=embed_dims,
834
+ kernel_size=kernel_size,
835
+ stride=stride,
836
+ padding=padding,
837
+ dilation=dilation,
838
+ bias=bias,
839
+ )
840
+
841
+ if input_size:
842
+ input_size = to_2tuple(input_size)
843
+ self.init_input_size = input_size
844
+ h_out = (
845
+ input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
846
+ ) // stride[0] + 1
847
+ w_out = (
848
+ input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
849
+ ) // stride[1] + 1
850
+ self.init_out_size = (h_out, w_out)
851
+ else:
852
+ self.init_input_size = None
853
+ self.init_out_size = None
854
+
855
+ def forward(self, x):
856
+ x = self.projection(x)
857
+ out_size = (x.shape[2], x.shape[3])
858
+ x = x.flatten(2).transpose(1, 2)
859
+ return x, out_size
860
+
861
+
862
+ # ----------------------------------------------------------------------------
863
+ class SwiGLUFFN(nn.Module):
864
+ """SwiGLU FFN layer.
865
+ https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
866
+ """ # noqa
867
+
868
+ def __init__(
869
+ self,
870
+ embed_dims: int,
871
+ feedforward_channels: Optional[int] = None,
872
+ out_dims: Optional[int] = None,
873
+ layer_scale_init_value: float = 0.0,
874
+ bias: bool = True,
875
+ add_identity: bool = True,
876
+ ) -> None:
877
+ super().__init__()
878
+ self.embed_dims = embed_dims
879
+ self.out_dims = out_dims or embed_dims
880
+ hidden_dims = feedforward_channels or embed_dims
881
+
882
+ self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)
883
+ self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias)
884
+
885
+ if layer_scale_init_value > 0:
886
+ self.gamma2 = LayerScale(dim=embed_dims, scale=layer_scale_init_value)
887
+ else:
888
+ self.gamma2 = nn.Identity()
889
+
890
+ self.add_identity = add_identity
891
+
892
+ def forward(
893
+ self, x: torch.Tensor, identity: Optional[torch.Tensor] = None
894
+ ) -> torch.Tensor:
895
+ x12 = self.w12(x)
896
+ x1, x2 = x12.chunk(2, dim=-1)
897
+ hidden = F.silu(x1) * x2
898
+ out = self.w3(hidden)
899
+ out = self.gamma2(out)
900
+
901
+ if self.out_dims != self.embed_dims or not self.add_identity:
902
+ # due to the dimension inconsistence or user setting
903
+ # not to apply residual operation
904
+ return out
905
+
906
+ if identity is None:
907
+ identity = x
908
+ return identity + out
sapiens/dense/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import pathlib
8
+ import pkgutil
9
+
10
+ from .. import __version__
11
+
12
+ _src = pathlib.Path(__file__).with_name("src")
13
+ __path__ = pkgutil.extend_path(__path__, __name__) # allow namespace merge
14
+ __path__.append(str(_src))
15
+ del pathlib, pkgutil, _src
16
+
17
+
18
+ # -----------------------------------------------------
19
+ from importlib import import_module as _imp
20
+
21
+ _pkg = _imp(__name__ + ".src") # runs src/__init__.py
sapiens/dense/configs/albedo/render_people/sapiens2_0.4b_albedo_render_people-1024x768.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 2e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 10
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_0.4b"
34
+ embed_dim = 1024
35
+ num_layers = 24
36
+ num_heads = 16
37
+
38
+ layer_decay_rate = 0.8
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors"
40
+
41
+ ##-----------------------------------------------------------------
42
+ image_size = (1024, 768) ## height x width
43
+ patch_size = 16
44
+
45
+ # ------------------------------------------------------------------
46
+ use_fsdp = True
47
+ # use_fsdp = False
48
+
49
+ use_compile = True
50
+ # use_compile = False
51
+
52
+ ## DDP config
53
+ if use_fsdp is False:
54
+ accelerator_cfg = dict(
55
+ type="DDP",
56
+ log_with="tensorboard",
57
+ # find_unused_parameters=True,
58
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
59
+ max_interval=num_iters,
60
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
61
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
62
+ )
63
+
64
+ else:
65
+ accelerator_cfg = dict(
66
+ type="FSDP",
67
+ log_with="tensorboard",
68
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
69
+ max_interval=num_iters,
70
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
71
+ step_scheduler_with_optimizer=False,
72
+ fsdp_cfg=dict(
73
+ fsdp_version=2, # DTensor-based engine
74
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
75
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
76
+ mixed_precision=dict(
77
+ param_dtype="bf16",
78
+ reduce_dtype="bf16",
79
+ ),
80
+ cpu_ram_efficient_loading=False,
81
+ ),
82
+ )
83
+
84
+ if use_compile:
85
+ accelerator_cfg["compile_cfg"] = dict(
86
+ backend="inductor",
87
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
88
+ fullgraph=False,
89
+ dynamic=False,
90
+ )
91
+
92
+ # ------------------------------------------------------------------
93
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
94
+ logger = dict(
95
+ type="Logger",
96
+ log_interval=log_every_iters,
97
+ )
98
+ checkpoint = dict(
99
+ type="Checkpointer",
100
+ save_interval=save_every_iters,
101
+ )
102
+
103
+ visualizer = dict(
104
+ type="AlbedoVisualizer",
105
+ vis_interval=vis_every_iters,
106
+ vis_max_samples=4,
107
+ vis_image_width=384,
108
+ vis_image_height=512,
109
+ )
110
+
111
+
112
+ ##-----------------------------------------------------------------
113
+ train_pipeline = [
114
+ dict(
115
+ type="AlbedoRandomScale",
116
+ scale_min=0.5,
117
+ scale_max=2.0,
118
+ prob=0.3,
119
+ ),
120
+ dict(
121
+ type="AlbedoRandomCropContinuous",
122
+ ar_range=(0.5, 2.0),
123
+ area_range=(0.4, 1.0),
124
+ num_attempts=8,
125
+ prob=0.3,
126
+ ),
127
+ dict(
128
+ type="AlbedoRandomFlip",
129
+ prob=0.3,
130
+ ),
131
+ dict(type="AlbedoResize", height=1024, width=768),
132
+ dict(type="RandomGaussianNoise", prob=0.2, var_range=(5.0, 20.0)),
133
+ dict(
134
+ type="AlbedoPackInputs",
135
+ meta_keys=(
136
+ "img_path",
137
+ "ori_shape",
138
+ ),
139
+ ),
140
+ ]
141
+
142
+ val_pipeline = [
143
+ dict(type="AlbedoResize", height=1024, width=768, test_mode=True),
144
+ dict(
145
+ type="AlbedoPackInputs",
146
+ test_mode=True,
147
+ meta_keys=(
148
+ "img_path",
149
+ "orig_img_height",
150
+ "orig_img_width",
151
+ "img_shape",
152
+ "pad_shape",
153
+ ),
154
+ ),
155
+ ]
156
+
157
+ test_pipeline = [
158
+ dict(type="AlbedoResizePadImage", height=1024, width=768, pad_val=0),
159
+ dict(
160
+ type="AlbedoPackInputs",
161
+ meta_keys=(
162
+ "img_path",
163
+ "orig_img_height",
164
+ "orig_img_width",
165
+ "padding_size",
166
+ ),
167
+ ),
168
+ ]
169
+
170
+
171
+ render_people_dataset = dict(
172
+ type="AlbedoRenderPeopleDataset",
173
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo",
174
+ )
175
+
176
+ train_datasets = [render_people_dataset]
177
+
178
+ train_dataloader = dict(
179
+ batch_size=1,
180
+ num_workers=4,
181
+ persistent_workers=True,
182
+ shuffle=True,
183
+ dataset=dict(
184
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
185
+ ),
186
+ )
187
+
188
+ val_dataloader = dict(
189
+ batch_size=4,
190
+ num_workers=4,
191
+ persistent_workers=True,
192
+ multiprocessing_context="spawn",
193
+ # num_workers=0, # debug
194
+ # persistent_workers=False, # debug
195
+ shuffle=False,
196
+ dataset=dict(
197
+ type="AlbedoRenderPeopleDataset",
198
+ test_mode=True,
199
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo_test",
200
+ pipeline=val_pipeline,
201
+ ),
202
+ )
203
+
204
+ val_cfg = dict(
205
+ val_interval=val_every_iters,
206
+ evaluator=dict(
207
+ type="AlbedoEvaluator",
208
+ ),
209
+ )
210
+
211
+ data_preprocessor = dict(
212
+ type="ImagePreprocessor",
213
+ mean=[123.675, 116.28, 103.53],
214
+ std=[58.395, 57.12, 57.375],
215
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
216
+ )
217
+
218
+ ##-----------------------------------------------------------------
219
+ model = dict(
220
+ type="AlbedoEstimator",
221
+ backbone=dict(
222
+ type="Sapiens2",
223
+ arch=model_name,
224
+ img_size=image_size,
225
+ patch_size=patch_size,
226
+ final_norm=True,
227
+ use_tokenizer=False,
228
+ with_cls_token=True,
229
+ out_type="featmap",
230
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
231
+ ),
232
+ decode_head=dict(
233
+ type="AlbedoHead",
234
+ in_channels=embed_dim,
235
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
236
+ conv_out_channels=[64, 32, 16],
237
+ conv_kernel_sizes=[3, 3, 3],
238
+ loss_decode=[
239
+ dict(type="L1Loss", loss_weight=2.0),
240
+ dict(type="AlbedoGradL1Loss", loss_weight=2.0),
241
+ # dict(type="AlbedoLowFreqL1Loss", down_sample=32, loss_weight=1.0),
242
+ dict(type="AlbedoChromaticityL1Loss", loss_weight=1.0),
243
+ ],
244
+ ),
245
+ )
246
+
247
+
248
+ ##-----------------------------------------------------------------
249
+ optimizer = dict(
250
+ type="AdamW",
251
+ lr=5e-4,
252
+ betas=(0.9, 0.999),
253
+ weight_decay=0.1,
254
+ paramwise_cfg=dict(
255
+ num_layers=num_layers,
256
+ layer_decay_rate=layer_decay_rate,
257
+ ),
258
+ fused=True,
259
+ )
260
+
261
+ scheduler = dict(
262
+ type="SequentialLR",
263
+ milestones=[warmup_iters],
264
+ schedulers=[
265
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
266
+ dict(
267
+ type="PolynomialLR",
268
+ total_iters=num_iters - warmup_iters,
269
+ power=1.0,
270
+ ),
271
+ ],
272
+ )
273
+
274
+ clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0)
sapiens/dense/configs/albedo/render_people/sapiens2_0.8b_albedo_render_people-1024x768.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 2e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ # val_every_iters = 2000
22
+ val_every_iters = 10000
23
+
24
+ # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 10
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_0.8b"
35
+ embed_dim = 1280
36
+ num_layers = 32
37
+ num_heads = 16
38
+
39
+ layer_decay_rate = 0.85
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ # ------------------------------------------------------------------
47
+ use_fsdp = True
48
+ # use_fsdp = False
49
+
50
+ use_compile = True
51
+ # use_compile = False
52
+
53
+ ## DDP config
54
+ if use_fsdp is False:
55
+ accelerator_cfg = dict(
56
+ type="DDP",
57
+ log_with="tensorboard",
58
+ # find_unused_parameters=True,
59
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
60
+ max_interval=num_iters,
61
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
62
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
63
+ )
64
+
65
+ else:
66
+ accelerator_cfg = dict(
67
+ type="FSDP",
68
+ log_with="tensorboard",
69
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
70
+ max_interval=num_iters,
71
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
72
+ step_scheduler_with_optimizer=False,
73
+ fsdp_cfg=dict(
74
+ fsdp_version=2, # DTensor-based engine
75
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
76
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
77
+ mixed_precision=dict(
78
+ param_dtype="bf16",
79
+ reduce_dtype="bf16",
80
+ ),
81
+ cpu_ram_efficient_loading=False,
82
+ ),
83
+ )
84
+
85
+ if use_compile:
86
+ accelerator_cfg["compile_cfg"] = dict(
87
+ backend="inductor",
88
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
89
+ fullgraph=False,
90
+ dynamic=False,
91
+ )
92
+
93
+ # ------------------------------------------------------------------
94
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
95
+ logger = dict(
96
+ type="Logger",
97
+ log_interval=log_every_iters,
98
+ )
99
+ checkpoint = dict(
100
+ type="Checkpointer",
101
+ save_interval=save_every_iters,
102
+ )
103
+
104
+ visualizer = dict(
105
+ type="AlbedoVisualizer",
106
+ vis_interval=vis_every_iters,
107
+ vis_max_samples=4,
108
+ vis_image_width=384,
109
+ vis_image_height=512,
110
+ )
111
+
112
+
113
+ ##-----------------------------------------------------------------
114
+ train_pipeline = [
115
+ dict(
116
+ type="AlbedoRandomScale",
117
+ scale_min=0.5,
118
+ scale_max=2.0,
119
+ prob=0.3,
120
+ ),
121
+ dict(
122
+ type="AlbedoRandomCropContinuous",
123
+ ar_range=(0.5, 2.0),
124
+ area_range=(0.4, 1.0),
125
+ num_attempts=8,
126
+ prob=0.3,
127
+ ),
128
+ dict(
129
+ type="AlbedoRandomFlip",
130
+ prob=0.3,
131
+ ),
132
+ dict(type="AlbedoResize", height=1024, width=768),
133
+ dict(type="RandomGaussianNoise", prob=0.2, var_range=(5.0, 20.0)),
134
+ dict(
135
+ type="AlbedoPackInputs",
136
+ meta_keys=(
137
+ "img_path",
138
+ "ori_shape",
139
+ ),
140
+ ),
141
+ ]
142
+
143
+ val_pipeline = [
144
+ dict(type="AlbedoResize", height=1024, width=768, test_mode=True),
145
+ dict(
146
+ type="AlbedoPackInputs",
147
+ test_mode=True,
148
+ meta_keys=(
149
+ "img_path",
150
+ "orig_img_height",
151
+ "orig_img_width",
152
+ "img_shape",
153
+ "pad_shape",
154
+ ),
155
+ ),
156
+ ]
157
+
158
+ test_pipeline = [
159
+ dict(type="AlbedoResizePadImage", height=1024, width=768, pad_val=0),
160
+ dict(
161
+ type="AlbedoPackInputs",
162
+ meta_keys=(
163
+ "img_path",
164
+ "orig_img_height",
165
+ "orig_img_width",
166
+ "padding_size",
167
+ ),
168
+ ),
169
+ ]
170
+
171
+
172
+ render_people_dataset = dict(
173
+ type="AlbedoRenderPeopleDataset",
174
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo",
175
+ )
176
+
177
+ train_datasets = [render_people_dataset]
178
+
179
+ train_dataloader = dict(
180
+ batch_size=1,
181
+ num_workers=4,
182
+ persistent_workers=True,
183
+ shuffle=True,
184
+ dataset=dict(
185
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
186
+ ),
187
+ )
188
+
189
+ val_dataloader = dict(
190
+ batch_size=4,
191
+ num_workers=4,
192
+ persistent_workers=True,
193
+ multiprocessing_context="spawn",
194
+ # num_workers=0, # debug
195
+ # persistent_workers=False, # debug
196
+ shuffle=False,
197
+ dataset=dict(
198
+ type="AlbedoRenderPeopleDataset",
199
+ test_mode=True,
200
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo_test",
201
+ pipeline=val_pipeline,
202
+ ),
203
+ )
204
+
205
+ val_cfg = dict(
206
+ val_interval=val_every_iters,
207
+ evaluator=dict(
208
+ type="AlbedoEvaluator",
209
+ ),
210
+ )
211
+
212
+ data_preprocessor = dict(
213
+ type="ImagePreprocessor",
214
+ mean=[123.675, 116.28, 103.53],
215
+ std=[58.395, 57.12, 57.375],
216
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
217
+ )
218
+
219
+ ##-----------------------------------------------------------------
220
+ model = dict(
221
+ type="AlbedoEstimator",
222
+ backbone=dict(
223
+ type="Sapiens2",
224
+ arch=model_name,
225
+ img_size=image_size,
226
+ patch_size=patch_size,
227
+ final_norm=True,
228
+ use_tokenizer=False,
229
+ with_cls_token=True,
230
+ out_type="featmap",
231
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
232
+ ),
233
+ decode_head=dict(
234
+ type="AlbedoHead",
235
+ in_channels=embed_dim,
236
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
237
+ conv_out_channels=[64, 32, 16],
238
+ conv_kernel_sizes=[3, 3, 3],
239
+ loss_decode=[
240
+ dict(type="L1Loss", loss_weight=2.0),
241
+ dict(type="AlbedoGradL1Loss", loss_weight=2.0),
242
+ # dict(type="AlbedoLowFreqL1Loss", down_sample=32, loss_weight=1.0),
243
+ dict(type="AlbedoChromaticityL1Loss", loss_weight=1.0),
244
+ ],
245
+ ),
246
+ )
247
+
248
+
249
+ ##-----------------------------------------------------------------
250
+ optimizer = dict(
251
+ type="AdamW",
252
+ lr=5e-4,
253
+ betas=(0.9, 0.999),
254
+ weight_decay=0.1,
255
+ paramwise_cfg=dict(
256
+ num_layers=num_layers,
257
+ layer_decay_rate=layer_decay_rate,
258
+ ),
259
+ fused=True,
260
+ )
261
+
262
+ scheduler = dict(
263
+ type="SequentialLR",
264
+ milestones=[warmup_iters],
265
+ schedulers=[
266
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
267
+ dict(
268
+ type="PolynomialLR",
269
+ total_iters=num_iters - warmup_iters,
270
+ power=1.0,
271
+ ),
272
+ ],
273
+ )
274
+
275
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/albedo/render_people/sapiens2_1b_albedo_render_people-1024x768.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 10
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_1b"
34
+ embed_dim = 1536
35
+ num_layers = 40
36
+ num_heads = 24
37
+ layer_decay_rate = 0.9
38
+
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors"
40
+
41
+ ##-----------------------------------------------------------------
42
+ image_size = (1024, 768) ## height x width
43
+ patch_size = 16
44
+
45
+ # ------------------------------------------------------------------
46
+ use_fsdp = True
47
+ # use_fsdp = False
48
+
49
+ use_compile = True
50
+ # use_compile = False
51
+
52
+ ## DDP config
53
+ if use_fsdp is False:
54
+ accelerator_cfg = dict(
55
+ type="DDP",
56
+ log_with="tensorboard",
57
+ # find_unused_parameters=True,
58
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
59
+ max_interval=num_iters,
60
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
61
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
62
+ )
63
+
64
+ else:
65
+ accelerator_cfg = dict(
66
+ type="FSDP",
67
+ log_with="tensorboard",
68
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
69
+ max_interval=num_iters,
70
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
71
+ step_scheduler_with_optimizer=False,
72
+ fsdp_cfg=dict(
73
+ fsdp_version=2, # DTensor-based engine
74
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
75
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
76
+ mixed_precision=dict(
77
+ param_dtype="bf16",
78
+ reduce_dtype="bf16",
79
+ ),
80
+ cpu_ram_efficient_loading=False,
81
+ ),
82
+ )
83
+
84
+ if use_compile:
85
+ accelerator_cfg["compile_cfg"] = dict(
86
+ backend="inductor",
87
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
88
+ fullgraph=False,
89
+ dynamic=False,
90
+ )
91
+
92
+ # ------------------------------------------------------------------
93
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
94
+ logger = dict(
95
+ type="Logger",
96
+ log_interval=log_every_iters,
97
+ )
98
+ checkpoint = dict(
99
+ type="Checkpointer",
100
+ save_interval=save_every_iters,
101
+ )
102
+
103
+ visualizer = dict(
104
+ type="AlbedoVisualizer",
105
+ vis_interval=vis_every_iters,
106
+ vis_max_samples=4,
107
+ vis_image_width=384,
108
+ vis_image_height=512,
109
+ )
110
+
111
+
112
+ ##-----------------------------------------------------------------
113
+ train_pipeline = [
114
+ dict(
115
+ type="AlbedoRandomScale",
116
+ scale_min=0.5,
117
+ scale_max=2.0,
118
+ prob=0.3,
119
+ ),
120
+ dict(
121
+ type="AlbedoRandomCropContinuous",
122
+ ar_range=(0.5, 2.0),
123
+ area_range=(0.4, 1.0),
124
+ num_attempts=8,
125
+ prob=0.3,
126
+ ),
127
+ dict(
128
+ type="AlbedoRandomFlip",
129
+ prob=0.3,
130
+ ),
131
+ dict(type="AlbedoResize", height=1024, width=768),
132
+ dict(type="RandomGaussianNoise", prob=0.2, var_range=(5.0, 20.0)),
133
+ dict(
134
+ type="AlbedoPackInputs",
135
+ meta_keys=(
136
+ "img_path",
137
+ "ori_shape",
138
+ ),
139
+ ),
140
+ ]
141
+
142
+ val_pipeline = [
143
+ dict(type="AlbedoResize", height=1024, width=768, test_mode=True),
144
+ dict(
145
+ type="AlbedoPackInputs",
146
+ test_mode=True,
147
+ meta_keys=(
148
+ "img_path",
149
+ "orig_img_height",
150
+ "orig_img_width",
151
+ "img_shape",
152
+ "pad_shape",
153
+ ),
154
+ ),
155
+ ]
156
+
157
+ test_pipeline = [
158
+ dict(type="AlbedoResizePadImage", height=1024, width=768, pad_val=0),
159
+ dict(
160
+ type="AlbedoPackInputs",
161
+ meta_keys=(
162
+ "img_path",
163
+ "orig_img_height",
164
+ "orig_img_width",
165
+ "padding_size",
166
+ ),
167
+ ),
168
+ ]
169
+
170
+
171
+ render_people_dataset = dict(
172
+ type="AlbedoRenderPeopleDataset",
173
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo",
174
+ )
175
+
176
+ train_datasets = [render_people_dataset]
177
+
178
+ train_dataloader = dict(
179
+ batch_size=1,
180
+ num_workers=4,
181
+ persistent_workers=True,
182
+ shuffle=True,
183
+ dataset=dict(
184
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
185
+ ),
186
+ )
187
+
188
+ val_dataloader = dict(
189
+ batch_size=4,
190
+ num_workers=4,
191
+ persistent_workers=True,
192
+ multiprocessing_context="spawn",
193
+ # num_workers=0, # debug
194
+ # persistent_workers=False, # debug
195
+ shuffle=False,
196
+ dataset=dict(
197
+ type="AlbedoRenderPeopleDataset",
198
+ test_mode=True,
199
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo_test",
200
+ pipeline=val_pipeline,
201
+ ),
202
+ )
203
+
204
+ val_cfg = dict(
205
+ val_interval=val_every_iters,
206
+ evaluator=dict(
207
+ type="AlbedoEvaluator",
208
+ ),
209
+ )
210
+
211
+ data_preprocessor = dict(
212
+ type="ImagePreprocessor",
213
+ mean=[123.675, 116.28, 103.53],
214
+ std=[58.395, 57.12, 57.375],
215
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
216
+ )
217
+
218
+ ##-----------------------------------------------------------------
219
+ model = dict(
220
+ type="AlbedoEstimator",
221
+ backbone=dict(
222
+ type="Sapiens2",
223
+ arch=model_name,
224
+ img_size=image_size,
225
+ patch_size=patch_size,
226
+ final_norm=True,
227
+ use_tokenizer=False,
228
+ with_cls_token=True,
229
+ out_type="featmap",
230
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
231
+ ),
232
+ decode_head=dict(
233
+ type="AlbedoHead",
234
+ in_channels=embed_dim,
235
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
236
+ conv_out_channels=[64, 32, 16],
237
+ conv_kernel_sizes=[3, 3, 3],
238
+ loss_decode=[
239
+ dict(type="L1Loss", loss_weight=2.0),
240
+ dict(type="AlbedoGradL1Loss", loss_weight=2.0),
241
+ # dict(type="AlbedoLowFreqL1Loss", down_sample=32, loss_weight=1.0),
242
+ dict(type="AlbedoChromaticityL1Loss", loss_weight=1.0),
243
+ ],
244
+ ),
245
+ )
246
+
247
+
248
+ ##-----------------------------------------------------------------
249
+ optimizer = dict(
250
+ type="AdamW",
251
+ lr=5e-4,
252
+ betas=(0.9, 0.999),
253
+ weight_decay=0.1,
254
+ paramwise_cfg=dict(
255
+ num_layers=num_layers,
256
+ layer_decay_rate=layer_decay_rate,
257
+ ),
258
+ fused=True,
259
+ )
260
+
261
+ scheduler = dict(
262
+ type="SequentialLR",
263
+ milestones=[warmup_iters],
264
+ schedulers=[
265
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
266
+ dict(
267
+ type="PolynomialLR",
268
+ total_iters=num_iters - warmup_iters,
269
+ power=1.0,
270
+ ),
271
+ ],
272
+ )
273
+
274
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/albedo/render_people/sapiens2_5b_albedo_render_people-1024x768.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ # val_every_iters = 1000
22
+ # val_every_iters = 20000
23
+ val_every_iters = 40000
24
+
25
+ # # debug
26
+ # vis_every_iters = 1
27
+ # log_every_iters = 1
28
+ # val_every_iters = 10
29
+ # save_every_iters = 1000
30
+
31
+ load_from = None
32
+ resume = False
33
+
34
+ # ------------------------------------------------------------------
35
+ model_name = "sapiens2_5b"
36
+ embed_dim = 2432
37
+ num_layers = 56
38
+ num_heads = 32
39
+ layer_decay_rate = 0.94
40
+
41
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors"
42
+
43
+ ##-----------------------------------------------------------------
44
+ image_size = (1024, 768) ## height x width
45
+ patch_size = 16
46
+
47
+ # ------------------------------------------------------------------
48
+ use_fsdp = True
49
+ # use_fsdp = False
50
+
51
+ use_compile = True
52
+ # use_compile = False
53
+
54
+ ## DDP config
55
+ if use_fsdp is False:
56
+ accelerator_cfg = dict(
57
+ type="DDP",
58
+ log_with="tensorboard",
59
+ # find_unused_parameters=True,
60
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
61
+ max_interval=num_iters,
62
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
63
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
64
+ )
65
+
66
+ else:
67
+ accelerator_cfg = dict(
68
+ type="FSDP",
69
+ log_with="tensorboard",
70
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
71
+ max_interval=num_iters,
72
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
73
+ step_scheduler_with_optimizer=False,
74
+ fsdp_cfg=dict(
75
+ fsdp_version=2, # DTensor-based engine
76
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
77
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
78
+ mixed_precision=dict(
79
+ param_dtype="bf16",
80
+ reduce_dtype="bf16",
81
+ ),
82
+ cpu_ram_efficient_loading=False,
83
+ ),
84
+ # parallelism_cfg=dict(
85
+ # dp_shard_size=2, # Fully Sharded Data Parallel degree
86
+ # dp_replicate_size=1, # Data Parallel degree
87
+ # tp_size=1, # Tensor Parallel degree
88
+ # cp_size=4, # Context Parallel degree
89
+ # ),
90
+ )
91
+
92
+ if use_compile:
93
+ accelerator_cfg["compile_cfg"] = dict(
94
+ backend="inductor",
95
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
96
+ fullgraph=False,
97
+ dynamic=False,
98
+ )
99
+
100
+ # ------------------------------------------------------------------
101
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
102
+ logger = dict(
103
+ type="Logger",
104
+ log_interval=log_every_iters,
105
+ )
106
+ checkpoint = dict(
107
+ type="Checkpointer",
108
+ save_interval=save_every_iters,
109
+ )
110
+
111
+ visualizer = dict(
112
+ type="AlbedoVisualizer",
113
+ vis_interval=vis_every_iters,
114
+ vis_max_samples=4,
115
+ vis_image_width=384,
116
+ vis_image_height=512,
117
+ )
118
+
119
+
120
+ ##-----------------------------------------------------------------
121
+ train_pipeline = [
122
+ dict(
123
+ type="AlbedoRandomScale",
124
+ scale_min=0.5,
125
+ scale_max=2.0,
126
+ prob=0.3,
127
+ ),
128
+ dict(
129
+ type="AlbedoRandomCropContinuous",
130
+ ar_range=(0.5, 2.0),
131
+ area_range=(0.4, 1.0),
132
+ num_attempts=8,
133
+ prob=0.3,
134
+ ),
135
+ dict(
136
+ type="AlbedoRandomFlip",
137
+ prob=0.3,
138
+ ),
139
+ dict(type="AlbedoResize", height=1024, width=768),
140
+ dict(type="RandomGaussianNoise", prob=0.2, var_range=(5.0, 20.0)),
141
+ dict(
142
+ type="AlbedoPackInputs",
143
+ meta_keys=(
144
+ "img_path",
145
+ "ori_shape",
146
+ ),
147
+ ),
148
+ ]
149
+
150
+ val_pipeline = [
151
+ dict(type="AlbedoResize", height=1024, width=768, test_mode=True),
152
+ dict(
153
+ type="AlbedoPackInputs",
154
+ test_mode=True,
155
+ meta_keys=(
156
+ "img_path",
157
+ "orig_img_height",
158
+ "orig_img_width",
159
+ "img_shape",
160
+ "pad_shape",
161
+ ),
162
+ ),
163
+ ]
164
+
165
+ test_pipeline = [
166
+ dict(type="AlbedoResizePadImage", height=1024, width=768, pad_val=0),
167
+ dict(
168
+ type="AlbedoPackInputs",
169
+ meta_keys=(
170
+ "img_path",
171
+ "orig_img_height",
172
+ "orig_img_width",
173
+ "padding_size",
174
+ ),
175
+ ),
176
+ ]
177
+
178
+
179
+ render_people_dataset = dict(
180
+ type="AlbedoRenderPeopleDataset",
181
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo",
182
+ )
183
+
184
+ train_datasets = [render_people_dataset]
185
+
186
+ train_dataloader = dict(
187
+ batch_size=1,
188
+ num_workers=4,
189
+ persistent_workers=True,
190
+ shuffle=True,
191
+ dataset=dict(
192
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
193
+ ),
194
+ )
195
+
196
+ val_dataloader = dict(
197
+ batch_size=4,
198
+ num_workers=4,
199
+ persistent_workers=True,
200
+ multiprocessing_context="spawn",
201
+ shuffle=False,
202
+ dataset=dict(
203
+ type="AlbedoRenderPeopleDataset",
204
+ test_mode=True,
205
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo_test",
206
+ pipeline=val_pipeline,
207
+ ),
208
+ )
209
+
210
+ val_cfg = dict(
211
+ val_interval=val_every_iters,
212
+ evaluator=dict(
213
+ type="AlbedoEvaluator",
214
+ ),
215
+ )
216
+
217
+ data_preprocessor = dict(
218
+ type="ImagePreprocessor",
219
+ mean=[123.675, 116.28, 103.53],
220
+ std=[58.395, 57.12, 57.375],
221
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
222
+ )
223
+
224
+ ##-----------------------------------------------------------------
225
+ model = dict(
226
+ type="AlbedoEstimator",
227
+ backbone=dict(
228
+ type="Sapiens2",
229
+ arch=model_name,
230
+ img_size=image_size,
231
+ patch_size=patch_size,
232
+ final_norm=True,
233
+ use_tokenizer=False,
234
+ with_cls_token=True,
235
+ out_type="featmap",
236
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
237
+ ),
238
+ decode_head=dict(
239
+ type="AlbedoHead",
240
+ in_channels=embed_dim,
241
+ upsample_channels=[1536, 768, 512, 256], ## 1K resolution
242
+ conv_out_channels=[64, 32, 16],
243
+ conv_kernel_sizes=[3, 3, 3],
244
+ loss_decode=[
245
+ dict(type="L1Loss", loss_weight=2.0),
246
+ dict(type="AlbedoGradL1Loss", loss_weight=2.0),
247
+ # dict(type="AlbedoLowFreqL1Loss", down_sample=32, loss_weight=1.0),
248
+ dict(type="AlbedoChromaticityL1Loss", loss_weight=1.0),
249
+ ],
250
+ ),
251
+ )
252
+
253
+
254
+ ##-----------------------------------------------------------------
255
+ optimizer = dict(
256
+ type="AdamW",
257
+ lr=5e-4,
258
+ betas=(0.9, 0.999),
259
+ weight_decay=0.1,
260
+ paramwise_cfg=dict(
261
+ num_layers=num_layers,
262
+ layer_decay_rate=layer_decay_rate,
263
+ ),
264
+ fused=True,
265
+ )
266
+
267
+ scheduler = dict(
268
+ type="SequentialLR",
269
+ milestones=[warmup_iters],
270
+ schedulers=[
271
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
272
+ dict(
273
+ type="PolynomialLR",
274
+ total_iters=num_iters - warmup_iters,
275
+ power=1.0,
276
+ ),
277
+ ],
278
+ )
279
+
280
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.4b_normal_metasim_render_people-1024x768.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 2e4
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 2
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_0.4b"
34
+ embed_dim = 1024
35
+ num_layers = 24
36
+ num_heads = 16
37
+
38
+ layer_decay_rate = 0.8
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors"
40
+
41
+ ##-----------------------------------------------------------------
42
+ image_size = (1024, 768) ## height x width
43
+ patch_size = 16
44
+
45
+ # ------------------------------------------------------------------
46
+ use_fsdp = True
47
+ # use_fsdp = False
48
+
49
+ use_compile = True
50
+ # use_compile = False
51
+
52
+ ## DDP config
53
+ if use_fsdp is False:
54
+ accelerator_cfg = dict(
55
+ type="DDP",
56
+ log_with="tensorboard",
57
+ # find_unused_parameters=True,
58
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
59
+ max_interval=num_iters,
60
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
61
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
62
+ )
63
+
64
+ else:
65
+ accelerator_cfg = dict(
66
+ type="FSDP",
67
+ log_with="tensorboard",
68
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
69
+ max_interval=num_iters,
70
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
71
+ step_scheduler_with_optimizer=False,
72
+ fsdp_cfg=dict(
73
+ fsdp_version=2, # DTensor-based engine
74
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
75
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
76
+ mixed_precision=dict(
77
+ param_dtype="bf16",
78
+ reduce_dtype="bf16",
79
+ ),
80
+ cpu_ram_efficient_loading=False,
81
+ ),
82
+ )
83
+
84
+ if use_compile:
85
+ accelerator_cfg["compile_cfg"] = dict(
86
+ backend="inductor",
87
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
88
+ fullgraph=False,
89
+ dynamic=False,
90
+ )
91
+
92
+ # ------------------------------------------------------------------
93
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
94
+ logger = dict(
95
+ type="Logger",
96
+ log_interval=log_every_iters,
97
+ )
98
+ checkpoint = dict(
99
+ type="Checkpointer",
100
+ save_interval=save_every_iters,
101
+ )
102
+
103
+ visualizer = dict(
104
+ type="NormalVisualizer",
105
+ vis_interval=vis_every_iters,
106
+ vis_max_samples=8,
107
+ vis_image_width=384,
108
+ vis_image_height=512,
109
+ )
110
+
111
+
112
+ ##-----------------------------------------------------------------
113
+ train_pipeline = [
114
+ dict(type="PhotoMetricDistortion"),
115
+ dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2),
116
+ dict(
117
+ type="NormalRandomScale",
118
+ scale_min=0.5,
119
+ scale_max=2.0,
120
+ prob=0.3,
121
+ ),
122
+ dict(
123
+ type="NormalRandomCropContinuous",
124
+ ar_range=(0.5, 2.0),
125
+ area_range=(0.4, 1.0),
126
+ num_attempts=8,
127
+ prob=0.3,
128
+ ),
129
+ dict(
130
+ type="NormalRandomFlip",
131
+ prob=0.3,
132
+ ),
133
+ dict(type="NormalResize", height=1024, width=768),
134
+ dict(
135
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
136
+ ),
137
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
138
+ dict(type="RandomSolarize", prob=0.3, threshold=128),
139
+ dict(type="NormalGenerateTarget"),
140
+ dict(
141
+ type="NormalPackInputs",
142
+ meta_keys=(
143
+ "img_path",
144
+ "ori_shape",
145
+ ),
146
+ ),
147
+ ]
148
+
149
+ val_pipeline = [
150
+ dict(type="NormalResize", height=1024, width=768, test_mode=True),
151
+ dict(
152
+ type="NormalPackInputs",
153
+ test_mode=True,
154
+ meta_keys=(
155
+ "img_path",
156
+ "orig_img_height",
157
+ "orig_img_width",
158
+ "img_shape",
159
+ "pad_shape",
160
+ ),
161
+ ),
162
+ ]
163
+
164
+ test_pipeline = [
165
+ dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0),
166
+ dict(
167
+ type="NormalPackInputs",
168
+ meta_keys=(
169
+ "img_path",
170
+ "orig_img_height",
171
+ "orig_img_width",
172
+ "padding_size",
173
+ ),
174
+ ),
175
+ ]
176
+
177
+ metasim_dataset = dict(
178
+ type="NormalMetaSimDataset",
179
+ airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data",
180
+ json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json",
181
+ )
182
+
183
+ render_people_dataset = dict(
184
+ type="NormalRenderPeopleBodyDataset", ## body only
185
+ data_root=f"{_DATA_ROOT}/synthetic",
186
+ seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg",
187
+ )
188
+
189
+ multihuman_render_people_dataset = dict(
190
+ type="NormalRenderPeopleMultihumanDataset",
191
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human",
192
+ normal_extension=".npz",
193
+ seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman
194
+ )
195
+
196
+ # train_datasets = 2 * [metasim_dataset] + [
197
+ # render_people_dataset,
198
+ # multihuman_render_people_dataset,
199
+ # ]
200
+
201
+ # train_datasets = [render_people_dataset]
202
+ # train_datasets = [multihuman_render_people_dataset]
203
+ train_datasets = [metasim_dataset]
204
+
205
+ train_dataloader = dict(
206
+ batch_size=1,
207
+ num_workers=4,
208
+ persistent_workers=True,
209
+ shuffle=True,
210
+ dataset=dict(
211
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
212
+ ),
213
+ )
214
+
215
+ val_dataloader = dict(
216
+ batch_size=4,
217
+ num_workers=4,
218
+ persistent_workers=True,
219
+ multiprocessing_context="spawn",
220
+ # num_workers=0, # debug
221
+ # persistent_workers=False, # debug
222
+ shuffle=False,
223
+ dataset=dict(
224
+ type="NormalRenderPeopleBodyDataset", ## body only
225
+ # num_samples=100, ## debug: only use N samples for validation
226
+ test_mode=True,
227
+ data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation",
228
+ pipeline=val_pipeline,
229
+ ),
230
+ )
231
+
232
+ val_cfg = dict(
233
+ val_interval=val_every_iters,
234
+ evaluator=dict(
235
+ type="NormalEvaluator",
236
+ ),
237
+ )
238
+
239
+ data_preprocessor = dict(
240
+ type="ImagePreprocessor",
241
+ mean=[123.675, 116.28, 103.53],
242
+ std=[58.395, 57.12, 57.375],
243
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
244
+ )
245
+
246
+ ##-----------------------------------------------------------------
247
+ model = dict(
248
+ type="NormalEstimator",
249
+ backbone=dict(
250
+ type="Sapiens2",
251
+ arch=model_name,
252
+ img_size=image_size,
253
+ patch_size=patch_size,
254
+ final_norm=True,
255
+ use_tokenizer=False,
256
+ with_cls_token=True,
257
+ out_type="featmap",
258
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
259
+ ),
260
+ decode_head=dict(
261
+ type="NormalHead",
262
+ in_channels=embed_dim,
263
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
264
+ conv_out_channels=[64, 32, 16],
265
+ conv_kernel_sizes=[3, 3, 3],
266
+ loss_decode=[
267
+ dict(
268
+ type="NormalCosineSimilarityLoss",
269
+ loss_weight=10.0,
270
+ ),
271
+ dict(type="L1Loss", loss_weight=1.0),
272
+ dict(type="NormalGradL1Loss", loss_weight=10.0),
273
+ ],
274
+ ),
275
+ )
276
+
277
+
278
+ ##-----------------------------------------------------------------
279
+ optimizer = dict(
280
+ type="AdamW",
281
+ lr=5e-4,
282
+ betas=(0.9, 0.999),
283
+ weight_decay=0.1,
284
+ paramwise_cfg=dict(
285
+ num_layers=num_layers,
286
+ layer_decay_rate=layer_decay_rate,
287
+ ),
288
+ fused=True,
289
+ )
290
+
291
+ scheduler = dict(
292
+ type="SequentialLR",
293
+ milestones=[warmup_iters],
294
+ schedulers=[
295
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
296
+ dict(
297
+ type="PolynomialLR",
298
+ total_iters=num_iters - warmup_iters,
299
+ power=1.0,
300
+ ),
301
+ ],
302
+ )
303
+
304
+ clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0)
sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.8b_normal_metasim_render_people-1024x768.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 1e4
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 2
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_0.8b"
34
+ embed_dim = 1280
35
+ num_layers = 32
36
+ num_heads = 16
37
+
38
+ layer_decay_rate = 0.85
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors"
40
+
41
+ ##-----------------------------------------------------------------
42
+ image_size = (1024, 768) ## height x width
43
+ patch_size = 16
44
+
45
+ # ------------------------------------------------------------------
46
+ use_fsdp = True
47
+ # use_fsdp = False
48
+
49
+ use_compile = True
50
+ # use_compile = False
51
+
52
+ ## DDP config
53
+ if use_fsdp is False:
54
+ accelerator_cfg = dict(
55
+ type="DDP",
56
+ log_with="tensorboard",
57
+ # find_unused_parameters=True,
58
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
59
+ max_interval=num_iters,
60
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
61
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
62
+ )
63
+
64
+ else:
65
+ accelerator_cfg = dict(
66
+ type="FSDP",
67
+ log_with="tensorboard",
68
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
69
+ max_interval=num_iters,
70
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
71
+ step_scheduler_with_optimizer=False,
72
+ fsdp_cfg=dict(
73
+ fsdp_version=2, # DTensor-based engine
74
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
75
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
76
+ mixed_precision=dict(
77
+ param_dtype="bf16",
78
+ reduce_dtype="bf16",
79
+ ),
80
+ cpu_ram_efficient_loading=False,
81
+ ),
82
+ )
83
+
84
+ if use_compile:
85
+ accelerator_cfg["compile_cfg"] = dict(
86
+ backend="inductor",
87
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
88
+ fullgraph=False,
89
+ dynamic=False,
90
+ )
91
+
92
+ # ------------------------------------------------------------------
93
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
94
+ logger = dict(
95
+ type="Logger",
96
+ log_interval=log_every_iters,
97
+ )
98
+ checkpoint = dict(
99
+ type="Checkpointer",
100
+ save_interval=save_every_iters,
101
+ )
102
+
103
+ visualizer = dict(
104
+ type="NormalVisualizer",
105
+ vis_interval=vis_every_iters,
106
+ vis_max_samples=8,
107
+ vis_image_width=384,
108
+ vis_image_height=512,
109
+ )
110
+
111
+
112
+ ##-----------------------------------------------------------------
113
+ train_pipeline = [
114
+ dict(type="PhotoMetricDistortion"),
115
+ dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2),
116
+ dict(
117
+ type="NormalRandomScale",
118
+ scale_min=0.5,
119
+ scale_max=2.0,
120
+ prob=0.3,
121
+ ),
122
+ dict(
123
+ type="NormalRandomCropContinuous",
124
+ ar_range=(0.5, 2.0),
125
+ area_range=(0.4, 1.0),
126
+ num_attempts=8,
127
+ prob=0.3,
128
+ ),
129
+ dict(
130
+ type="NormalRandomFlip",
131
+ prob=0.3,
132
+ ),
133
+ dict(type="NormalResize", height=1024, width=768),
134
+ dict(
135
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
136
+ ),
137
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
138
+ dict(type="RandomSolarize", prob=0.3, threshold=128),
139
+ dict(type="NormalGenerateTarget"),
140
+ dict(
141
+ type="NormalPackInputs",
142
+ meta_keys=(
143
+ "img_path",
144
+ "ori_shape",
145
+ ),
146
+ ),
147
+ ]
148
+
149
+ val_pipeline = [
150
+ dict(type="NormalResize", height=1024, width=768, test_mode=True),
151
+ dict(
152
+ type="NormalPackInputs",
153
+ test_mode=True,
154
+ meta_keys=(
155
+ "img_path",
156
+ "orig_img_height",
157
+ "orig_img_width",
158
+ "img_shape",
159
+ "pad_shape",
160
+ ),
161
+ ),
162
+ ]
163
+
164
+ test_pipeline = [
165
+ dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0),
166
+ dict(
167
+ type="NormalPackInputs",
168
+ meta_keys=(
169
+ "img_path",
170
+ "orig_img_height",
171
+ "orig_img_width",
172
+ "padding_size",
173
+ ),
174
+ ),
175
+ ]
176
+
177
+ metasim_dataset = dict(
178
+ type="NormalMetaSimDataset",
179
+ airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data",
180
+ json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json",
181
+ )
182
+
183
+ render_people_dataset = dict(
184
+ type="NormalRenderPeopleBodyDataset", ## body only
185
+ data_root=f"{_DATA_ROOT}/synthetic",
186
+ seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg",
187
+ )
188
+
189
+ multihuman_render_people_dataset = dict(
190
+ type="NormalRenderPeopleMultihumanDataset",
191
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human",
192
+ normal_extension=".npz",
193
+ seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman
194
+ )
195
+
196
+ # train_datasets = 2 * [metasim_dataset] + [
197
+ # render_people_dataset,
198
+ # multihuman_render_people_dataset,
199
+ # ]
200
+
201
+ # train_datasets = [render_people_dataset]
202
+ # train_datasets = [multihuman_render_people_dataset]
203
+ train_datasets = [metasim_dataset]
204
+
205
+ train_dataloader = dict(
206
+ batch_size=1,
207
+ num_workers=4,
208
+ persistent_workers=True,
209
+ shuffle=True,
210
+ dataset=dict(
211
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
212
+ ),
213
+ )
214
+
215
+ val_dataloader = dict(
216
+ batch_size=4,
217
+ num_workers=4,
218
+ persistent_workers=True,
219
+ multiprocessing_context="spawn",
220
+ # num_workers=0, # debug
221
+ # persistent_workers=False, # debug
222
+ shuffle=False,
223
+ dataset=dict(
224
+ type="NormalRenderPeopleBodyDataset", ## body only
225
+ # num_samples=100, ## debug: only use N samples for validation
226
+ test_mode=True,
227
+ data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation",
228
+ pipeline=val_pipeline,
229
+ ),
230
+ )
231
+
232
+ val_cfg = dict(
233
+ val_interval=val_every_iters,
234
+ evaluator=dict(
235
+ type="NormalEvaluator",
236
+ ),
237
+ )
238
+
239
+ data_preprocessor = dict(
240
+ type="ImagePreprocessor",
241
+ mean=[123.675, 116.28, 103.53],
242
+ std=[58.395, 57.12, 57.375],
243
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
244
+ )
245
+
246
+ ##-----------------------------------------------------------------
247
+ model = dict(
248
+ type="NormalEstimator",
249
+ backbone=dict(
250
+ type="Sapiens2",
251
+ arch=model_name,
252
+ img_size=image_size,
253
+ patch_size=patch_size,
254
+ final_norm=True,
255
+ use_tokenizer=False,
256
+ with_cls_token=True,
257
+ out_type="featmap",
258
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
259
+ ),
260
+ decode_head=dict(
261
+ type="NormalHead",
262
+ in_channels=embed_dim,
263
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
264
+ conv_out_channels=[64, 32, 16],
265
+ conv_kernel_sizes=[3, 3, 3],
266
+ loss_decode=[
267
+ dict(
268
+ type="NormalCosineSimilarityLoss",
269
+ loss_weight=10.0,
270
+ ),
271
+ dict(type="L1Loss", loss_weight=1.0),
272
+ dict(type="NormalGradL1Loss", loss_weight=10.0),
273
+ ],
274
+ ),
275
+ )
276
+
277
+
278
+ ##-----------------------------------------------------------------
279
+ optimizer = dict(
280
+ type="AdamW",
281
+ lr=5e-4,
282
+ betas=(0.9, 0.999),
283
+ weight_decay=0.1,
284
+ paramwise_cfg=dict(
285
+ num_layers=num_layers,
286
+ layer_decay_rate=layer_decay_rate,
287
+ ),
288
+ fused=True,
289
+ )
290
+
291
+ scheduler = dict(
292
+ type="SequentialLR",
293
+ milestones=[warmup_iters],
294
+ schedulers=[
295
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
296
+ dict(
297
+ type="PolynomialLR",
298
+ total_iters=num_iters - warmup_iters,
299
+ power=1.0,
300
+ ),
301
+ ],
302
+ )
303
+
304
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/normal/metasim_render_people/sapiens2_1b_normal_metasim_render_people-1024x768.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+ # num_iters = 1e4 ## light finetune
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_1b"
35
+ embed_dim = 1536
36
+ num_layers = 40
37
+ num_heads = 24
38
+ layer_decay_rate = 0.9
39
+
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ # ------------------------------------------------------------------
47
+ use_fsdp = True
48
+ # use_fsdp = False
49
+
50
+ use_compile = True
51
+ # use_compile = False
52
+
53
+ ## DDP config
54
+ if use_fsdp is False:
55
+ accelerator_cfg = dict(
56
+ type="DDP",
57
+ log_with="tensorboard",
58
+ # find_unused_parameters=True,
59
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
60
+ max_interval=num_iters,
61
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
62
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
63
+ )
64
+
65
+ else:
66
+ accelerator_cfg = dict(
67
+ type="FSDP",
68
+ log_with="tensorboard",
69
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
70
+ max_interval=num_iters,
71
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
72
+ step_scheduler_with_optimizer=False,
73
+ fsdp_cfg=dict(
74
+ fsdp_version=2, # DTensor-based engine
75
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
76
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
77
+ mixed_precision=dict(
78
+ param_dtype="bf16",
79
+ reduce_dtype="bf16",
80
+ ),
81
+ cpu_ram_efficient_loading=False,
82
+ ),
83
+ )
84
+
85
+ if use_compile:
86
+ accelerator_cfg["compile_cfg"] = dict(
87
+ backend="inductor",
88
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
89
+ fullgraph=False,
90
+ dynamic=False,
91
+ )
92
+
93
+ # ------------------------------------------------------------------
94
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
95
+ logger = dict(
96
+ type="Logger",
97
+ log_interval=log_every_iters,
98
+ )
99
+ checkpoint = dict(
100
+ type="Checkpointer",
101
+ save_interval=save_every_iters,
102
+ )
103
+
104
+ visualizer = dict(
105
+ type="NormalVisualizer",
106
+ vis_interval=vis_every_iters,
107
+ vis_max_samples=4,
108
+ vis_image_width=384,
109
+ vis_image_height=512,
110
+ )
111
+
112
+
113
+ ##-----------------------------------------------------------------
114
+ train_pipeline = [
115
+ dict(type="PhotoMetricDistortion"),
116
+ dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2),
117
+ dict(
118
+ type="NormalRandomScale",
119
+ scale_min=0.5,
120
+ scale_max=2.0,
121
+ prob=0.3,
122
+ ),
123
+ dict(
124
+ type="NormalRandomCropContinuous",
125
+ ar_range=(0.5, 2.0),
126
+ area_range=(0.4, 1.0),
127
+ num_attempts=8,
128
+ prob=0.3,
129
+ ),
130
+ dict(
131
+ type="NormalRandomFlip",
132
+ prob=0.3,
133
+ ),
134
+ dict(type="NormalResize", height=1024, width=768),
135
+ dict(
136
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
137
+ ),
138
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
139
+ dict(type="RandomSolarize", prob=0.3, threshold=128),
140
+ dict(type="NormalGenerateTarget"),
141
+ dict(
142
+ type="NormalPackInputs",
143
+ meta_keys=(
144
+ "img_path",
145
+ "ori_shape",
146
+ ),
147
+ ),
148
+ ]
149
+
150
+ val_pipeline = [
151
+ dict(type="NormalResize", height=1024, width=768, test_mode=True),
152
+ dict(
153
+ type="NormalPackInputs",
154
+ test_mode=True,
155
+ meta_keys=(
156
+ "img_path",
157
+ "orig_img_height",
158
+ "orig_img_width",
159
+ "img_shape",
160
+ "pad_shape",
161
+ ),
162
+ ),
163
+ ]
164
+
165
+ test_pipeline = [
166
+ dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0),
167
+ dict(
168
+ type="NormalPackInputs",
169
+ meta_keys=(
170
+ "img_path",
171
+ "orig_img_height",
172
+ "orig_img_width",
173
+ "padding_size",
174
+ ),
175
+ ),
176
+ ]
177
+
178
+ metasim_dataset = dict(
179
+ type="NormalMetaSimDataset",
180
+ airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data",
181
+ json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json",
182
+ )
183
+
184
+ render_people_dataset = dict(
185
+ type="NormalRenderPeopleBodyDataset", ## body only
186
+ data_root=f"{_DATA_ROOT}/synthetic",
187
+ seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg",
188
+ )
189
+
190
+ multihuman_render_people_dataset = dict(
191
+ type="NormalRenderPeopleMultihumanDataset",
192
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human",
193
+ normal_extension=".npz",
194
+ seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman
195
+ )
196
+
197
+ # train_datasets = 2 * [metasim_dataset] + [
198
+ # render_people_dataset,
199
+ # multihuman_render_people_dataset,
200
+ # ]
201
+
202
+ # train_datasets = [render_people_dataset]
203
+ # train_datasets = [multihuman_render_people_dataset]
204
+ train_datasets = [metasim_dataset]
205
+
206
+ train_dataloader = dict(
207
+ batch_size=1,
208
+ num_workers=4,
209
+ persistent_workers=True,
210
+ shuffle=True,
211
+ dataset=dict(
212
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
213
+ ),
214
+ )
215
+
216
+ val_dataloader = dict(
217
+ batch_size=4,
218
+ num_workers=4,
219
+ persistent_workers=True,
220
+ multiprocessing_context="spawn",
221
+ # num_workers=0, # debug
222
+ # persistent_workers=False, # debug
223
+ shuffle=False,
224
+ dataset=dict(
225
+ type="NormalRenderPeopleBodyDataset", ## body only
226
+ # num_samples=100, ## debug: only use N samples for validation
227
+ test_mode=True,
228
+ data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation",
229
+ pipeline=val_pipeline,
230
+ ),
231
+ )
232
+
233
+ val_cfg = dict(
234
+ val_interval=val_every_iters,
235
+ evaluator=dict(
236
+ type="NormalEvaluator",
237
+ ),
238
+ )
239
+
240
+ data_preprocessor = dict(
241
+ type="ImagePreprocessor",
242
+ mean=[123.675, 116.28, 103.53],
243
+ std=[58.395, 57.12, 57.375],
244
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
245
+ )
246
+
247
+ ##-----------------------------------------------------------------
248
+ model = dict(
249
+ type="NormalEstimator",
250
+ backbone=dict(
251
+ type="Sapiens2",
252
+ arch=model_name,
253
+ img_size=image_size,
254
+ patch_size=patch_size,
255
+ final_norm=True,
256
+ use_tokenizer=False,
257
+ # with_cls_token=False,
258
+ with_cls_token=True,
259
+ out_type="featmap",
260
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
261
+ ),
262
+ decode_head=dict(
263
+ type="NormalHead",
264
+ in_channels=embed_dim,
265
+ upsample_channels=[768, 512, 256, 128], ## 1K resolution
266
+ conv_out_channels=[64, 32, 16],
267
+ conv_kernel_sizes=[3, 3, 3],
268
+ loss_decode=[
269
+ dict(
270
+ type="NormalCosineSimilarityLoss",
271
+ loss_weight=10.0,
272
+ ),
273
+ dict(type="L1Loss", loss_weight=1.0),
274
+ dict(type="NormalGradL1Loss", loss_weight=10.0),
275
+ ],
276
+ ),
277
+ )
278
+
279
+
280
+ ##-----------------------------------------------------------------
281
+ optimizer = dict(
282
+ type="AdamW",
283
+ lr=5e-4,
284
+ betas=(0.9, 0.999),
285
+ weight_decay=0.1,
286
+ paramwise_cfg=dict(
287
+ num_layers=num_layers,
288
+ layer_decay_rate=layer_decay_rate,
289
+ ),
290
+ fused=True,
291
+ )
292
+
293
+ scheduler = dict(
294
+ type="SequentialLR",
295
+ milestones=[warmup_iters],
296
+ schedulers=[
297
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
298
+ dict(
299
+ type="PolynomialLR",
300
+ total_iters=num_iters - warmup_iters,
301
+ power=1.0,
302
+ ),
303
+ ],
304
+ )
305
+
306
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/normal/metasim_render_people/sapiens2_5b_normal_metasim_render_people-1024x768.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+ # num_iters = 1e4 ## light finetune
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_5b"
35
+ embed_dim = 2432
36
+ num_layers = 56
37
+ num_heads = 32
38
+ layer_decay_rate = 0.94
39
+
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ # ------------------------------------------------------------------
47
+ use_fsdp = True
48
+ # use_fsdp = False
49
+
50
+ use_compile = True
51
+ # use_compile = False
52
+
53
+ ## DDP config
54
+ if use_fsdp is False:
55
+ accelerator_cfg = dict(
56
+ type="DDP",
57
+ log_with="tensorboard",
58
+ # find_unused_parameters=True,
59
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
60
+ max_interval=num_iters,
61
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
62
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
63
+ )
64
+
65
+ else:
66
+ accelerator_cfg = dict(
67
+ type="FSDP",
68
+ log_with="tensorboard",
69
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
70
+ max_interval=num_iters,
71
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
72
+ step_scheduler_with_optimizer=False,
73
+ fsdp_cfg=dict(
74
+ fsdp_version=2, # DTensor-based engine
75
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
76
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
77
+ mixed_precision=dict(
78
+ param_dtype="bf16",
79
+ reduce_dtype="bf16",
80
+ ),
81
+ cpu_ram_efficient_loading=False,
82
+ ),
83
+ # parallelism_cfg=dict(
84
+ # dp_shard_size=2, # Fully Sharded Data Parallel degree
85
+ # dp_replicate_size=1, # Data Parallel degree
86
+ # tp_size=1, # Tensor Parallel degree
87
+ # cp_size=4, # Context Parallel degree
88
+ # ),
89
+ )
90
+
91
+ if use_compile:
92
+ accelerator_cfg["compile_cfg"] = dict(
93
+ backend="inductor",
94
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
95
+ fullgraph=False,
96
+ dynamic=False,
97
+ )
98
+
99
+ # ------------------------------------------------------------------
100
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
101
+ logger = dict(
102
+ type="Logger",
103
+ log_interval=log_every_iters,
104
+ )
105
+ checkpoint = dict(
106
+ type="Checkpointer",
107
+ save_interval=save_every_iters,
108
+ )
109
+
110
+ visualizer = dict(
111
+ type="NormalVisualizer",
112
+ vis_interval=vis_every_iters,
113
+ vis_max_samples=4,
114
+ vis_image_width=384,
115
+ vis_image_height=512,
116
+ )
117
+
118
+
119
+ ##-----------------------------------------------------------------
120
+ train_pipeline = [
121
+ dict(type="PhotoMetricDistortion"),
122
+ dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2),
123
+ dict(
124
+ type="NormalRandomScale",
125
+ scale_min=0.5,
126
+ scale_max=2.0,
127
+ prob=0.3,
128
+ ),
129
+ dict(
130
+ type="NormalRandomCropContinuous",
131
+ ar_range=(0.5, 2.0),
132
+ area_range=(0.4, 1.0),
133
+ num_attempts=8,
134
+ prob=0.3,
135
+ ),
136
+ dict(
137
+ type="NormalRandomFlip",
138
+ prob=0.3,
139
+ ),
140
+ dict(type="NormalResize", height=1024, width=768),
141
+ dict(
142
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
143
+ ),
144
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
145
+ dict(type="RandomSolarize", prob=0.3, threshold=128),
146
+ dict(type="NormalGenerateTarget"),
147
+ dict(
148
+ type="NormalPackInputs",
149
+ meta_keys=(
150
+ "img_path",
151
+ "ori_shape",
152
+ ),
153
+ ),
154
+ ]
155
+
156
+ val_pipeline = [
157
+ dict(type="NormalResize", height=1024, width=768, test_mode=True),
158
+ dict(
159
+ type="NormalPackInputs",
160
+ test_mode=True,
161
+ meta_keys=(
162
+ "img_path",
163
+ "orig_img_height",
164
+ "orig_img_width",
165
+ "img_shape",
166
+ "pad_shape",
167
+ ),
168
+ ),
169
+ ]
170
+
171
+ test_pipeline = [
172
+ dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0),
173
+ dict(
174
+ type="NormalPackInputs",
175
+ meta_keys=(
176
+ "img_path",
177
+ "orig_img_height",
178
+ "orig_img_width",
179
+ "img_shape",
180
+ ),
181
+ ),
182
+ ]
183
+
184
+ metasim_dataset = dict(
185
+ type="NormalMetaSimDataset",
186
+ airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data",
187
+ json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json",
188
+ )
189
+
190
+ render_people_dataset = dict(
191
+ type="NormalRenderPeopleBodyDataset", ## body only
192
+ data_root=f"{_DATA_ROOT}/synthetic",
193
+ seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg",
194
+ )
195
+
196
+ multihuman_render_people_dataset = dict(
197
+ type="NormalRenderPeopleMultihumanDataset",
198
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human",
199
+ normal_extension=".npz",
200
+ seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman
201
+ )
202
+
203
+ # train_datasets = 2 * [metasim_dataset] + [
204
+ # render_people_dataset,
205
+ # multihuman_render_people_dataset,
206
+ # ]
207
+
208
+ # train_datasets = [render_people_dataset]
209
+ # train_datasets = [multihuman_render_people_dataset]
210
+ train_datasets = [metasim_dataset]
211
+
212
+ train_dataloader = dict(
213
+ batch_size=1,
214
+ num_workers=4,
215
+ persistent_workers=True,
216
+ shuffle=True,
217
+ dataset=dict(
218
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
219
+ ),
220
+ )
221
+
222
+ val_dataloader = dict(
223
+ batch_size=4,
224
+ num_workers=4,
225
+ persistent_workers=True,
226
+ multiprocessing_context="spawn",
227
+ # num_workers=0, # debug
228
+ # persistent_workers=False, # debug
229
+ shuffle=False,
230
+ dataset=dict(
231
+ type="NormalRenderPeopleBodyDataset", ## body only
232
+ # num_samples=100, ## debug: only use N samples for validation
233
+ test_mode=True,
234
+ data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation",
235
+ pipeline=val_pipeline,
236
+ ),
237
+ )
238
+
239
+ val_cfg = dict(
240
+ val_interval=val_every_iters,
241
+ evaluator=dict(
242
+ type="NormalEvaluator",
243
+ ),
244
+ )
245
+
246
+ data_preprocessor = dict(
247
+ type="ImagePreprocessor",
248
+ mean=[123.675, 116.28, 103.53],
249
+ std=[58.395, 57.12, 57.375],
250
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
251
+ )
252
+
253
+ ##-----------------------------------------------------------------
254
+ model = dict(
255
+ type="NormalEstimator",
256
+ backbone=dict(
257
+ type="Sapiens2",
258
+ arch=model_name,
259
+ img_size=image_size,
260
+ patch_size=patch_size,
261
+ final_norm=True,
262
+ use_tokenizer=False,
263
+ with_cls_token=True,
264
+ out_type="featmap",
265
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
266
+ ),
267
+ decode_head=dict(
268
+ type="NormalHead",
269
+ in_channels=embed_dim,
270
+ upsample_channels=[1536, 768, 512, 256], ## 1K resolution
271
+ conv_out_channels=[128, 64, 32],
272
+ conv_kernel_sizes=[3, 3, 3],
273
+ loss_decode=[
274
+ dict(
275
+ type="NormalCosineSimilarityLoss",
276
+ loss_weight=10.0,
277
+ ),
278
+ dict(type="L1Loss", loss_weight=1.0),
279
+ dict(type="NormalGradL1Loss", loss_weight=10.0),
280
+ ],
281
+ ),
282
+ )
283
+
284
+
285
+ ##-----------------------------------------------------------------
286
+ optimizer = dict(
287
+ type="AdamW",
288
+ # lr=5e-4,
289
+ lr=1e-4,
290
+ betas=(0.9, 0.999),
291
+ weight_decay=0.1,
292
+ paramwise_cfg=dict(
293
+ num_layers=num_layers,
294
+ layer_decay_rate=layer_decay_rate,
295
+ ),
296
+ fused=True,
297
+ )
298
+
299
+ scheduler = dict(
300
+ type="SequentialLR",
301
+ milestones=[warmup_iters],
302
+ schedulers=[
303
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
304
+ dict(
305
+ type="PolynomialLR",
306
+ total_iters=num_iters - warmup_iters,
307
+ power=1.0,
308
+ ),
309
+ ],
310
+ )
311
+
312
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/pointmap/render_people/sapiens2_0.4b_pointmap_render_people-1024x768.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 2e4 ## 16 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 2
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_0.4b"
34
+ embed_dim = 1024
35
+ num_layers = 24
36
+ num_heads = 16
37
+
38
+ layer_decay_rate = 0.8
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors"
40
+
41
+ ##-----------------------------------------------------------------
42
+ image_size = (1024, 768) ## height x width
43
+
44
+ patch_size = 16
45
+ num_tokens = (image_size[0] // patch_size) * (image_size[1] // patch_size)
46
+ canonical_focal_length = 768.0
47
+
48
+ # ------------------------------------------------------------------
49
+ use_fsdp = True
50
+ # use_fsdp = False
51
+
52
+ use_compile = True
53
+ # use_compile = False
54
+
55
+ ## DDP config
56
+ if use_fsdp is False:
57
+ accelerator_cfg = dict(
58
+ type="DDP",
59
+ log_with="tensorboard",
60
+ # find_unused_parameters=True,
61
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
62
+ max_interval=num_iters,
63
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
64
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
65
+ )
66
+
67
+ else:
68
+ accelerator_cfg = dict(
69
+ type="FSDP",
70
+ log_with="tensorboard",
71
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
72
+ max_interval=num_iters,
73
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
74
+ step_scheduler_with_optimizer=False,
75
+ fsdp_cfg=dict(
76
+ fsdp_version=2, # DTensor-based engine
77
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
78
+ cpu_ram_efficient_loading=False,
79
+ ),
80
+ )
81
+
82
+ ## Note: to merge sharded weight using FSDP
83
+ # accelerate merge-weights pytorch_model_fsdp_0/ .
84
+
85
+ if use_compile:
86
+ accelerator_cfg["compile_cfg"] = dict(
87
+ backend="inductor",
88
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
89
+ fullgraph=False,
90
+ dynamic=False,
91
+ )
92
+
93
+ # ------------------------------------------------------------------
94
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
95
+ logger = dict(
96
+ type="Logger",
97
+ log_interval=log_every_iters,
98
+ )
99
+ checkpoint = dict(
100
+ type="Checkpointer",
101
+ save_interval=save_every_iters,
102
+ )
103
+
104
+ visualizer = dict(
105
+ type="PointmapVisualizer",
106
+ vis_interval=vis_every_iters,
107
+ vis_max_samples=4,
108
+ vis_image_width=384,
109
+ vis_image_height=512,
110
+ )
111
+
112
+
113
+ ##-----------------------------------------------------------------
114
+ train_pipeline = [
115
+ dict(type="PhotoMetricDistortion"),
116
+ dict(
117
+ type="PointmapRandomScale",
118
+ scale_min=0.5,
119
+ scale_max=2.0,
120
+ prob=0.3,
121
+ ),
122
+ dict(
123
+ type="PointmapRandomCropContinuous",
124
+ ar_range=(0.5, 2.0),
125
+ area_range=(0.4, 1.0),
126
+ num_attempts=8,
127
+ prob=0.3,
128
+ ),
129
+ dict(
130
+ type="PointmapRandomFlip",
131
+ prob=0.3,
132
+ ),
133
+ dict(type="PointmapResize", height=1024, width=768),
134
+ ## target is same res as output, otherwise we get artifacts.
135
+ dict(
136
+ type="PointmapGenerateTarget",
137
+ canonical_focal_length=canonical_focal_length,
138
+ target_downsample_factor=1,
139
+ ),
140
+ dict(
141
+ type="PointmapPackInputs",
142
+ meta_keys=(
143
+ "img_path",
144
+ "ori_shape",
145
+ "img_shape",
146
+ "pad_shape",
147
+ "scale",
148
+ "flip",
149
+ "flip_direction",
150
+ "original_K",
151
+ "K",
152
+ "M",
153
+ ),
154
+ ),
155
+ ]
156
+
157
+ val_pipeline = [
158
+ dict(type="PointmapResize", height=1024, width=768),
159
+ dict(type="PointmapGenerateTarget", canonical_focal_length=canonical_focal_length),
160
+ dict(
161
+ type="PointmapPackInputs",
162
+ meta_keys=(
163
+ "img_path",
164
+ "orig_img_height",
165
+ "orig_img_width",
166
+ "img_shape",
167
+ "pad_shape",
168
+ "scale",
169
+ "padding_size",
170
+ "K",
171
+ "M",
172
+ ),
173
+ ),
174
+ ]
175
+
176
+ test_pipeline = [
177
+ dict(type="PointmapResizePadImage", height=1024, width=768, pad_val=0),
178
+ dict(
179
+ type="PointmapPackInputs",
180
+ meta_keys=(
181
+ "img_path",
182
+ "orig_img_height",
183
+ "orig_img_width",
184
+ "img_shape",
185
+ "pad_shape",
186
+ "scale",
187
+ "padding_size",
188
+ "K",
189
+ "M",
190
+ ),
191
+ ),
192
+ ]
193
+
194
+ render_people_dataset = dict(
195
+ type="PointmapRenderPeopleDataset",
196
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2",
197
+ )
198
+
199
+ train_datasets = [render_people_dataset]
200
+
201
+ train_dataloader = dict(
202
+ batch_size=1,
203
+ num_workers=4,
204
+ persistent_workers=True,
205
+ shuffle=True,
206
+ dataset=dict(
207
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
208
+ ),
209
+ )
210
+
211
+ val_dataloader = dict(
212
+ batch_size=4,
213
+ num_workers=4,
214
+ persistent_workers=True,
215
+ multiprocessing_context="spawn",
216
+ # num_workers=0, # debug
217
+ # persistent_workers=False, # debug
218
+ shuffle=False,
219
+ dataset=dict(
220
+ type="PointmapRenderPeopleDataset",
221
+ # num_samples=100, ## debug: only use N samples for validation
222
+ test_mode=True,
223
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2_test",
224
+ pipeline=val_pipeline,
225
+ ),
226
+ )
227
+
228
+ val_cfg = dict(
229
+ val_interval=val_every_iters,
230
+ evaluator=dict(
231
+ type="PointmapEvaluator",
232
+ ),
233
+ )
234
+
235
+ data_preprocessor = dict(
236
+ type="ImagePreprocessor",
237
+ mean=[123.675, 116.28, 103.53],
238
+ std=[58.395, 57.12, 57.375],
239
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
240
+ )
241
+
242
+ ##-----------------------------------------------------------------
243
+ model = dict(
244
+ type="PointmapEstimator",
245
+ canonical_focal_length=canonical_focal_length,
246
+ backbone=dict(
247
+ type="Sapiens2",
248
+ arch=model_name,
249
+ img_size=image_size,
250
+ patch_size=patch_size,
251
+ final_norm=True,
252
+ use_tokenizer=False,
253
+ with_cls_token=True,
254
+ out_type="featmap",
255
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
256
+ ),
257
+ decode_head=dict(
258
+ type="PointmapHead",
259
+ in_channels=embed_dim,
260
+ upsample_channels=[1536, 768, 512, 256],
261
+ conv_out_channels=[64, 32, 16],
262
+ conv_kernel_sizes=[3, 3, 3],
263
+ scale_conv_out_channels=(1536, 512, 128),
264
+ scale_conv_kernel_sizes=(1, 1, 1),
265
+ scale_final_layer=(
266
+ (num_tokens // ((2 * 2 * 2) * (2 * 2 * 2))) * 128,
267
+ 512,
268
+ 128,
269
+ 1,
270
+ ), ## scale regress
271
+ loss_decode=[
272
+ dict(type="L1Loss", loss_weight=2.0), ## on pointmap, XYZ
273
+ dict(
274
+ type="MultiscaleL1Loss",
275
+ loss_weight=1.0,
276
+ scale_factor=2,
277
+ ),
278
+ dict(type="SiLogLoss", loss_weight=1.0), ## only applies silog loss
279
+ dict(
280
+ type="PointmapIntrinsicsConsistencyLoss",
281
+ loss_weight=1.0,
282
+ ),
283
+ dict(
284
+ type="PointmapShiftInvariantL1Loss",
285
+ loss_weight=1.0,
286
+ ),
287
+ dict(type="PointmapNormalLoss", loss_weight=2.0),
288
+ dict(
289
+ type="PointmapScaleL1Loss", loss_weight=4.0
290
+ ), ## Canonical XYZ = scale * XYZ
291
+ ],
292
+ ),
293
+ )
294
+
295
+
296
+ ##-----------------------------------------------------------------
297
+ optimizer = dict(
298
+ type="AdamW",
299
+ lr=5e-4,
300
+ betas=(0.9, 0.999),
301
+ weight_decay=0.1,
302
+ paramwise_cfg=dict(
303
+ num_layers=num_layers,
304
+ layer_decay_rate=layer_decay_rate,
305
+ ),
306
+ fused=True,
307
+ )
308
+
309
+ scheduler = dict(
310
+ type="SequentialLR",
311
+ milestones=[warmup_iters],
312
+ schedulers=[
313
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
314
+ dict(
315
+ type="PolynomialLR",
316
+ total_iters=num_iters - warmup_iters,
317
+ power=1.0,
318
+ ),
319
+ ],
320
+ )
321
+
322
+ clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0)
sapiens/dense/configs/pointmap/render_people/sapiens2_0.8b_pointmap_render_people-1024x768.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 2e4 ## 16 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+ # num_iters = 1e4 ## light finetune
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_0.8b"
35
+ embed_dim = 1280
36
+ num_layers = 32
37
+ num_heads = 16
38
+
39
+ layer_decay_rate = 0.85
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+
45
+ patch_size = 16
46
+ num_tokens = (image_size[0] // patch_size) * (image_size[1] // patch_size)
47
+ canonical_focal_length = 768.0
48
+
49
+ # ------------------------------------------------------------------
50
+ use_fsdp = True
51
+ # use_fsdp = False
52
+
53
+ use_compile = True
54
+ # use_compile = False
55
+
56
+ ## DDP config
57
+ if use_fsdp is False:
58
+ accelerator_cfg = dict(
59
+ type="DDP",
60
+ log_with="tensorboard",
61
+ # find_unused_parameters=True,
62
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
63
+ max_interval=num_iters,
64
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
65
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
66
+ )
67
+
68
+ else:
69
+ accelerator_cfg = dict(
70
+ type="FSDP",
71
+ log_with="tensorboard",
72
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
73
+ max_interval=num_iters,
74
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
75
+ step_scheduler_with_optimizer=False,
76
+ fsdp_cfg=dict(
77
+ fsdp_version=2, # DTensor-based engine
78
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
79
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
80
+ # mixed_precision=dict(
81
+ # param_dtype="bf16",
82
+ # reduce_dtype="bf16",
83
+ # ),
84
+ cpu_ram_efficient_loading=False,
85
+ ),
86
+ )
87
+
88
+ if use_compile:
89
+ accelerator_cfg["compile_cfg"] = dict(
90
+ backend="inductor",
91
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
92
+ fullgraph=False,
93
+ dynamic=False,
94
+ )
95
+
96
+ # ------------------------------------------------------------------
97
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
98
+ logger = dict(
99
+ type="Logger",
100
+ log_interval=log_every_iters,
101
+ )
102
+ checkpoint = dict(
103
+ type="Checkpointer",
104
+ save_interval=save_every_iters,
105
+ )
106
+
107
+ visualizer = dict(
108
+ type="PointmapVisualizer",
109
+ vis_interval=vis_every_iters,
110
+ vis_max_samples=4,
111
+ vis_image_width=384,
112
+ vis_image_height=512,
113
+ )
114
+
115
+
116
+ ##-----------------------------------------------------------------
117
+ train_pipeline = [
118
+ dict(type="PhotoMetricDistortion"),
119
+ dict(
120
+ type="PointmapRandomScale",
121
+ scale_min=0.5,
122
+ scale_max=2.0,
123
+ prob=0.3,
124
+ ),
125
+ dict(
126
+ type="PointmapRandomCropContinuous",
127
+ ar_range=(0.5, 2.0),
128
+ area_range=(0.4, 1.0),
129
+ num_attempts=8,
130
+ prob=0.3,
131
+ ),
132
+ dict(
133
+ type="PointmapRandomFlip",
134
+ prob=0.3,
135
+ ),
136
+ dict(type="PointmapResize", height=1024, width=768),
137
+ ## target is same res as output, otherwise we get artifacts.
138
+ dict(
139
+ type="PointmapGenerateTarget",
140
+ canonical_focal_length=canonical_focal_length,
141
+ target_downsample_factor=1,
142
+ ),
143
+ dict(
144
+ type="PointmapPackInputs",
145
+ meta_keys=(
146
+ "img_path",
147
+ "ori_shape",
148
+ "img_shape",
149
+ "pad_shape",
150
+ "scale",
151
+ "flip",
152
+ "flip_direction",
153
+ "original_K",
154
+ "K",
155
+ "M",
156
+ ),
157
+ ),
158
+ ]
159
+
160
+ val_pipeline = [
161
+ dict(type="PointmapResize", height=1024, width=768),
162
+ dict(type="PointmapGenerateTarget", canonical_focal_length=canonical_focal_length),
163
+ dict(
164
+ type="PointmapPackInputs",
165
+ meta_keys=(
166
+ "img_path",
167
+ "orig_img_height",
168
+ "orig_img_width",
169
+ "img_shape",
170
+ "pad_shape",
171
+ "scale",
172
+ "padding_size",
173
+ "K",
174
+ "M",
175
+ ),
176
+ ),
177
+ ]
178
+
179
+ test_pipeline = [
180
+ dict(type="PointmapResizePadImage", height=1024, width=768, pad_val=0),
181
+ dict(
182
+ type="PointmapPackInputs",
183
+ meta_keys=(
184
+ "img_path",
185
+ "orig_img_height",
186
+ "orig_img_width",
187
+ "img_shape",
188
+ "pad_shape",
189
+ "scale",
190
+ "padding_size",
191
+ "K",
192
+ "M",
193
+ ),
194
+ ),
195
+ ]
196
+
197
+ render_people_dataset = dict(
198
+ type="PointmapRenderPeopleDataset",
199
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2",
200
+ )
201
+
202
+ train_datasets = [render_people_dataset]
203
+
204
+ train_dataloader = dict(
205
+ batch_size=1,
206
+ num_workers=4,
207
+ persistent_workers=True,
208
+ shuffle=True,
209
+ dataset=dict(
210
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
211
+ ),
212
+ )
213
+
214
+ val_dataloader = dict(
215
+ batch_size=4,
216
+ num_workers=4,
217
+ persistent_workers=True,
218
+ multiprocessing_context="spawn",
219
+ # num_workers=0, # debug
220
+ # persistent_workers=False, # debug
221
+ shuffle=False,
222
+ dataset=dict(
223
+ type="PointmapRenderPeopleDataset",
224
+ # num_samples=100, ## debug: only use N samples for validation
225
+ test_mode=True,
226
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2_test",
227
+ pipeline=val_pipeline,
228
+ ),
229
+ )
230
+
231
+ val_cfg = dict(
232
+ val_interval=val_every_iters,
233
+ evaluator=dict(
234
+ type="PointmapEvaluator",
235
+ ),
236
+ )
237
+
238
+ data_preprocessor = dict(
239
+ type="ImagePreprocessor",
240
+ mean=[123.675, 116.28, 103.53],
241
+ std=[58.395, 57.12, 57.375],
242
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
243
+ )
244
+
245
+ ##-----------------------------------------------------------------
246
+ model = dict(
247
+ type="PointmapEstimator",
248
+ canonical_focal_length=canonical_focal_length,
249
+ backbone=dict(
250
+ type="Sapiens2",
251
+ arch=model_name,
252
+ img_size=image_size,
253
+ patch_size=patch_size,
254
+ final_norm=True,
255
+ use_tokenizer=False,
256
+ with_cls_token=True,
257
+ out_type="featmap",
258
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
259
+ ),
260
+ decode_head=dict(
261
+ type="PointmapHead",
262
+ in_channels=embed_dim,
263
+ upsample_channels=[1536, 768, 512, 256],
264
+ conv_out_channels=[64, 32, 16],
265
+ conv_kernel_sizes=[3, 3, 3],
266
+ scale_conv_out_channels=(1536, 512, 128),
267
+ scale_conv_kernel_sizes=(1, 1, 1),
268
+ scale_final_layer=(
269
+ (num_tokens // ((2 * 2 * 2) * (2 * 2 * 2))) * 128,
270
+ 512,
271
+ 128,
272
+ 1,
273
+ ), ## scale regress
274
+ loss_decode=[
275
+ dict(type="L1Loss", loss_weight=2.0), ## on pointmap, XYZ
276
+ dict(
277
+ type="MultiscaleL1Loss",
278
+ loss_weight=1.0,
279
+ scale_factor=2,
280
+ ),
281
+ dict(type="SiLogLoss", loss_weight=1.0), ## only applies silog loss
282
+ dict(
283
+ type="PointmapIntrinsicsConsistencyLoss",
284
+ loss_weight=1.0,
285
+ ),
286
+ dict(
287
+ type="PointmapShiftInvariantL1Loss",
288
+ loss_weight=1.0,
289
+ ),
290
+ dict(type="PointmapNormalLoss", loss_weight=2.0),
291
+ dict(
292
+ type="PointmapScaleL1Loss", loss_weight=4.0
293
+ ), ## Canonical XYZ = scale * XYZ
294
+ ],
295
+ ),
296
+ )
297
+
298
+
299
+ ##-----------------------------------------------------------------
300
+ optimizer = dict(
301
+ type="AdamW",
302
+ lr=5e-4,
303
+ betas=(0.9, 0.999),
304
+ weight_decay=0.1,
305
+ paramwise_cfg=dict(
306
+ num_layers=num_layers,
307
+ layer_decay_rate=layer_decay_rate,
308
+ ),
309
+ fused=True,
310
+ )
311
+
312
+ scheduler = dict(
313
+ type="SequentialLR",
314
+ milestones=[warmup_iters],
315
+ schedulers=[
316
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
317
+ dict(
318
+ type="PolynomialLR",
319
+ total_iters=num_iters - warmup_iters,
320
+ power=1.0,
321
+ ),
322
+ ],
323
+ )
324
+
325
+ clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0)
sapiens/dense/configs/pointmap/render_people/sapiens2_1b_pointmap_render_people-1024x768.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters.
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 2
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_1b"
34
+ embed_dim = 1536
35
+ num_layers = 40
36
+ num_heads = 24
37
+ layer_decay_rate = 0.9
38
+
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors"
40
+
41
+ ##-----------------------------------------------------------------
42
+ image_size = (1024, 768) ## height x width
43
+
44
+ patch_size = 16
45
+ num_tokens = (image_size[0] // patch_size) * (image_size[1] // patch_size)
46
+ canonical_focal_length = 768.0
47
+
48
+ # ------------------------------------------------------------------
49
+ use_fsdp = True
50
+ # use_fsdp = False
51
+
52
+ use_compile = True
53
+ # use_compile = False
54
+
55
+ ## DDP config
56
+ if use_fsdp is False:
57
+ accelerator_cfg = dict(
58
+ type="DDP",
59
+ log_with="tensorboard",
60
+ # find_unused_parameters=True,
61
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
62
+ max_interval=num_iters,
63
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
64
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
65
+ )
66
+
67
+ else:
68
+ accelerator_cfg = dict(
69
+ type="FSDP",
70
+ log_with="tensorboard",
71
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
72
+ max_interval=num_iters,
73
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
74
+ step_scheduler_with_optimizer=False,
75
+ fsdp_cfg=dict(
76
+ fsdp_version=2, # DTensor-based engine
77
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
78
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
79
+ # mixed_precision=dict(
80
+ # param_dtype="bf16",
81
+ # reduce_dtype="bf16",
82
+ # ),
83
+ cpu_ram_efficient_loading=False,
84
+ ),
85
+ )
86
+
87
+ if use_compile:
88
+ accelerator_cfg["compile_cfg"] = dict(
89
+ backend="inductor",
90
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
91
+ fullgraph=False,
92
+ dynamic=False,
93
+ )
94
+
95
+ # ------------------------------------------------------------------
96
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
97
+ logger = dict(
98
+ type="Logger",
99
+ log_interval=log_every_iters,
100
+ )
101
+ checkpoint = dict(
102
+ type="Checkpointer",
103
+ save_interval=save_every_iters,
104
+ )
105
+
106
+ visualizer = dict(
107
+ type="PointmapVisualizer",
108
+ vis_interval=vis_every_iters,
109
+ vis_max_samples=4,
110
+ vis_image_width=384,
111
+ vis_image_height=512,
112
+ )
113
+
114
+
115
+ ##-----------------------------------------------------------------
116
+ train_pipeline = [
117
+ dict(type="PhotoMetricDistortion"),
118
+ dict(
119
+ type="PointmapRandomScale",
120
+ scale_min=0.5,
121
+ scale_max=2.0,
122
+ prob=0.3,
123
+ ),
124
+ dict(
125
+ type="PointmapRandomCropContinuous",
126
+ ar_range=(0.5, 2.0),
127
+ area_range=(0.4, 1.0),
128
+ num_attempts=8,
129
+ prob=0.3,
130
+ ),
131
+ dict(
132
+ type="PointmapRandomFlip",
133
+ prob=0.3,
134
+ ),
135
+ dict(type="PointmapResize", height=1024, width=768),
136
+ ## target is same res as output, otherwise we get artifacts.
137
+ dict(
138
+ type="PointmapGenerateTarget",
139
+ canonical_focal_length=canonical_focal_length,
140
+ target_downsample_factor=1,
141
+ ),
142
+ dict(
143
+ type="PointmapPackInputs",
144
+ meta_keys=(
145
+ "img_path",
146
+ "ori_shape",
147
+ "img_shape",
148
+ "pad_shape",
149
+ "scale",
150
+ "flip",
151
+ "flip_direction",
152
+ "original_K",
153
+ "K",
154
+ "M",
155
+ ),
156
+ ),
157
+ ]
158
+
159
+ val_pipeline = [
160
+ dict(type="PointmapResize", height=1024, width=768),
161
+ dict(type="PointmapGenerateTarget", canonical_focal_length=canonical_focal_length),
162
+ dict(
163
+ type="PointmapPackInputs",
164
+ meta_keys=(
165
+ "img_path",
166
+ "orig_img_height",
167
+ "orig_img_width",
168
+ "img_shape",
169
+ "pad_shape",
170
+ "scale",
171
+ "padding_size",
172
+ "K",
173
+ "M",
174
+ ),
175
+ ),
176
+ ]
177
+
178
+ test_pipeline = [
179
+ dict(type="PointmapResizePadImage", height=1024, width=768, pad_val=0),
180
+ dict(
181
+ type="PointmapPackInputs",
182
+ meta_keys=(
183
+ "img_path",
184
+ "orig_img_height",
185
+ "orig_img_width",
186
+ "padding_size",
187
+ ),
188
+ ),
189
+ ]
190
+
191
+ render_people_dataset = dict(
192
+ type="PointmapRenderPeopleDataset",
193
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2",
194
+ )
195
+
196
+ train_datasets = [render_people_dataset]
197
+
198
+ train_dataloader = dict(
199
+ batch_size=1,
200
+ num_workers=4,
201
+ persistent_workers=True,
202
+ shuffle=True,
203
+ dataset=dict(
204
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
205
+ ),
206
+ )
207
+
208
+ val_dataloader = dict(
209
+ batch_size=4,
210
+ num_workers=4,
211
+ persistent_workers=True,
212
+ multiprocessing_context="spawn",
213
+ # num_workers=0, # debug
214
+ # persistent_workers=False, # debug
215
+ shuffle=False,
216
+ dataset=dict(
217
+ type="PointmapRenderPeopleDataset",
218
+ # num_samples=100, ## debug: only use N samples for validation
219
+ test_mode=True,
220
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2_test",
221
+ pipeline=val_pipeline,
222
+ ),
223
+ )
224
+
225
+ val_cfg = dict(
226
+ val_interval=val_every_iters,
227
+ evaluator=dict(
228
+ type="PointmapEvaluator",
229
+ ),
230
+ )
231
+
232
+ data_preprocessor = dict(
233
+ type="ImagePreprocessor",
234
+ mean=[123.675, 116.28, 103.53],
235
+ std=[58.395, 57.12, 57.375],
236
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
237
+ )
238
+
239
+ ##-----------------------------------------------------------------
240
+ model = dict(
241
+ type="PointmapEstimator",
242
+ canonical_focal_length=canonical_focal_length,
243
+ backbone=dict(
244
+ type="Sapiens2",
245
+ arch=model_name,
246
+ img_size=image_size,
247
+ patch_size=patch_size,
248
+ final_norm=True,
249
+ use_tokenizer=False,
250
+ with_cls_token=True,
251
+ out_type="featmap",
252
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
253
+ ),
254
+ decode_head=dict(
255
+ type="PointmapHead",
256
+ in_channels=embed_dim,
257
+ upsample_channels=[1536, 768, 512, 256],
258
+ conv_out_channels=[64, 32, 16],
259
+ conv_kernel_sizes=[3, 3, 3],
260
+ scale_conv_out_channels=(1536, 512, 128),
261
+ scale_conv_kernel_sizes=(1, 1, 1),
262
+ scale_final_layer=(
263
+ (num_tokens // ((2 * 2 * 2) * (2 * 2 * 2))) * 128,
264
+ 512,
265
+ 128,
266
+ 1,
267
+ ), ## scale regress
268
+ loss_decode=[
269
+ dict(type="L1Loss", loss_weight=2.0), ## on pointmap, XYZ
270
+ dict(
271
+ type="MultiscaleL1Loss",
272
+ loss_weight=1.0,
273
+ scale_factor=2,
274
+ ),
275
+ dict(type="SiLogLoss", loss_weight=1.0), ## only applies silog loss
276
+ dict(
277
+ type="PointmapIntrinsicsConsistencyLoss",
278
+ loss_weight=1.0,
279
+ ),
280
+ dict(
281
+ type="PointmapShiftInvariantL1Loss",
282
+ loss_weight=1.0,
283
+ ),
284
+ dict(type="PointmapNormalLoss", loss_weight=2.0),
285
+ dict(
286
+ type="PointmapScaleL1Loss", loss_weight=4.0
287
+ ), ## Canonical XYZ = scale * XYZ
288
+ ],
289
+ ),
290
+ )
291
+
292
+
293
+ ##-----------------------------------------------------------------
294
+ optimizer = dict(
295
+ type="AdamW",
296
+ lr=5e-4,
297
+ betas=(0.9, 0.999),
298
+ weight_decay=0.1,
299
+ paramwise_cfg=dict(
300
+ num_layers=num_layers,
301
+ layer_decay_rate=layer_decay_rate,
302
+ ),
303
+ fused=True,
304
+ )
305
+
306
+ scheduler = dict(
307
+ type="SequentialLR",
308
+ milestones=[warmup_iters],
309
+ schedulers=[
310
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
311
+ dict(
312
+ type="PolynomialLR",
313
+ total_iters=num_iters - warmup_iters,
314
+ power=1.0,
315
+ ),
316
+ ],
317
+ )
318
+
319
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/pointmap/render_people/sapiens2_5b_pointmap_render_people-1024x768.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 1, global bs: 256. num samples: 1e6. 1e6/256 = 3906. 1 epoch = 3906 iters.
16
+
17
+ ## debug
18
+ # warmup_iters = 100
19
+ # num_iters = 300
20
+
21
+ # ------------------------------------------------------------------------------
22
+ vis_every_iters = 100
23
+ log_every_iters = 10
24
+ save_every_iters = 1000
25
+ val_every_iters = 1000
26
+
27
+ # # debug
28
+ # vis_every_iters = 1
29
+ # log_every_iters = 1
30
+ # val_every_iters = 10
31
+
32
+ load_from = None
33
+ resume = False
34
+
35
+ # ------------------------------------------------------------------
36
+ model_name = "sapiens2_5b"
37
+ embed_dim = 2432
38
+ num_layers = 56
39
+ num_heads = 32
40
+ layer_decay_rate = 0.94
41
+
42
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors"
43
+
44
+
45
+ ##-----------------------------------------------------------------
46
+ image_size = (1024, 768) ## height x width
47
+
48
+ patch_size = 16
49
+ num_tokens = (image_size[0] // patch_size) * (image_size[1] // patch_size)
50
+ canonical_focal_length = 768.0
51
+
52
+ # ------------------------------------------------------------------
53
+ use_fsdp = True
54
+ # use_fsdp = False
55
+
56
+ use_compile = True
57
+ # use_compile = False
58
+
59
+ ## DDP config
60
+ if use_fsdp is False:
61
+ accelerator_cfg = dict(
62
+ type="DDP",
63
+ log_with="tensorboard",
64
+ # find_unused_parameters=True,
65
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
66
+ max_interval=num_iters,
67
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
68
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
69
+ )
70
+
71
+ else:
72
+ accelerator_cfg = dict(
73
+ type="FSDP",
74
+ log_with="tensorboard",
75
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
76
+ max_interval=num_iters,
77
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
78
+ step_scheduler_with_optimizer=False,
79
+ fsdp_cfg=dict(
80
+ fsdp_version=2, # DTensor-based engine
81
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
82
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
83
+ # mixed_precision=dict(
84
+ # param_dtype="bf16",
85
+ # reduce_dtype="bf16",
86
+ # ),
87
+ cpu_ram_efficient_loading=False,
88
+ ),
89
+ )
90
+
91
+ if use_compile:
92
+ accelerator_cfg["compile_cfg"] = dict(
93
+ backend="inductor",
94
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
95
+ fullgraph=False,
96
+ dynamic=False,
97
+ )
98
+
99
+ # ------------------------------------------------------------------
100
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
101
+ logger = dict(
102
+ type="Logger",
103
+ log_interval=log_every_iters,
104
+ )
105
+ checkpoint = dict(
106
+ type="Checkpointer",
107
+ save_interval=save_every_iters,
108
+ )
109
+
110
+ visualizer = dict(
111
+ type="PointmapVisualizer",
112
+ vis_interval=vis_every_iters,
113
+ vis_max_samples=4,
114
+ vis_image_width=384,
115
+ vis_image_height=512,
116
+ )
117
+
118
+
119
+ ##-----------------------------------------------------------------
120
+ train_pipeline = [
121
+ dict(type="PhotoMetricDistortion"),
122
+ dict(
123
+ type="PointmapRandomScale",
124
+ scale_min=0.5,
125
+ scale_max=2.0,
126
+ prob=0.3,
127
+ ),
128
+ dict(
129
+ type="PointmapRandomCropContinuous",
130
+ ar_range=(0.5, 2.0),
131
+ area_range=(0.4, 1.0),
132
+ num_attempts=8,
133
+ prob=0.3,
134
+ ),
135
+ dict(
136
+ type="PointmapRandomFlip",
137
+ prob=0.3,
138
+ ),
139
+ dict(type="PointmapResize", height=1024, width=768),
140
+ ## target is same res as output, otherwise we get artifacts.
141
+ dict(
142
+ type="PointmapGenerateTarget",
143
+ canonical_focal_length=canonical_focal_length,
144
+ target_downsample_factor=1,
145
+ ),
146
+ dict(
147
+ type="PointmapPackInputs",
148
+ meta_keys=(
149
+ "img_path",
150
+ "ori_shape",
151
+ "img_shape",
152
+ "pad_shape",
153
+ "scale",
154
+ "flip",
155
+ "flip_direction",
156
+ "original_K",
157
+ "K",
158
+ "M",
159
+ ),
160
+ ),
161
+ ]
162
+
163
+ val_pipeline = [
164
+ dict(type="PointmapResize", height=1024, width=768),
165
+ dict(type="PointmapGenerateTarget", canonical_focal_length=canonical_focal_length),
166
+ dict(
167
+ type="PointmapPackInputs",
168
+ meta_keys=(
169
+ "img_path",
170
+ "orig_img_height",
171
+ "orig_img_width",
172
+ "img_shape",
173
+ "pad_shape",
174
+ "scale",
175
+ "padding_size",
176
+ "K",
177
+ "M",
178
+ ),
179
+ ),
180
+ ]
181
+
182
+ test_pipeline = [
183
+ dict(type="PointmapResizePadImage", height=1024, width=768, pad_val=0),
184
+ dict(
185
+ type="PointmapPackInputs",
186
+ meta_keys=(
187
+ "img_path",
188
+ "orig_img_height",
189
+ "orig_img_width",
190
+ "img_shape",
191
+ "pad_shape",
192
+ "scale",
193
+ "padding_size",
194
+ "K",
195
+ "M",
196
+ ),
197
+ ),
198
+ ]
199
+
200
+ render_people_dataset = dict(
201
+ type="PointmapRenderPeopleDataset",
202
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2",
203
+ )
204
+
205
+ train_datasets = [render_people_dataset]
206
+
207
+ train_dataloader = dict(
208
+ batch_size=1,
209
+ num_workers=4,
210
+ persistent_workers=True,
211
+ shuffle=True,
212
+ dataset=dict(
213
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
214
+ ),
215
+ )
216
+
217
+ val_dataloader = dict(
218
+ batch_size=1,
219
+ num_workers=4,
220
+ persistent_workers=True,
221
+ shuffle=False,
222
+ dataset=dict(
223
+ type="PointmapRenderPeopleDataset",
224
+ # num_samples=100, ## only use N samples for validation
225
+ test_mode=True,
226
+ data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2_test",
227
+ pipeline=val_pipeline,
228
+ ),
229
+ )
230
+
231
+ val_cfg = dict(
232
+ val_interval=val_every_iters,
233
+ evaluator=dict(
234
+ type="PointmapEvaluator",
235
+ ),
236
+ )
237
+
238
+ data_preprocessor = dict(
239
+ type="ImagePreprocessor",
240
+ mean=[123.675, 116.28, 103.53],
241
+ std=[58.395, 57.12, 57.375],
242
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
243
+ )
244
+
245
+ ##-----------------------------------------------------------------
246
+ model = dict(
247
+ type="PointmapEstimator",
248
+ canonical_focal_length=canonical_focal_length,
249
+ backbone=dict(
250
+ type="Sapiens2",
251
+ arch=model_name,
252
+ img_size=image_size,
253
+ patch_size=patch_size,
254
+ final_norm=True,
255
+ use_tokenizer=False,
256
+ with_cls_token=True,
257
+ out_type="featmap",
258
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
259
+ ),
260
+ decode_head=dict(
261
+ type="PointmapHead",
262
+ in_channels=embed_dim,
263
+ # upsample_channels=[1536, 768, 512, 256],
264
+ # conv_out_channels=[64, 32, 16],
265
+ # conv_kernel_sizes=[3, 3, 3],
266
+ upsample_channels=[1536, 768, 768, 768], ## 1K resolution
267
+ conv_out_channels=[128, 64, 32],
268
+ conv_kernel_sizes=[3, 3, 3],
269
+ scale_conv_out_channels=(1536, 512, 128),
270
+ scale_conv_kernel_sizes=(1, 1, 1),
271
+ scale_final_layer=(
272
+ (num_tokens // ((2 * 2 * 2) * (2 * 2 * 2))) * 128,
273
+ 512,
274
+ 128,
275
+ 1,
276
+ ), ## scale regress
277
+ loss_decode=[
278
+ dict(type="L1Loss", loss_weight=2.0), ## on pointmap, XYZ
279
+ dict(
280
+ type="MultiscaleL1Loss",
281
+ loss_weight=1.0,
282
+ scale_factor=2,
283
+ ),
284
+ dict(type="SiLogLoss", loss_weight=1.0), ## only applies silog loss
285
+ dict(
286
+ type="PointmapIntrinsicsConsistencyLoss",
287
+ loss_weight=1.0,
288
+ ),
289
+ dict(
290
+ type="PointmapShiftInvariantL1Loss",
291
+ loss_weight=1.0,
292
+ ),
293
+ dict(type="PointmapNormalLoss", loss_weight=2.0),
294
+ dict(
295
+ type="PointmapScaleL1Loss", loss_weight=4.0
296
+ ), ## Canonical XYZ = scale * XYZ
297
+ ],
298
+ ),
299
+ )
300
+
301
+
302
+ ##-----------------------------------------------------------------
303
+ optimizer = dict(
304
+ type="AdamW",
305
+ # lr=5e-4,
306
+ lr=1e-4,
307
+ betas=(0.9, 0.999),
308
+ weight_decay=0.1,
309
+ paramwise_cfg=dict(
310
+ num_layers=num_layers,
311
+ layer_decay_rate=layer_decay_rate,
312
+ ),
313
+ fused=True,
314
+ )
315
+
316
+ scheduler = dict(
317
+ type="SequentialLR",
318
+ milestones=[warmup_iters],
319
+ schedulers=[
320
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
321
+ dict(
322
+ type="PolynomialLR",
323
+ total_iters=num_iters - warmup_iters,
324
+ power=1.0,
325
+ ),
326
+ ],
327
+ )
328
+
329
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.4b_seg_shutterstock_goliath-1024x768.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 2e4
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 2
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_0.4b"
34
+ embed_dim = 1024
35
+ num_layers = 24
36
+ num_heads = 16
37
+
38
+ layer_decay_rate = 0.8
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors"
40
+
41
+ num_classes = 29 ## 29 classes
42
+ CLASS_WEIGHT = [
43
+ 0.1,
44
+ 10,
45
+ 10,
46
+ 3,
47
+ 2,
48
+ 4,
49
+ 4,
50
+ 2,
51
+ 2,
52
+ 6,
53
+ 10,
54
+ 3,
55
+ 3,
56
+ 1,
57
+ 4,
58
+ 4,
59
+ 2,
60
+ 2,
61
+ 6,
62
+ 10,
63
+ 3,
64
+ 3,
65
+ 1,
66
+ 1,
67
+ 10,
68
+ 10,
69
+ 10,
70
+ 10,
71
+ 10,
72
+ ] ## 29 classes
73
+
74
+ ##-----------------------------------------------------------------
75
+ image_size = (1024, 768) ## height x width
76
+ patch_size = 16
77
+
78
+ # ------------------------------------------------------------------
79
+ # use_fsdp = True
80
+ use_fsdp = False
81
+
82
+ use_compile = True
83
+ # use_compile = False
84
+
85
+ ## DDP config
86
+ if use_fsdp is False:
87
+ accelerator_cfg = dict(
88
+ type="DDP",
89
+ log_with="tensorboard",
90
+ # find_unused_parameters=True,
91
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
92
+ max_interval=num_iters,
93
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
94
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
95
+ )
96
+
97
+ else:
98
+ accelerator_cfg = dict(
99
+ type="FSDP",
100
+ log_with="tensorboard",
101
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
102
+ max_interval=num_iters,
103
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
104
+ step_scheduler_with_optimizer=False,
105
+ fsdp_cfg=dict(
106
+ fsdp_version=2, # DTensor-based engine
107
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
108
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
109
+ mixed_precision=dict(
110
+ param_dtype="bf16",
111
+ reduce_dtype="bf16",
112
+ ),
113
+ cpu_ram_efficient_loading=False,
114
+ ),
115
+ )
116
+
117
+ if use_compile:
118
+ accelerator_cfg["compile_cfg"] = dict(
119
+ backend="inductor",
120
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
121
+ fullgraph=False,
122
+ dynamic=False,
123
+ )
124
+
125
+ # ------------------------------------------------------------------
126
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
127
+ logger = dict(
128
+ type="Logger",
129
+ log_interval=log_every_iters,
130
+ )
131
+ checkpoint = dict(
132
+ type="Checkpointer",
133
+ save_interval=save_every_iters,
134
+ )
135
+
136
+ visualizer = dict(
137
+ type="SegVisualizer",
138
+ vis_interval=vis_every_iters,
139
+ vis_max_samples=4,
140
+ vis_image_width=384,
141
+ vis_image_height=512,
142
+ class_palette_type="dome29",
143
+ )
144
+
145
+
146
+ ##-----------------------------------------------------------------
147
+ train_pipeline = [
148
+ dict(
149
+ type="SegRandomBackground",
150
+ prob=0.8,
151
+ skip_key="is_itw",
152
+ background_images_root=f"{_DATA_ROOT}/BG-20k/train",
153
+ ),
154
+ dict(
155
+ type="SegRandomResize",
156
+ base_height=1024,
157
+ base_width=768,
158
+ ratio_range=(0.4, 2.0),
159
+ keep_ratio=True,
160
+ ),
161
+ dict(
162
+ type="SegRandomCrop",
163
+ crop_height=1024,
164
+ crop_width=768,
165
+ prob=0.3,
166
+ cat_max_ratio=0.75,
167
+ ),
168
+ dict(
169
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
170
+ ),
171
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
172
+ dict(
173
+ type="SegRandomRotate", prob=0.5, degree=60, seg_pad_val=0
174
+ ), ## the black pixels are set as background
175
+ dict(
176
+ type="SegRandomHorizontalFlip",
177
+ prob=0.5,
178
+ swap_seg_labels=[
179
+ (5, 14),
180
+ (6, 15),
181
+ (7, 16),
182
+ (8, 17),
183
+ (9, 18),
184
+ (10, 19),
185
+ (11, 20),
186
+ (12, 21),
187
+ ],
188
+ ), ## for the 29 classes,
189
+ dict(type="PhotoMetricDistortion"),
190
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False),
191
+ dict(type="SegPackInputs"),
192
+ ]
193
+
194
+ val_pipeline = [
195
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True),
196
+ dict(type="SegPackInputs", test_mode=True),
197
+ ]
198
+
199
+ test_pipeline = [
200
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True),
201
+ dict(type="SegPackInputs", test_mode=True),
202
+ ]
203
+
204
+ ##------------------------------------------------------------------------
205
+ dataset_dome_train = dict(
206
+ type="SegDomeClass29Dataset",
207
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_segmentation_33_train:2024092600.json",
208
+ )
209
+
210
+ dataset_shutterstock_train = dict(
211
+ type="SegShutterstockClass29Dataset",
212
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_train:2024121600.json",
213
+ )
214
+
215
+ dataset_ca3_wide_train = dict(
216
+ type="SegDomeClass29Dataset",
217
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_wide_angle_body_segmentation_33_train:2024091700.json",
218
+ )
219
+
220
+ dataset_caa_train = dict(
221
+ type="SegDomeClass29Dataset",
222
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/cca_segmentation_33_train:2024092400.json",
223
+ )
224
+
225
+ dataset_ca3_zoom_train = dict(
226
+ type="SegShutterstockClass29Dataset",
227
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_zoom_in_body_segmentation_50_train:2024091700.json",
228
+ )
229
+
230
+ dataset_lighticon_train = dict(
231
+ type="SegShutterstockClass29Dataset",
232
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/lighticon_lightful_body_segmentation_51_train:2025021900.json",
233
+ )
234
+
235
+ dataset_internal_train = dict(
236
+ type="SegInternalClass29Dataset",
237
+ # ann_file=f"{_DATA_ROOT}/annotations/stylized_sapiens/20250807/Internal_segmentation_32:2025080700.json",
238
+ ann_file=f"{_DATA_ROOT}/annotations/internal_dataset/20251103/internal_keypoint_344_segmentation_32_train:2025091500.json",
239
+ )
240
+
241
+ train_datasets = [
242
+ dataset_dome_train,
243
+ dataset_ca3_wide_train,
244
+ dataset_caa_train,
245
+ dataset_ca3_zoom_train,
246
+ dataset_lighticon_train,
247
+ dataset_internal_train,
248
+ ] + 2 * [dataset_shutterstock_train]
249
+
250
+ train_dataloader = dict(
251
+ batch_size=1,
252
+ num_workers=8,
253
+ persistent_workers=True,
254
+ shuffle=True,
255
+ dataset=dict(
256
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
257
+ ),
258
+ )
259
+
260
+ val_dataloader = dict(
261
+ batch_size=4,
262
+ num_workers=4,
263
+ persistent_workers=True,
264
+ multiprocessing_context="spawn", ## avoids fork error with airstore
265
+ # num_workers=0, # debug
266
+ # persistent_workers=False, # debug
267
+ shuffle=False,
268
+ dataset=dict(
269
+ type="SegShutterstockClass29Dataset",
270
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_test:2024121600.json",
271
+ test_mode=True,
272
+ pipeline=val_pipeline,
273
+ ),
274
+ collate_fn=dict(type="eval_collate"),
275
+ )
276
+
277
+ val_cfg = dict(
278
+ val_interval=val_every_iters,
279
+ evaluator=dict(type="SegEvaluator", class_names="dome29", nan_to_num=0.0),
280
+ )
281
+
282
+ data_preprocessor = dict(
283
+ type="ImagePreprocessor",
284
+ mean=[123.675, 116.28, 103.53],
285
+ std=[58.395, 57.12, 57.375],
286
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
287
+ )
288
+
289
+ ##-----------------------------------------------------------------
290
+ model = dict(
291
+ type="SegEstimator",
292
+ backbone=dict(
293
+ type="Sapiens2",
294
+ arch=model_name,
295
+ img_size=image_size,
296
+ patch_size=patch_size,
297
+ final_norm=True,
298
+ use_tokenizer=False,
299
+ with_cls_token=True,
300
+ out_type="featmap",
301
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
302
+ ),
303
+ decode_head=dict(
304
+ type="SegHead",
305
+ in_channels=embed_dim,
306
+ deconv_out_channels=(
307
+ 512,
308
+ 256,
309
+ 128,
310
+ 64,
311
+ ), ## this will 2x at each step. so total is 16x. 1K output.
312
+ deconv_kernel_sizes=(4, 4, 4, 4),
313
+ conv_out_channels=(64, 64),
314
+ conv_kernel_sizes=(1, 1),
315
+ num_classes=num_classes,
316
+ loss_decode=[
317
+ dict(
318
+ type="CrossEntropyLoss",
319
+ loss_weight=1.0,
320
+ reduction="none",
321
+ class_weight=CLASS_WEIGHT,
322
+ ignore_index=255,
323
+ ),
324
+ dict(
325
+ type="DiceLoss",
326
+ loss_weight=1.0,
327
+ reduction="none",
328
+ activate=True,
329
+ use_sigmoid=False,
330
+ include_background=False,
331
+ ignore_index=255,
332
+ ),
333
+ ],
334
+ ),
335
+ )
336
+
337
+
338
+ ##-----------------------------------------------------------------
339
+ optimizer = dict(
340
+ type="AdamW",
341
+ lr=5e-4,
342
+ betas=(0.9, 0.999),
343
+ weight_decay=0.1,
344
+ paramwise_cfg=dict(
345
+ num_layers=num_layers,
346
+ layer_decay_rate=layer_decay_rate,
347
+ ),
348
+ fused=True, ## use fused AdamW
349
+ )
350
+
351
+ scheduler = dict(
352
+ type="SequentialLR",
353
+ milestones=[warmup_iters],
354
+ schedulers=[
355
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
356
+ dict(
357
+ type="PolynomialLR",
358
+ total_iters=num_iters - warmup_iters,
359
+ power=1.0,
360
+ ),
361
+ ],
362
+ )
363
+
364
+ clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0)
sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.8b_seg_shutterstock_goliath-1024x768.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 3e4 ## bs: 5; 16 gpus
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 1000
21
+ val_every_iters = 1000
22
+
23
+ # # # # debug
24
+ # vis_every_iters = 1
25
+ # log_every_iters = 1
26
+ # val_every_iters = 2
27
+ # save_every_iters = 1000
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_0.8b"
34
+ embed_dim = 1280
35
+ num_layers = 32
36
+ num_heads = 16
37
+
38
+ layer_decay_rate = 0.85
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors"
40
+
41
+ num_classes = 29 ## 29 classes
42
+ CLASS_WEIGHT = [
43
+ 0.1,
44
+ 10,
45
+ 10,
46
+ 3,
47
+ 2,
48
+ 4,
49
+ 4,
50
+ 2,
51
+ 2,
52
+ 6,
53
+ 10,
54
+ 3,
55
+ 3,
56
+ 1,
57
+ 4,
58
+ 4,
59
+ 2,
60
+ 2,
61
+ 6,
62
+ 10,
63
+ 3,
64
+ 3,
65
+ 1,
66
+ 1,
67
+ 10,
68
+ 10,
69
+ 10,
70
+ 10,
71
+ 10,
72
+ ] ## 29 classes
73
+
74
+ ##-----------------------------------------------------------------
75
+ image_size = (1024, 768) ## height x width
76
+ patch_size = 16
77
+
78
+ # ------------------------------------------------------------------
79
+ # use_fsdp = True
80
+ use_fsdp = False
81
+
82
+ use_compile = True
83
+ # use_compile = False
84
+
85
+ ## DDP config
86
+ if use_fsdp is False:
87
+ accelerator_cfg = dict(
88
+ type="DDP",
89
+ log_with="tensorboard",
90
+ # find_unused_parameters=True,
91
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
92
+ max_interval=num_iters,
93
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
94
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
95
+ )
96
+
97
+ else:
98
+ accelerator_cfg = dict(
99
+ type="FSDP",
100
+ log_with="tensorboard",
101
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
102
+ max_interval=num_iters,
103
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
104
+ step_scheduler_with_optimizer=False,
105
+ fsdp_cfg=dict(
106
+ fsdp_version=2, # DTensor-based engine
107
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
108
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
109
+ mixed_precision=dict(
110
+ param_dtype="bf16",
111
+ reduce_dtype="bf16",
112
+ ),
113
+ cpu_ram_efficient_loading=False,
114
+ ),
115
+ )
116
+
117
+ ## Note: to merge sharded weight using FSDP
118
+ # accelerate merge-weights pytorch_model_fsdp_0/ .
119
+
120
+ if use_compile:
121
+ accelerator_cfg["compile_cfg"] = dict(
122
+ backend="inductor",
123
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
124
+ fullgraph=False,
125
+ dynamic=False,
126
+ )
127
+
128
+ # ------------------------------------------------------------------
129
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
130
+ logger = dict(
131
+ type="Logger",
132
+ log_interval=log_every_iters,
133
+ )
134
+ checkpoint = dict(
135
+ type="Checkpointer",
136
+ save_interval=save_every_iters,
137
+ )
138
+
139
+ visualizer = dict(
140
+ type="SegVisualizer",
141
+ vis_interval=vis_every_iters,
142
+ vis_max_samples=4,
143
+ vis_image_width=384,
144
+ vis_image_height=512,
145
+ class_palette_type="dome29",
146
+ )
147
+
148
+
149
+ ##-----------------------------------------------------------------
150
+ train_pipeline = [
151
+ dict(
152
+ type="SegRandomBackground",
153
+ prob=0.8,
154
+ skip_key="is_itw",
155
+ background_images_root=f"{_DATA_ROOT}/BG-20k/train",
156
+ ),
157
+ dict(
158
+ type="SegRandomResize",
159
+ base_height=1024,
160
+ base_width=768,
161
+ ratio_range=(0.4, 2.0),
162
+ keep_ratio=True,
163
+ ),
164
+ dict(
165
+ type="SegRandomCrop",
166
+ crop_height=1024,
167
+ crop_width=768,
168
+ prob=0.3,
169
+ cat_max_ratio=0.75,
170
+ ),
171
+ dict(
172
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
173
+ ),
174
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
175
+ dict(
176
+ type="SegRandomRotate", prob=0.5, degree=60, seg_pad_val=0
177
+ ), ## the black pixels are set as background
178
+ dict(
179
+ type="SegRandomHorizontalFlip",
180
+ prob=0.5,
181
+ swap_seg_labels=[
182
+ (5, 14),
183
+ (6, 15),
184
+ (7, 16),
185
+ (8, 17),
186
+ (9, 18),
187
+ (10, 19),
188
+ (11, 20),
189
+ (12, 21),
190
+ ],
191
+ ), ## for the 29 classes,
192
+ dict(type="PhotoMetricDistortion"),
193
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False),
194
+ dict(type="SegPackInputs"),
195
+ ]
196
+
197
+ val_pipeline = [
198
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True),
199
+ dict(type="SegPackInputs", test_mode=True),
200
+ ]
201
+
202
+ test_pipeline = [
203
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True),
204
+ dict(type="SegPackInputs", test_mode=True),
205
+ ]
206
+
207
+ ##------------------------------------------------------------------------
208
+ dataset_dome_train = dict(
209
+ type="SegDomeClass29Dataset",
210
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_segmentation_33_train:2024092600.json",
211
+ )
212
+
213
+ dataset_shutterstock_train = dict(
214
+ type="SegShutterstockClass29Dataset",
215
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_train:2024121600.json",
216
+ )
217
+
218
+ dataset_ca3_wide_train = dict(
219
+ type="SegDomeClass29Dataset",
220
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_wide_angle_body_segmentation_33_train:2024091700.json",
221
+ )
222
+
223
+ dataset_caa_train = dict(
224
+ type="SegDomeClass29Dataset",
225
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/cca_segmentation_33_train:2024092400.json",
226
+ )
227
+
228
+ dataset_ca3_zoom_train = dict(
229
+ type="SegShutterstockClass29Dataset",
230
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_zoom_in_body_segmentation_50_train:2024091700.json",
231
+ )
232
+
233
+ dataset_lighticon_train = dict(
234
+ type="SegShutterstockClass29Dataset",
235
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/lighticon_lightful_body_segmentation_51_train:2025021900.json",
236
+ )
237
+
238
+ dataset_internal_train = dict(
239
+ type="SegInternalClass29Dataset",
240
+ # ann_file=f"{_DATA_ROOT}/annotations/stylized_sapiens/20250807/Internal_segmentation_32:2025080700.json",
241
+ ann_file=f"{_DATA_ROOT}/annotations/internal_dataset/20251103/internal_keypoint_344_segmentation_32_train:2025091500.json",
242
+ )
243
+
244
+ train_datasets = [
245
+ dataset_dome_train,
246
+ dataset_ca3_wide_train,
247
+ dataset_caa_train,
248
+ dataset_ca3_zoom_train,
249
+ dataset_lighticon_train,
250
+ dataset_internal_train,
251
+ ] + 2 * [dataset_shutterstock_train]
252
+
253
+ train_dataloader = dict(
254
+ batch_size=1,
255
+ num_workers=4,
256
+ persistent_workers=True,
257
+ shuffle=True,
258
+ dataset=dict(
259
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
260
+ ),
261
+ )
262
+
263
+ val_dataloader = dict(
264
+ batch_size=4,
265
+ num_workers=4,
266
+ persistent_workers=True,
267
+ multiprocessing_context="spawn", ## avoids fork error with airstore
268
+ # num_workers=0, # debug
269
+ # persistent_workers=False, # debug
270
+ shuffle=False,
271
+ dataset=dict(
272
+ type="SegShutterstockClass29Dataset",
273
+ # num_samples=40, ## only use N samples for validation
274
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_test:2024121600.json",
275
+ test_mode=True,
276
+ pipeline=val_pipeline,
277
+ ),
278
+ collate_fn=dict(type="eval_collate"),
279
+ )
280
+
281
+ val_cfg = dict(
282
+ val_interval=val_every_iters,
283
+ evaluator=dict(type="SegEvaluator", class_names="dome29", nan_to_num=0.0),
284
+ )
285
+
286
+ data_preprocessor = dict(
287
+ type="ImagePreprocessor",
288
+ mean=[123.675, 116.28, 103.53],
289
+ std=[58.395, 57.12, 57.375],
290
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
291
+ )
292
+
293
+ ##-----------------------------------------------------------------
294
+ model = dict(
295
+ type="SegEstimator",
296
+ backbone=dict(
297
+ type="Sapiens2",
298
+ arch=model_name,
299
+ img_size=image_size,
300
+ patch_size=patch_size,
301
+ final_norm=True,
302
+ use_tokenizer=False,
303
+ with_cls_token=True,
304
+ out_type="featmap",
305
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
306
+ ),
307
+ decode_head=dict(
308
+ type="SegHead",
309
+ in_channels=embed_dim,
310
+ deconv_out_channels=(
311
+ 512,
312
+ 256,
313
+ 128,
314
+ 64,
315
+ ), ## this will 2x at each step. so total is 16x. 1K output.
316
+ deconv_kernel_sizes=(4, 4, 4, 4),
317
+ conv_out_channels=(64, 64),
318
+ conv_kernel_sizes=(1, 1),
319
+ num_classes=num_classes,
320
+ loss_decode=[
321
+ dict(
322
+ type="CrossEntropyLoss",
323
+ loss_weight=1.0,
324
+ reduction="none",
325
+ class_weight=CLASS_WEIGHT,
326
+ ignore_index=255,
327
+ ),
328
+ dict(
329
+ type="DiceLoss",
330
+ loss_weight=1.0,
331
+ reduction="none",
332
+ activate=True,
333
+ use_sigmoid=False,
334
+ include_background=False,
335
+ ignore_index=255,
336
+ ),
337
+ ],
338
+ ),
339
+ )
340
+
341
+
342
+ ##-----------------------------------------------------------------
343
+ optimizer = dict(
344
+ type="AdamW",
345
+ lr=5e-4,
346
+ betas=(0.9, 0.999),
347
+ weight_decay=0.1,
348
+ paramwise_cfg=dict(
349
+ num_layers=num_layers,
350
+ layer_decay_rate=layer_decay_rate,
351
+ ),
352
+ fused=True,
353
+ )
354
+
355
+ scheduler = dict(
356
+ type="SequentialLR",
357
+ milestones=[warmup_iters],
358
+ schedulers=[
359
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
360
+ dict(
361
+ type="PolynomialLR",
362
+ total_iters=num_iters - warmup_iters,
363
+ power=1.0,
364
+ ),
365
+ ],
366
+ )
367
+
368
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_1b_seg_shutterstock_goliath-1024x768.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 2e4
16
+ # num_iters = 4e4
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_1b"
35
+ embed_dim = 1536
36
+ num_layers = 40
37
+ num_heads = 24
38
+
39
+ layer_decay_rate = 0.9
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors"
41
+
42
+ num_classes = 29 ## 29 classes
43
+ CLASS_WEIGHT = [
44
+ 0.1,
45
+ 10,
46
+ 10,
47
+ 3,
48
+ 2,
49
+ 4,
50
+ 4,
51
+ 2,
52
+ 2,
53
+ 6,
54
+ 10,
55
+ 3,
56
+ 3,
57
+ 1,
58
+ 4,
59
+ 4,
60
+ 2,
61
+ 2,
62
+ 6,
63
+ 10,
64
+ 3,
65
+ 3,
66
+ 1,
67
+ 1,
68
+ 10,
69
+ 10,
70
+ 10,
71
+ 10,
72
+ 10,
73
+ ] ## 29 classes
74
+
75
+ ##-----------------------------------------------------------------
76
+ image_size = (1024, 768) ## height x width
77
+ patch_size = 16
78
+
79
+ # ------------------------------------------------------------------
80
+ use_fsdp = True
81
+ # use_fsdp = False
82
+
83
+ use_compile = True
84
+ # use_compile = False
85
+
86
+ ## DDP config
87
+ if use_fsdp is False:
88
+ accelerator_cfg = dict(
89
+ type="DDP",
90
+ log_with="tensorboard",
91
+ # find_unused_parameters=True,
92
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
93
+ max_interval=num_iters,
94
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
95
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
96
+ )
97
+
98
+ else:
99
+ accelerator_cfg = dict(
100
+ type="FSDP",
101
+ log_with="tensorboard",
102
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
103
+ max_interval=num_iters,
104
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
105
+ step_scheduler_with_optimizer=False,
106
+ fsdp_cfg=dict(
107
+ fsdp_version=2, # DTensor-based engine
108
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
109
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
110
+ mixed_precision=dict(
111
+ param_dtype="bf16",
112
+ reduce_dtype="bf16",
113
+ ),
114
+ cpu_ram_efficient_loading=False,
115
+ ),
116
+ )
117
+
118
+ if use_compile:
119
+ accelerator_cfg["compile_cfg"] = dict(
120
+ backend="inductor",
121
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
122
+ fullgraph=False,
123
+ dynamic=False,
124
+ )
125
+
126
+ # ------------------------------------------------------------------
127
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
128
+ logger = dict(
129
+ type="Logger",
130
+ log_interval=log_every_iters,
131
+ )
132
+ checkpoint = dict(
133
+ type="Checkpointer",
134
+ save_interval=save_every_iters,
135
+ )
136
+
137
+ visualizer = dict(
138
+ type="SegVisualizer",
139
+ vis_interval=vis_every_iters,
140
+ vis_max_samples=4,
141
+ vis_image_width=384,
142
+ vis_image_height=512,
143
+ class_palette_type="dome29",
144
+ )
145
+
146
+
147
+ ##-----------------------------------------------------------------
148
+ train_pipeline = [
149
+ dict(
150
+ type="SegRandomBackground",
151
+ prob=0.8,
152
+ skip_key="is_itw",
153
+ background_images_root=f"{_DATA_ROOT}/BG-20k/train",
154
+ ),
155
+ dict(
156
+ type="SegRandomResize",
157
+ base_height=1024,
158
+ base_width=768,
159
+ ratio_range=(0.4, 2.0),
160
+ keep_ratio=True,
161
+ ),
162
+ dict(
163
+ type="SegRandomCrop",
164
+ crop_height=1024,
165
+ crop_width=768,
166
+ prob=0.3,
167
+ cat_max_ratio=0.75,
168
+ ),
169
+ dict(
170
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
171
+ ),
172
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
173
+ dict(
174
+ type="SegRandomRotate", prob=0.5, degree=60, seg_pad_val=0
175
+ ), ## the black pixels are set as background
176
+ dict(
177
+ type="SegRandomHorizontalFlip",
178
+ prob=0.5,
179
+ swap_seg_labels=[
180
+ (5, 14),
181
+ (6, 15),
182
+ (7, 16),
183
+ (8, 17),
184
+ (9, 18),
185
+ (10, 19),
186
+ (11, 20),
187
+ (12, 21),
188
+ ],
189
+ ), ## for the 29 classes,
190
+ dict(type="PhotoMetricDistortion"),
191
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False),
192
+ dict(type="SegPackInputs"),
193
+ ]
194
+
195
+ val_pipeline = [
196
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True),
197
+ dict(type="SegPackInputs", test_mode=True),
198
+ ]
199
+
200
+ test_pipeline = [
201
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True),
202
+ dict(type="SegPackInputs", test_mode=True),
203
+ ]
204
+
205
+ ##------------------------------------------------------------------------
206
+ dataset_dome_train = dict(
207
+ type="SegDomeClass29Dataset",
208
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_segmentation_33_train:2024092600.json",
209
+ )
210
+
211
+ dataset_shutterstock_train = dict(
212
+ type="SegShutterstockClass29Dataset",
213
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_train:2024121600.json",
214
+ )
215
+
216
+ dataset_ca3_wide_train = dict(
217
+ type="SegDomeClass29Dataset",
218
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_wide_angle_body_segmentation_33_train:2024091700.json",
219
+ )
220
+
221
+ dataset_caa_train = dict(
222
+ type="SegDomeClass29Dataset",
223
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/cca_segmentation_33_train:2024092400.json",
224
+ )
225
+
226
+ dataset_ca3_zoom_train = dict(
227
+ type="SegShutterstockClass29Dataset",
228
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_zoom_in_body_segmentation_50_train:2024091700.json",
229
+ )
230
+
231
+ dataset_lighticon_train = dict(
232
+ type="SegShutterstockClass29Dataset",
233
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/lighticon_lightful_body_segmentation_51_train:2025021900.json",
234
+ )
235
+
236
+ dataset_internal_train = dict(
237
+ type="SegInternalClass29Dataset",
238
+ # ann_file=f"{_DATA_ROOT}/annotations/stylized_sapiens/20250807/Internal_segmentation_32:2025080700.json",
239
+ ann_file=f"{_DATA_ROOT}/annotations/internal_dataset/20251103/internal_keypoint_344_segmentation_32_train:2025091500.json",
240
+ )
241
+
242
+ train_datasets = [
243
+ dataset_dome_train,
244
+ dataset_ca3_wide_train,
245
+ dataset_caa_train,
246
+ dataset_ca3_zoom_train,
247
+ dataset_lighticon_train,
248
+ dataset_internal_train,
249
+ ] + 2 * [dataset_shutterstock_train]
250
+
251
+ train_dataloader = dict(
252
+ batch_size=1,
253
+ num_workers=4,
254
+ persistent_workers=True,
255
+ shuffle=True,
256
+ dataset=dict(
257
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
258
+ ),
259
+ )
260
+
261
+ val_dataloader = dict(
262
+ batch_size=4,
263
+ num_workers=4,
264
+ persistent_workers=True,
265
+ multiprocessing_context="spawn", ## avoids fork error with airstore
266
+ # num_workers=0, # debug
267
+ # persistent_workers=False, # debug
268
+ shuffle=False,
269
+ dataset=dict(
270
+ type="SegShutterstockClass29Dataset",
271
+ # num_samples=40, ## only use N samples for validation
272
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_test:2024121600.json",
273
+ test_mode=True,
274
+ pipeline=val_pipeline,
275
+ ),
276
+ collate_fn=dict(type="eval_collate"),
277
+ )
278
+
279
+ val_cfg = dict(
280
+ val_interval=val_every_iters,
281
+ evaluator=dict(type="SegEvaluator", class_names="dome29", nan_to_num=0.0),
282
+ )
283
+
284
+ data_preprocessor = dict(
285
+ type="ImagePreprocessor",
286
+ mean=[123.675, 116.28, 103.53],
287
+ std=[58.395, 57.12, 57.375],
288
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
289
+ )
290
+
291
+ ##-----------------------------------------------------------------
292
+ model = dict(
293
+ type="SegEstimator",
294
+ backbone=dict(
295
+ type="Sapiens2",
296
+ arch=model_name,
297
+ img_size=image_size,
298
+ patch_size=patch_size,
299
+ final_norm=True,
300
+ use_tokenizer=False,
301
+ with_cls_token=True,
302
+ out_type="featmap",
303
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
304
+ ),
305
+ decode_head=dict(
306
+ type="SegHead",
307
+ in_channels=embed_dim,
308
+ deconv_out_channels=(
309
+ 512,
310
+ 256,
311
+ 128,
312
+ 64,
313
+ ), ## this will 2x at each step. so total is 16x. 1K output.
314
+ deconv_kernel_sizes=(4, 4, 4, 4),
315
+ conv_out_channels=(64, 64),
316
+ conv_kernel_sizes=(1, 1),
317
+ num_classes=num_classes,
318
+ loss_decode=[
319
+ dict(
320
+ type="CrossEntropyLoss",
321
+ loss_weight=1.0,
322
+ reduction="none",
323
+ class_weight=CLASS_WEIGHT,
324
+ ignore_index=255,
325
+ ),
326
+ dict(
327
+ type="DiceLoss",
328
+ loss_weight=1.0,
329
+ reduction="none",
330
+ activate=True,
331
+ use_sigmoid=False,
332
+ include_background=False,
333
+ ignore_index=255,
334
+ ),
335
+ ],
336
+ ),
337
+ )
338
+
339
+
340
+ ##-----------------------------------------------------------------
341
+ optimizer = dict(
342
+ type="AdamW",
343
+ lr=5e-4,
344
+ betas=(0.9, 0.999),
345
+ weight_decay=0.1,
346
+ paramwise_cfg=dict(
347
+ num_layers=num_layers,
348
+ layer_decay_rate=layer_decay_rate,
349
+ ),
350
+ fused=True,
351
+ )
352
+
353
+ scheduler = dict(
354
+ type="SequentialLR",
355
+ milestones=[warmup_iters],
356
+ schedulers=[
357
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
358
+ dict(
359
+ type="PolynomialLR",
360
+ total_iters=num_iters - warmup_iters,
361
+ power=1.0,
362
+ ),
363
+ ],
364
+ )
365
+
366
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_5b_seg_shutterstock_goliath-1024x768.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ num_iters = 5e4 ## for h200; bs is 4
16
+
17
+ # ------------------------------------------------------------------------------
18
+ vis_every_iters = 100
19
+ log_every_iters = 10
20
+ save_every_iters = 2000
21
+ # val_every_iters = 2000
22
+ val_every_iters = 10000
23
+
24
+ # # # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 10
28
+
29
+ load_from = None
30
+ resume = False
31
+
32
+ # ------------------------------------------------------------------
33
+ model_name = "sapiens2_5b"
34
+ embed_dim = 2432
35
+ num_layers = 56
36
+ num_heads = 32
37
+
38
+ layer_decay_rate = 0.94
39
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors"
40
+
41
+ num_classes = 29 ## 29 classes
42
+ CLASS_WEIGHT = [
43
+ 0.1,
44
+ 10,
45
+ 10,
46
+ 3,
47
+ 2,
48
+ 4,
49
+ 4,
50
+ 2,
51
+ 2,
52
+ 6,
53
+ 10,
54
+ 3,
55
+ 3,
56
+ 1,
57
+ 4,
58
+ 4,
59
+ 2,
60
+ 2,
61
+ 6,
62
+ 10,
63
+ 3,
64
+ 3,
65
+ 1,
66
+ 1,
67
+ 10,
68
+ 10,
69
+ 10,
70
+ 10,
71
+ 10,
72
+ ] ## 29 classes
73
+
74
+ ##-----------------------------------------------------------------
75
+ image_size = (1024, 768) ## height x width
76
+ patch_size = 16
77
+
78
+ # ------------------------------------------------------------------
79
+ use_fsdp = True
80
+ # use_fsdp = False
81
+
82
+ use_compile = True
83
+ # use_compile = False
84
+
85
+ ## DDP config
86
+ if use_fsdp is False:
87
+ accelerator_cfg = dict(
88
+ type="DDP",
89
+ log_with="tensorboard",
90
+ # find_unused_parameters=True,
91
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
92
+ max_interval=num_iters,
93
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
94
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
95
+ )
96
+
97
+ else:
98
+ accelerator_cfg = dict(
99
+ type="FSDP",
100
+ log_with="tensorboard",
101
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
102
+ max_interval=num_iters,
103
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
104
+ step_scheduler_with_optimizer=False,
105
+ fsdp_cfg=dict(
106
+ fsdp_version=2, # DTensor-based engine
107
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
108
+ # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working
109
+ mixed_precision=dict(
110
+ param_dtype="bf16",
111
+ reduce_dtype="bf16",
112
+ ),
113
+ cpu_ram_efficient_loading=False,
114
+ ),
115
+ )
116
+
117
+ ## Note: to merge sharded weight using FSDP
118
+ # accelerate merge-weights pytorch_model_fsdp_0/ .
119
+
120
+ if use_compile:
121
+ accelerator_cfg["compile_cfg"] = dict(
122
+ backend="inductor",
123
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
124
+ fullgraph=False,
125
+ dynamic=False,
126
+ )
127
+
128
+ # ------------------------------------------------------------------
129
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
130
+ logger = dict(
131
+ type="Logger",
132
+ log_interval=log_every_iters,
133
+ )
134
+ checkpoint = dict(
135
+ type="Checkpointer",
136
+ save_interval=save_every_iters,
137
+ )
138
+
139
+ visualizer = dict(
140
+ type="SegVisualizer",
141
+ vis_interval=vis_every_iters,
142
+ vis_max_samples=4,
143
+ vis_image_width=384,
144
+ vis_image_height=512,
145
+ class_palette_type="dome29",
146
+ )
147
+
148
+
149
+ ##-----------------------------------------------------------------
150
+ train_pipeline = [
151
+ dict(
152
+ type="SegRandomBackground",
153
+ prob=0.8,
154
+ skip_key="is_itw",
155
+ background_images_root=f"{_DATA_ROOT}/BG-20k/train",
156
+ ),
157
+ dict(
158
+ type="SegRandomResize",
159
+ base_height=1024,
160
+ base_width=768,
161
+ ratio_range=(0.4, 2.0),
162
+ keep_ratio=True,
163
+ ),
164
+ dict(
165
+ type="SegRandomCrop",
166
+ crop_height=1024,
167
+ crop_width=768,
168
+ prob=0.3,
169
+ cat_max_ratio=0.75,
170
+ ),
171
+ dict(
172
+ type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0)
173
+ ),
174
+ dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)),
175
+ dict(
176
+ type="SegRandomRotate", prob=0.5, degree=60, seg_pad_val=0
177
+ ), ## the black pixels are set as background
178
+ dict(
179
+ type="SegRandomHorizontalFlip",
180
+ prob=0.5,
181
+ swap_seg_labels=[
182
+ (5, 14),
183
+ (6, 15),
184
+ (7, 16),
185
+ (8, 17),
186
+ (9, 18),
187
+ (10, 19),
188
+ (11, 20),
189
+ (12, 21),
190
+ ],
191
+ ), ## for the 29 classes,
192
+ dict(type="PhotoMetricDistortion"),
193
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False),
194
+ dict(type="SegPackInputs"),
195
+ ]
196
+
197
+ val_pipeline = [
198
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False),
199
+ dict(type="SegPackInputs"),
200
+ ]
201
+
202
+ test_pipeline = [
203
+ dict(type="SegResize", height=1024, width=768, keep_ratio=False),
204
+ dict(type="SegPackInputs"),
205
+ ]
206
+
207
+ ##------------------------------------------------------------------------
208
+ dataset_dome_train = dict(
209
+ type="SegDomeClass29Dataset",
210
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_segmentation_33_train:2024092600.json",
211
+ )
212
+
213
+ dataset_shutterstock_train = dict(
214
+ type="SegShutterstockClass29Dataset",
215
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_train:2024121600.json",
216
+ )
217
+
218
+ dataset_ca3_wide_train = dict(
219
+ type="SegDomeClass29Dataset",
220
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_wide_angle_body_segmentation_33_train:2024091700.json",
221
+ )
222
+
223
+ dataset_caa_train = dict(
224
+ type="SegDomeClass29Dataset",
225
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/cca_segmentation_33_train:2024092400.json",
226
+ )
227
+
228
+ dataset_ca3_zoom_train = dict(
229
+ type="SegShutterstockClass29Dataset",
230
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_zoom_in_body_segmentation_50_train:2024091700.json",
231
+ )
232
+
233
+ dataset_lighticon_train = dict(
234
+ type="SegShutterstockClass29Dataset",
235
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/lighticon_lightful_body_segmentation_51_train:2025021900.json",
236
+ )
237
+
238
+ dataset_internal_train = dict(
239
+ type="SegInternalClass29Dataset",
240
+ # ann_file=f"{_DATA_ROOT}/annotations/stylized_sapiens/20250807/Internal_segmentation_32:2025080700.json",
241
+ ann_file=f"{_DATA_ROOT}/annotations/internal_dataset/20251103/internal_keypoint_344_segmentation_32_train:2025091500.json",
242
+ )
243
+
244
+ train_datasets = [
245
+ dataset_dome_train,
246
+ dataset_ca3_wide_train,
247
+ dataset_caa_train,
248
+ dataset_ca3_zoom_train,
249
+ dataset_lighticon_train,
250
+ dataset_internal_train,
251
+ ] + 2 * [dataset_shutterstock_train]
252
+
253
+ train_dataloader = dict(
254
+ batch_size=1,
255
+ num_workers=4,
256
+ persistent_workers=True,
257
+ shuffle=True,
258
+ dataset=dict(
259
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
260
+ ),
261
+ )
262
+
263
+ val_dataloader = dict(
264
+ batch_size=1,
265
+ num_workers=4,
266
+ persistent_workers=True,
267
+ multiprocessing_context="spawn", ## avoids fork error with airstore
268
+ shuffle=False,
269
+ dataset=dict(
270
+ type="SegShutterstockClass29Dataset",
271
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_test:2024121600.json",
272
+ test_mode=True,
273
+ pipeline=val_pipeline,
274
+ ),
275
+ collate_fn=dict(type="eval_collate"),
276
+ )
277
+
278
+ val_cfg = dict(
279
+ val_interval=val_every_iters,
280
+ evaluator=dict(type="SegEvaluator", class_names="dome29", nan_to_num=0.0),
281
+ )
282
+
283
+ data_preprocessor = dict(
284
+ type="ImagePreprocessor",
285
+ mean=[123.675, 116.28, 103.53],
286
+ std=[58.395, 57.12, 57.375],
287
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
288
+ )
289
+
290
+ ##-----------------------------------------------------------------
291
+ model = dict(
292
+ type="SegEstimator",
293
+ backbone=dict(
294
+ type="Sapiens2",
295
+ arch=model_name,
296
+ img_size=image_size,
297
+ patch_size=patch_size,
298
+ final_norm=True,
299
+ use_tokenizer=False,
300
+ with_cls_token=True,
301
+ out_type="featmap",
302
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
303
+ ),
304
+ decode_head=dict(
305
+ type="SegHead",
306
+ in_channels=embed_dim,
307
+ deconv_out_channels=(
308
+ 512,
309
+ 256,
310
+ 128,
311
+ 64,
312
+ ), ## this will 2x at each step. so total is 16x. 1K output.
313
+ deconv_kernel_sizes=(4, 4, 4, 4),
314
+ conv_out_channels=(64, 64),
315
+ conv_kernel_sizes=(1, 1),
316
+ num_classes=num_classes,
317
+ loss_decode=[
318
+ dict(
319
+ type="CrossEntropyLoss",
320
+ loss_weight=1.0,
321
+ reduction="none",
322
+ class_weight=CLASS_WEIGHT,
323
+ ignore_index=255,
324
+ ),
325
+ dict(
326
+ type="DiceLoss",
327
+ loss_weight=1.0,
328
+ reduction="none",
329
+ activate=True,
330
+ use_sigmoid=False,
331
+ include_background=False,
332
+ ignore_index=255,
333
+ ),
334
+ ],
335
+ ),
336
+ )
337
+
338
+
339
+ ##-----------------------------------------------------------------
340
+ optimizer = dict(
341
+ type="AdamW",
342
+ lr=5e-4,
343
+ betas=(0.9, 0.999),
344
+ weight_decay=0.1,
345
+ paramwise_cfg=dict(
346
+ num_layers=num_layers,
347
+ layer_decay_rate=layer_decay_rate,
348
+ ),
349
+ fused=True,
350
+ )
351
+
352
+ scheduler = dict(
353
+ type="SequentialLR",
354
+ milestones=[warmup_iters],
355
+ schedulers=[
356
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
357
+ dict(
358
+ type="PolynomialLR",
359
+ total_iters=num_iters - warmup_iters,
360
+ power=1.0,
361
+ ),
362
+ ],
363
+ )
364
+
365
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
sapiens/dense/scripts/albedo/train/sapiens2_0.4b/node.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ cd "$(dirname "$(realpath "$0")")/../../../.." || exit
4
+
5
+ #-------------------------------------------------------------------------------
6
+ DEVICES=0,1,2,3,4,5,6,7
7
+ # DEVICES=0
8
+
9
+ #-------------------------------------------------------------------------------
10
+ TASK="albedo"
11
+ DATASET="render_people"
12
+ MODEL="sapiens2_0.4b_${TASK}_${DATASET}-1024x768"
13
+
14
+ CONFIG_FILE="configs/${TASK}/$DATASET/${MODEL}.py"
15
+ TRAIN_BATCH_SIZE_PER_GPU=20
16
+
17
+ #-------------------------------------------------------------------------------
18
+ # mode='debug'
19
+ mode='multi-gpu'
20
+
21
+ #-------------------------------------------------------------------------------
22
+ OUTPUT_DIR="Outputs/${TASK}/train/${MODEL}/node"
23
+ OUTPUT_DIR="$(echo "${OUTPUT_DIR}/$(date +"%m-%d-%Y_%H:%M:%S")")"
24
+
25
+ #-------------------------------------------------------------------------------
26
+ OPTIONS="train_dataloader.batch_size=$TRAIN_BATCH_SIZE_PER_GPU"
27
+ OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}"
28
+ CMD_RESUME="${RESUME_FROM:+--resume $RESUME_FROM}"
29
+
30
+ export TF_CPP_MIN_LOG_LEVEL=2
31
+ PORT=$(( ((RANDOM<<15)|RANDOM) % 63001 + 2000 ))
32
+
33
+ #-------------------------------------------------------------------------------
34
+ if [ "$mode" = "debug" ]; then
35
+ export TORCH_DISTRIBUTED_DEBUG=DETAIL
36
+ TRAIN_BATCH_SIZE_PER_GPU=1
37
+ OPTIONS="train_dataloader.batch_size=${TRAIN_BATCH_SIZE_PER_GPU} train_dataloader.num_workers=0 train_dataloader.persistent_workers=False"
38
+ OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}"
39
+
40
+ CUDA_VISIBLE_DEVICES=${DEVICES} python tools/train.py ${CONFIG_FILE} \
41
+ --work-dir ${OUTPUT_DIR} \
42
+ --cfg-options ${OPTIONS} \
43
+ ${CMD_RESUME}
44
+
45
+ elif [ "$mode" = "multi-gpu" ]; then
46
+ NUM_GPUS=$(echo $DEVICES | tr -s ',' ' ' | wc -w)
47
+
48
+ LOG_FILE="${OUTPUT_DIR}/log.txt"
49
+ mkdir -p ${OUTPUT_DIR}
50
+ touch ${LOG_FILE}
51
+
52
+ CUDA_VISIBLE_DEVICES=${DEVICES} PORT=${PORT} 'tools/dist_train.sh' ${CONFIG_FILE} \
53
+ ${NUM_GPUS} \
54
+ --work-dir ${OUTPUT_DIR} \
55
+ --cfg-options ${OPTIONS} \
56
+ ${CMD_RESUME} \
57
+ | tee ${LOG_FILE}
58
+ fi
sapiens/dense/scripts/albedo/train/sapiens2_0.8b/node.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ cd "$(dirname "$(realpath "$0")")/../../../.." || exit
4
+
5
+ #-------------------------------------------------------------------------------
6
+ DEVICES=0,1,2,3,4,5,6,7
7
+ # DEVICES=0
8
+
9
+ #-------------------------------------------------------------------------------
10
+ TASK="albedo"
11
+ DATASET="render_people"
12
+ MODEL="sapiens2_0.8b_${TASK}_${DATASET}-1024x768"
13
+
14
+ CONFIG_FILE="configs/${TASK}/$DATASET/${MODEL}.py"
15
+ TRAIN_BATCH_SIZE_PER_GPU=12
16
+ LOAD_FROM=''
17
+
18
+ #-------------------------------------------------------------------------------
19
+ # mode='debug'
20
+ mode='multi-gpu'
21
+
22
+ #-------------------------------------------------------------------------------
23
+ OUTPUT_DIR="Outputs/${TASK}/train/${MODEL}/node"
24
+ OUTPUT_DIR="$(echo "${OUTPUT_DIR}/$(date +"%m-%d-%Y_%H:%M:%S")")"
25
+
26
+ #-------------------------------------------------------------------------------
27
+ OPTIONS="train_dataloader.batch_size=$TRAIN_BATCH_SIZE_PER_GPU"
28
+ OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}"
29
+ CMD_RESUME="${RESUME_FROM:+--resume $RESUME_FROM}"
30
+
31
+ export TF_CPP_MIN_LOG_LEVEL=2
32
+ PORT=$(( ((RANDOM<<15)|RANDOM) % 63001 + 2000 ))
33
+
34
+ #-------------------------------------------------------------------------------
35
+ if [ "$mode" = "debug" ]; then
36
+ export TORCH_DISTRIBUTED_DEBUG=DETAIL
37
+ TRAIN_BATCH_SIZE_PER_GPU=1
38
+ OPTIONS="train_dataloader.batch_size=${TRAIN_BATCH_SIZE_PER_GPU} train_dataloader.num_workers=0 train_dataloader.persistent_workers=False"
39
+ OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}"
40
+
41
+ CUDA_VISIBLE_DEVICES=${DEVICES} python tools/train.py ${CONFIG_FILE} \
42
+ --work-dir ${OUTPUT_DIR} \
43
+ --cfg-options ${OPTIONS} \
44
+ ${CMD_RESUME}
45
+
46
+ elif [ "$mode" = "multi-gpu" ]; then
47
+ NUM_GPUS=$(echo $DEVICES | tr -s ',' ' ' | wc -w)
48
+
49
+ LOG_FILE="${OUTPUT_DIR}/log.txt"
50
+ mkdir -p ${OUTPUT_DIR}
51
+ touch ${LOG_FILE}
52
+
53
+ CUDA_VISIBLE_DEVICES=${DEVICES} PORT=${PORT} 'tools/dist_train.sh' ${CONFIG_FILE} \
54
+ ${NUM_GPUS} \
55
+ --work-dir ${OUTPUT_DIR} \
56
+ --cfg-options ${OPTIONS} \
57
+ ${CMD_RESUME} \
58
+ | tee ${LOG_FILE}
59
+ fi
sapiens/dense/scripts/albedo/train/sapiens2_1b/node.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ cd "$(dirname "$(realpath "$0")")/../../../.." || exit
4
+
5
+ #-------------------------------------------------------------------------------
6
+ DEVICES=0,1,2,3,4,5,6,7
7
+ # DEVICES=0
8
+
9
+ #-------------------------------------------------------------------------------
10
+ TASK="albedo"
11
+ DATASET="render_people"
12
+ MODEL="sapiens2_1b_${TASK}_${DATASET}-1024x768"
13
+
14
+ CONFIG_FILE="configs/${TASK}/$DATASET/${MODEL}.py"
15
+ TRAIN_BATCH_SIZE_PER_GPU=7
16
+
17
+
18
+ #-------------------------------------------------------------------------------
19
+ # mode='debug'
20
+ mode='multi-gpu'
21
+
22
+ #-------------------------------------------------------------------------------
23
+ OUTPUT_DIR="Outputs/${TASK}/train/${MODEL}/node"
24
+ OUTPUT_DIR="$(echo "${OUTPUT_DIR}/$(date +"%m-%d-%Y_%H:%M:%S")")"
25
+
26
+ #-------------------------------------------------------------------------------
27
+ OPTIONS="train_dataloader.batch_size=$TRAIN_BATCH_SIZE_PER_GPU"
28
+ OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}"
29
+ CMD_RESUME="${RESUME_FROM:+--resume $RESUME_FROM}"
30
+
31
+ export TF_CPP_MIN_LOG_LEVEL=2
32
+ PORT=$(( ((RANDOM<<15)|RANDOM) % 63001 + 2000 ))
33
+
34
+ #-------------------------------------------------------------------------------
35
+ if [ "$mode" = "debug" ]; then
36
+ export TORCH_DISTRIBUTED_DEBUG=DETAIL
37
+ TRAIN_BATCH_SIZE_PER_GPU=1
38
+ OPTIONS="train_dataloader.batch_size=${TRAIN_BATCH_SIZE_PER_GPU} train_dataloader.num_workers=0 train_dataloader.persistent_workers=False"
39
+ OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}"
40
+
41
+ CUDA_VISIBLE_DEVICES=${DEVICES} python tools/train.py ${CONFIG_FILE} \
42
+ --work-dir ${OUTPUT_DIR} \
43
+ --cfg-options ${OPTIONS} \
44
+ ${CMD_RESUME}
45
+
46
+ elif [ "$mode" = "multi-gpu" ]; then
47
+ NUM_GPUS=$(echo $DEVICES | tr -s ',' ' ' | wc -w)
48
+
49
+ LOG_FILE="${OUTPUT_DIR}/log.txt"
50
+ mkdir -p ${OUTPUT_DIR}
51
+ touch ${LOG_FILE}
52
+
53
+ CUDA_VISIBLE_DEVICES=${DEVICES} PORT=${PORT} 'tools/dist_train.sh' ${CONFIG_FILE} \
54
+ ${NUM_GPUS} \
55
+ --work-dir ${OUTPUT_DIR} \
56
+ --cfg-options ${OPTIONS} \
57
+ ${CMD_RESUME} \
58
+ | tee ${LOG_FILE}
59
+ fi
sapiens/dense/scripts/albedo/train/sapiens2_5b/node.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ cd "$(dirname "$(realpath "$0")")/../../../.." || exit
4
+
5
+ #-------------------------------------------------------------------------------
6
+ DEVICES=0,1,2,3,4,5,6,7
7
+ # DEVICES=0
8
+
9
+ #-------------------------------------------------------------------------------
10
+ TASK="albedo"
11
+ DATASET="render_people"
12
+ MODEL="sapiens2_5b_${TASK}_${DATASET}-1024x768"
13
+
14
+ CONFIG_FILE="configs/${TASK}/$DATASET/${MODEL}.py"
15
+ TRAIN_BATCH_SIZE_PER_GPU=3
16
+
17
+ # LOAD_FROM=""
18
+
19
+ #-------------------------------------------------------------------------------
20
+ # mode='debug'
21
+ mode='multi-gpu'
22
+
23
+ #-------------------------------------------------------------------------------
24
+ OUTPUT_DIR="Outputs/${TASK}/train/${MODEL}/node"
25
+ OUTPUT_DIR="$(echo "${OUTPUT_DIR}/$(date +"%m-%d-%Y_%H:%M:%S")")"
26
+
27
+ #-------------------------------------------------------------------------------
28
+ OPTIONS="train_dataloader.batch_size=$TRAIN_BATCH_SIZE_PER_GPU"
29
+ OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}"
30
+ CMD_RESUME="${RESUME_FROM:+--resume $RESUME_FROM}"
31
+
32
+ export TF_CPP_MIN_LOG_LEVEL=2
33
+ PORT=$(( ((RANDOM<<15)|RANDOM) % 63001 + 2000 ))
34
+
35
+ #-------------------------------------------------------------------------------
36
+ if [ "$mode" = "debug" ]; then
37
+ export TORCH_DISTRIBUTED_DEBUG=DETAIL
38
+ TRAIN_BATCH_SIZE_PER_GPU=1
39
+ OPTIONS="train_dataloader.batch_size=${TRAIN_BATCH_SIZE_PER_GPU} train_dataloader.num_workers=0 train_dataloader.persistent_workers=False"
40
+ OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}"
41
+
42
+ CUDA_VISIBLE_DEVICES=${DEVICES} python tools/train.py ${CONFIG_FILE} \
43
+ --work-dir ${OUTPUT_DIR} \
44
+ --cfg-options ${OPTIONS} \
45
+ ${CMD_RESUME}
46
+
47
+ elif [ "$mode" = "multi-gpu" ]; then
48
+ NUM_GPUS=$(echo $DEVICES | tr -s ',' ' ' | wc -w)
49
+
50
+ LOG_FILE="${OUTPUT_DIR}/log.txt"
51
+ mkdir -p ${OUTPUT_DIR}
52
+ touch ${LOG_FILE}
53
+
54
+ CUDA_VISIBLE_DEVICES=${DEVICES} PORT=${PORT} 'tools/dist_train.sh' ${CONFIG_FILE} \
55
+ ${NUM_GPUS} \
56
+ --work-dir ${OUTPUT_DIR} \
57
+ --cfg-options ${OPTIONS} \
58
+ ${CMD_RESUME} \
59
+ | tee ${LOG_FILE}
60
+ fi