| |
| """ |
| Training script for S23DR 2026. |
| |
| Usage: |
| python -m s23dr_2026_example.train --cache-dir hf://usm3d/s23dr-2026-sampled_2048_v2:train --steps 80000 --aug-rotate |
| """ |
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path as _Path |
| if __package__ is None or __package__ == "": |
| _here = _Path(__file__).resolve().parent |
| if str(_here.parent) not in sys.path: |
| sys.path.insert(0, str(_here.parent)) |
| __package__ = _here.name |
|
|
| import argparse |
| import gc |
| import json |
| import math |
| import subprocess |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| from .tokenizer import EdgeDepthSequenceConfig |
| from .model import EdgeDepthSegmentsModel |
| from .data import build_loader, build_tokens |
| from .losses import compute_loss, _loss_inner |
|
|
| |
| from .data import HFCachedDataset, collate as _collate |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser(description="S23DR 2026 training") |
| p.add_argument("--cache-dir", default=None, help="HF dataset path (hf://repo:split)") |
| p.add_argument("--val-cache-dir", default="", help="Separate cache for validation") |
| p.add_argument("--seq-len", type=int, default=2048, |
| help="Input sequence length (2048 or 4096, must match dataset)") |
| p.add_argument("--arch", choices=["perceiver", "transformer"], default="perceiver", |
| help="perceiver=latent bottleneck, transformer=full self-attention encoder") |
| p.add_argument("--segments", type=int, default=32) |
| p.add_argument("--hidden", type=int, default=128) |
| p.add_argument("--ff", type=int, default=512) |
| p.add_argument("--latent-tokens", type=int, default=128) |
| p.add_argument("--latent-layers", type=int, default=7) |
| p.add_argument("--encoder-layers", type=int, default=4, |
| help="Encoder layers (transformer arch only)") |
| p.add_argument("--pre-encoder-layers", type=int, default=0, |
| help="Self-attn layers on full token sequence before perceiver bottleneck") |
| p.add_argument("--decoder-layers", type=int, default=3) |
| p.add_argument("--decoder-input-xattn", action="store_true", |
| help="Add cross-attention from segment queries to input tokens in each decoder layer") |
| p.add_argument("--qk-norm", action="store_true", |
| help="Normalize Q and K per-head with learned temperature (stabilizes wide models)") |
| p.add_argument("--qk-norm-type", choices=["l2", "rms"], default="l2", |
| help="QK-norm type: l2 (unit sphere) or rms (RMSNorm, preserves magnitudes)") |
| p.add_argument("--learnable-fourier", action="store_true", |
| help="Make Fourier positional encoding learnable (vs fixed random)") |
| p.add_argument("--num-heads", type=int, default=4, help="Attention heads") |
| p.add_argument("--kv-heads-cross", type=int, default=2, |
| help="KV heads for cross-attention (GQA; 0 = standard MHA)") |
| p.add_argument("--kv-heads-self", type=int, default=2, |
| help="KV heads for self-attention (GQA; 0 = standard MHA)") |
| p.add_argument("--cross-attn-interval", type=int, default=4, |
| help="Perceiver cross-attention frequency (every N latent layers)") |
| p.add_argument("--dropout", type=float, default=0.1) |
| p.add_argument("--weight-decay", type=float, default=0.01, help="AdamW weight decay") |
| p.add_argument("--steps", type=int, default=5000) |
| p.add_argument("--batch-size", type=int, default=32) |
| p.add_argument("--lr", type=float, default=3e-4) |
| p.add_argument("--adam-betas", default="0.9,0.95", help="AdamW beta1,beta2") |
| p.add_argument("--warmup", type=int, default=200, help="LR warmup steps") |
| p.add_argument("--cosine-decay", action="store_true", |
| help="Cosine decay LR after warmup (to lr*0.01 at end)") |
| p.add_argument("--cooldown-start", type=int, default=0, |
| help="Step to begin linear cooldown to lr*0.01 (0=disabled, constant LR after warmup)") |
| p.add_argument("--cooldown-steps", type=int, default=0, |
| help="Number of steps for linear cooldown (0=no cooldown)") |
| p.add_argument("--seed", type=int, default=7) |
| p.add_argument("--deterministic", action="store_true", |
| help="Force deterministic mode (disables torch.compile, slower but bit-reproducible)") |
| p.add_argument("--varifold-weight", type=float, default=0.0) |
| p.add_argument("--varifold-cross-only", action="store_true", |
| help="Drop varifold self-energy (avoids O(S^2) spike, sinkhorn handles repulsion)") |
| p.add_argument("--sinkhorn-weight", type=float, default=1.0) |
| p.add_argument("--sinkhorn-eps", type=float, default=0.1, |
| help="Sinkhorn regularization (larger = softer matching, stronger gradients)") |
| p.add_argument("--sinkhorn-eps-start", type=float, default=None, |
| help="Starting eps for epsilon annealing (anneals to --sinkhorn-eps). None=no annealing.") |
| p.add_argument("--sinkhorn-eps-schedule", choices=["linear", "sqrt", "none"], default="none", |
| help="Eps annealing schedule: linear, sqrt, or none (default: no annealing)") |
| p.add_argument("--sinkhorn-iters", type=int, default=20, |
| help="Sinkhorn iterations") |
| p.add_argument("--sinkhorn-dustbin", type=float, default=0.3, |
| help="Sinkhorn dustbin cost in normalized space") |
| p.add_argument("--endpoint-weight", type=float, default=0.0, |
| help="Weight for endpoint distance loss (sinkhorn-matched, symmetric)") |
| p.add_argument("--endpoint-warmup", type=int, default=0, |
| help="Steps to linearly warm up endpoint weight from 0 (0=instant)") |
| p.add_argument("--aug-rotate", action="store_true") |
| p.add_argument("--aug-jitter", type=float, default=0.0, |
| help="Point position jitter std in normalized space (0=disabled, try 0.005)") |
| p.add_argument("--aug-drop", type=float, default=0.0, |
| help="Fraction of points to randomly drop (0=disabled, try 0.1)") |
| p.add_argument("--aug-flip", action="store_true", |
| help="Random mirror along X axis (50%% chance)") |
| p.add_argument("--rms-norm", action="store_true", default=True, |
| help="Use RMSNorm (default). Use --no-rms-norm for LayerNorm") |
| p.add_argument("--no-rms-norm", dest="rms_norm", action="store_false") |
| p.add_argument("--activation", default="gelu", help="FFN activation: gelu, relu, relu_sq") |
| p.add_argument("--behind-emb-dim", type=int, default=8, |
| help="Behind-gestalt embedding dim (0 to disable)") |
| p.add_argument("--vote-features", action="store_true", |
| help="Add n_views_voted + vote_frac as raw token features (requires v2 data)") |
| p.add_argument("--segment-param", choices=["midpoint_halfvec", "midpoint_dir_len"], |
| default="midpoint_halfvec", |
| help="Output parameterization: halfvec (default) or decoupled direction+length") |
| p.add_argument("--length-floor", type=float, default=0.0, |
| help="Minimum segment length for midpoint_dir_len (0=no floor)") |
| p.add_argument("--segment-conf", action="store_true", |
| help="Add per-segment confidence head (use with --conf-thresh at eval)") |
| p.add_argument("--conf-weight", type=float, default=0.0, |
| help="Weight for confidence loss (requires --segment-conf)") |
| p.add_argument("--conf-mode", choices=["sinkhorn", "sinkhorn_detach"], default="sinkhorn", |
| help="Confidence training: 'match'=BCE, 'sinkhorn'=OT mass, 'sinkhorn_detach'=OT mass (detached)") |
| p.add_argument("--conf-clamp-min", type=float, default=None, |
| help="Clamp conf logits to this minimum before sigmoid (e.g., -5)") |
| p.add_argument("--conf-head-wd", type=float, default=None, |
| help="Separate weight decay for conf head (default: same as other params)") |
| p.add_argument("--ema-decay", type=float, default=0.0, |
| help="EMA decay rate (0=disabled, try 0.9999). Saves EMA weights in checkpoints.") |
| p.add_argument("--out-dir", default=str(_Path(__file__).resolve().parent / "runs")) |
| p.add_argument("--resume", default="") |
| p.add_argument("--cpu", action="store_true") |
| p.add_argument("--args-from", default=None, |
| help="Load defaults from a run's args.json (CLI flags override)") |
|
|
| |
| |
| raw_args = p.parse_args() |
| if raw_args.args_from is not None: |
| import json as _json |
| args_path = _Path(raw_args.args_from) |
| if not args_path.exists(): |
| raise FileNotFoundError(f"--args-from file not found: {args_path}") |
| saved = _json.loads(args_path.read_text()) |
| valid_dests = {a.dest for a in p._actions} |
| defaults = {} |
| for k, v in saved.items(): |
| if k in valid_dests and k != "args_from": |
| defaults[k] = v |
| p.set_defaults(**defaults) |
| args = p.parse_args() |
| print(f"Loaded defaults from {args_path} (CLI flags override)") |
| else: |
| args = raw_args |
|
|
| |
| if not args.cache_dir: |
| p.error("--cache-dir is required (either directly or via --args-from)") |
|
|
| |
| if args.arch == "transformer": |
| perceiver_only = [] |
| if args.latent_tokens != 128: |
| perceiver_only.append(f"--latent-tokens={args.latent_tokens}") |
| if args.latent_layers != 7: |
| perceiver_only.append(f"--latent-layers={args.latent_layers}") |
| if args.pre_encoder_layers != 0: |
| perceiver_only.append(f"--pre-encoder-layers={args.pre_encoder_layers}") |
| if args.cross_attn_interval != 4: |
| perceiver_only.append(f"--cross-attn-interval={args.cross_attn_interval}") |
| if perceiver_only: |
| raise ValueError( |
| f"Args {', '.join(perceiver_only)} have no effect with --arch transformer. " |
| f"Use --arch perceiver or remove them.") |
| if args.conf_weight > 0 and not args.segment_conf: |
| raise ValueError("--conf-weight requires --segment-conf") |
| if args.conf_mode in ("sinkhorn", "sinkhorn_detach") and args.sinkhorn_weight == 0: |
| raise ValueError("--conf-mode sinkhorn requires --sinkhorn-weight > 0") |
| if args.cosine_decay and args.cooldown_start > 0: |
| raise ValueError("--cosine-decay and --cooldown-start are mutually exclusive") |
|
|
| device = torch.device("cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu")) |
| print(f"Device: {device}") |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| |
| import hashlib, os |
| args_hash = hashlib.md5(json.dumps(vars(args), sort_keys=True).encode()).hexdigest()[:4] |
| run_tag = time.strftime("%Y%m%d_%H%M%S") + f"_{args_hash}_{os.getpid() % 10000:04d}" |
| out_dir = Path(args.out_dir) / run_tag |
| out_dir.mkdir(parents=True, exist_ok=True) |
| (out_dir / "checkpoints").mkdir(exist_ok=True) |
|
|
| |
| import sys as _sys |
| _log_path = out_dir / "train.log" |
| class _Tee: |
| def __init__(self, path, stream): |
| self._file = open(path, "a") |
| self._stream = stream |
| def write(self, data): |
| self._stream.write(data) |
| self._file.write(data) |
| self._file.flush() |
| def flush(self): |
| self._stream.flush() |
| self._file.flush() |
| _sys.stdout = _Tee(_log_path, _sys.stdout) |
| _sys.stderr = _Tee(_log_path, _sys.stderr) |
|
|
| git_sha = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, |
| cwd=str(_Path(__file__).parent)).stdout.strip() |
| git_dirty = subprocess.run(["git", "diff", "--quiet"], capture_output=True, |
| cwd=str(_Path(__file__).parent)).returncode != 0 |
| run_info = {**vars(args), "git_sha": git_sha, "git_dirty": git_dirty} |
| (out_dir / "args.json").write_text(json.dumps(run_info, indent=2, sort_keys=True) + "\n") |
|
|
| |
| if args.varifold_cross_only: |
| from . import losses as L |
| L.VARIFOLD_CROSS_ONLY = True |
| print("Varifold: cross-only mode (no self-energy)") |
|
|
| |
| seq_len = args.seq_len |
| norm_class = torch.nn.RMSNorm if args.rms_norm else None |
| seq_cfg = EdgeDepthSequenceConfig(seq_len=seq_len) |
| model = EdgeDepthSegmentsModel( |
| seq_cfg=seq_cfg, segments=args.segments, hidden=args.hidden, |
| num_heads=args.num_heads, kv_heads_cross=args.kv_heads_cross, |
| kv_heads_self=args.kv_heads_self, |
| dim_feedforward=args.ff, dropout=args.dropout, |
| latent_tokens=args.latent_tokens, latent_layers=args.latent_layers, |
| decoder_layers=args.decoder_layers, cross_attn_interval=args.cross_attn_interval, |
| norm_class=norm_class, activation=args.activation, |
| segment_conf=args.segment_conf, |
| segment_param=args.segment_param, |
| length_floor=args.length_floor, |
| arch=args.arch, encoder_layers=args.encoder_layers, |
| pre_encoder_layers=args.pre_encoder_layers, |
| behind_emb_dim=args.behind_emb_dim, |
| use_vote_features=args.vote_features, |
| decoder_input_xattn=args.decoder_input_xattn, |
| qk_norm=args.qk_norm, |
| qk_norm_type=args.qk_norm_type, |
| learnable_fourier=args.learnable_fourier, |
| ).to(device) |
|
|
| try: |
| from torchinfo import summary |
| summary(model.segmenter, |
| input_data=[torch.zeros(1, seq_len, model.tokenizer.out_dim, device=device), |
| torch.ones(1, seq_len, device=device, dtype=torch.bool)], |
| col_names=("input_size", "output_size", "num_params"), verbose=1) |
| except ImportError: |
| pass |
| print(f"Total params: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| |
| torch.set_float32_matmul_precision("high") |
| if args.deterministic: |
| torch.use_deterministic_algorithms(True) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| import os |
| os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8") |
| print("Deterministic mode: no torch.compile, bit-reproducible but ~3x slower") |
| elif device.type == "cuda": |
| model.segmenter = torch.compile(model.segmenter, mode="reduce-overhead", fullgraph=True) |
| from . import losses as L |
| L._loss_fn = torch.compile(_loss_inner, mode="reduce-overhead", fullgraph=True) |
| print("Compiled model + loss (reduce-overhead, fullgraph)") |
|
|
| |
| ema_model = None |
| if args.ema_decay > 0: |
| from copy import deepcopy |
| ema_model = deepcopy(model).eval() |
| for p_ema in ema_model.parameters(): |
| p_ema.requires_grad_(False) |
| print(f"EMA enabled (decay={args.ema_decay})") |
|
|
| |
| start_step = 0 |
| if args.resume: |
| ckpt = torch.load(args.resume, map_location=device, weights_only=False) |
| try: |
| model.load_state_dict(ckpt["model"]) |
| except RuntimeError: |
| state = {k.replace("segmenter._orig_mod.", "segmenter."): v |
| for k, v in ckpt["model"].items()} |
| model.load_state_dict(state) |
| start_step = ckpt.get("step", 0) |
| print(f"Resumed from {args.resume} at step {start_step}") |
|
|
| betas = tuple(float(x) for x in args.adam_betas.split(",")) |
|
|
| |
| conf_wd = args.conf_head_wd if args.conf_head_wd is not None else args.weight_decay |
| if args.conf_head_wd is not None: |
| conf_decay_params = [] |
| other_params = [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if 'conf_head' in name: |
| conf_decay_params.append(param) |
| else: |
| other_params.append(param) |
| param_groups = [ |
| {"params": other_params, "weight_decay": args.weight_decay}, |
| {"params": conf_decay_params, "weight_decay": conf_wd}, |
| ] |
| print(f"Conf head WD: {conf_wd} ({len(conf_decay_params)} params)") |
| else: |
| param_groups = model.parameters() |
|
|
| opt = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay, |
| betas=betas) |
| if args.resume and "optimizer" in ckpt: |
| opt.load_state_dict(ckpt["optimizer"]) |
|
|
| |
| torch.manual_seed(args.seed + 7919) |
| np.random.seed(args.seed + 7919) |
| train_loader = build_loader(args.cache_dir, args.batch_size, aug_rotate=args.aug_rotate, |
| aug_jitter=args.aug_jitter, aug_drop=args.aug_drop, |
| aug_flip=args.aug_flip) |
| val_loader = build_loader(args.val_cache_dir, args.batch_size) if args.val_cache_dir else None |
| data_iter = iter(train_loader) |
|
|
| |
| log_int = max(1, min(50, args.steps // 20)) |
| ckpt_int = 5000 |
| val_int = ckpt_int if val_loader else 0 |
|
|
| |
| global_step = start_step |
| loss_ema, loss_sq_ema = 0.0, 0.0 |
| t_start = time.perf_counter() |
|
|
| print(f"Training for {args.steps} steps | {args.segments}seg " |
| f"{args.hidden}h {args.latent_tokens}x{args.latent_layers}L " |
| f"{args.decoder_layers}D") |
|
|
| |
| try: |
| next_batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(train_loader) |
| next_batch = next(data_iter) |
|
|
| |
| gc.collect() |
| gc.freeze() |
| gc.disable() |
|
|
| amp_ctx = torch.autocast(device_type='cuda', dtype=torch.bfloat16, |
| enabled=(device.type == 'cuda')) |
|
|
| while global_step < args.steps: |
| tokens, masks, gt_list, scales, meta = build_tokens(next_batch, model, device) |
|
|
| |
| if args.sinkhorn_eps_start is not None and args.sinkhorn_eps_start != args.sinkhorn_eps: |
| if args.sinkhorn_eps_schedule == "sqrt": |
| ratio_sq = (args.sinkhorn_eps_start / args.sinkhorn_eps) ** 2 |
| t0 = max(args.steps * 0.8 / max(ratio_sq - 1, 1e-6), 1.0) |
| current_eps = args.sinkhorn_eps_start / math.sqrt(1 + global_step / t0) |
| current_eps = max(current_eps, args.sinkhorn_eps) |
| else: |
| frac = min(global_step / max(args.steps * 0.8, 1), 1.0) |
| current_eps = args.sinkhorn_eps_start + frac * (args.sinkhorn_eps - args.sinkhorn_eps_start) |
| else: |
| current_eps = args.sinkhorn_eps |
|
|
| with amp_ctx: |
| out = model.forward_tokens(tokens, masks) |
| pred = out["segments"] |
| conf = out.get("conf") |
|
|
| |
| if args.endpoint_warmup > 0 and global_step < args.endpoint_warmup: |
| current_ep_w = args.endpoint_weight * global_step / args.endpoint_warmup |
| else: |
| current_ep_w = args.endpoint_weight |
|
|
| loss, terms = compute_loss(pred, gt_list, scales.to(device), device, |
| args.varifold_weight, args.sinkhorn_weight, |
| endpoint_w=current_ep_w, |
| conf_logits=conf, conf_weight=args.conf_weight, |
| conf_mode=args.conf_mode, |
| sinkhorn_eps=current_eps, |
| sinkhorn_iters=args.sinkhorn_iters, |
| sinkhorn_dustbin=args.sinkhorn_dustbin, |
| conf_clamp_min=args.conf_clamp_min) |
|
|
| loss_val = loss.item() |
| |
| if global_step < 100: |
| loss_ema = loss_val if global_step == start_step else 0.9 * loss_ema + 0.1 * loss_val |
| loss_sq_ema = loss_val**2 if global_step == start_step else 0.9 * loss_sq_ema + 0.1 * loss_val**2 |
| else: |
| loss_ema = 0.99 * loss_ema + 0.01 * loss_val |
| loss_sq_ema = 0.99 * loss_sq_ema + 0.01 * loss_val**2 |
| loss_std = max(math.sqrt(max(loss_sq_ema - loss_ema**2, 0)), 1e-6) |
| spike_thresh = loss_ema + 5 * loss_std |
|
|
| |
| if not math.isfinite(loss_val) or loss_val > max(spike_thresh, 0.5): |
| sample_ids = [m.get("sample_id", "?") for m in meta] |
| skip_reason = f"loss={loss_val:.2f} > thresh={spike_thresh:.2f}" |
| print(f"Step {global_step}: {skip_reason}, skipping (samples: {sample_ids[:3]})") |
| with open(out_dir / "skipped_samples.jsonl", "a") as f: |
| f.write(json.dumps({"step": global_step, "reason": skip_reason, |
| "samples": sample_ids}) + "\n") |
| try: |
| next_batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(train_loader) |
| next_batch = next(data_iter) |
| continue |
|
|
| opt.zero_grad() |
| loss.backward() |
|
|
| |
| try: |
| next_batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(train_loader) |
| next_batch = next(data_iter) |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
|
|
| |
| if global_step < args.warmup: |
| lr = args.lr * (global_step + 1) / max(1, args.warmup) |
| elif args.cosine_decay: |
| progress = (global_step - args.warmup) / max(1, args.steps - args.warmup) |
| lr = args.lr * (0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))) |
| elif args.cooldown_start > 0 and global_step >= args.cooldown_start: |
| progress = (global_step - args.cooldown_start) / max(1, args.cooldown_steps) |
| lr = args.lr * max(0.01, 1.0 - 0.99 * min(1.0, progress)) |
| else: |
| lr = args.lr |
| for pg in opt.param_groups: |
| pg["lr"] = lr |
| opt.step() |
| global_step += 1 |
|
|
| |
| if ema_model is not None: |
| decay = args.ema_decay |
| with torch.no_grad(): |
| for p_ema, p_model in zip(ema_model.parameters(), model.parameters()): |
| p_ema.lerp_(p_model, 1.0 - decay) |
|
|
| |
| entry = {"step": global_step, "ts": time.time(), "loss": loss.item(), "lr": lr} |
| entry.update({k: v.item() for k, v in terms.items()}) |
| if global_step % log_int == 0: |
| grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() |
| if p.grad is not None) ** 0.5 |
| entry["grad_norm"] = grad_norm |
|
|
| if global_step % log_int == 0: |
| ms = (time.perf_counter() - t_start) / log_int * 1000 |
| t_start = time.perf_counter() |
| t_str = " ".join(f"{k}={v:.4f}" for k, v in terms.items()) |
| print(f"[{global_step}/{args.steps}] loss={loss.item():.4f} {t_str} " |
| f"lr={lr:.2e} gnorm={entry.get('grad_norm', 0):.3f} [{ms:.0f}ms/step]") |
|
|
| if val_int > 0 and global_step % val_int == 0: |
| try: |
| vl_list = [] |
| with torch.no_grad(), amp_ctx: |
| for vb in val_loader: |
| vt, vm, vg, vs, _ = build_tokens(vb, model, device) |
| vo = model.forward_tokens(vt, vm) |
| vl, _ = compute_loss(vo["segments"], vg, vs.to(device), device, |
| args.varifold_weight, args.sinkhorn_weight) |
| if math.isfinite(vl.item()): |
| vl_list.append(vl.item()) |
| if vl_list: |
| val_loss = float(np.mean(vl_list)) |
| print(f" val_loss={val_loss:.4f}") |
| entry["val_loss"] = val_loss |
| except Exception as e: |
| print(f" val eval failed: {e}") |
|
|
| |
| with open(out_dir / "history.jsonl", "a") as f: |
| f.write(json.dumps(entry) + "\n") |
|
|
| if global_step % ckpt_int == 0: |
| try: |
| gc.enable(); gc.collect(); gc.freeze(); gc.disable() |
| torch.cuda.empty_cache() |
| save_dict = {"step": global_step, "model": model.state_dict(), |
| "optimizer": opt.state_dict(), "args": vars(args)} |
| if ema_model is not None: |
| save_dict["ema_model"] = ema_model.state_dict() |
| torch.save(save_dict, out_dir / "checkpoints" / f"step{global_step:06d}.pt") |
| except Exception as e: |
| print(f" checkpoint save failed: {e}") |
|
|
| |
| save_dict = {"step": global_step, "model": model.state_dict(), |
| "optimizer": opt.state_dict(), "args": vars(args)} |
| if ema_model is not None: |
| save_dict["ema_model"] = ema_model.state_dict() |
| torch.save(save_dict, out_dir / "checkpoints" / "final.pt") |
| print(f"Done. {global_step} steps. Output: {out_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|