Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
File size: 5,996 Bytes
f4487da
4946666
f4487da
4946666
 
 
f4487da
4946666
 
f4487da
4946666
 
 
 
 
f4487da
4946666
 
f4487da
 
 
4946666
 
 
 
f4487da
4946666
f4487da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4946666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env python3
"""Stage 2: priority-sample cached .pt scenes into fixed-size .npz files.

Reads the per-scene .pt files produced by cache_scenes.py, priority-samples
a fixed number of points (2048 or 4096), normalizes, and writes one .npz per
scene (~50KB at 2048, ~100KB at 4096).

A fixed seed is used so every scene gets one deterministic sample -- no
per-epoch sampling augmentation, every epoch sees the same points.

Usage:
  python -m s23dr_2026_example.make_sampled_cache \\
      --in-dir cache/full --out-dir cache/sampled_2048 --seq-len 2048
  python -m s23dr_2026_example.make_sampled_cache \\
      --in-dir cache/full --out-dir cache/sampled_4096 --seq-len 4096

The 3:1 colmap:depth quota ratio is fixed: at seq_len=2048 that's
colmap=1536/depth=512; at seq_len=4096 that's colmap=3072/depth=1024.
"""
from __future__ import annotations

import argparse
import time
from pathlib import Path

import numpy as np
import torch


# Priority sampling (same logic as train.py)
def _priority_sample(source, group_id, seq_len, colmap_quota, depth_quota):
    def pick(src_id, quota):
        base = source == src_id
        picked, remaining = [], quota
        for tier in range(5):
            if remaining <= 0:
                break
            pool = np.where(base & (group_id == tier))[0]
            if len(pool) == 0:
                continue
            np.random.shuffle(pool)
            take = min(remaining, len(pool))
            picked.append(pool[:take])
            remaining -= take
        if remaining > 0:
            pool = np.where(base & (group_id >= 0))[0]
            if len(pool) > 0:
                np.random.shuffle(pool)
                picked.append(pool[:min(remaining, len(pool))])
                remaining -= min(remaining, len(pool))
        return np.concatenate(picked) if picked else np.array([], dtype=np.int64), remaining

    idx_c, rem_c = pick(0, colmap_quota)
    idx_d, rem_d = pick(1, depth_quota)

    if rem_c > 0:
        extra = np.setdiff1d(np.where((source == 1) & (group_id >= 0))[0], idx_d)
        np.random.shuffle(extra)
        idx_d = np.concatenate([idx_d, extra[:rem_c]])
    if rem_d > 0:
        extra = np.setdiff1d(np.where((source == 0) & (group_id >= 0))[0], idx_c)
        np.random.shuffle(extra)
        idx_c = np.concatenate([idx_c, extra[:rem_d]])

    indices = np.concatenate([idx_c, idx_d])
    num_valid = len(indices)
    if num_valid < seq_len:
        if num_valid == 0:
            return np.zeros(seq_len, dtype=np.int64), np.zeros(seq_len, dtype=bool)
        indices = np.concatenate([indices, np.full(seq_len - num_valid, indices[-1])])
    mask = np.zeros(seq_len, dtype=bool)
    mask[:num_valid] = True
    return indices[:seq_len], mask


def _process_sample(d, seq_len, colmap_q, depth_q):
    """Sample and normalize one cached scene dict into a small npz-ready dict."""
    xyz = np.asarray(d["xyz"], np.float32)
    source = np.asarray(d["source"], np.uint8)
    group_id = np.asarray(d["group_id"], np.int8)
    class_id = np.asarray(d["class_id"], np.uint8)
    vis_src = np.asarray(d["visible_src"], np.uint8)
    vis_id = np.asarray(d["visible_id"], np.int16)
    center = np.asarray(d["center"], np.float32)
    scale = float(d["scale"])
    gt_v = np.asarray(d["gt_vertices"], np.float32)
    gt_e = np.asarray(d["gt_edges"], np.int32)

    indices, mask = _priority_sample(source, group_id, seq_len, colmap_q, depth_q)
    xyz_norm = ((xyz[indices] - center) / scale).astype(np.float32)
    gt_seg = np.stack([gt_v[gt_e[:, 0]], gt_v[gt_e[:, 1]]], axis=1)
    gt_seg_norm = ((gt_seg - center) / scale).astype(np.float32)

    result = {
        "xyz_norm": xyz_norm,
        "class_id": class_id[indices].astype(np.uint8),
        "source": source[indices].astype(np.uint8),
        "mask": mask,
        "gt_segments": gt_seg_norm,
        "scale": np.float32(scale),
        "center": center,
        "gt_vertices": gt_v,
        "gt_edges": gt_e,
        "visible_src": vis_src[indices].astype(np.uint8),
        "visible_id": vis_id[indices].astype(np.int16),
    }
    if "behind_gest_id" in d:
        result["behind"] = np.asarray(d["behind_gest_id"], np.int16)[indices]
    if "n_views_voted" in d:
        result["n_views_voted"] = np.asarray(d["n_views_voted"], np.uint8)[indices]
    if "vote_frac" in d:
        result["vote_frac"] = np.asarray(d["vote_frac"], np.float32)[indices]
    if "gt_edge_classes" in d:
        result["gt_edge_classes"] = np.asarray(d["gt_edge_classes"], np.int64)
    return result


def main():
    p = argparse.ArgumentParser(description="Stage 2: cached .pt -> sampled .npz")
    p.add_argument("--in-dir", required=True, help="Directory of .pt files from cache_scenes.py")
    p.add_argument("--out-dir", required=True, help="Output directory for .npz files")
    p.add_argument("--seq-len", type=int, default=2048, help="Points per sample (2048 or 4096)")
    p.add_argument("--seed", type=int, default=7)
    args = p.parse_args()

    colmap_q = args.seq_len * 3 // 4
    depth_q = args.seq_len - colmap_q
    print(f"seq_len={args.seq_len}  colmap={colmap_q}  depth={depth_q}")

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    np.random.seed(args.seed)

    files = sorted(Path(args.in_dir).glob("*.pt"))
    print(f"Found {len(files)} .pt files in {args.in_dir}")

    done = 0
    t0 = time.perf_counter()
    for f in files:
        out_f = out_dir / (f.stem + ".npz")
        if out_f.exists():
            done += 1
            continue
        d = torch.load(f, weights_only=False)
        result = _process_sample(d, args.seq_len, colmap_q, depth_q)
        np.savez(out_f, **result)
        done += 1
        if done % 2000 == 0:
            rate = done / (time.perf_counter() - t0)
            print(f"  {done}/{len(files)} [{rate:.0f}/s]")

    elapsed = time.perf_counter() - t0
    print(f"Done. {done} files in {elapsed:.0f}s -> {out_dir}")


if __name__ == "__main__":
    main()