File size: 4,617 Bytes
0232680 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | # https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/prompt_encoder.py
from typing import Optional, Union
import numpy as np
import torch
from torch import nn
from .common import PatchEncoder, group_with_centers_and_knn
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((3, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [-1,1]."""
# assuming coords are in [-1, 1] and have d_1 x ... x d_n x D shape
coords = coords @ self.positional_encoding_gaussian_matrix
# TODO: Why using 2 * np.pi?
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, coords: torch.Tensor) -> torch.Tensor:
"""
Args:
coords: shape (..., coord_dim), normalized coordinates in [-1, 1].
Returns:
torch.Tensor: shape (..., num_pos_feats), positional encoding.
"""
if (coords < -1 - 1e-6).any() or (coords > 1 + 1e-6).any():
print("Bounds: ", (coords.min(), coords.max()))
raise ValueError(f"Input coordinates must be normalized to [-1, 1].")
# TODO: whether to convert to float?
return self._pe_encoding(coords)
class PointEncoder(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.embed_dim = embed_dim
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 2 # pos/neg point
point_embeddings = [
nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)
]
self.point_embeddings = nn.ModuleList(point_embeddings)
def forward(self, points: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Embeds point prompts.
Args:
points: [..., 3]. Point coordinates.
labels: [...], integer (or boolean). Point labels.
Returns:
torch.Tensor: [..., embed_dim]. Embedded points.
"""
assert points.shape[:-1] == labels.shape
point_embedding = self.pe_layer.forward(points)
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
class MaskEncoder(nn.Module):
def __init__(
self,
embed_dim,
in_channels=4,
radius=None,
centralize_features=False,
):
super().__init__()
self.embed_dim = embed_dim
self.in_channels = in_channels # (x, y, z, logit)
self.radius = radius
self.centralize_features = centralize_features
self.patch_encoder = PatchEncoder(in_channels, embed_dim, [128, 512])
self.no_mask_embed = nn.Embedding(1, embed_dim)
def forward(
self,
masks: Union[torch.Tensor, None],
coords: torch.Tensor,
centers: torch.Tensor,
knn_idx: torch.Tensor,
center_idx: torch.Tensor = None,
) -> torch.Tensor:
"""Embeds mask inputs.
Args:
masks: [B * M, N], float. Mask inputs.
coords: [B, N, 3]. Point coordinates.
centers: [B, L, 3]. Center coordinates.
knn_idx: [B, L, K]. KNN indices.
center_idx: [B, L]. Index of center in the point cloud.
Returns:
torch.Tensor: [B * M, L, embed_dim]. Dense embeddings.
"""
if masks is None:
dense_embeddings = self.no_mask_embed.weight.reshape(1, 1, -1).expand(
centers.shape[0], centers.shape[1], -1
)
else:
masks = masks.detach()
patches = group_with_centers_and_knn(
coords,
masks.unsqueeze(-1),
centers,
knn_idx,
radius=self.radius,
center_idx=center_idx,
centralize_features=self.centralize_features,
)
dense_embeddings = self.patch_encoder(patches)
return dense_embeddings
|