File size: 3,003 Bytes
f4487da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | """Tokenizer: learned embeddings + Fourier features for the point cloud tokens.
The EdgeDepthSequenceBuilder holds the learned embedding tables (label, source,
behind) and the random Fourier positional encoding. At training time,
build_tokens() in data.py applies these to pre-sampled point indices on GPU.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from .point_fusion import NUM_ADE, NUM_GEST
# -- Config --
@dataclass(frozen=True)
class EdgeDepthSequenceConfig:
seq_len: int = 2048
colmap_points: int = 1280
depth_points: int = 768
use_fourier: bool = True
fourier_dim: int = 32
fourier_scale: float = 10.0
# -- Fourier positional encoding --
class FourierFeatures(nn.Module):
def __init__(self, in_dim: int = 3, fourier_dim: int = 64,
scale: float = 10.0, seed: int = 0,
learnable: bool = False):
super().__init__()
gen = torch.Generator()
gen.manual_seed(seed)
B = torch.randn(fourier_dim, in_dim, generator=gen) * scale
if learnable:
self.B = nn.Parameter(B)
else:
self.register_buffer("B", B, persistent=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
proj = (2.0 * np.pi) * (x @ self.B.t())
return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
# -- Sequence builder (holds embeddings) --
class EdgeDepthSequenceBuilder(nn.Module):
"""Holds learned embeddings for point cloud tokenization.
Used by the model at training time: build_tokens() calls
self.label_emb(class_id), self.src_emb(source), etc.
"""
def __init__(self, cfg: EdgeDepthSequenceConfig, label_emb_dim: int = 16,
src_emb_dim: int = 2, behind_emb_dim: int = 8,
fourier_seed: int = 0, use_vote_features: bool = False,
learnable_fourier: bool = False):
super().__init__()
self.cfg = cfg
self.num_labels = 13 # 11 structural + other_house + non_house
self.label_emb = nn.Embedding(self.num_labels, label_emb_dim)
self.src_emb = nn.Embedding(2, src_emb_dim)
self.behind_emb_dim = behind_emb_dim
if behind_emb_dim > 0:
self.behind_emb = nn.Embedding(NUM_GEST + 1, behind_emb_dim)
# Fourier positional encoding
if cfg.use_fourier:
self.pos_enc = FourierFeatures(
in_dim=3, fourier_dim=cfg.fourier_dim,
scale=cfg.fourier_scale, seed=fourier_seed,
learnable=learnable_fourier,
)
pos_dim = 3 + 2 * cfg.fourier_dim
else:
self.pos_enc = None
pos_dim = 3
vote_dim = 2 if use_vote_features else 0 # n_views_voted + vote_frac
self.use_vote_features = use_vote_features
self.out_dim = pos_dim + label_emb_dim + src_emb_dim + behind_emb_dim + vote_dim
|