| """S23DR 2026 submission: learned wireframe prediction from fused point clouds. |
| |
| Pipeline: raw sample -> point fusion -> priority sample 2048 -> model -> post-process -> wireframe |
| """ |
| from pathlib import Path |
| from tqdm import tqdm |
| import json |
| import os |
| import sys |
| import time |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def empty_solution(): |
| return np.zeros((2, 3)), [(0, 1)] |
|
|
|
|
| |
| |
| |
|
|
| |
| SCRIPT_DIR = Path(__file__).resolve().parent |
| sys.path.insert(0, str(SCRIPT_DIR)) |
|
|
| from s23dr_2026_example.point_fusion import build_compact_scene, FuserConfig |
| from s23dr_2026_example.cache_scenes import ( |
| _compute_group_and_class, _compute_smart_center_scale, |
| ) |
| from s23dr_2026_example.make_sampled_cache import _priority_sample |
|
|
| |
| from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig |
| from s23dr_2026_example.model import EdgeDepthSegmentsModel |
| from s23dr_2026_example.segment_postprocess import merge_vertices_iterative |
| from s23dr_2026_example.varifold import segments_to_vertices_edges |
| from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal |
|
|
| SEQ_LEN = 4096 |
| COLMAP_QUOTA = 3072 |
| DEPTH_QUOTA = 1024 |
| CONF_THRESH = 0.5 |
| MERGE_THRESH = 0.4 |
| SNAP_RADIUS = 0.5 |
|
|
|
|
| def fuse_and_sample(sample, cfg, rng): |
| """Run point fusion + priority sampling on a raw dataset sample. |
| |
| Returns a dict with xyz_norm, class_id, source, mask, center, scale, etc. |
| ready for model inference. Returns None if fusion fails. |
| """ |
| try: |
| scene = build_compact_scene(sample, cfg, rng) |
| except Exception as e: |
| print(f" Fusion failed: {e}") |
| return None |
|
|
| xyz = scene["xyz"] |
| source = scene["source"] |
|
|
| if len(xyz) < 10: |
| return None |
|
|
| |
| behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16)) |
| group_id, class_id = _compute_group_and_class( |
| scene["visible_src"], scene["visible_id"], behind_id, source) |
|
|
| |
| center, scale = _compute_smart_center_scale(xyz, source) |
|
|
| |
| indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA) |
|
|
| xyz_norm = (xyz[indices] - center) / scale |
|
|
| result = { |
| "xyz_norm": xyz_norm.astype(np.float32), |
| "class_id": class_id[indices].astype(np.int64), |
| "source": source[indices].astype(np.int64), |
| "mask": mask, |
| "center": center.astype(np.float32), |
| "scale": np.float32(scale), |
| } |
|
|
| |
| if "behind_gest_id" in scene: |
| behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None) |
| result["behind"] = behind.astype(np.int64) |
| if "n_views_voted" in scene: |
| result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32) |
| if "vote_frac" in scene: |
| result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32) |
|
|
| |
| result["visible_src"] = scene["visible_src"][indices].astype(np.int64) |
| result["visible_id"] = scene["visible_id"][indices].astype(np.int64) |
|
|
| return result |
|
|
|
|
| def load_model(checkpoint_path, device): |
| """Load model from checkpoint.""" |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| args = ckpt.get("args", {}) |
|
|
| norm_class = torch.nn.RMSNorm if args.get("rms_norm") else None |
| seq_cfg = EdgeDepthSequenceConfig( |
| seq_len=SEQ_LEN, colmap_points=COLMAP_QUOTA, depth_points=DEPTH_QUOTA) |
|
|
| model = EdgeDepthSegmentsModel( |
| seq_cfg=seq_cfg, |
| segments=args.get("segments", 64), |
| hidden=args.get("hidden", 256), |
| num_heads=args.get("num_heads", 4), |
| kv_heads_cross=args.get("kv_heads_cross", 2), |
| kv_heads_self=args.get("kv_heads_self", 2), |
| dim_feedforward=args.get("ff", 1024), |
| dropout=args.get("dropout", 0.1), |
| latent_tokens=args.get("latent_tokens", 256), |
| latent_layers=args.get("latent_layers", 7), |
| decoder_layers=args.get("decoder_layers", 3), |
| cross_attn_interval=args.get("cross_attn_interval", 4), |
| norm_class=norm_class, |
| activation=args.get("activation", "gelu"), |
| segment_conf=args.get("segment_conf", True), |
| behind_emb_dim=args.get("behind_emb_dim", 8), |
| use_vote_features=args.get("vote_features", True), |
| arch=args.get("arch", "perceiver"), |
| encoder_layers=args.get("encoder_layers", 4), |
| pre_encoder_layers=args.get("pre_encoder_layers", 0), |
| segment_param=args.get("segment_param", "midpoint_dir_len"), |
| qk_norm=args.get("qk_norm", True), |
| ).to(device) |
|
|
| |
| state = ckpt["model"] |
| fixed = {k.replace("segmenter._orig_mod.", "segmenter."): v |
| for k, v in state.items()} |
| model.load_state_dict(fixed, strict=True) |
| model.eval() |
| return model |
|
|
|
|
| def build_tokens_single(sample_dict, model, device): |
| """Build token tensor for a single sample (no DataLoader).""" |
| xyz = torch.as_tensor(sample_dict["xyz_norm"], dtype=torch.float32).unsqueeze(0).to(device) |
| cid = torch.as_tensor(sample_dict["class_id"], dtype=torch.long).unsqueeze(0).to(device) |
| src = torch.as_tensor(sample_dict["source"], dtype=torch.long).unsqueeze(0).to(device) |
| masks = torch.as_tensor(sample_dict["mask"], dtype=torch.bool).unsqueeze(0).to(device) |
|
|
| B, T, _ = xyz.shape |
| tok = model.tokenizer |
| fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \ |
| if tok.pos_enc is not None else xyz.new_zeros(B, T, 0) |
| parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))] |
|
|
| if tok.behind_emb_dim > 0: |
| if "behind" in sample_dict: |
| beh = torch.as_tensor(sample_dict["behind"], dtype=torch.long).unsqueeze(0).to(device) |
| else: |
| beh = xyz.new_zeros(B, T, dtype=torch.long) |
| parts.append(tok.behind_emb(beh)) |
|
|
| if tok.use_vote_features: |
| if "n_views_voted" in sample_dict and "vote_frac" in sample_dict: |
| nv = ((torch.as_tensor(sample_dict["n_views_voted"], dtype=torch.float32).unsqueeze(0).to(device) - 2.7) / 1.0).unsqueeze(-1) |
| vf = ((torch.as_tensor(sample_dict["vote_frac"], dtype=torch.float32).unsqueeze(0).to(device) - 0.5) / 0.25).unsqueeze(-1) |
| parts.extend([nv, vf]) |
| else: |
| parts.extend([xyz.new_zeros(B, T, 1), xyz.new_zeros(B, T, 1)]) |
|
|
| tokens = torch.cat(parts, dim=-1) |
| return tokens, masks |
|
|
|
|
| def predict_sample(sample_dict, model, device): |
| """Run model inference + post-processing on a fused sample. |
| |
| Returns (vertices, edges) in world space. |
| """ |
| tokens, masks = build_tokens_single(sample_dict, model, device) |
| scale = float(sample_dict["scale"]) |
| center = sample_dict["center"] |
|
|
| with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16, |
| enabled=(device.type == 'cuda')): |
| out = model.forward_tokens(tokens, masks) |
|
|
| segs = out["segments"][0].float().cpu() |
| conf = torch.sigmoid(out["conf"][0].float()).cpu().numpy() if "conf" in out else None |
|
|
| |
| if conf is not None: |
| keep = conf > CONF_THRESH |
| segs = segs[keep] |
| if len(segs) < 1: |
| return empty_solution() |
|
|
| |
| segs_world = segs.numpy() * scale + center |
|
|
| |
| pv, pe = segments_to_vertices_edges(torch.tensor(segs_world)) |
| pv, pe = pv.numpy(), np.array(pe, dtype=np.int32) |
|
|
| |
| pv, pe = merge_vertices_iterative(pv, pe) |
|
|
| |
| xyz_norm = sample_dict["xyz_norm"] |
| mask = sample_dict["mask"] |
| cid = sample_dict["class_id"] |
| xyz_world = xyz_norm[mask] * scale + center |
| cid_valid = cid[mask] |
| pv = snap_to_point_cloud(pv, xyz_world, cid_valid, snap_radius=SNAP_RADIUS) |
|
|
| |
| pv = snap_horizontal(pv, pe) |
|
|
| if len(pv) < 2 or len(pe) < 1: |
| return empty_solution() |
|
|
| edges = [(int(a), int(b)) for a, b in pe] |
| return pv, edges |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| t_start = time.time() |
|
|
| |
| param_path = Path("params.json") |
| with param_path.open() as f: |
| params = json.load(f) |
| print(f"Competition: {params.get('competition_id', '?')}") |
| print(f"Dataset: {params.get('dataset', '?')}") |
|
|
| |
| data_path = Path("/tmp/data") |
| if not data_path.exists(): |
| from huggingface_hub import snapshot_download |
| snapshot_download( |
| repo_id=params["dataset"], |
| local_dir="/tmp/data", |
| repo_type="dataset", |
| ) |
|
|
| from datasets import load_dataset |
| data_files = { |
| "validation": [str(p) for p in data_path.rglob("*public*/**/*.tar")], |
| "test": [str(p) for p in data_path.rglob("*private*/**/*.tar")], |
| } |
| print(f"Data files: {data_files}") |
| dataset = load_dataset( |
| str(data_path / "hoho22k_2026_test_x_anon.py"), |
| data_files=data_files, |
| trust_remote_code=True, |
| writer_batch_size=100, |
| ) |
| print(f"Loaded: {dataset}") |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| checkpoint_path = SCRIPT_DIR / "checkpoint.pt" |
| model = load_model(checkpoint_path, device) |
| print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params") |
|
|
| |
| cfg = FuserConfig() |
| rng = np.random.RandomState(2718) |
|
|
| |
| solution = [] |
| total_samples = sum(len(dataset[s]) for s in dataset) |
| processed = 0 |
|
|
| for subset_name in dataset: |
| print(f"\nProcessing {subset_name} ({len(dataset[subset_name])} samples)...") |
|
|
| for sample in tqdm(dataset[subset_name], desc=subset_name): |
| order_id = sample["order_id"] |
|
|
| |
| fused = fuse_and_sample(sample, cfg, rng) |
| if fused is None: |
| pred_v, pred_e = empty_solution() |
| else: |
| try: |
| pred_v, pred_e = predict_sample(fused, model, device) |
| except Exception as e: |
| print(f" Predict failed for {order_id}: {e}") |
| pred_v, pred_e = empty_solution() |
|
|
| solution.append({ |
| "order_id": order_id, |
| "wf_vertices": pred_v.tolist() if isinstance(pred_v, np.ndarray) else pred_v, |
| "wf_edges": [(int(a), int(b)) for a, b in pred_e], |
| }) |
| processed += 1 |
|
|
| if processed % 50 == 0: |
| elapsed = time.time() - t_start |
| rate = elapsed / processed |
| remaining = (total_samples - processed) * rate |
| print(f" [{processed}/{total_samples}] " |
| f"{elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining") |
|
|
| |
| with open("submission.json", "w") as f: |
| json.dump(solution, f) |
|
|
| elapsed = time.time() - t_start |
| print(f"\nDone. {processed} samples in {elapsed:.0f}s ({elapsed/max(processed,1):.1f}s/sample)") |
| print(f"Saved submission.json ({len(solution)} entries)") |
|
|