| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class PatchEmbed(nn.Module): |
| def __init__(self, fs: int = 200, patch_seconds: float = 1.0, overlap_seconds: float = 0.1, embed_dim: int = 512): |
| super().__init__() |
|
|
| self.patch_size = int(round(patch_seconds * fs)) |
| self.overlap_size = int(round(overlap_seconds * fs)) |
|
|
| self.step = self.patch_size - self.overlap_size |
|
|
| self.linear = nn.Linear(self.patch_size, embed_dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| patches = x.unfold(dimension=-1, size=self.patch_size, step=self.step) |
| return self.linear(patches) |
|
|
|
|
| class PosEnc(nn.Module): |
| def __init__(self, n_freqs: int = 4, embed_dim: int = 512): |
| super().__init__() |
|
|
| freqs = torch.linspace(1.0, 10.0, n_freqs) |
| self.register_buffer("freq_matrix", torch.cartesian_prod(freqs, freqs, freqs, freqs).transpose(1, 0)) |
|
|
| fourier_features_dim = 2 * (n_freqs**4) |
|
|
| self.fourier_linear = nn.Linear(fourier_features_dim, embed_dim, bias=False) |
| self.learned_linear = nn.Sequential(nn.Linear(4, embed_dim, bias=False), nn.GELU(), nn.LayerNorm(embed_dim)) |
|
|
| self.final_norm = nn.LayerNorm(embed_dim) |
|
|
| def forward(self, coords: torch.Tensor): |
| phases = torch.matmul(coords, self.freq_matrix) |
|
|
| fourier_features = torch.cat([torch.sin(phases), torch.cos(phases)], -1) |
| fourier_emb = self.fourier_linear(fourier_features) |
|
|
| learned_emb = self.learned_linear(coords) |
|
|
| return self.final_norm(fourier_emb + learned_emb) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, embed_dim: int, heads: int, dropout: float = 0.0): |
| super().__init__() |
|
|
| assert embed_dim % heads == 0, "dim must be divisible by heads" |
|
|
| self.pre_attn_norm = nn.LayerNorm(embed_dim) |
| self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=heads, dropout=dropout, batch_first=True) |
|
|
| self.pre_ffn_norm = nn.LayerNorm(embed_dim) |
| self.ffn = nn.Sequential(nn.Linear(embed_dim, 4 * embed_dim), nn.GELU(), nn.Linear(4 * embed_dim, embed_dim)) |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| attn_in = self.pre_attn_norm(x) |
|
|
| attn_out, _ = self.attn(attn_in, attn_in, attn_in) |
| x = x + attn_out |
|
|
| ffn_in = self.pre_ffn_norm(x) |
|
|
| ffn_out = self.ffn(ffn_in) |
| x = x + ffn_out |
|
|
| return x, ffn_out |
|
|
|
|
| class TransformerEncoderDecoder(nn.Module): |
| def __init__(self, embed_dim: int = 512, depth: int = 16, heads: int = 8): |
| super().__init__() |
|
|
| self.layers = nn.ModuleList([TransformerBlock(embed_dim, heads) for _ in range(depth)]) |
| self.final_norm = nn.LayerNorm(embed_dim) |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: |
| intermediate = [] |
|
|
| for layer in self.layers: |
| x, ffn_out = layer(x) |
| intermediate.append(ffn_out) |
|
|
| return self.final_norm(x), intermediate |
|
|
|
|
| class MAEDecoder(nn.Module): |
| def __init__(self, embed_dim: int = 512, decoder_depth: int = 4, decoder_heads: int = 8, patch_size: int = 200): |
| super().__init__() |
|
|
| |
| |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| nn.init.normal_(self.mask_token, std=0.02) |
|
|
| |
| |
| self.decoder = TransformerEncoderDecoder(embed_dim=embed_dim, depth=decoder_depth, heads=decoder_heads) |
|
|
| |
| |
| self.predict = nn.Linear(embed_dim, patch_size, bias=True) |
|
|
| def forward(self, x_visible: torch.Tensor, pos_enc: nn.Module, coords: torch.Tensor, mask: torch.Tensor): |
| B, N_Total, D = coords.shape[0], coords.shape[1], x_visible.shape[-1] |
|
|
| |
| |
| x_full = self.mask_token.expand(B, N_Total, D).clone() |
|
|
| |
| |
| for i in range(B): |
| |
| x_full[i, mask[i]] = x_visible[i] |
|
|
| |
| |
| |
| pos_emb = pos_enc(coords) |
|
|
| |
| x_full = x_full + pos_emb |
|
|
| |
| |
| |
| x_decoded, _ = self.decoder(x_full) |
|
|
| |
| |
| prediction = self.predict(x_decoded) |
|
|
| return prediction |
|
|
|
|
| def generate_mask(coords: torch.Tensor, mask_ratio: float = 0.55, spatial_radius: float = 3.0, temporal_radius: float = 3.0): |
| B, N, _ = coords.shape |
| device = coords.device |
|
|
| |
| num_masked_target = int(mask_ratio * N) |
|
|
| |
| mask = torch.ones(B, N, dtype=torch.bool, device=device) |
|
|
| for b in range(B): |
| spatial_coords = coords[b, :, :3] |
| temporal_coords = coords[b, :, 3] |
|
|
| |
| |
| while (~mask[b]).sum() < num_masked_target: |
| |
| seed_idx = torch.randint(0, N, (1,)).item() |
|
|
| |
| seed_spatial = spatial_coords[seed_idx] |
| dists_spatial = torch.norm(spatial_coords - seed_spatial, dim=1) |
|
|
| seed_temporal = temporal_coords[seed_idx] |
| dists_temporal = torch.abs(temporal_coords - seed_temporal) |
|
|
| |
| in_block = (dists_spatial <= spatial_radius) & (dists_temporal <= temporal_radius) |
|
|
| |
| mask[b, in_block] = False |
|
|
| |
| |
|
|
| |
| masked_indices = torch.where(mask[b] == False)[0] |
| num_current_masked = len(masked_indices) |
|
|
| if num_current_masked > num_masked_target: |
| |
| |
| shuffled_indices = masked_indices[torch.randperm(num_current_masked)] |
|
|
| |
| |
| excess_indices = shuffled_indices[num_masked_target:] |
|
|
| mask[b, excess_indices] = True |
|
|
| return mask |
|
|
|
|
| class MAE(nn.Module): |
| def __init__( |
| self, |
| |
| fs: int = 200, |
| patch_seconds: float = 1.0, |
| overlap_seconds: float = 0.1, |
| |
| embed_dim: int = 512, |
| encoder_depth: int = 12, |
| encoder_heads: int = 8, |
| decoder_depth: int = 4, |
| decoder_heads: int = 8, |
| |
| mask_ratio: float = 0.55, |
| aux_loss_weight: float = 0.1, |
| ): |
| super().__init__() |
|
|
| self.embed_dim = embed_dim |
| self.mask_ratio = mask_ratio |
| self.aux_loss_weight = aux_loss_weight |
|
|
| |
| self.patch_embed = PatchEmbed(fs, patch_seconds, overlap_seconds, embed_dim) |
|
|
| |
| self.patch_size = self.patch_embed.patch_size |
| self.step = self.patch_embed.step |
|
|
| |
| self.pos_enc = PosEnc(n_freqs=4, embed_dim=embed_dim) |
|
|
| |
| self.encoder = TransformerEncoderDecoder(embed_dim=embed_dim, depth=encoder_depth, heads=encoder_heads) |
|
|
| |
| self.decoder = MAEDecoder(embed_dim=embed_dim, decoder_depth=decoder_depth, decoder_heads=decoder_heads, patch_size=self.patch_size) |
|
|
| |
| |
| self.aux_dim = encoder_depth * embed_dim |
|
|
| |
| self.aux_query = nn.Parameter(torch.randn(1, 1, self.aux_dim)) |
| nn.init.normal_(self.aux_query, std=0.02) |
|
|
| |
| self.aux_linear = nn.Linear(self.aux_dim, embed_dim, bias=False) |
|
|
| |
| self.aux_predict = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, self.patch_size)) |
|
|
| def prepare_coords(self, xyz: torch.Tensor, num_patches: int): |
| B, C, _ = xyz.shape |
| device = xyz.device |
|
|
| |
| time_idx = torch.arange(num_patches, device=device, dtype=torch.float32) |
|
|
| |
| |
| spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1) |
|
|
| |
| |
| time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1) |
|
|
| |
| coords = torch.cat([spat, time], dim=-1) |
|
|
| |
| return coords.flatten(1, 2) |
|
|
| def forward(self, x: torch.Tensor, xyz: torch.Tensor): |
| B, _, _ = x.shape |
|
|
| |
| |
| patches = x.unfold(-1, self.patch_size, self.step) |
| num_patches = patches.shape[2] |
|
|
| |
| tokens = self.patch_embed.linear(patches) |
|
|
| |
| tokens_flat = tokens.flatten(1, 2) |
| patches_flat = patches.flatten(1, 2) |
|
|
| |
| coords = self.prepare_coords(xyz, num_patches) |
|
|
| |
| |
| mask = generate_mask(coords, mask_ratio=self.mask_ratio) |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| n_vis = mask[0].sum().item() |
|
|
| x_vis = tokens_flat[mask].view(B, n_vis, -1) |
| coords_vis = coords[mask].view(B, n_vis, -1) |
|
|
| |
| pe_vis = self.pos_enc(coords_vis) |
| x_vis = x_vis + pe_vis |
|
|
| |
| x_encoded, intermediates = self.encoder(x_vis) |
|
|
| |
| predictions_main = self.decoder(x_visible=x_encoded, pos_enc=self.pos_enc, coords=coords, mask=mask) |
|
|
| |
| |
| aux_input = torch.cat(intermediates, dim=-1) |
|
|
| |
| |
| |
| attn_scores = torch.matmul(aux_input, self.aux_query.transpose(1, 2)) |
| attn_weights = F.softmax(attn_scores, dim=1) |
|
|
| |
| global_token = torch.sum(attn_weights * aux_input, dim=1, keepdim=True) |
|
|
| |
| global_emb = self.aux_linear(global_token) |
|
|
| |
| |
| |
| n_masked = (~mask[0]).sum().item() |
| coords_masked = coords[~mask].view(B, n_masked, -1) |
|
|
| pe_masked = self.pos_enc(coords_masked) |
|
|
| |
| global_expanded = global_emb.expand(-1, n_masked, -1) |
|
|
| |
| aux_pred_in = global_expanded + pe_masked |
| predictions_aux = self.aux_predict(aux_pred_in) |
|
|
| |
| |
| target_masked = patches_flat[~mask].view(B, n_masked, -1) |
|
|
| |
| pred_main_masked = predictions_main[~mask].view(B, n_masked, -1) |
| loss_main = F.l1_loss(pred_main_masked, target_masked) |
|
|
| |
| loss_aux = F.l1_loss(predictions_aux, target_masked) |
|
|
| total_loss = loss_main + self.aux_loss_weight * loss_aux |
|
|
| return total_loss, predictions_main, mask |
|
|