Add per-batch heartbeat logging + CUDA visibility log
Browse files- train_v2_prod.py +21 -3
train_v2_prod.py
CHANGED
|
@@ -469,10 +469,13 @@ def get_last_pos(past_features, past_length):
|
|
| 469 |
return past_features[torch.arange(B, device=past_features.device), idx, :3]
|
| 470 |
|
| 471 |
|
| 472 |
-
def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0
|
|
|
|
| 473 |
model.train()
|
| 474 |
sums = {"nll": 0.0, "ade": 0.0, "jepa": 0.0, "total": 0.0, "n": 0}
|
| 475 |
-
|
|
|
|
|
|
|
| 476 |
past_f = batch["past_features"].to(device)
|
| 477 |
past_l = batch["past_length"].to(device)
|
| 478 |
target = batch["target_pos"].to(device)
|
|
@@ -492,6 +495,13 @@ def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0):
|
|
| 492 |
sums["jepa"] += losses["jepa"].item() * bs
|
| 493 |
sums["total"] += losses["total"].item() * bs
|
| 494 |
sums["n"] += bs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
n = max(sums["n"], 1)
|
| 496 |
return {k: v / n for k, v in sums.items() if k != "n"} | {
|
| 497 |
"ade_train": sums["ade"] / n
|
|
@@ -617,7 +627,15 @@ def main():
|
|
| 617 |
np.random.seed(args.seed)
|
| 618 |
|
| 619 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 620 |
-
print(f"[v2] device={device} tag={args.tag} lambda_jepa={args.lambda_jepa}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
if HAS_TRACKIO and args.trackio_name:
|
| 623 |
trackio.init(project="flight-jepa-v2", name=args.trackio_name,
|
|
|
|
| 469 |
return past_features[torch.arange(B, device=past_features.device), idx, :3]
|
| 470 |
|
| 471 |
|
| 472 |
+
def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0,
|
| 473 |
+
log_every: int = 50):
|
| 474 |
model.train()
|
| 475 |
sums = {"nll": 0.0, "ade": 0.0, "jepa": 0.0, "total": 0.0, "n": 0}
|
| 476 |
+
t0 = time.time()
|
| 477 |
+
n_batches = len(loader) if hasattr(loader, "__len__") else 0
|
| 478 |
+
for bi, batch in enumerate(loader):
|
| 479 |
past_f = batch["past_features"].to(device)
|
| 480 |
past_l = batch["past_length"].to(device)
|
| 481 |
target = batch["target_pos"].to(device)
|
|
|
|
| 495 |
sums["jepa"] += losses["jepa"].item() * bs
|
| 496 |
sums["total"] += losses["total"].item() * bs
|
| 497 |
sums["n"] += bs
|
| 498 |
+
|
| 499 |
+
if (bi + 1) % log_every == 0 or bi == 0:
|
| 500 |
+
dt = time.time() - t0
|
| 501 |
+
rate = (bi + 1) / max(dt, 0.001)
|
| 502 |
+
print(f" [batch {bi+1}/{n_batches}] {dt:.1f}s elapsed, "
|
| 503 |
+
f"{rate:.1f} batch/s, loss={losses['total'].item():.4f}",
|
| 504 |
+
flush=True)
|
| 505 |
n = max(sums["n"], 1)
|
| 506 |
return {k: v / n for k, v in sums.items() if k != "n"} | {
|
| 507 |
"ade_train": sums["ade"] / n
|
|
|
|
| 627 |
np.random.seed(args.seed)
|
| 628 |
|
| 629 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 630 |
+
print(f"[v2] device={device} tag={args.tag} lambda_jepa={args.lambda_jepa}",
|
| 631 |
+
flush=True)
|
| 632 |
+
if device == "cuda":
|
| 633 |
+
print(f"[v2] cuda device: {torch.cuda.get_device_name(0)} "
|
| 634 |
+
f"vram={torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB",
|
| 635 |
+
flush=True)
|
| 636 |
+
else:
|
| 637 |
+
print("[v2] WARNING: CUDA not available, training on CPU. "
|
| 638 |
+
"This will be very slow.", flush=True)
|
| 639 |
|
| 640 |
if HAS_TRACKIO and args.trackio_name:
|
| 641 |
trackio.init(project="flight-jepa-v2", name=args.trackio_name,
|