"""S23DR 2026 submission: learned wireframe prediction from fused point clouds. Pipeline: raw sample -> point fusion -> priority sample 2048 -> model -> post-process -> wireframe """ from pathlib import Path from tqdm import tqdm import json import os import sys import time import numpy as np import torch def empty_solution(): return np.zeros((2, 3)), [(0, 1)] # --------------------------------------------------------------------------- # Point fusion + sampling (from cache_scenes.py / make_sampled_cache.py) # --------------------------------------------------------------------------- # Add our package to path SCRIPT_DIR = Path(__file__).resolve().parent sys.path.insert(0, str(SCRIPT_DIR)) from s23dr_2026_example.point_fusion import build_compact_scene, FuserConfig from s23dr_2026_example.cache_scenes import ( _compute_group_and_class, _compute_smart_center_scale, ) from s23dr_2026_example.make_sampled_cache import _priority_sample # Tokenizer / model imports from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig from s23dr_2026_example.model import EdgeDepthSegmentsModel from s23dr_2026_example.segment_postprocess import merge_vertices_iterative from s23dr_2026_example.varifold import segments_to_vertices_edges from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal SEQ_LEN = 4096 COLMAP_QUOTA = 3072 DEPTH_QUOTA = 1024 CONF_THRESH = 0.5 MERGE_THRESH = 0.4 SNAP_RADIUS = 0.5 def fuse_and_sample(sample, cfg, rng): """Run point fusion + priority sampling on a raw dataset sample. Returns a dict with xyz_norm, class_id, source, mask, center, scale, etc. ready for model inference. Returns None if fusion fails. """ try: scene = build_compact_scene(sample, cfg, rng) except Exception as e: print(f" Fusion failed: {e}") return None xyz = scene["xyz"] source = scene["source"] if len(xyz) < 10: return None # Compute group_id and class_id (same as cache_scenes.py) behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16)) group_id, class_id = _compute_group_and_class( scene["visible_src"], scene["visible_id"], behind_id, source) # Normalize center, scale = _compute_smart_center_scale(xyz, source) # Priority sample indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA) xyz_norm = (xyz[indices] - center) / scale result = { "xyz_norm": xyz_norm.astype(np.float32), "class_id": class_id[indices].astype(np.int64), "source": source[indices].astype(np.int64), "mask": mask, "center": center.astype(np.float32), "scale": np.float32(scale), } # Optional fields if "behind_gest_id" in scene: behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None) result["behind"] = behind.astype(np.int64) if "n_views_voted" in scene: result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32) if "vote_frac" in scene: result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32) # Visible src/id for snap post-processing result["visible_src"] = scene["visible_src"][indices].astype(np.int64) result["visible_id"] = scene["visible_id"][indices].astype(np.int64) return result def load_model(checkpoint_path, device): """Load model from checkpoint.""" ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) args = ckpt.get("args", {}) norm_class = torch.nn.RMSNorm if args.get("rms_norm") else None seq_cfg = EdgeDepthSequenceConfig( seq_len=SEQ_LEN, colmap_points=COLMAP_QUOTA, depth_points=DEPTH_QUOTA) model = EdgeDepthSegmentsModel( seq_cfg=seq_cfg, segments=args.get("segments", 64), hidden=args.get("hidden", 256), num_heads=args.get("num_heads", 4), kv_heads_cross=args.get("kv_heads_cross", 2), kv_heads_self=args.get("kv_heads_self", 2), dim_feedforward=args.get("ff", 1024), dropout=args.get("dropout", 0.1), latent_tokens=args.get("latent_tokens", 256), latent_layers=args.get("latent_layers", 7), decoder_layers=args.get("decoder_layers", 3), cross_attn_interval=args.get("cross_attn_interval", 4), norm_class=norm_class, activation=args.get("activation", "gelu"), segment_conf=args.get("segment_conf", True), behind_emb_dim=args.get("behind_emb_dim", 8), use_vote_features=args.get("vote_features", True), arch=args.get("arch", "perceiver"), encoder_layers=args.get("encoder_layers", 4), pre_encoder_layers=args.get("pre_encoder_layers", 0), segment_param=args.get("segment_param", "midpoint_dir_len"), qk_norm=args.get("qk_norm", True), ).to(device) # Handle torch.compile _orig_mod prefix state = ckpt["model"] fixed = {k.replace("segmenter._orig_mod.", "segmenter."): v for k, v in state.items()} model.load_state_dict(fixed, strict=True) model.eval() return model def build_tokens_single(sample_dict, model, device): """Build token tensor for a single sample (no DataLoader).""" xyz = torch.as_tensor(sample_dict["xyz_norm"], dtype=torch.float32).unsqueeze(0).to(device) cid = torch.as_tensor(sample_dict["class_id"], dtype=torch.long).unsqueeze(0).to(device) src = torch.as_tensor(sample_dict["source"], dtype=torch.long).unsqueeze(0).to(device) masks = torch.as_tensor(sample_dict["mask"], dtype=torch.bool).unsqueeze(0).to(device) B, T, _ = xyz.shape tok = model.tokenizer fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \ if tok.pos_enc is not None else xyz.new_zeros(B, T, 0) parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))] if tok.behind_emb_dim > 0: if "behind" in sample_dict: beh = torch.as_tensor(sample_dict["behind"], dtype=torch.long).unsqueeze(0).to(device) else: beh = xyz.new_zeros(B, T, dtype=torch.long) parts.append(tok.behind_emb(beh)) if tok.use_vote_features: if "n_views_voted" in sample_dict and "vote_frac" in sample_dict: nv = ((torch.as_tensor(sample_dict["n_views_voted"], dtype=torch.float32).unsqueeze(0).to(device) - 2.7) / 1.0).unsqueeze(-1) vf = ((torch.as_tensor(sample_dict["vote_frac"], dtype=torch.float32).unsqueeze(0).to(device) - 0.5) / 0.25).unsqueeze(-1) parts.extend([nv, vf]) else: parts.extend([xyz.new_zeros(B, T, 1), xyz.new_zeros(B, T, 1)]) tokens = torch.cat(parts, dim=-1) return tokens, masks def predict_sample(sample_dict, model, device): """Run model inference + post-processing on a fused sample. Returns (vertices, edges) in world space. """ tokens, masks = build_tokens_single(sample_dict, model, device) scale = float(sample_dict["scale"]) center = sample_dict["center"] with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16, enabled=(device.type == 'cuda')): out = model.forward_tokens(tokens, masks) segs = out["segments"][0].float().cpu() conf = torch.sigmoid(out["conf"][0].float()).cpu().numpy() if "conf" in out else None # Confidence filter if conf is not None: keep = conf > CONF_THRESH segs = segs[keep] if len(segs) < 1: return empty_solution() # To world space segs_world = segs.numpy() * scale + center # Vertices + edges from segments pv, pe = segments_to_vertices_edges(torch.tensor(segs_world)) pv, pe = pv.numpy(), np.array(pe, dtype=np.int32) # Merge pv, pe = merge_vertices_iterative(pv, pe) # Snap to point cloud xyz_norm = sample_dict["xyz_norm"] mask = sample_dict["mask"] cid = sample_dict["class_id"] xyz_world = xyz_norm[mask] * scale + center cid_valid = cid[mask] pv = snap_to_point_cloud(pv, xyz_world, cid_valid, snap_radius=SNAP_RADIUS) # Horizontal snap pv = snap_horizontal(pv, pe) if len(pv) < 2 or len(pe) < 1: return empty_solution() edges = [(int(a), int(b)) for a, b in pe] return pv, edges # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": t_start = time.time() # Load params param_path = Path("params.json") with param_path.open() as f: params = json.load(f) print(f"Competition: {params.get('competition_id', '?')}") print(f"Dataset: {params.get('dataset', '?')}") # Load test data data_path = Path("/tmp/data") if not data_path.exists(): from huggingface_hub import snapshot_download snapshot_download( repo_id=params["dataset"], local_dir="/tmp/data", repo_type="dataset", ) from datasets import load_dataset data_files = { "validation": [str(p) for p in data_path.rglob("*public*/**/*.tar")], "test": [str(p) for p in data_path.rglob("*private*/**/*.tar")], } print(f"Data files: {data_files}") dataset = load_dataset( str(data_path / "hoho22k_2026_test_x_anon.py"), data_files=data_files, trust_remote_code=True, writer_batch_size=100, ) print(f"Loaded: {dataset}") # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") checkpoint_path = SCRIPT_DIR / "checkpoint.pt" model = load_model(checkpoint_path, device) print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params") # Point fusion config cfg = FuserConfig() rng = np.random.RandomState(2718) # Process all samples solution = [] total_samples = sum(len(dataset[s]) for s in dataset) processed = 0 for subset_name in dataset: print(f"\nProcessing {subset_name} ({len(dataset[subset_name])} samples)...") for sample in tqdm(dataset[subset_name], desc=subset_name): order_id = sample["order_id"] # Fuse + sample fused = fuse_and_sample(sample, cfg, rng) if fused is None: pred_v, pred_e = empty_solution() else: try: pred_v, pred_e = predict_sample(fused, model, device) except Exception as e: print(f" Predict failed for {order_id}: {e}") pred_v, pred_e = empty_solution() solution.append({ "order_id": order_id, "wf_vertices": pred_v.tolist() if isinstance(pred_v, np.ndarray) else pred_v, "wf_edges": [(int(a), int(b)) for a, b in pred_e], }) processed += 1 if processed % 50 == 0: elapsed = time.time() - t_start rate = elapsed / processed remaining = (total_samples - processed) * rate print(f" [{processed}/{total_samples}] " f"{elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining") # Save with open("submission.json", "w") as f: json.dump(solution, f) elapsed = time.time() - t_start print(f"\nDone. {processed} samples in {elapsed:.0f}s ({elapsed/max(processed,1):.1f}s/sample)") print(f"Saved submission.json ({len(solution)} entries)")