bdck commited on
Commit
0232680
·
verified ·
1 Parent(s): 8847f53

Upload point_sam/model/prompt_encoder.py

Browse files
Files changed (1) hide show
  1. point_sam/model/prompt_encoder.py +131 -0
point_sam/model/prompt_encoder.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/prompt_encoder.py
2
+ from typing import Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .common import PatchEncoder, group_with_centers_and_knn
9
+
10
+
11
+ class PositionEmbeddingRandom(nn.Module):
12
+ """
13
+ Positional encoding using random spatial frequencies.
14
+ """
15
+
16
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
17
+ super().__init__()
18
+ if scale is None or scale <= 0.0:
19
+ scale = 1.0
20
+ self.register_buffer(
21
+ "positional_encoding_gaussian_matrix",
22
+ scale * torch.randn((3, num_pos_feats)),
23
+ )
24
+
25
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
26
+ """Positionally encode points that are normalized to [-1,1]."""
27
+ # assuming coords are in [-1, 1] and have d_1 x ... x d_n x D shape
28
+ coords = coords @ self.positional_encoding_gaussian_matrix
29
+ # TODO: Why using 2 * np.pi?
30
+ coords = 2 * np.pi * coords
31
+ # outputs d_1 x ... x d_n x C shape
32
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
33
+
34
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
35
+ """
36
+ Args:
37
+ coords: shape (..., coord_dim), normalized coordinates in [-1, 1].
38
+
39
+ Returns:
40
+ torch.Tensor: shape (..., num_pos_feats), positional encoding.
41
+ """
42
+ if (coords < -1 - 1e-6).any() or (coords > 1 + 1e-6).any():
43
+ print("Bounds: ", (coords.min(), coords.max()))
44
+ raise ValueError(f"Input coordinates must be normalized to [-1, 1].")
45
+ # TODO: whether to convert to float?
46
+ return self._pe_encoding(coords)
47
+
48
+
49
+ class PointEncoder(nn.Module):
50
+ def __init__(self, embed_dim: int):
51
+ super().__init__()
52
+ self.embed_dim = embed_dim
53
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
54
+
55
+ self.num_point_embeddings: int = 2 # pos/neg point
56
+ point_embeddings = [
57
+ nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)
58
+ ]
59
+ self.point_embeddings = nn.ModuleList(point_embeddings)
60
+
61
+ def forward(self, points: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
62
+ """Embeds point prompts.
63
+
64
+ Args:
65
+ points: [..., 3]. Point coordinates.
66
+ labels: [...], integer (or boolean). Point labels.
67
+
68
+ Returns:
69
+ torch.Tensor: [..., embed_dim]. Embedded points.
70
+ """
71
+ assert points.shape[:-1] == labels.shape
72
+ point_embedding = self.pe_layer.forward(points)
73
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
74
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
75
+ return point_embedding
76
+
77
+
78
+ class MaskEncoder(nn.Module):
79
+ def __init__(
80
+ self,
81
+ embed_dim,
82
+ in_channels=4,
83
+ radius=None,
84
+ centralize_features=False,
85
+ ):
86
+ super().__init__()
87
+ self.embed_dim = embed_dim
88
+ self.in_channels = in_channels # (x, y, z, logit)
89
+ self.radius = radius
90
+ self.centralize_features = centralize_features
91
+
92
+ self.patch_encoder = PatchEncoder(in_channels, embed_dim, [128, 512])
93
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
94
+
95
+ def forward(
96
+ self,
97
+ masks: Union[torch.Tensor, None],
98
+ coords: torch.Tensor,
99
+ centers: torch.Tensor,
100
+ knn_idx: torch.Tensor,
101
+ center_idx: torch.Tensor = None,
102
+ ) -> torch.Tensor:
103
+ """Embeds mask inputs.
104
+
105
+ Args:
106
+ masks: [B * M, N], float. Mask inputs.
107
+ coords: [B, N, 3]. Point coordinates.
108
+ centers: [B, L, 3]. Center coordinates.
109
+ knn_idx: [B, L, K]. KNN indices.
110
+ center_idx: [B, L]. Index of center in the point cloud.
111
+
112
+ Returns:
113
+ torch.Tensor: [B * M, L, embed_dim]. Dense embeddings.
114
+ """
115
+ if masks is None:
116
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, 1, -1).expand(
117
+ centers.shape[0], centers.shape[1], -1
118
+ )
119
+ else:
120
+ masks = masks.detach()
121
+ patches = group_with_centers_and_knn(
122
+ coords,
123
+ masks.unsqueeze(-1),
124
+ centers,
125
+ knn_idx,
126
+ radius=self.radius,
127
+ center_idx=center_idx,
128
+ centralize_features=self.centralize_features,
129
+ )
130
+ dense_embeddings = self.patch_encoder(patches)
131
+ return dense_embeddings