Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
jacklangerman's picture
4096-release (#1)
0f31e57
"""S23DR 2026 submission: learned wireframe prediction from fused point clouds.
Pipeline: raw sample -> point fusion -> priority sample 2048 -> model -> post-process -> wireframe
"""
from pathlib import Path
from tqdm import tqdm
import json
import os
import sys
import time
import numpy as np
import torch
def empty_solution():
return np.zeros((2, 3)), [(0, 1)]
# ---------------------------------------------------------------------------
# Point fusion + sampling (from cache_scenes.py / make_sampled_cache.py)
# ---------------------------------------------------------------------------
# Add our package to path
SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(SCRIPT_DIR))
from s23dr_2026_example.point_fusion import build_compact_scene, FuserConfig
from s23dr_2026_example.cache_scenes import (
_compute_group_and_class, _compute_smart_center_scale,
)
from s23dr_2026_example.make_sampled_cache import _priority_sample
# Tokenizer / model imports
from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig
from s23dr_2026_example.model import EdgeDepthSegmentsModel
from s23dr_2026_example.segment_postprocess import merge_vertices_iterative
from s23dr_2026_example.varifold import segments_to_vertices_edges
from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
SEQ_LEN = 4096
COLMAP_QUOTA = 3072
DEPTH_QUOTA = 1024
CONF_THRESH = 0.5
MERGE_THRESH = 0.4
SNAP_RADIUS = 0.5
def fuse_and_sample(sample, cfg, rng):
"""Run point fusion + priority sampling on a raw dataset sample.
Returns a dict with xyz_norm, class_id, source, mask, center, scale, etc.
ready for model inference. Returns None if fusion fails.
"""
try:
scene = build_compact_scene(sample, cfg, rng)
except Exception as e:
print(f" Fusion failed: {e}")
return None
xyz = scene["xyz"]
source = scene["source"]
if len(xyz) < 10:
return None
# Compute group_id and class_id (same as cache_scenes.py)
behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16))
group_id, class_id = _compute_group_and_class(
scene["visible_src"], scene["visible_id"], behind_id, source)
# Normalize
center, scale = _compute_smart_center_scale(xyz, source)
# Priority sample
indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA)
xyz_norm = (xyz[indices] - center) / scale
result = {
"xyz_norm": xyz_norm.astype(np.float32),
"class_id": class_id[indices].astype(np.int64),
"source": source[indices].astype(np.int64),
"mask": mask,
"center": center.astype(np.float32),
"scale": np.float32(scale),
}
# Optional fields
if "behind_gest_id" in scene:
behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None)
result["behind"] = behind.astype(np.int64)
if "n_views_voted" in scene:
result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32)
if "vote_frac" in scene:
result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32)
# Visible src/id for snap post-processing
result["visible_src"] = scene["visible_src"][indices].astype(np.int64)
result["visible_id"] = scene["visible_id"][indices].astype(np.int64)
return result
def load_model(checkpoint_path, device):
"""Load model from checkpoint."""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
args = ckpt.get("args", {})
norm_class = torch.nn.RMSNorm if args.get("rms_norm") else None
seq_cfg = EdgeDepthSequenceConfig(
seq_len=SEQ_LEN, colmap_points=COLMAP_QUOTA, depth_points=DEPTH_QUOTA)
model = EdgeDepthSegmentsModel(
seq_cfg=seq_cfg,
segments=args.get("segments", 64),
hidden=args.get("hidden", 256),
num_heads=args.get("num_heads", 4),
kv_heads_cross=args.get("kv_heads_cross", 2),
kv_heads_self=args.get("kv_heads_self", 2),
dim_feedforward=args.get("ff", 1024),
dropout=args.get("dropout", 0.1),
latent_tokens=args.get("latent_tokens", 256),
latent_layers=args.get("latent_layers", 7),
decoder_layers=args.get("decoder_layers", 3),
cross_attn_interval=args.get("cross_attn_interval", 4),
norm_class=norm_class,
activation=args.get("activation", "gelu"),
segment_conf=args.get("segment_conf", True),
behind_emb_dim=args.get("behind_emb_dim", 8),
use_vote_features=args.get("vote_features", True),
arch=args.get("arch", "perceiver"),
encoder_layers=args.get("encoder_layers", 4),
pre_encoder_layers=args.get("pre_encoder_layers", 0),
segment_param=args.get("segment_param", "midpoint_dir_len"),
qk_norm=args.get("qk_norm", True),
).to(device)
# Handle torch.compile _orig_mod prefix
state = ckpt["model"]
fixed = {k.replace("segmenter._orig_mod.", "segmenter."): v
for k, v in state.items()}
model.load_state_dict(fixed, strict=True)
model.eval()
return model
def build_tokens_single(sample_dict, model, device):
"""Build token tensor for a single sample (no DataLoader)."""
xyz = torch.as_tensor(sample_dict["xyz_norm"], dtype=torch.float32).unsqueeze(0).to(device)
cid = torch.as_tensor(sample_dict["class_id"], dtype=torch.long).unsqueeze(0).to(device)
src = torch.as_tensor(sample_dict["source"], dtype=torch.long).unsqueeze(0).to(device)
masks = torch.as_tensor(sample_dict["mask"], dtype=torch.bool).unsqueeze(0).to(device)
B, T, _ = xyz.shape
tok = model.tokenizer
fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \
if tok.pos_enc is not None else xyz.new_zeros(B, T, 0)
parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))]
if tok.behind_emb_dim > 0:
if "behind" in sample_dict:
beh = torch.as_tensor(sample_dict["behind"], dtype=torch.long).unsqueeze(0).to(device)
else:
beh = xyz.new_zeros(B, T, dtype=torch.long)
parts.append(tok.behind_emb(beh))
if tok.use_vote_features:
if "n_views_voted" in sample_dict and "vote_frac" in sample_dict:
nv = ((torch.as_tensor(sample_dict["n_views_voted"], dtype=torch.float32).unsqueeze(0).to(device) - 2.7) / 1.0).unsqueeze(-1)
vf = ((torch.as_tensor(sample_dict["vote_frac"], dtype=torch.float32).unsqueeze(0).to(device) - 0.5) / 0.25).unsqueeze(-1)
parts.extend([nv, vf])
else:
parts.extend([xyz.new_zeros(B, T, 1), xyz.new_zeros(B, T, 1)])
tokens = torch.cat(parts, dim=-1)
return tokens, masks
def predict_sample(sample_dict, model, device):
"""Run model inference + post-processing on a fused sample.
Returns (vertices, edges) in world space.
"""
tokens, masks = build_tokens_single(sample_dict, model, device)
scale = float(sample_dict["scale"])
center = sample_dict["center"]
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
enabled=(device.type == 'cuda')):
out = model.forward_tokens(tokens, masks)
segs = out["segments"][0].float().cpu()
conf = torch.sigmoid(out["conf"][0].float()).cpu().numpy() if "conf" in out else None
# Confidence filter
if conf is not None:
keep = conf > CONF_THRESH
segs = segs[keep]
if len(segs) < 1:
return empty_solution()
# To world space
segs_world = segs.numpy() * scale + center
# Vertices + edges from segments
pv, pe = segments_to_vertices_edges(torch.tensor(segs_world))
pv, pe = pv.numpy(), np.array(pe, dtype=np.int32)
# Merge
pv, pe = merge_vertices_iterative(pv, pe)
# Snap to point cloud
xyz_norm = sample_dict["xyz_norm"]
mask = sample_dict["mask"]
cid = sample_dict["class_id"]
xyz_world = xyz_norm[mask] * scale + center
cid_valid = cid[mask]
pv = snap_to_point_cloud(pv, xyz_world, cid_valid, snap_radius=SNAP_RADIUS)
# Horizontal snap
pv = snap_horizontal(pv, pe)
if len(pv) < 2 or len(pe) < 1:
return empty_solution()
edges = [(int(a), int(b)) for a, b in pe]
return pv, edges
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
t_start = time.time()
# Load params
param_path = Path("params.json")
with param_path.open() as f:
params = json.load(f)
print(f"Competition: {params.get('competition_id', '?')}")
print(f"Dataset: {params.get('dataset', '?')}")
# Load test data
data_path = Path("/tmp/data")
if not data_path.exists():
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=params["dataset"],
local_dir="/tmp/data",
repo_type="dataset",
)
from datasets import load_dataset
data_files = {
"validation": [str(p) for p in data_path.rglob("*public*/**/*.tar")],
"test": [str(p) for p in data_path.rglob("*private*/**/*.tar")],
}
print(f"Data files: {data_files}")
dataset = load_dataset(
str(data_path / "hoho22k_2026_test_x_anon.py"),
data_files=data_files,
trust_remote_code=True,
writer_batch_size=100,
)
print(f"Loaded: {dataset}")
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
checkpoint_path = SCRIPT_DIR / "checkpoint.pt"
model = load_model(checkpoint_path, device)
print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
# Point fusion config
cfg = FuserConfig()
rng = np.random.RandomState(2718)
# Process all samples
solution = []
total_samples = sum(len(dataset[s]) for s in dataset)
processed = 0
for subset_name in dataset:
print(f"\nProcessing {subset_name} ({len(dataset[subset_name])} samples)...")
for sample in tqdm(dataset[subset_name], desc=subset_name):
order_id = sample["order_id"]
# Fuse + sample
fused = fuse_and_sample(sample, cfg, rng)
if fused is None:
pred_v, pred_e = empty_solution()
else:
try:
pred_v, pred_e = predict_sample(fused, model, device)
except Exception as e:
print(f" Predict failed for {order_id}: {e}")
pred_v, pred_e = empty_solution()
solution.append({
"order_id": order_id,
"wf_vertices": pred_v.tolist() if isinstance(pred_v, np.ndarray) else pred_v,
"wf_edges": [(int(a), int(b)) for a, b in pred_e],
})
processed += 1
if processed % 50 == 0:
elapsed = time.time() - t_start
rate = elapsed / processed
remaining = (total_samples - processed) * rate
print(f" [{processed}/{total_samples}] "
f"{elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining")
# Save
with open("submission.json", "w") as f:
json.dump(solution, f)
elapsed = time.time() - t_start
print(f"\nDone. {processed} samples in {elapsed:.0f}s ({elapsed/max(processed,1):.1f}s/sample)")
print(f"Saved submission.json ({len(solution)} entries)")