Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
jacklangerman's picture
4096-release (#1)
0f31e57
"""Data loading for pre-sampled HF datasets.
Expects pre-sampled npz blobs with xyz_norm (not full PCD).
Supports both 2048-point and 4096-point datasets.
Use make_sampled_cache.py to produce these from full point clouds.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from .tokenizer import EdgeDepthSequenceConfig
# Default token budget (for 2048-point datasets; 4096 uses 3072/1024)
SEQ_LEN = 2048
COLMAP_POINTS = 1536
DEPTH_POINTS = 512
# ---------------------------------------------------------------------------
# Datasets
# ---------------------------------------------------------------------------
def _load_bad_sample_ids():
"""Load the set of known-bad sample IDs (misaligned GT, extreme scale)."""
bad_file = Path(__file__).parent / "bad_samples.txt"
if not bad_file.exists():
return set()
return set(line.strip() for line in bad_file.read_text().splitlines() if line.strip())
class HFCachedDataset(torch.utils.data.Dataset):
"""Load pre-sampled HuggingFace dataset into memory."""
def __init__(self, hf_dataset, aug_rotate=False, aug_jitter=0.0,
aug_drop=0.0, aug_flip=False):
import io as _io
bad_ids = _load_bad_sample_ids()
print(f"Pre-decoding {len(hf_dataset)} samples into memory...")
self.samples = []
self.order_ids = []
n_skipped = 0
for i, sample in enumerate(hf_dataset):
if sample["order_id"] in bad_ids:
n_skipped += 1
continue
d = dict(np.load(_io.BytesIO(sample["data"])))
if "xyz_norm" not in d:
raise ValueError(
f"Sample {sample['order_id']} missing 'xyz_norm' -- this looks like "
f"a full PCD dataset, not pre-sampled. Use make_sampled_cache.py first.")
self.samples.append(d)
self.order_ids.append(sample["order_id"])
if (i + 1) % 2000 == 0:
print(f" {i+1}/{len(hf_dataset)}...")
print(f" Done. {len(self.samples)} samples in memory"
f" ({n_skipped} bad samples filtered).")
self.aug_rotate = aug_rotate
self.aug_jitter = aug_jitter
self.aug_drop = aug_drop
self.aug_flip = aug_flip
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
out = _process_sample(self.samples[idx], self.aug_rotate,
self.aug_jitter, self.aug_drop, self.aug_flip)
out["sample_id"] = self.order_ids[idx]
return out
def _process_sample(d, aug_rotate, aug_jitter=0.0, aug_drop=0.0, aug_flip=False):
"""Process a pre-sampled npz dict into training tensors.
Args:
aug_rotate: random yaw rotation
aug_jitter: std of Gaussian noise added to point positions (0=disabled)
aug_drop: fraction of points to randomly drop (0=disabled)
aug_flip: random mirror along X axis (50% chance)
"""
xyz_norm = d["xyz_norm"].copy()
gt_seg = d["gt_segments"].copy()
mask = d["mask"].copy()
if aug_rotate:
theta = np.random.rand() * 2 * np.pi
cos_t, sin_t = np.cos(theta), np.sin(theta)
x, z = xyz_norm[:, 0].copy(), xyz_norm[:, 2].copy()
xyz_norm[:, 0] = x * cos_t - z * sin_t
xyz_norm[:, 2] = x * sin_t + z * cos_t
for ep in range(2):
sx, sz = gt_seg[:, ep, 0].copy(), gt_seg[:, ep, 2].copy()
gt_seg[:, ep, 0] = sx * cos_t - sz * sin_t
gt_seg[:, ep, 2] = sx * sin_t + sz * cos_t
if aug_flip and np.random.rand() < 0.5:
xyz_norm[:, 0] = -xyz_norm[:, 0]
gt_seg[:, :, 0] = -gt_seg[:, :, 0]
if aug_jitter > 0:
valid = mask.astype(bool)
xyz_norm[valid] += np.random.randn(valid.sum(), 3).astype(np.float32) * aug_jitter
if aug_drop > 0:
valid_idx = np.where(mask)[0]
n_drop = int(len(valid_idx) * aug_drop)
if n_drop > 0:
drop_idx = np.random.choice(valid_idx, n_drop, replace=False)
mask[drop_idx] = False
result = {
"xyz_norm": torch.as_tensor(xyz_norm, dtype=torch.float32),
"class_id": torch.as_tensor(d["class_id"], dtype=torch.long),
"source": torch.as_tensor(d["source"], dtype=torch.long),
"mask": torch.as_tensor(mask),
"gt_segments": torch.as_tensor(gt_seg, dtype=torch.float32),
"scale": torch.tensor(float(d["scale"]), dtype=torch.float32),
"center": torch.as_tensor(d["center"], dtype=torch.float32),
"gt_vertices": d["gt_vertices"],
"gt_edges": d["gt_edges"],
"visible_src": torch.as_tensor(d["visible_src"], dtype=torch.long),
"visible_id": torch.as_tensor(d["visible_id"], dtype=torch.long),
}
if "behind" in d:
result["behind"] = torch.as_tensor(
np.clip(np.asarray(d["behind"], dtype=np.int16), 0, None), dtype=torch.long)
if "n_views_voted" in d:
result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32)
if "vote_frac" in d:
result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32)
return result
# ---------------------------------------------------------------------------
# Collation + DataLoader
# ---------------------------------------------------------------------------
def collate(batch):
"""Stack samples into batched tensors."""
out = {
"xyz_norm": torch.stack([d["xyz_norm"] for d in batch]),
"class_id": torch.stack([d["class_id"] for d in batch]),
"source": torch.stack([d["source"] for d in batch]),
"mask": torch.stack([d["mask"] for d in batch]),
"gt_segments": [d["gt_segments"] for d in batch],
"scales": torch.stack([d["scale"] for d in batch]),
"meta": batch,
}
# Optional fields: check ALL samples, not just batch[0].
# If any sample has it, all must have it (no mixed data versions).
for field in ("behind", "n_views_voted", "vote_frac"):
if any(field in d for d in batch):
missing = [i for i, d in enumerate(batch) if field not in d]
if missing:
raise KeyError(
f"Field '{field}' present in some batch samples but missing in "
f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
out[field] = torch.stack([d[field] for d in batch])
return out
def build_loader(cache_dir, batch_size, aug_rotate=False, aug_jitter=0.0,
aug_drop=0.0, aug_flip=False):
"""Create a DataLoader from HF dataset.
cache_dir should be 'hf://repo/name:split' format.
"""
if not cache_dir.startswith("hf://"):
raise ValueError(
f"cache_dir must be 'hf://repo:split' format, got: {cache_dir}. "
f"Local .pt caches are no longer supported in the training path.")
parts = cache_dir[5:].split(":")
repo = parts[0]
split = parts[1] if len(parts) > 1 else "train"
from datasets import load_dataset
hf_ds = load_dataset(repo, split=split)
ds = HFCachedDataset(hf_ds, aug_rotate=aug_rotate, aug_jitter=aug_jitter,
aug_drop=aug_drop, aug_flip=aug_flip)
loader = torch.utils.data.DataLoader(
ds, batch_size=batch_size, shuffle=True,
num_workers=0, collate_fn=collate,
)
print(f"Dataset: {len(ds)} scenes, batch_size={batch_size}")
return loader
# ---------------------------------------------------------------------------
# Token building (GPU)
# ---------------------------------------------------------------------------
def build_tokens(batch, model, device):
"""Apply Fourier features + learned embeddings on GPU."""
xyz = batch["xyz_norm"].to(device)
cid = batch["class_id"].to(device)
src = batch["source"].to(device)
masks = batch["mask"].to(device)
gt = [g.to(device) for g in batch["gt_segments"]]
scales = batch["scales"]
B, T, _ = xyz.shape
tok = model.tokenizer
fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \
if tok.pos_enc is not None else xyz.new_zeros(B, T, 0)
parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))]
if tok.behind_emb_dim > 0:
if "behind" in batch:
beh = batch["behind"].to(device)
else:
# Data doesn't have behind -- use zeros (embed index 0).
# This is intentional for eval on old data; for training,
# fail fast by requiring the field (checked in _process_sample).
beh = xyz.new_zeros(B, T, dtype=torch.long)
parts.append(tok.behind_emb(beh))
if tok.use_vote_features:
if "n_views_voted" not in batch or "vote_frac" not in batch:
raise KeyError(
"Model expects vote features (--vote-features) but data is missing "
"'n_views_voted'/'vote_frac'. Use v2 dataset or regenerate cache.")
# Normalize to ~zero mean, unit variance (dataset stats: nv~2.7+/-1.0, vf~0.5+/-0.25)
nv = ((batch["n_views_voted"].to(device).float() - 2.7) / 1.0).unsqueeze(-1)
vf = ((batch["vote_frac"].to(device).float() - 0.5) / 0.25).unsqueeze(-1)
parts.extend([nv, vf])
tokens = torch.cat(parts, dim=-1)
return tokens, masks, gt, scales, batch["meta"]