Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
File size: 9,390 Bytes
f4487da
 
0f31e57
 
f4487da
 
 
 
 
 
 
 
 
 
 
0f31e57
f4487da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""Data loading for pre-sampled HF datasets.

Expects pre-sampled npz blobs with xyz_norm (not full PCD).
Supports both 2048-point and 4096-point datasets.
Use make_sampled_cache.py to produce these from full point clouds.
"""
from __future__ import annotations

from pathlib import Path

import numpy as np
import torch

from .tokenizer import EdgeDepthSequenceConfig

# Default token budget (for 2048-point datasets; 4096 uses 3072/1024)
SEQ_LEN = 2048
COLMAP_POINTS = 1536
DEPTH_POINTS = 512


# ---------------------------------------------------------------------------
# Datasets
# ---------------------------------------------------------------------------

def _load_bad_sample_ids():
    """Load the set of known-bad sample IDs (misaligned GT, extreme scale)."""
    bad_file = Path(__file__).parent / "bad_samples.txt"
    if not bad_file.exists():
        return set()
    return set(line.strip() for line in bad_file.read_text().splitlines() if line.strip())


class HFCachedDataset(torch.utils.data.Dataset):
    """Load pre-sampled HuggingFace dataset into memory."""

    def __init__(self, hf_dataset, aug_rotate=False, aug_jitter=0.0,
                 aug_drop=0.0, aug_flip=False):
        import io as _io
        bad_ids = _load_bad_sample_ids()
        print(f"Pre-decoding {len(hf_dataset)} samples into memory...")
        self.samples = []
        self.order_ids = []
        n_skipped = 0
        for i, sample in enumerate(hf_dataset):
            if sample["order_id"] in bad_ids:
                n_skipped += 1
                continue
            d = dict(np.load(_io.BytesIO(sample["data"])))
            if "xyz_norm" not in d:
                raise ValueError(
                    f"Sample {sample['order_id']} missing 'xyz_norm' -- this looks like "
                    f"a full PCD dataset, not pre-sampled. Use make_sampled_cache.py first.")
            self.samples.append(d)
            self.order_ids.append(sample["order_id"])
            if (i + 1) % 2000 == 0:
                print(f"  {i+1}/{len(hf_dataset)}...")
        print(f"  Done. {len(self.samples)} samples in memory"
              f" ({n_skipped} bad samples filtered).")
        self.aug_rotate = aug_rotate
        self.aug_jitter = aug_jitter
        self.aug_drop = aug_drop
        self.aug_flip = aug_flip

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        out = _process_sample(self.samples[idx], self.aug_rotate,
                              self.aug_jitter, self.aug_drop, self.aug_flip)
        out["sample_id"] = self.order_ids[idx]
        return out


def _process_sample(d, aug_rotate, aug_jitter=0.0, aug_drop=0.0, aug_flip=False):
    """Process a pre-sampled npz dict into training tensors.

    Args:
        aug_rotate: random yaw rotation
        aug_jitter: std of Gaussian noise added to point positions (0=disabled)
        aug_drop: fraction of points to randomly drop (0=disabled)
        aug_flip: random mirror along X axis (50% chance)
    """
    xyz_norm = d["xyz_norm"].copy()
    gt_seg = d["gt_segments"].copy()
    mask = d["mask"].copy()

    if aug_rotate:
        theta = np.random.rand() * 2 * np.pi
        cos_t, sin_t = np.cos(theta), np.sin(theta)
        x, z = xyz_norm[:, 0].copy(), xyz_norm[:, 2].copy()
        xyz_norm[:, 0] = x * cos_t - z * sin_t
        xyz_norm[:, 2] = x * sin_t + z * cos_t
        for ep in range(2):
            sx, sz = gt_seg[:, ep, 0].copy(), gt_seg[:, ep, 2].copy()
            gt_seg[:, ep, 0] = sx * cos_t - sz * sin_t
            gt_seg[:, ep, 2] = sx * sin_t + sz * cos_t

    if aug_flip and np.random.rand() < 0.5:
        xyz_norm[:, 0] = -xyz_norm[:, 0]
        gt_seg[:, :, 0] = -gt_seg[:, :, 0]

    if aug_jitter > 0:
        valid = mask.astype(bool)
        xyz_norm[valid] += np.random.randn(valid.sum(), 3).astype(np.float32) * aug_jitter

    if aug_drop > 0:
        valid_idx = np.where(mask)[0]
        n_drop = int(len(valid_idx) * aug_drop)
        if n_drop > 0:
            drop_idx = np.random.choice(valid_idx, n_drop, replace=False)
            mask[drop_idx] = False

    result = {
        "xyz_norm": torch.as_tensor(xyz_norm, dtype=torch.float32),
        "class_id": torch.as_tensor(d["class_id"], dtype=torch.long),
        "source": torch.as_tensor(d["source"], dtype=torch.long),
        "mask": torch.as_tensor(mask),
        "gt_segments": torch.as_tensor(gt_seg, dtype=torch.float32),
        "scale": torch.tensor(float(d["scale"]), dtype=torch.float32),
        "center": torch.as_tensor(d["center"], dtype=torch.float32),
        "gt_vertices": d["gt_vertices"],
        "gt_edges": d["gt_edges"],
        "visible_src": torch.as_tensor(d["visible_src"], dtype=torch.long),
        "visible_id": torch.as_tensor(d["visible_id"], dtype=torch.long),
    }
    if "behind" in d:
        result["behind"] = torch.as_tensor(
            np.clip(np.asarray(d["behind"], dtype=np.int16), 0, None), dtype=torch.long)
    if "n_views_voted" in d:
        result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32)
    if "vote_frac" in d:
        result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32)
    return result


# ---------------------------------------------------------------------------
# Collation + DataLoader
# ---------------------------------------------------------------------------

def collate(batch):
    """Stack samples into batched tensors."""
    out = {
        "xyz_norm": torch.stack([d["xyz_norm"] for d in batch]),
        "class_id": torch.stack([d["class_id"] for d in batch]),
        "source": torch.stack([d["source"] for d in batch]),
        "mask": torch.stack([d["mask"] for d in batch]),
        "gt_segments": [d["gt_segments"] for d in batch],
        "scales": torch.stack([d["scale"] for d in batch]),
        "meta": batch,
    }
    # Optional fields: check ALL samples, not just batch[0].
    # If any sample has it, all must have it (no mixed data versions).
    for field in ("behind", "n_views_voted", "vote_frac"):
        if any(field in d for d in batch):
            missing = [i for i, d in enumerate(batch) if field not in d]
            if missing:
                raise KeyError(
                    f"Field '{field}' present in some batch samples but missing in "
                    f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
            out[field] = torch.stack([d[field] for d in batch])
    return out


def build_loader(cache_dir, batch_size, aug_rotate=False, aug_jitter=0.0,
                 aug_drop=0.0, aug_flip=False):
    """Create a DataLoader from HF dataset.

    cache_dir should be 'hf://repo/name:split' format.
    """
    if not cache_dir.startswith("hf://"):
        raise ValueError(
            f"cache_dir must be 'hf://repo:split' format, got: {cache_dir}. "
            f"Local .pt caches are no longer supported in the training path.")
    parts = cache_dir[5:].split(":")
    repo = parts[0]
    split = parts[1] if len(parts) > 1 else "train"
    from datasets import load_dataset
    hf_ds = load_dataset(repo, split=split)
    ds = HFCachedDataset(hf_ds, aug_rotate=aug_rotate, aug_jitter=aug_jitter,
                          aug_drop=aug_drop, aug_flip=aug_flip)
    loader = torch.utils.data.DataLoader(
        ds, batch_size=batch_size, shuffle=True,
        num_workers=0, collate_fn=collate,
    )
    print(f"Dataset: {len(ds)} scenes, batch_size={batch_size}")
    return loader


# ---------------------------------------------------------------------------
# Token building (GPU)
# ---------------------------------------------------------------------------

def build_tokens(batch, model, device):
    """Apply Fourier features + learned embeddings on GPU."""
    xyz = batch["xyz_norm"].to(device)
    cid = batch["class_id"].to(device)
    src = batch["source"].to(device)
    masks = batch["mask"].to(device)
    gt = [g.to(device) for g in batch["gt_segments"]]
    scales = batch["scales"]

    B, T, _ = xyz.shape
    tok = model.tokenizer
    fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \
        if tok.pos_enc is not None else xyz.new_zeros(B, T, 0)
    parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))]
    if tok.behind_emb_dim > 0:
        if "behind" in batch:
            beh = batch["behind"].to(device)
        else:
            # Data doesn't have behind -- use zeros (embed index 0).
            # This is intentional for eval on old data; for training,
            # fail fast by requiring the field (checked in _process_sample).
            beh = xyz.new_zeros(B, T, dtype=torch.long)
        parts.append(tok.behind_emb(beh))
    if tok.use_vote_features:
        if "n_views_voted" not in batch or "vote_frac" not in batch:
            raise KeyError(
                "Model expects vote features (--vote-features) but data is missing "
                "'n_views_voted'/'vote_frac'. Use v2 dataset or regenerate cache.")
        # Normalize to ~zero mean, unit variance (dataset stats: nv~2.7+/-1.0, vf~0.5+/-0.25)
        nv = ((batch["n_views_voted"].to(device).float() - 2.7) / 1.0).unsqueeze(-1)
        vf = ((batch["vote_frac"].to(device).float() - 0.5) / 0.25).unsqueeze(-1)
        parts.extend([nv, vf])
    tokens = torch.cat(parts, dim=-1)
    return tokens, masks, gt, scales, batch["meta"]