"""Data loading for pre-sampled HF datasets. Expects pre-sampled npz blobs with xyz_norm (not full PCD). Supports both 2048-point and 4096-point datasets. Use make_sampled_cache.py to produce these from full point clouds. """ from __future__ import annotations from pathlib import Path import numpy as np import torch from .tokenizer import EdgeDepthSequenceConfig # Default token budget (for 2048-point datasets; 4096 uses 3072/1024) SEQ_LEN = 2048 COLMAP_POINTS = 1536 DEPTH_POINTS = 512 # --------------------------------------------------------------------------- # Datasets # --------------------------------------------------------------------------- def _load_bad_sample_ids(): """Load the set of known-bad sample IDs (misaligned GT, extreme scale).""" bad_file = Path(__file__).parent / "bad_samples.txt" if not bad_file.exists(): return set() return set(line.strip() for line in bad_file.read_text().splitlines() if line.strip()) class HFCachedDataset(torch.utils.data.Dataset): """Load pre-sampled HuggingFace dataset into memory.""" def __init__(self, hf_dataset, aug_rotate=False, aug_jitter=0.0, aug_drop=0.0, aug_flip=False): import io as _io bad_ids = _load_bad_sample_ids() print(f"Pre-decoding {len(hf_dataset)} samples into memory...") self.samples = [] self.order_ids = [] n_skipped = 0 for i, sample in enumerate(hf_dataset): if sample["order_id"] in bad_ids: n_skipped += 1 continue d = dict(np.load(_io.BytesIO(sample["data"]))) if "xyz_norm" not in d: raise ValueError( f"Sample {sample['order_id']} missing 'xyz_norm' -- this looks like " f"a full PCD dataset, not pre-sampled. Use make_sampled_cache.py first.") self.samples.append(d) self.order_ids.append(sample["order_id"]) if (i + 1) % 2000 == 0: print(f" {i+1}/{len(hf_dataset)}...") print(f" Done. {len(self.samples)} samples in memory" f" ({n_skipped} bad samples filtered).") self.aug_rotate = aug_rotate self.aug_jitter = aug_jitter self.aug_drop = aug_drop self.aug_flip = aug_flip def __len__(self): return len(self.samples) def __getitem__(self, idx): out = _process_sample(self.samples[idx], self.aug_rotate, self.aug_jitter, self.aug_drop, self.aug_flip) out["sample_id"] = self.order_ids[idx] return out def _process_sample(d, aug_rotate, aug_jitter=0.0, aug_drop=0.0, aug_flip=False): """Process a pre-sampled npz dict into training tensors. Args: aug_rotate: random yaw rotation aug_jitter: std of Gaussian noise added to point positions (0=disabled) aug_drop: fraction of points to randomly drop (0=disabled) aug_flip: random mirror along X axis (50% chance) """ xyz_norm = d["xyz_norm"].copy() gt_seg = d["gt_segments"].copy() mask = d["mask"].copy() if aug_rotate: theta = np.random.rand() * 2 * np.pi cos_t, sin_t = np.cos(theta), np.sin(theta) x, z = xyz_norm[:, 0].copy(), xyz_norm[:, 2].copy() xyz_norm[:, 0] = x * cos_t - z * sin_t xyz_norm[:, 2] = x * sin_t + z * cos_t for ep in range(2): sx, sz = gt_seg[:, ep, 0].copy(), gt_seg[:, ep, 2].copy() gt_seg[:, ep, 0] = sx * cos_t - sz * sin_t gt_seg[:, ep, 2] = sx * sin_t + sz * cos_t if aug_flip and np.random.rand() < 0.5: xyz_norm[:, 0] = -xyz_norm[:, 0] gt_seg[:, :, 0] = -gt_seg[:, :, 0] if aug_jitter > 0: valid = mask.astype(bool) xyz_norm[valid] += np.random.randn(valid.sum(), 3).astype(np.float32) * aug_jitter if aug_drop > 0: valid_idx = np.where(mask)[0] n_drop = int(len(valid_idx) * aug_drop) if n_drop > 0: drop_idx = np.random.choice(valid_idx, n_drop, replace=False) mask[drop_idx] = False result = { "xyz_norm": torch.as_tensor(xyz_norm, dtype=torch.float32), "class_id": torch.as_tensor(d["class_id"], dtype=torch.long), "source": torch.as_tensor(d["source"], dtype=torch.long), "mask": torch.as_tensor(mask), "gt_segments": torch.as_tensor(gt_seg, dtype=torch.float32), "scale": torch.tensor(float(d["scale"]), dtype=torch.float32), "center": torch.as_tensor(d["center"], dtype=torch.float32), "gt_vertices": d["gt_vertices"], "gt_edges": d["gt_edges"], "visible_src": torch.as_tensor(d["visible_src"], dtype=torch.long), "visible_id": torch.as_tensor(d["visible_id"], dtype=torch.long), } if "behind" in d: result["behind"] = torch.as_tensor( np.clip(np.asarray(d["behind"], dtype=np.int16), 0, None), dtype=torch.long) if "n_views_voted" in d: result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32) if "vote_frac" in d: result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32) return result # --------------------------------------------------------------------------- # Collation + DataLoader # --------------------------------------------------------------------------- def collate(batch): """Stack samples into batched tensors.""" out = { "xyz_norm": torch.stack([d["xyz_norm"] for d in batch]), "class_id": torch.stack([d["class_id"] for d in batch]), "source": torch.stack([d["source"] for d in batch]), "mask": torch.stack([d["mask"] for d in batch]), "gt_segments": [d["gt_segments"] for d in batch], "scales": torch.stack([d["scale"] for d in batch]), "meta": batch, } # Optional fields: check ALL samples, not just batch[0]. # If any sample has it, all must have it (no mixed data versions). for field in ("behind", "n_views_voted", "vote_frac"): if any(field in d for d in batch): missing = [i for i, d in enumerate(batch) if field not in d] if missing: raise KeyError( f"Field '{field}' present in some batch samples but missing in " f"{len(missing)}/{len(batch)}. Mixed data versions in cache?") out[field] = torch.stack([d[field] for d in batch]) return out def build_loader(cache_dir, batch_size, aug_rotate=False, aug_jitter=0.0, aug_drop=0.0, aug_flip=False): """Create a DataLoader from HF dataset. cache_dir should be 'hf://repo/name:split' format. """ if not cache_dir.startswith("hf://"): raise ValueError( f"cache_dir must be 'hf://repo:split' format, got: {cache_dir}. " f"Local .pt caches are no longer supported in the training path.") parts = cache_dir[5:].split(":") repo = parts[0] split = parts[1] if len(parts) > 1 else "train" from datasets import load_dataset hf_ds = load_dataset(repo, split=split) ds = HFCachedDataset(hf_ds, aug_rotate=aug_rotate, aug_jitter=aug_jitter, aug_drop=aug_drop, aug_flip=aug_flip) loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate, ) print(f"Dataset: {len(ds)} scenes, batch_size={batch_size}") return loader # --------------------------------------------------------------------------- # Token building (GPU) # --------------------------------------------------------------------------- def build_tokens(batch, model, device): """Apply Fourier features + learned embeddings on GPU.""" xyz = batch["xyz_norm"].to(device) cid = batch["class_id"].to(device) src = batch["source"].to(device) masks = batch["mask"].to(device) gt = [g.to(device) for g in batch["gt_segments"]] scales = batch["scales"] B, T, _ = xyz.shape tok = model.tokenizer fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \ if tok.pos_enc is not None else xyz.new_zeros(B, T, 0) parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))] if tok.behind_emb_dim > 0: if "behind" in batch: beh = batch["behind"].to(device) else: # Data doesn't have behind -- use zeros (embed index 0). # This is intentional for eval on old data; for training, # fail fast by requiring the field (checked in _process_sample). beh = xyz.new_zeros(B, T, dtype=torch.long) parts.append(tok.behind_emb(beh)) if tok.use_vote_features: if "n_views_voted" not in batch or "vote_frac" not in batch: raise KeyError( "Model expects vote features (--vote-features) but data is missing " "'n_views_voted'/'vote_frac'. Use v2 dataset or regenerate cache.") # Normalize to ~zero mean, unit variance (dataset stats: nv~2.7+/-1.0, vf~0.5+/-0.25) nv = ((batch["n_views_voted"].to(device).float() - 2.7) / 1.0).unsqueeze(-1) vf = ((batch["vote_frac"].to(device).float() - 0.5) / 0.25).unsqueeze(-1) parts.extend([nv, vf]) tokens = torch.cat(parts, dim=-1) return tokens, masks, gt, scales, batch["meta"]