AlexWortega commited on
Commit
47f8396
·
verified ·
1 Parent(s): ccdcfe1

Upload train_jepa.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_jepa.py +241 -0
train_jepa.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Training script for Spatial JEPA on The Well datasets.
4
+
5
+ Usage:
6
+ python train_jepa.py --dataset turbulent_radiative_layer_2D --batch_size 16
7
+ python train_jepa.py --dataset active_matter --streaming --epochs 50
8
+ """
9
+ import argparse
10
+ import logging
11
+ import math
12
+ import os
13
+ import time
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.amp import GradScaler, autocast
18
+ from tqdm import tqdm
19
+
20
+ from data_pipeline import create_dataloader, prepare_batch, get_channel_info
21
+ from jepa import JEPA
22
+
23
+ logging.basicConfig(level=logging.WARNING) # suppress noisy library logs
24
+ logger = logging.getLogger("train_jepa")
25
+ logger.setLevel(logging.INFO)
26
+ _handler = logging.StreamHandler()
27
+ _handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S"))
28
+ logger.addHandler(_handler)
29
+ logger.propagate = False
30
+
31
+
32
+ def cosine_lr(step, warmup, total, base_lr, min_lr=1e-6):
33
+ if step < warmup:
34
+ return base_lr * step / max(warmup, 1)
35
+ progress = (step - warmup) / max(total - warmup, 1)
36
+ return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(progress * math.pi))
37
+
38
+
39
+ def cosine_ema(step, total, start=0.996, end=1.0):
40
+ """EMA decay schedule: ramps from start to end over training."""
41
+ progress = step / max(total, 1)
42
+ return end - (end - start) * (1 + math.cos(progress * math.pi)) / 2
43
+
44
+
45
+ def train(args):
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ logger.info(f"Device: {device}")
48
+
49
+ # ---- Data ----
50
+ logger.info(f"Loading dataset: {args.dataset} (streaming={args.streaming})")
51
+ train_loader, train_dataset = create_dataloader(
52
+ dataset_name=args.dataset,
53
+ split="train",
54
+ batch_size=args.batch_size,
55
+ n_steps_input=args.n_input,
56
+ n_steps_output=args.n_output,
57
+ num_workers=args.workers,
58
+ streaming=args.streaming,
59
+ local_path=args.local_path,
60
+ )
61
+
62
+ ch_info = get_channel_info(train_dataset)
63
+ logger.info(f"Channel info: {ch_info}")
64
+
65
+ c_in = ch_info["input_channels"]
66
+ c_out = ch_info["output_channels"]
67
+
68
+ # JEPA uses same channel count for input and target
69
+ # If they differ, we use max and pad in forward
70
+ assert c_in == c_out, (
71
+ f"JEPA expects same input/output channels, got {c_in} vs {c_out}. "
72
+ "Set n_input == n_output or use different architecture."
73
+ )
74
+
75
+ # ---- Model ----
76
+ model = JEPA(
77
+ in_channels=c_in,
78
+ latent_channels=args.latent_ch,
79
+ base_ch=args.base_ch,
80
+ pred_hidden=args.pred_hidden,
81
+ ema_decay=args.ema_start,
82
+ ).to(device)
83
+
84
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
85
+ logger.info(f"Trainable parameters: {n_params:,}")
86
+
87
+ # ---- Optimizer ----
88
+ # Only optimize online encoder + predictor (target is EMA)
89
+ trainable = list(model.online_encoder.parameters()) + list(model.predictor.parameters())
90
+ optimizer = torch.optim.AdamW(trainable, lr=args.lr, weight_decay=args.wd)
91
+ scaler = GradScaler("cuda", enabled=args.amp)
92
+
93
+ # ---- Resume ----
94
+ start_epoch = 0
95
+ global_step = 0
96
+ if args.resume and os.path.exists(args.resume):
97
+ ckpt = torch.load(args.resume, map_location=device, weights_only=False)
98
+ model.load_state_dict(ckpt["model"])
99
+ optimizer.load_state_dict(ckpt["optimizer"])
100
+ scaler.load_state_dict(ckpt["scaler"])
101
+ start_epoch = ckpt["epoch"] + 1
102
+ global_step = ckpt["global_step"]
103
+ logger.info(f"Resumed from epoch {start_epoch}, step {global_step}")
104
+
105
+ # ---- Training ----
106
+ os.makedirs(args.ckpt_dir, exist_ok=True)
107
+ total_steps = args.epochs * len(train_loader)
108
+
109
+ try:
110
+ import wandb
111
+
112
+ if args.wandb:
113
+ wandb.init(project="the-well-jepa", config=vars(args))
114
+ except ImportError:
115
+ args.wandb = False
116
+
117
+ logger.info(f"Starting training: {args.epochs} epochs, ~{total_steps} steps")
118
+
119
+ for epoch in range(start_epoch, args.epochs):
120
+ model.train()
121
+ epoch_loss = 0.0
122
+ epoch_metrics = {"sim": 0, "var": 0, "cov": 0}
123
+ n_batches = 0
124
+ t0 = time.time()
125
+
126
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
127
+ for batch in pbar:
128
+ try:
129
+ x_input, x_target = prepare_batch(batch, device)
130
+ except Exception as e:
131
+ logger.warning(f"Batch error: {e}, skipping")
132
+ continue
133
+
134
+ # LR schedule
135
+ lr = cosine_lr(global_step, args.warmup, total_steps, args.lr)
136
+ for pg in optimizer.param_groups:
137
+ pg["lr"] = lr
138
+
139
+ # EMA schedule
140
+ ema = cosine_ema(global_step, total_steps, args.ema_start, args.ema_end)
141
+ model.set_ema_decay(ema)
142
+
143
+ optimizer.zero_grad(set_to_none=True)
144
+
145
+ with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=args.amp):
146
+ loss, metrics = model.compute_loss(x_input, x_target)
147
+
148
+ scaler.scale(loss).backward()
149
+ scaler.unscale_(optimizer)
150
+ nn.utils.clip_grad_norm_(trainable, args.grad_clip)
151
+ scaler.step(optimizer)
152
+ scaler.update()
153
+
154
+ # EMA update
155
+ model.update_target()
156
+
157
+ epoch_loss += loss.item()
158
+ for k in epoch_metrics:
159
+ epoch_metrics[k] += metrics[k]
160
+ n_batches += 1
161
+ global_step += 1
162
+
163
+ pbar.set_postfix(
164
+ loss=f"{loss.item():.4f}",
165
+ sim=f"{metrics['sim']:.4f}",
166
+ ema=f"{ema:.4f}",
167
+ )
168
+
169
+ if args.wandb:
170
+ wandb.log(
171
+ {"train/loss": loss.item(), "train/lr": lr, "train/ema": ema, **{f"train/{k}": v for k, v in metrics.items()}},
172
+ step=global_step,
173
+ )
174
+
175
+ avg_loss = epoch_loss / max(n_batches, 1)
176
+ avg_m = {k: v / max(n_batches, 1) for k, v in epoch_metrics.items()}
177
+ elapsed = time.time() - t0
178
+ logger.info(
179
+ f"Epoch {epoch}: loss={avg_loss:.4f}, sim={avg_m['sim']:.4f}, "
180
+ f"var={avg_m['var']:.4f}, cov={avg_m['cov']:.4f}, "
181
+ f"time={elapsed:.1f}s"
182
+ )
183
+
184
+ # Checkpoint
185
+ if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
186
+ ckpt_path = os.path.join(args.ckpt_dir, f"jepa_ep{epoch:04d}.pt")
187
+ torch.save(
188
+ {
189
+ "epoch": epoch,
190
+ "global_step": global_step,
191
+ "model": model.state_dict(),
192
+ "optimizer": optimizer.state_dict(),
193
+ "scaler": scaler.state_dict(),
194
+ "args": vars(args),
195
+ "ch_info": ch_info,
196
+ },
197
+ ckpt_path,
198
+ )
199
+ logger.info(f"Saved {ckpt_path}")
200
+
201
+ logger.info("Training complete.")
202
+
203
+
204
+ def main():
205
+ p = argparse.ArgumentParser(description="Train Spatial JEPA on The Well")
206
+ # Data
207
+ p.add_argument("--dataset", default="turbulent_radiative_layer_2D")
208
+ p.add_argument("--streaming", action="store_true", default=True)
209
+ p.add_argument("--no-streaming", dest="streaming", action="store_false")
210
+ p.add_argument("--local_path", default=None)
211
+ p.add_argument("--batch_size", type=int, default=16)
212
+ p.add_argument("--workers", type=int, default=0)
213
+ p.add_argument("--n_input", type=int, default=1)
214
+ p.add_argument("--n_output", type=int, default=1)
215
+ # Model
216
+ p.add_argument("--latent_ch", type=int, default=128)
217
+ p.add_argument("--base_ch", type=int, default=32)
218
+ p.add_argument("--pred_hidden", type=int, default=256)
219
+ # Optimization
220
+ p.add_argument("--lr", type=float, default=3e-4)
221
+ p.add_argument("--wd", type=float, default=0.05)
222
+ p.add_argument("--warmup", type=int, default=500)
223
+ p.add_argument("--grad_clip", type=float, default=1.0)
224
+ p.add_argument("--amp", action="store_true", default=True)
225
+ p.add_argument("--no-amp", dest="amp", action="store_false")
226
+ p.add_argument("--epochs", type=int, default=100)
227
+ p.add_argument("--ema_start", type=float, default=0.996)
228
+ p.add_argument("--ema_end", type=float, default=1.0)
229
+ # Checkpointing
230
+ p.add_argument("--ckpt_dir", default="checkpoints/jepa")
231
+ p.add_argument("--save_every", type=int, default=5)
232
+ p.add_argument("--resume", default=None)
233
+ # Logging
234
+ p.add_argument("--wandb", action="store_true", default=False)
235
+
236
+ args = p.parse_args()
237
+ train(args)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()