guychuk commited on
Commit
bc4c081
·
verified ·
1 Parent(s): 7c60349

Add per-batch heartbeat logging + CUDA visibility log

Browse files
Files changed (1) hide show
  1. train_v2_prod.py +21 -3
train_v2_prod.py CHANGED
@@ -469,10 +469,13 @@ def get_last_pos(past_features, past_length):
469
  return past_features[torch.arange(B, device=past_features.device), idx, :3]
470
 
471
 
472
- def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0):
 
473
  model.train()
474
  sums = {"nll": 0.0, "ade": 0.0, "jepa": 0.0, "total": 0.0, "n": 0}
475
- for batch in loader:
 
 
476
  past_f = batch["past_features"].to(device)
477
  past_l = batch["past_length"].to(device)
478
  target = batch["target_pos"].to(device)
@@ -492,6 +495,13 @@ def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0):
492
  sums["jepa"] += losses["jepa"].item() * bs
493
  sums["total"] += losses["total"].item() * bs
494
  sums["n"] += bs
 
 
 
 
 
 
 
495
  n = max(sums["n"], 1)
496
  return {k: v / n for k, v in sums.items() if k != "n"} | {
497
  "ade_train": sums["ade"] / n
@@ -617,7 +627,15 @@ def main():
617
  np.random.seed(args.seed)
618
 
619
  device = "cuda" if torch.cuda.is_available() else "cpu"
620
- print(f"[v2] device={device} tag={args.tag} lambda_jepa={args.lambda_jepa}")
 
 
 
 
 
 
 
 
621
 
622
  if HAS_TRACKIO and args.trackio_name:
623
  trackio.init(project="flight-jepa-v2", name=args.trackio_name,
 
469
  return past_features[torch.arange(B, device=past_features.device), idx, :3]
470
 
471
 
472
+ def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0,
473
+ log_every: int = 50):
474
  model.train()
475
  sums = {"nll": 0.0, "ade": 0.0, "jepa": 0.0, "total": 0.0, "n": 0}
476
+ t0 = time.time()
477
+ n_batches = len(loader) if hasattr(loader, "__len__") else 0
478
+ for bi, batch in enumerate(loader):
479
  past_f = batch["past_features"].to(device)
480
  past_l = batch["past_length"].to(device)
481
  target = batch["target_pos"].to(device)
 
495
  sums["jepa"] += losses["jepa"].item() * bs
496
  sums["total"] += losses["total"].item() * bs
497
  sums["n"] += bs
498
+
499
+ if (bi + 1) % log_every == 0 or bi == 0:
500
+ dt = time.time() - t0
501
+ rate = (bi + 1) / max(dt, 0.001)
502
+ print(f" [batch {bi+1}/{n_batches}] {dt:.1f}s elapsed, "
503
+ f"{rate:.1f} batch/s, loss={losses['total'].item():.4f}",
504
+ flush=True)
505
  n = max(sums["n"], 1)
506
  return {k: v / n for k, v in sums.items() if k != "n"} | {
507
  "ade_train": sums["ade"] / n
 
627
  np.random.seed(args.seed)
628
 
629
  device = "cuda" if torch.cuda.is_available() else "cpu"
630
+ print(f"[v2] device={device} tag={args.tag} lambda_jepa={args.lambda_jepa}",
631
+ flush=True)
632
+ if device == "cuda":
633
+ print(f"[v2] cuda device: {torch.cuda.get_device_name(0)} "
634
+ f"vram={torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB",
635
+ flush=True)
636
+ else:
637
+ print("[v2] WARNING: CUDA not available, training on CPU. "
638
+ "This will be very slow.", flush=True)
639
 
640
  if HAS_TRACKIO and args.trackio_name:
641
  trackio.init(project="flight-jepa-v2", name=args.trackio_name,