| """Tokenizer: learned embeddings + Fourier features for the point cloud tokens. |
| |
| The EdgeDepthSequenceBuilder holds the learned embedding tables (label, source, |
| behind) and the random Fourier positional encoding. At training time, |
| build_tokens() in data.py applies these to pre-sampled point indices on GPU. |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from .point_fusion import NUM_ADE, NUM_GEST |
|
|
|
|
| |
|
|
| @dataclass(frozen=True) |
| class EdgeDepthSequenceConfig: |
| seq_len: int = 2048 |
| colmap_points: int = 1280 |
| depth_points: int = 768 |
| use_fourier: bool = True |
| fourier_dim: int = 32 |
| fourier_scale: float = 10.0 |
|
|
|
|
| |
|
|
| class FourierFeatures(nn.Module): |
| def __init__(self, in_dim: int = 3, fourier_dim: int = 64, |
| scale: float = 10.0, seed: int = 0, |
| learnable: bool = False): |
| super().__init__() |
| gen = torch.Generator() |
| gen.manual_seed(seed) |
| B = torch.randn(fourier_dim, in_dim, generator=gen) * scale |
| if learnable: |
| self.B = nn.Parameter(B) |
| else: |
| self.register_buffer("B", B, persistent=True) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| proj = (2.0 * np.pi) * (x @ self.B.t()) |
| return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1) |
|
|
|
|
| |
|
|
| class EdgeDepthSequenceBuilder(nn.Module): |
| """Holds learned embeddings for point cloud tokenization. |
| |
| Used by the model at training time: build_tokens() calls |
| self.label_emb(class_id), self.src_emb(source), etc. |
| """ |
|
|
| def __init__(self, cfg: EdgeDepthSequenceConfig, label_emb_dim: int = 16, |
| src_emb_dim: int = 2, behind_emb_dim: int = 8, |
| fourier_seed: int = 0, use_vote_features: bool = False, |
| learnable_fourier: bool = False): |
| super().__init__() |
| self.cfg = cfg |
|
|
| self.num_labels = 13 |
| self.label_emb = nn.Embedding(self.num_labels, label_emb_dim) |
| self.src_emb = nn.Embedding(2, src_emb_dim) |
| self.behind_emb_dim = behind_emb_dim |
| if behind_emb_dim > 0: |
| self.behind_emb = nn.Embedding(NUM_GEST + 1, behind_emb_dim) |
|
|
| |
| if cfg.use_fourier: |
| self.pos_enc = FourierFeatures( |
| in_dim=3, fourier_dim=cfg.fourier_dim, |
| scale=cfg.fourier_scale, seed=fourier_seed, |
| learnable=learnable_fourier, |
| ) |
| pos_dim = 3 + 2 * cfg.fourier_dim |
| else: |
| self.pos_enc = None |
| pos_dim = 3 |
|
|
| vote_dim = 2 if use_vote_features else 0 |
| self.use_vote_features = use_vote_features |
| self.out_dim = pos_dim + label_emb_dim + src_emb_dim + behind_emb_dim + vote_dim |
|
|