Rawal Khirodkar
Initial sapiens2-normal Space (HF download at startup, all 4 sizes)
ba23d94
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from sapiens.engine.models.base_model import BaseModel
from sapiens.registry import MODELS
from torch.nn import Linear, Sequential
# ----------------------------------------------------------------------------
def to_2tuple(x):
if isinstance(x, (str, bytes)):
return (x, x)
if isinstance(x, Sequence):
x = tuple(x)
if len(x) == 2:
return x
raise ValueError("Expected scalar or length-2 iterable")
return (x, x)
def resize_pos_embed(
pos_embed, src_shape, dst_shape, mode="bicubic", num_extra_tokens=1
):
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
return pos_embed
assert pos_embed.ndim == 3, "shape of pos_embed must be [1, L, C]"
_, L, C = pos_embed.shape
src_h, src_w = src_shape
assert L == src_h * src_w + num_extra_tokens, (
f"The length of `pos_embed` ({L}) doesn't match the expected "
f"shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the"
"`img_size` argument."
)
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
# The cubic interpolate algorithm only accepts float32
dst_weight = F.interpolate(
src_weight.float(), size=dst_shape, align_corners=False, mode=mode
)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
dst_weight = dst_weight.to(src_weight.dtype)
return torch.cat((extra_tokens, dst_weight), dim=1)
# ----------------------------------------------------------------------------
class PatchEmbed(nn.Module):
def __init__(
self,
in_channels=3,
embed_dims=768,
kernel_size=16,
stride=16,
padding="corner",
dilation=1,
bias=True,
input_size=None,
):
super().__init__()
self.embed_dims = embed_dims
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
padding = 0
padding = to_2tuple(padding)
self.projection = nn.Conv2d(
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
if input_size:
input_size = to_2tuple(input_size)
self.init_input_size = input_size
h_out = (
input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
) // stride[0] + 1
w_out = (
input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
) // stride[1] + 1
self.init_out_size = (h_out, w_out)
else:
self.init_input_size = None
self.init_out_size = None
def forward(self, x):
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
x = x.flatten(2).transpose(1, 2)
return x, out_size
# ----------------------------------------------------------------------------
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
inplace: bool = False,
data_format: str = "channels_last",
scale: float = 1e-5,
):
super().__init__()
assert data_format in (
"channels_last",
"channels_first",
), "'data_format' could only be channels_last or channels_first."
self.inplace = inplace
self.data_format = data_format
self.weight = nn.Parameter(torch.ones(dim) * scale)
def forward(self, x) -> torch.Tensor:
if self.data_format == "channels_first":
shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
else:
shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
if self.inplace:
return x.mul_(self.weight.view(*shape))
else:
return x * self.weight.view(*shape)
# ----------------------------------------------------------------------------
class FFN(nn.Module):
def __init__(
self,
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
add_identity=True,
layer_scale_init_value=0.0,
):
super().__init__()
assert num_fcs >= 2, f"num_fcs should be no less than 2. got {num_fcs}."
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
Sequential(
Linear(in_channels, feedforward_channels),
nn.GELU(),
nn.Dropout(ffn_drop),
)
)
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.layers = Sequential(*layers)
self.dropout_layer = nn.Identity()
self.add_identity = add_identity
if layer_scale_init_value > 0:
self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value)
else:
self.gamma2 = nn.Identity()
def forward(self, x, identity=None):
out = self.layers(x)
out = self.gamma2(out)
if not self.add_identity:
return out
if identity is None:
identity = x
return identity + out
# ----------------------------------------------------------------------------
class MultiheadAttention(nn.Module):
def __init__(
self,
embed_dims,
num_heads,
input_dims=None,
attn_drop=0.0,
proj_drop=0.0,
qkv_bias=True,
proj_bias=True,
v_shortcut=False,
):
super(MultiheadAttention, self).__init__()
self.input_dims = input_dims or embed_dims
self.embed_dims = embed_dims
self.num_heads = num_heads
self.v_shortcut = v_shortcut
self.head_dims = embed_dims // num_heads
self.scaled_dot_product_attention = F.scaled_dot_product_attention
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.gamma1 = nn.Identity()
def forward(self, x):
B, N, _ = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dims)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_drop = self.attn_drop if self.training else 0.0
x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
x = self.proj(x)
x = self.gamma1(self.proj_drop(x))
if self.v_shortcut:
x = v.squeeze(1) + x
return x
# ----------------------------------------------------------------------------
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.0,
attn_drop_rate=0.0,
num_fcs=2,
qkv_bias=True,
):
super(TransformerEncoderLayer, self).__init__()
self.embed_dims = embed_dims
self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
qkv_bias=qkv_bias,
)
self.ln2 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
add_identity=True,
)
@property
def norm1(self):
return self.ln1
@property
def norm2(self):
return self.ln2
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln2(x), identity=x)
return x
# ----------------------------------------------------------------------------
@MODELS.register_module()
class Sapiens(BaseModel):
arch_zoo = {
**dict.fromkeys( ## this is vit-large
["0.3b", "sapiens_0.3b"],
{
"embed_dims": 1024,
"num_layers": 24,
"num_heads": 16,
"feedforward_channels": 1024 * 4,
},
),
**dict.fromkeys( ## this is vit-huge
["0.6b", "sapiens_0.6b"],
{
"embed_dims": 1280,
"num_layers": 32,
"num_heads": 16,
"feedforward_channels": 1280 * 4,
},
),
**dict.fromkeys( ## this is vit-g
["1b", "sapiens_1b"],
{
"embed_dims": 1536,
"num_layers": 40,
"num_heads": 24,
"feedforward_channels": 1536 * 4,
},
),
**dict.fromkeys(
["2b", "sapiens_2b"],
{
"embed_dims": 1920,
"num_layers": 48,
"num_heads": 32,
"feedforward_channels": 1920 * 4,
},
),
}
num_extra_tokens = 1 # class token
OUT_TYPES = {"raw", "cls_token", "featmap"}
def __init__(
self,
arch="base",
img_size=1024,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.0,
qkv_bias=True,
final_norm=True,
out_type="cls_token",
with_cls_token=True,
frozen_stages=-1,
interpolate_mode="bicubic",
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None,
):
super(Sapiens, self).__init__(init_cfg=init_cfg)
arch = arch.lower()
assert arch in set(self.arch_zoo), (
f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
)
self.arch_settings = self.arch_zoo[arch]
self.embed_dims = self.arch_settings["embed_dims"]
self.num_layers = self.arch_settings["num_layers"]
self.img_size = to_2tuple(img_size)
self.patch_size = patch_size
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
kernel_size=patch_size,
stride=patch_size,
bias=True,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(
f"Unsupported `out_type` {out_type}, please "
f"choose from {self.OUT_TYPES}"
)
self.out_type = out_type
# Set cls token
self.with_cls_token = with_cls_token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
elif out_type != "cls_token":
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError('with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_extra_tokens, self.embed_dims)
)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), (
f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.'
)
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, (
f"Invalid out_indices {index}"
)
self.out_indices = out_indices
self.layers = nn.Sequential()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings["num_heads"],
feedforward_channels=self.arch_settings["feedforward_channels"],
drop_rate=drop_rate,
qkv_bias=qkv_bias,
)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
self.frozen_stages = frozen_stages
self.pre_norm = nn.Identity()
self.final_norm = final_norm
if final_norm:
self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
self._freeze_stages()
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.init_weights()
return
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + "pos_embed"
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
from sapiens.engine.logger import Logger
logger = Logger.get_current_instance()
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
# Handle class token removal if needed
if not self.with_cls_token:
if ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1:
# Remove cls token from state dict if it's not used
state_dict[name] = state_dict[name][:, 1:]
ckpt_pos_embed_shape = state_dict[name].shape
elif ckpt_pos_embed_shape[1] % 2 == 1:
# Remove class token when interpolation is required
if rank == 0:
logger.info(
"Note: removing the class token from pretrained weights"
)
state_dict[name] = state_dict[name][:, 1:]
ckpt_pos_embed_shape = state_dict[name].shape
# Skip if shapes already match
if self.pos_embed.shape == ckpt_pos_embed_shape:
return
if rank == 0:
logger.info(
f"Resize the pos_embed shape from {ckpt_pos_embed_shape} "
f"to {self.pos_embed.shape}."
)
# Calculate grid dimensions
pos_h, pos_w = self.patch_embed.init_out_size
assert pos_h >= pos_w # for vertical aspect ratio or square
# Number of non-extra tokens in checkpoint
num_vis = ckpt_pos_embed_shape[1] - self.num_extra_tokens
# Determine original grid shape
side = int(math.sqrt(num_vis))
factor = int(math.sqrt((num_vis * self.patch_size * self.patch_size) // 12))
# Set old grid based on aspect ratio detection
if side * side == num_vis:
old_grid = (side, side) # square grid
elif 4 * factor * 3 * factor == num_vis * self.patch_size * self.patch_size:
old_grid = (
(factor * 4) // self.patch_size,
(factor * 3) // self.patch_size,
) # 4:3 ratio
else:
if rank == 0:
logger.warning(
f"Original pos_embed tokens ({num_vis}) not square or 4:3 does not match current size"
)
state_dict[name] = self.pos_embed
return
# Resize position embedding
new_grid = (pos_h, pos_w)
state_dict[name] = resize_pos_embed(
state_dict[name],
old_grid,
new_grid,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens,
)
@property
def norm1(self):
return self.ln1
@property
def norm2(self):
return self.ln2
@staticmethod
def resize_pos_embed(*args, **kwargs):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)
def _freeze_stages(self):
# freeze position embedding
if self.pos_embed is not None:
self.pos_embed.requires_grad = False
# set dropout to eval model
self.drop_after_pos.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze pre-norm
for param in self.pre_norm.parameters():
param.requires_grad = False
# freeze cls_token
if self.cls_token is not None:
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers):
if self.final_norm:
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False
if self.out_type == "avg_featmap":
self.ln2.eval()
for param in self.ln2.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
if self.cls_token is not None:
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens,
)
x = self.drop_after_pos(x)
x = self.pre_norm(x) ## B x (num tokens) x embed_dim
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.ln1(x)
if i in self.out_indices:
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == "raw":
return x
if self.out_type == "cls_token":
return x[:, 0]
patch_token = x[:, self.num_extra_tokens :]
if self.out_type == "featmap":
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)