#!/usr/bin/env python3 """ Training script for S23DR 2026. Usage: python -m s23dr_2026_example.train --cache-dir hf://usm3d/s23dr-2026-sampled_2048_v2:train --steps 80000 --aug-rotate """ from __future__ import annotations import sys from pathlib import Path as _Path if __package__ is None or __package__ == "": _here = _Path(__file__).resolve().parent if str(_here.parent) not in sys.path: sys.path.insert(0, str(_here.parent)) __package__ = _here.name import argparse import gc import json import math import subprocess import time from pathlib import Path import numpy as np import torch from .tokenizer import EdgeDepthSequenceConfig from .model import EdgeDepthSegmentsModel from .data import build_loader, build_tokens from .losses import compute_loss, _loss_inner # Re-export for eval scripts from .data import HFCachedDataset, collate as _collate # noqa: F401 # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): p = argparse.ArgumentParser(description="S23DR 2026 training") p.add_argument("--cache-dir", default=None, help="HF dataset path (hf://repo:split)") p.add_argument("--val-cache-dir", default="", help="Separate cache for validation") p.add_argument("--seq-len", type=int, default=2048, help="Input sequence length (2048 or 4096, must match dataset)") p.add_argument("--arch", choices=["perceiver", "transformer"], default="perceiver", help="perceiver=latent bottleneck, transformer=full self-attention encoder") p.add_argument("--segments", type=int, default=32) p.add_argument("--hidden", type=int, default=128) p.add_argument("--ff", type=int, default=512) p.add_argument("--latent-tokens", type=int, default=128) p.add_argument("--latent-layers", type=int, default=7) p.add_argument("--encoder-layers", type=int, default=4, help="Encoder layers (transformer arch only)") p.add_argument("--pre-encoder-layers", type=int, default=0, help="Self-attn layers on full token sequence before perceiver bottleneck") p.add_argument("--decoder-layers", type=int, default=3) p.add_argument("--decoder-input-xattn", action="store_true", help="Add cross-attention from segment queries to input tokens in each decoder layer") p.add_argument("--qk-norm", action="store_true", help="Normalize Q and K per-head with learned temperature (stabilizes wide models)") p.add_argument("--qk-norm-type", choices=["l2", "rms"], default="l2", help="QK-norm type: l2 (unit sphere) or rms (RMSNorm, preserves magnitudes)") p.add_argument("--learnable-fourier", action="store_true", help="Make Fourier positional encoding learnable (vs fixed random)") p.add_argument("--num-heads", type=int, default=4, help="Attention heads") p.add_argument("--kv-heads-cross", type=int, default=2, help="KV heads for cross-attention (GQA; 0 = standard MHA)") p.add_argument("--kv-heads-self", type=int, default=2, help="KV heads for self-attention (GQA; 0 = standard MHA)") p.add_argument("--cross-attn-interval", type=int, default=4, help="Perceiver cross-attention frequency (every N latent layers)") p.add_argument("--dropout", type=float, default=0.1) p.add_argument("--weight-decay", type=float, default=0.01, help="AdamW weight decay") p.add_argument("--steps", type=int, default=5000) p.add_argument("--batch-size", type=int, default=32) p.add_argument("--lr", type=float, default=3e-4) p.add_argument("--adam-betas", default="0.9,0.95", help="AdamW beta1,beta2") p.add_argument("--warmup", type=int, default=200, help="LR warmup steps") p.add_argument("--cosine-decay", action="store_true", help="Cosine decay LR after warmup (to lr*0.01 at end)") p.add_argument("--cooldown-start", type=int, default=0, help="Step to begin linear cooldown to lr*0.01 (0=disabled, constant LR after warmup)") p.add_argument("--cooldown-steps", type=int, default=0, help="Number of steps for linear cooldown (0=no cooldown)") p.add_argument("--seed", type=int, default=7) p.add_argument("--deterministic", action="store_true", help="Force deterministic mode (disables torch.compile, slower but bit-reproducible)") p.add_argument("--varifold-weight", type=float, default=0.0) p.add_argument("--varifold-cross-only", action="store_true", help="Drop varifold self-energy (avoids O(S^2) spike, sinkhorn handles repulsion)") p.add_argument("--sinkhorn-weight", type=float, default=1.0) p.add_argument("--sinkhorn-eps", type=float, default=0.1, help="Sinkhorn regularization (larger = softer matching, stronger gradients)") p.add_argument("--sinkhorn-eps-start", type=float, default=None, help="Starting eps for epsilon annealing (anneals to --sinkhorn-eps). None=no annealing.") p.add_argument("--sinkhorn-eps-schedule", choices=["linear", "sqrt", "none"], default="none", help="Eps annealing schedule: linear, sqrt, or none (default: no annealing)") p.add_argument("--sinkhorn-iters", type=int, default=20, help="Sinkhorn iterations") p.add_argument("--sinkhorn-dustbin", type=float, default=0.3, help="Sinkhorn dustbin cost in normalized space") p.add_argument("--endpoint-weight", type=float, default=0.0, help="Weight for endpoint distance loss (sinkhorn-matched, symmetric)") p.add_argument("--endpoint-warmup", type=int, default=0, help="Steps to linearly warm up endpoint weight from 0 (0=instant)") p.add_argument("--aug-rotate", action="store_true") p.add_argument("--aug-jitter", type=float, default=0.0, help="Point position jitter std in normalized space (0=disabled, try 0.005)") p.add_argument("--aug-drop", type=float, default=0.0, help="Fraction of points to randomly drop (0=disabled, try 0.1)") p.add_argument("--aug-flip", action="store_true", help="Random mirror along X axis (50%% chance)") p.add_argument("--rms-norm", action="store_true", default=True, help="Use RMSNorm (default). Use --no-rms-norm for LayerNorm") p.add_argument("--no-rms-norm", dest="rms_norm", action="store_false") p.add_argument("--activation", default="gelu", help="FFN activation: gelu, relu, relu_sq") p.add_argument("--behind-emb-dim", type=int, default=8, help="Behind-gestalt embedding dim (0 to disable)") p.add_argument("--vote-features", action="store_true", help="Add n_views_voted + vote_frac as raw token features (requires v2 data)") p.add_argument("--segment-param", choices=["midpoint_halfvec", "midpoint_dir_len"], default="midpoint_halfvec", help="Output parameterization: halfvec (default) or decoupled direction+length") p.add_argument("--length-floor", type=float, default=0.0, help="Minimum segment length for midpoint_dir_len (0=no floor)") p.add_argument("--segment-conf", action="store_true", help="Add per-segment confidence head (use with --conf-thresh at eval)") p.add_argument("--conf-weight", type=float, default=0.0, help="Weight for confidence loss (requires --segment-conf)") p.add_argument("--conf-mode", choices=["sinkhorn", "sinkhorn_detach"], default="sinkhorn", help="Confidence training: 'match'=BCE, 'sinkhorn'=OT mass, 'sinkhorn_detach'=OT mass (detached)") p.add_argument("--conf-clamp-min", type=float, default=None, help="Clamp conf logits to this minimum before sigmoid (e.g., -5)") p.add_argument("--conf-head-wd", type=float, default=None, help="Separate weight decay for conf head (default: same as other params)") p.add_argument("--ema-decay", type=float, default=0.0, help="EMA decay rate (0=disabled, try 0.9999). Saves EMA weights in checkpoints.") p.add_argument("--out-dir", default=str(_Path(__file__).resolve().parent / "runs")) p.add_argument("--resume", default="") p.add_argument("--cpu", action="store_true") p.add_argument("--args-from", default=None, help="Load defaults from a run's args.json (CLI flags override)") # If --args-from is specified, load defaults from that JSON file first, # then let CLI flags override. raw_args = p.parse_args() if raw_args.args_from is not None: import json as _json args_path = _Path(raw_args.args_from) if not args_path.exists(): raise FileNotFoundError(f"--args-from file not found: {args_path}") saved = _json.loads(args_path.read_text()) valid_dests = {a.dest for a in p._actions} defaults = {} for k, v in saved.items(): if k in valid_dests and k != "args_from": defaults[k] = v p.set_defaults(**defaults) args = p.parse_args() print(f"Loaded defaults from {args_path} (CLI flags override)") else: args = raw_args # Validate required args if not args.cache_dir: p.error("--cache-dir is required (either directly or via --args-from)") # Validate arg compatibility if args.arch == "transformer": perceiver_only = [] if args.latent_tokens != 128: perceiver_only.append(f"--latent-tokens={args.latent_tokens}") if args.latent_layers != 7: perceiver_only.append(f"--latent-layers={args.latent_layers}") if args.pre_encoder_layers != 0: perceiver_only.append(f"--pre-encoder-layers={args.pre_encoder_layers}") if args.cross_attn_interval != 4: perceiver_only.append(f"--cross-attn-interval={args.cross_attn_interval}") if perceiver_only: raise ValueError( f"Args {', '.join(perceiver_only)} have no effect with --arch transformer. " f"Use --arch perceiver or remove them.") if args.conf_weight > 0 and not args.segment_conf: raise ValueError("--conf-weight requires --segment-conf") if args.conf_mode in ("sinkhorn", "sinkhorn_detach") and args.sinkhorn_weight == 0: raise ValueError("--conf-mode sinkhorn requires --sinkhorn-weight > 0") if args.cosine_decay and args.cooldown_start > 0: raise ValueError("--cosine-decay and --cooldown-start are mutually exclusive") device = torch.device("cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu")) print(f"Device: {device}") torch.manual_seed(args.seed) np.random.seed(args.seed) # Output import hashlib, os args_hash = hashlib.md5(json.dumps(vars(args), sort_keys=True).encode()).hexdigest()[:4] run_tag = time.strftime("%Y%m%d_%H%M%S") + f"_{args_hash}_{os.getpid() % 10000:04d}" out_dir = Path(args.out_dir) / run_tag out_dir.mkdir(parents=True, exist_ok=True) (out_dir / "checkpoints").mkdir(exist_ok=True) # Tee stdout/stderr to run dir import sys as _sys _log_path = out_dir / "train.log" class _Tee: def __init__(self, path, stream): self._file = open(path, "a") self._stream = stream def write(self, data): self._stream.write(data) self._file.write(data) self._file.flush() def flush(self): self._stream.flush() self._file.flush() _sys.stdout = _Tee(_log_path, _sys.stdout) _sys.stderr = _Tee(_log_path, _sys.stderr) git_sha = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, cwd=str(_Path(__file__).parent)).stdout.strip() git_dirty = subprocess.run(["git", "diff", "--quiet"], capture_output=True, cwd=str(_Path(__file__).parent)).returncode != 0 run_info = {**vars(args), "git_sha": git_sha, "git_dirty": git_dirty} (out_dir / "args.json").write_text(json.dumps(run_info, indent=2, sort_keys=True) + "\n") # Set varifold cross-only mode before compile if args.varifold_cross_only: from . import losses as L L.VARIFOLD_CROSS_ONLY = True print("Varifold: cross-only mode (no self-energy)") # Model seq_len = args.seq_len norm_class = torch.nn.RMSNorm if args.rms_norm else None seq_cfg = EdgeDepthSequenceConfig(seq_len=seq_len) model = EdgeDepthSegmentsModel( seq_cfg=seq_cfg, segments=args.segments, hidden=args.hidden, num_heads=args.num_heads, kv_heads_cross=args.kv_heads_cross, kv_heads_self=args.kv_heads_self, dim_feedforward=args.ff, dropout=args.dropout, latent_tokens=args.latent_tokens, latent_layers=args.latent_layers, decoder_layers=args.decoder_layers, cross_attn_interval=args.cross_attn_interval, norm_class=norm_class, activation=args.activation, segment_conf=args.segment_conf, segment_param=args.segment_param, length_floor=args.length_floor, arch=args.arch, encoder_layers=args.encoder_layers, pre_encoder_layers=args.pre_encoder_layers, behind_emb_dim=args.behind_emb_dim, use_vote_features=args.vote_features, decoder_input_xattn=args.decoder_input_xattn, qk_norm=args.qk_norm, qk_norm_type=args.qk_norm_type, learnable_fourier=args.learnable_fourier, ).to(device) try: from torchinfo import summary summary(model.segmenter, input_data=[torch.zeros(1, seq_len, model.tokenizer.out_dim, device=device), torch.ones(1, seq_len, device=device, dtype=torch.bool)], col_names=("input_size", "output_size", "num_params"), verbose=1) except ImportError: pass print(f"Total params: {sum(p.numel() for p in model.parameters()):,}") # Compile (skip in deterministic mode for bit-reproducibility) torch.set_float32_matmul_precision("high") if args.deterministic: torch.use_deterministic_algorithms(True) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False import os os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8") print("Deterministic mode: no torch.compile, bit-reproducible but ~3x slower") elif device.type == "cuda": model.segmenter = torch.compile(model.segmenter, mode="reduce-overhead", fullgraph=True) from . import losses as L L._loss_fn = torch.compile(_loss_inner, mode="reduce-overhead", fullgraph=True) print("Compiled model + loss (reduce-overhead, fullgraph)") # EMA ema_model = None if args.ema_decay > 0: from copy import deepcopy ema_model = deepcopy(model).eval() for p_ema in ema_model.parameters(): p_ema.requires_grad_(False) print(f"EMA enabled (decay={args.ema_decay})") # Resume start_step = 0 if args.resume: ckpt = torch.load(args.resume, map_location=device, weights_only=False) try: model.load_state_dict(ckpt["model"]) except RuntimeError: state = {k.replace("segmenter._orig_mod.", "segmenter."): v for k, v in ckpt["model"].items()} model.load_state_dict(state) start_step = ckpt.get("step", 0) print(f"Resumed from {args.resume} at step {start_step}") betas = tuple(float(x) for x in args.adam_betas.split(",")) # Optimizer: AdamW with optional separate conf_head weight decay conf_wd = args.conf_head_wd if args.conf_head_wd is not None else args.weight_decay if args.conf_head_wd is not None: conf_decay_params = [] other_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if 'conf_head' in name: conf_decay_params.append(param) else: other_params.append(param) param_groups = [ {"params": other_params, "weight_decay": args.weight_decay}, {"params": conf_decay_params, "weight_decay": conf_wd}, ] print(f"Conf head WD: {conf_wd} ({len(conf_decay_params)} params)") else: param_groups = model.parameters() opt = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=betas) if args.resume and "optimizer" in ckpt: opt.load_state_dict(ckpt["optimizer"]) # Data torch.manual_seed(args.seed + 7919) np.random.seed(args.seed + 7919) train_loader = build_loader(args.cache_dir, args.batch_size, aug_rotate=args.aug_rotate, aug_jitter=args.aug_jitter, aug_drop=args.aug_drop, aug_flip=args.aug_flip) val_loader = build_loader(args.val_cache_dir, args.batch_size) if args.val_cache_dir else None data_iter = iter(train_loader) # Intervals log_int = max(1, min(50, args.steps // 20)) ckpt_int = 5000 val_int = ckpt_int if val_loader else 0 # Training loop global_step = start_step loss_ema, loss_sq_ema = 0.0, 0.0 t_start = time.perf_counter() print(f"Training for {args.steps} steps | {args.segments}seg " f"{args.hidden}h {args.latent_tokens}x{args.latent_layers}L " f"{args.decoder_layers}D") # Pre-fetch first batch try: next_batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) next_batch = next(data_iter) # Freeze GC after setup to eliminate stalls during training gc.collect() gc.freeze() gc.disable() amp_ctx = torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(device.type == 'cuda')) while global_step < args.steps: tokens, masks, gt_list, scales, meta = build_tokens(next_batch, model, device) # Epsilon annealing if args.sinkhorn_eps_start is not None and args.sinkhorn_eps_start != args.sinkhorn_eps: if args.sinkhorn_eps_schedule == "sqrt": ratio_sq = (args.sinkhorn_eps_start / args.sinkhorn_eps) ** 2 t0 = max(args.steps * 0.8 / max(ratio_sq - 1, 1e-6), 1.0) current_eps = args.sinkhorn_eps_start / math.sqrt(1 + global_step / t0) current_eps = max(current_eps, args.sinkhorn_eps) else: frac = min(global_step / max(args.steps * 0.8, 1), 1.0) current_eps = args.sinkhorn_eps_start + frac * (args.sinkhorn_eps - args.sinkhorn_eps_start) else: current_eps = args.sinkhorn_eps with amp_ctx: out = model.forward_tokens(tokens, masks) pred = out["segments"] conf = out.get("conf") # Endpoint weight warmup if args.endpoint_warmup > 0 and global_step < args.endpoint_warmup: current_ep_w = args.endpoint_weight * global_step / args.endpoint_warmup else: current_ep_w = args.endpoint_weight loss, terms = compute_loss(pred, gt_list, scales.to(device), device, args.varifold_weight, args.sinkhorn_weight, endpoint_w=current_ep_w, conf_logits=conf, conf_weight=args.conf_weight, conf_mode=args.conf_mode, sinkhorn_eps=current_eps, sinkhorn_iters=args.sinkhorn_iters, sinkhorn_dustbin=args.sinkhorn_dustbin, conf_clamp_min=args.conf_clamp_min) loss_val = loss.item() # Adaptive loss spike detection if global_step < 100: loss_ema = loss_val if global_step == start_step else 0.9 * loss_ema + 0.1 * loss_val loss_sq_ema = loss_val**2 if global_step == start_step else 0.9 * loss_sq_ema + 0.1 * loss_val**2 else: loss_ema = 0.99 * loss_ema + 0.01 * loss_val loss_sq_ema = 0.99 * loss_sq_ema + 0.01 * loss_val**2 loss_std = max(math.sqrt(max(loss_sq_ema - loss_ema**2, 0)), 1e-6) spike_thresh = loss_ema + 5 * loss_std # Skip on total loss spike or NaN if not math.isfinite(loss_val) or loss_val > max(spike_thresh, 0.5): sample_ids = [m.get("sample_id", "?") for m in meta] skip_reason = f"loss={loss_val:.2f} > thresh={spike_thresh:.2f}" print(f"Step {global_step}: {skip_reason}, skipping (samples: {sample_ids[:3]})") with open(out_dir / "skipped_samples.jsonl", "a") as f: f.write(json.dumps({"step": global_step, "reason": skip_reason, "samples": sample_ids}) + "\n") try: next_batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) next_batch = next(data_iter) continue opt.zero_grad() loss.backward() # Fetch next batch while GPU finishes backward try: next_batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) next_batch = next(data_iter) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # LR schedule: warmup -> constant -> optional cooldown or cosine if global_step < args.warmup: lr = args.lr * (global_step + 1) / max(1, args.warmup) elif args.cosine_decay: progress = (global_step - args.warmup) / max(1, args.steps - args.warmup) lr = args.lr * (0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))) elif args.cooldown_start > 0 and global_step >= args.cooldown_start: progress = (global_step - args.cooldown_start) / max(1, args.cooldown_steps) lr = args.lr * max(0.01, 1.0 - 0.99 * min(1.0, progress)) else: lr = args.lr for pg in opt.param_groups: pg["lr"] = lr opt.step() global_step += 1 # EMA update if ema_model is not None: decay = args.ema_decay with torch.no_grad(): for p_ema, p_model in zip(ema_model.parameters(), model.parameters()): p_ema.lerp_(p_model, 1.0 - decay) # Log entry = {"step": global_step, "ts": time.time(), "loss": loss.item(), "lr": lr} entry.update({k: v.item() for k, v in terms.items()}) if global_step % log_int == 0: grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None) ** 0.5 entry["grad_norm"] = grad_norm if global_step % log_int == 0: ms = (time.perf_counter() - t_start) / log_int * 1000 t_start = time.perf_counter() t_str = " ".join(f"{k}={v:.4f}" for k, v in terms.items()) print(f"[{global_step}/{args.steps}] loss={loss.item():.4f} {t_str} " f"lr={lr:.2e} gnorm={entry.get('grad_norm', 0):.3f} [{ms:.0f}ms/step]") if val_int > 0 and global_step % val_int == 0: try: vl_list = [] with torch.no_grad(), amp_ctx: for vb in val_loader: vt, vm, vg, vs, _ = build_tokens(vb, model, device) vo = model.forward_tokens(vt, vm) vl, _ = compute_loss(vo["segments"], vg, vs.to(device), device, args.varifold_weight, args.sinkhorn_weight) if math.isfinite(vl.item()): vl_list.append(vl.item()) if vl_list: val_loss = float(np.mean(vl_list)) print(f" val_loss={val_loss:.4f}") entry["val_loss"] = val_loss except Exception as e: print(f" val eval failed: {e}") # Write log entry with open(out_dir / "history.jsonl", "a") as f: f.write(json.dumps(entry) + "\n") if global_step % ckpt_int == 0: try: gc.enable(); gc.collect(); gc.freeze(); gc.disable() torch.cuda.empty_cache() save_dict = {"step": global_step, "model": model.state_dict(), "optimizer": opt.state_dict(), "args": vars(args)} if ema_model is not None: save_dict["ema_model"] = ema_model.state_dict() torch.save(save_dict, out_dir / "checkpoints" / f"step{global_step:06d}.pt") except Exception as e: print(f" checkpoint save failed: {e}") # Final save save_dict = {"step": global_step, "model": model.state_dict(), "optimizer": opt.state_dict(), "args": vars(args)} if ema_model is not None: save_dict["ema_model"] = ema_model.state_dict() torch.save(save_dict, out_dir / "checkpoints" / "final.pt") print(f"Done. {global_step} steps. Output: {out_dir}") if __name__ == "__main__": main()