| """Data loading for pre-sampled HF datasets. |
| |
| Expects pre-sampled npz blobs with xyz_norm (not full PCD). |
| Supports both 2048-point and 4096-point datasets. |
| Use make_sampled_cache.py to produce these from full point clouds. |
| """ |
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| from .tokenizer import EdgeDepthSequenceConfig |
|
|
| |
| SEQ_LEN = 2048 |
| COLMAP_POINTS = 1536 |
| DEPTH_POINTS = 512 |
|
|
|
|
| |
| |
| |
|
|
| def _load_bad_sample_ids(): |
| """Load the set of known-bad sample IDs (misaligned GT, extreme scale).""" |
| bad_file = Path(__file__).parent / "bad_samples.txt" |
| if not bad_file.exists(): |
| return set() |
| return set(line.strip() for line in bad_file.read_text().splitlines() if line.strip()) |
|
|
|
|
| class HFCachedDataset(torch.utils.data.Dataset): |
| """Load pre-sampled HuggingFace dataset into memory.""" |
|
|
| def __init__(self, hf_dataset, aug_rotate=False, aug_jitter=0.0, |
| aug_drop=0.0, aug_flip=False): |
| import io as _io |
| bad_ids = _load_bad_sample_ids() |
| print(f"Pre-decoding {len(hf_dataset)} samples into memory...") |
| self.samples = [] |
| self.order_ids = [] |
| n_skipped = 0 |
| for i, sample in enumerate(hf_dataset): |
| if sample["order_id"] in bad_ids: |
| n_skipped += 1 |
| continue |
| d = dict(np.load(_io.BytesIO(sample["data"]))) |
| if "xyz_norm" not in d: |
| raise ValueError( |
| f"Sample {sample['order_id']} missing 'xyz_norm' -- this looks like " |
| f"a full PCD dataset, not pre-sampled. Use make_sampled_cache.py first.") |
| self.samples.append(d) |
| self.order_ids.append(sample["order_id"]) |
| if (i + 1) % 2000 == 0: |
| print(f" {i+1}/{len(hf_dataset)}...") |
| print(f" Done. {len(self.samples)} samples in memory" |
| f" ({n_skipped} bad samples filtered).") |
| self.aug_rotate = aug_rotate |
| self.aug_jitter = aug_jitter |
| self.aug_drop = aug_drop |
| self.aug_flip = aug_flip |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| out = _process_sample(self.samples[idx], self.aug_rotate, |
| self.aug_jitter, self.aug_drop, self.aug_flip) |
| out["sample_id"] = self.order_ids[idx] |
| return out |
|
|
|
|
| def _process_sample(d, aug_rotate, aug_jitter=0.0, aug_drop=0.0, aug_flip=False): |
| """Process a pre-sampled npz dict into training tensors. |
| |
| Args: |
| aug_rotate: random yaw rotation |
| aug_jitter: std of Gaussian noise added to point positions (0=disabled) |
| aug_drop: fraction of points to randomly drop (0=disabled) |
| aug_flip: random mirror along X axis (50% chance) |
| """ |
| xyz_norm = d["xyz_norm"].copy() |
| gt_seg = d["gt_segments"].copy() |
| mask = d["mask"].copy() |
|
|
| if aug_rotate: |
| theta = np.random.rand() * 2 * np.pi |
| cos_t, sin_t = np.cos(theta), np.sin(theta) |
| x, z = xyz_norm[:, 0].copy(), xyz_norm[:, 2].copy() |
| xyz_norm[:, 0] = x * cos_t - z * sin_t |
| xyz_norm[:, 2] = x * sin_t + z * cos_t |
| for ep in range(2): |
| sx, sz = gt_seg[:, ep, 0].copy(), gt_seg[:, ep, 2].copy() |
| gt_seg[:, ep, 0] = sx * cos_t - sz * sin_t |
| gt_seg[:, ep, 2] = sx * sin_t + sz * cos_t |
|
|
| if aug_flip and np.random.rand() < 0.5: |
| xyz_norm[:, 0] = -xyz_norm[:, 0] |
| gt_seg[:, :, 0] = -gt_seg[:, :, 0] |
|
|
| if aug_jitter > 0: |
| valid = mask.astype(bool) |
| xyz_norm[valid] += np.random.randn(valid.sum(), 3).astype(np.float32) * aug_jitter |
|
|
| if aug_drop > 0: |
| valid_idx = np.where(mask)[0] |
| n_drop = int(len(valid_idx) * aug_drop) |
| if n_drop > 0: |
| drop_idx = np.random.choice(valid_idx, n_drop, replace=False) |
| mask[drop_idx] = False |
|
|
| result = { |
| "xyz_norm": torch.as_tensor(xyz_norm, dtype=torch.float32), |
| "class_id": torch.as_tensor(d["class_id"], dtype=torch.long), |
| "source": torch.as_tensor(d["source"], dtype=torch.long), |
| "mask": torch.as_tensor(mask), |
| "gt_segments": torch.as_tensor(gt_seg, dtype=torch.float32), |
| "scale": torch.tensor(float(d["scale"]), dtype=torch.float32), |
| "center": torch.as_tensor(d["center"], dtype=torch.float32), |
| "gt_vertices": d["gt_vertices"], |
| "gt_edges": d["gt_edges"], |
| "visible_src": torch.as_tensor(d["visible_src"], dtype=torch.long), |
| "visible_id": torch.as_tensor(d["visible_id"], dtype=torch.long), |
| } |
| if "behind" in d: |
| result["behind"] = torch.as_tensor( |
| np.clip(np.asarray(d["behind"], dtype=np.int16), 0, None), dtype=torch.long) |
| if "n_views_voted" in d: |
| result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32) |
| if "vote_frac" in d: |
| result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32) |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def collate(batch): |
| """Stack samples into batched tensors.""" |
| out = { |
| "xyz_norm": torch.stack([d["xyz_norm"] for d in batch]), |
| "class_id": torch.stack([d["class_id"] for d in batch]), |
| "source": torch.stack([d["source"] for d in batch]), |
| "mask": torch.stack([d["mask"] for d in batch]), |
| "gt_segments": [d["gt_segments"] for d in batch], |
| "scales": torch.stack([d["scale"] for d in batch]), |
| "meta": batch, |
| } |
| |
| |
| for field in ("behind", "n_views_voted", "vote_frac"): |
| if any(field in d for d in batch): |
| missing = [i for i, d in enumerate(batch) if field not in d] |
| if missing: |
| raise KeyError( |
| f"Field '{field}' present in some batch samples but missing in " |
| f"{len(missing)}/{len(batch)}. Mixed data versions in cache?") |
| out[field] = torch.stack([d[field] for d in batch]) |
| return out |
|
|
|
|
| def build_loader(cache_dir, batch_size, aug_rotate=False, aug_jitter=0.0, |
| aug_drop=0.0, aug_flip=False): |
| """Create a DataLoader from HF dataset. |
| |
| cache_dir should be 'hf://repo/name:split' format. |
| """ |
| if not cache_dir.startswith("hf://"): |
| raise ValueError( |
| f"cache_dir must be 'hf://repo:split' format, got: {cache_dir}. " |
| f"Local .pt caches are no longer supported in the training path.") |
| parts = cache_dir[5:].split(":") |
| repo = parts[0] |
| split = parts[1] if len(parts) > 1 else "train" |
| from datasets import load_dataset |
| hf_ds = load_dataset(repo, split=split) |
| ds = HFCachedDataset(hf_ds, aug_rotate=aug_rotate, aug_jitter=aug_jitter, |
| aug_drop=aug_drop, aug_flip=aug_flip) |
| loader = torch.utils.data.DataLoader( |
| ds, batch_size=batch_size, shuffle=True, |
| num_workers=0, collate_fn=collate, |
| ) |
| print(f"Dataset: {len(ds)} scenes, batch_size={batch_size}") |
| return loader |
|
|
|
|
| |
| |
| |
|
|
| def build_tokens(batch, model, device): |
| """Apply Fourier features + learned embeddings on GPU.""" |
| xyz = batch["xyz_norm"].to(device) |
| cid = batch["class_id"].to(device) |
| src = batch["source"].to(device) |
| masks = batch["mask"].to(device) |
| gt = [g.to(device) for g in batch["gt_segments"]] |
| scales = batch["scales"] |
|
|
| B, T, _ = xyz.shape |
| tok = model.tokenizer |
| fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \ |
| if tok.pos_enc is not None else xyz.new_zeros(B, T, 0) |
| parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))] |
| if tok.behind_emb_dim > 0: |
| if "behind" in batch: |
| beh = batch["behind"].to(device) |
| else: |
| |
| |
| |
| beh = xyz.new_zeros(B, T, dtype=torch.long) |
| parts.append(tok.behind_emb(beh)) |
| if tok.use_vote_features: |
| if "n_views_voted" not in batch or "vote_frac" not in batch: |
| raise KeyError( |
| "Model expects vote features (--vote-features) but data is missing " |
| "'n_views_voted'/'vote_frac'. Use v2 dataset or regenerate cache.") |
| |
| nv = ((batch["n_views_voted"].to(device).float() - 2.7) / 1.0).unsqueeze(-1) |
| vf = ((batch["vote_frac"].to(device).float() - 0.5) / 0.25).unsqueeze(-1) |
| parts.extend([nv, vf]) |
| tokens = torch.cat(parts, dim=-1) |
| return tokens, masks, gt, scales, batch["meta"] |
|
|