rogermt commited on
Commit
8b62ba9
·
verified ·
1 Parent(s): d5802bb

Add checkpointing, resume, CIFAR OOM fix, --sinkhorn-batch flag

Browse files

Major changes:
- trainer.py: Checkpoints saved every N steps + after each phase completion.
Phase-level checkpoints (phase1_complete.pt, phase2_complete.pt, etc.)
enable resuming from any phase.
- main.py: --resume-phase N loads phase N-1 checkpoint and skips completed phases.
--sinkhorn-batch overrides Sinkhorn batch size (for OOM on smaller GPUs).
--checkpoint-dir and --checkpoint-every for checkpoint control.
- config.yaml: CIFAR-10 sinkhorn.batch_size reduced 128→32 for T4 16GB VRAM.
Pool batches increased 2500→10000 to compensate (same total pool entries).
torch.cuda.empty_cache() called after pool building to free Sinkhorn memory."

Files changed (1) hide show
  1. trainer.py +130 -46
trainer.py CHANGED
@@ -6,6 +6,7 @@ Implements:
6
  3. NSF (Neural Straight Flow) training for NSGF++
7
  4. Phase-transition time predictor training
8
  5. End-to-end NSGF++ training pipeline
 
9
 
10
  Reference: arXiv:2401.14069, Section 4.2–4.4, Appendix D, E
11
  """
@@ -26,17 +27,26 @@ from model import VelocityMLP, VelocityUNet, PhaseTransitionPredictor
26
  logger = logging.getLogger(__name__)
27
 
28
 
 
 
 
 
 
 
 
29
  class NSGFTrainer:
30
  """Trainer for the Neural Sinkhorn Gradient Flow model.
31
 
32
  Loss (Eq. 14): L(θ) = E_{(x,v,t) ~ pool} ||v_θ(x, t) - v̂(x)||²
33
  """
34
  def __init__(self, model: nn.Module, data_loader: DatasetLoader,
35
- config: dict, device: str = "cpu"):
 
36
  self.model = model.to(device)
37
  self.data_loader = data_loader
38
  self.config = config
39
  self.device = device
 
40
 
41
  sink_cfg = config.get("sinkhorn", {})
42
  self.potential_computer = SinkhornPotentialComputer(
@@ -57,6 +67,7 @@ class NSGFTrainer:
57
  betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)),
58
  weight_decay=train_cfg.get("weight_decay", 0.0),
59
  )
 
60
 
61
  def build_trajectory_pool(self, num_batches: Optional[int] = None):
62
  if num_batches is None:
@@ -74,15 +85,18 @@ class NSGFTrainer:
74
  if (batch_idx + 1) % max(1, num_batches // 10) == 0:
75
  logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}")
76
  logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}")
 
 
 
77
  # Pre-concatenate for O(1) sampling during training
78
  self.pool.finalize()
79
  logger.info("Trajectory pool finalized (pre-concatenated for fast sampling).")
80
 
81
- def train(self) -> Dict[str, list]:
82
  self.model.train()
83
  history = {"loss": [], "step": []}
84
- logger.info(f"Starting NSGF velocity field matching: {self.num_iterations} iterations")
85
- for step in range(self.num_iterations):
86
  x_batch, v_batch, t_batch = self.pool.sample(self.train_batch_size, self.device)
87
  t_normalized = t_batch / max(self.gradient_flow.num_steps, 1.0)
88
  v_pred = self.model(x_batch, t_normalized)
@@ -90,11 +104,19 @@ class NSGFTrainer:
90
  self.optimizer.zero_grad()
91
  loss.backward()
92
  self.optimizer.step()
93
- if (step + 1) % 500 == 0 or step == 0:
94
  loss_val = loss.item()
95
  history["loss"].append(loss_val)
96
  history["step"].append(step + 1)
97
  logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
 
 
 
 
 
 
 
 
98
  logger.info("NSGF training complete.")
99
  return history
100
 
@@ -106,7 +128,8 @@ class NSFTrainer:
106
  """
107
  def __init__(self, model: nn.Module, nsgf_model: nn.Module,
108
  data_loader: DatasetLoader, config: dict,
109
- nsgf_num_steps: int = 5, device: str = "cpu"):
 
110
  self.model = model.to(device)
111
  self.nsgf_model = nsgf_model.to(device)
112
  self.nsgf_model.eval()
@@ -114,6 +137,7 @@ class NSFTrainer:
114
  self.config = config
115
  self.device = device
116
  self.nsgf_num_steps = nsgf_num_steps
 
117
 
118
  train_cfg = config.get("nsf_training", config.get("training", {}))
119
  self.num_iterations = train_cfg.get("num_iterations", 100000)
@@ -124,6 +148,7 @@ class NSFTrainer:
124
  betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)),
125
  weight_decay=train_cfg.get("weight_decay", 0.0),
126
  )
 
127
 
128
  @torch.no_grad()
129
  def _generate_nsgf_samples(self, n: int) -> torch.Tensor:
@@ -135,11 +160,11 @@ class NSFTrainer:
135
  X = X + dt * v
136
  return X
137
 
138
- def train(self) -> Dict[str, list]:
139
  self.model.train()
140
  history = {"loss": [], "step": []}
141
- logger.info(f"Starting NSF training: {self.num_iterations} iterations")
142
- for step in range(self.num_iterations):
143
  P0 = self._generate_nsgf_samples(self.train_batch_size)
144
  P1 = self.data_loader.sample_target(self.train_batch_size, self.device)
145
  t = torch.rand(self.train_batch_size, device=self.device)
@@ -154,11 +179,19 @@ class NSFTrainer:
154
  self.optimizer.zero_grad()
155
  loss.backward()
156
  self.optimizer.step()
157
- if (step + 1) % 500 == 0 or step == 0:
158
  loss_val = loss.item()
159
  history["loss"].append(loss_val)
160
  history["step"].append(step + 1)
161
  logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
 
 
 
 
 
 
 
 
162
  logger.info("NSF training complete.")
163
  return history
164
 
@@ -169,7 +202,8 @@ class PhaseTransitionTrainer:
169
  """
170
  def __init__(self, predictor: PhaseTransitionPredictor, nsgf_model: nn.Module,
171
  data_loader: DatasetLoader, config: dict,
172
- nsgf_num_steps: int = 5, device: str = "cpu"):
 
173
  self.predictor = predictor.to(device)
174
  self.nsgf_model = nsgf_model.to(device)
175
  self.nsgf_model.eval()
@@ -177,11 +211,13 @@ class PhaseTransitionTrainer:
177
  self.config = config
178
  self.device = device
179
  self.nsgf_num_steps = nsgf_num_steps
 
180
  tp_cfg = config.get("time_predictor", {})
181
  self.num_iterations = tp_cfg.get("num_iterations", 40000)
182
  self.batch_size = tp_cfg.get("batch_size", 128)
183
  self.lr = tp_cfg.get("learning_rate", 1e-4)
184
  self.optimizer = optim.Adam(self.predictor.parameters(), lr=self.lr, betas=(0.9, 0.999))
 
185
 
186
  @torch.no_grad()
187
  def _generate_nsgf_samples(self, n: int) -> torch.Tensor:
@@ -193,11 +229,11 @@ class PhaseTransitionTrainer:
193
  X = X + dt * v
194
  return X
195
 
196
- def train(self) -> Dict[str, list]:
197
  self.predictor.train()
198
  history = {"loss": [], "step": []}
199
- logger.info(f"Starting phase predictor training: {self.num_iterations} iterations")
200
- for step in range(self.num_iterations):
201
  P0 = self._generate_nsgf_samples(self.batch_size)
202
  P1 = self.data_loader.sample_target(self.batch_size, self.device)
203
  t = torch.rand(self.batch_size, device=self.device)
@@ -211,59 +247,107 @@ class PhaseTransitionTrainer:
211
  self.optimizer.zero_grad()
212
  loss.backward()
213
  self.optimizer.step()
214
- if (step + 1) % 1000 == 0 or step == 0:
215
  loss_val = loss.item()
216
  history["loss"].append(loss_val)
217
  history["step"].append(step + 1)
218
  logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
 
 
 
 
 
 
 
 
219
  logger.info("Phase predictor training complete.")
220
  return history
221
 
222
 
223
  class NSGFPlusPlusTrainer:
224
- """End-to-end NSGF++ trainer (Algorithm 3 / Appendix D)."""
 
 
 
225
  def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module,
226
  phase_predictor: PhaseTransitionPredictor,
227
- data_loader: DatasetLoader, config: dict, device: str = "cpu"):
 
228
  self.nsgf_model = nsgf_model
229
  self.nsf_model = nsf_model
230
  self.phase_predictor = phase_predictor
231
  self.data_loader = data_loader
232
  self.config = config
233
  self.device = device
 
234
 
235
- def train_all(self) -> Dict[str, Any]:
 
236
  results = {}
237
- logger.info("=" * 60)
238
- logger.info("Phase 1: Training NSGF model")
239
- logger.info("=" * 60)
240
- nsgf_trainer = NSGFTrainer(
241
- model=self.nsgf_model, data_loader=self.data_loader,
242
- config=self.config, device=self.device,
243
- )
244
- nsgf_trainer.build_trajectory_pool()
245
- results["nsgf"] = nsgf_trainer.train()
246
 
247
- logger.info("=" * 60)
248
- logger.info("Phase 2: Training NSF (Neural Straight Flow) model")
249
- logger.info("=" * 60)
250
- nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5)
251
- nsf_trainer = NSFTrainer(
252
- model=self.nsf_model, nsgf_model=self.nsgf_model,
253
- data_loader=self.data_loader, config=self.config,
254
- nsgf_num_steps=nsgf_steps, device=self.device,
255
- )
256
- results["nsf"] = nsf_trainer.train()
 
 
 
 
 
 
 
 
 
 
 
257
 
258
- logger.info("=" * 60)
259
- logger.info("Phase 3: Training phase-transition time predictor")
260
- logger.info("=" * 60)
261
- pt_trainer = PhaseTransitionTrainer(
262
- predictor=self.phase_predictor, nsgf_model=self.nsgf_model,
263
- data_loader=self.data_loader, config=self.config,
264
- nsgf_num_steps=nsgf_steps, device=self.device,
265
- )
266
- results["phase_predictor"] = pt_trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  logger.info("=" * 60)
269
  logger.info("NSGF++ training complete!")
 
6
  3. NSF (Neural Straight Flow) training for NSGF++
7
  4. Phase-transition time predictor training
8
  5. End-to-end NSGF++ training pipeline
9
+ 6. Checkpointing and resume support
10
 
11
  Reference: arXiv:2401.14069, Section 4.2–4.4, Appendix D, E
12
  """
 
27
  logger = logging.getLogger(__name__)
28
 
29
 
30
+ def _save_checkpoint(path: str, **kwargs):
31
+ """Save a checkpoint dict to disk."""
32
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
33
+ torch.save(kwargs, path)
34
+ logger.info(f"Checkpoint saved: {path}")
35
+
36
+
37
  class NSGFTrainer:
38
  """Trainer for the Neural Sinkhorn Gradient Flow model.
39
 
40
  Loss (Eq. 14): L(θ) = E_{(x,v,t) ~ pool} ||v_θ(x, t) - v̂(x)||²
41
  """
42
  def __init__(self, model: nn.Module, data_loader: DatasetLoader,
43
+ config: dict, device: str = "cpu",
44
+ checkpoint_dir: str = "checkpoints"):
45
  self.model = model.to(device)
46
  self.data_loader = data_loader
47
  self.config = config
48
  self.device = device
49
+ self.checkpoint_dir = checkpoint_dir
50
 
51
  sink_cfg = config.get("sinkhorn", {})
52
  self.potential_computer = SinkhornPotentialComputer(
 
67
  betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)),
68
  weight_decay=train_cfg.get("weight_decay", 0.0),
69
  )
70
+ self.checkpoint_every = config.get("checkpoint_every", 5000)
71
 
72
  def build_trajectory_pool(self, num_batches: Optional[int] = None):
73
  if num_batches is None:
 
85
  if (batch_idx + 1) % max(1, num_batches // 10) == 0:
86
  logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}")
87
  logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}")
88
+ # Free GPU memory used during Sinkhorn computation
89
+ if self.device != "cpu":
90
+ torch.cuda.empty_cache()
91
  # Pre-concatenate for O(1) sampling during training
92
  self.pool.finalize()
93
  logger.info("Trajectory pool finalized (pre-concatenated for fast sampling).")
94
 
95
+ def train(self, start_step: int = 0) -> Dict[str, list]:
96
  self.model.train()
97
  history = {"loss": [], "step": []}
98
+ logger.info(f"Starting NSGF velocity field matching: {self.num_iterations} iterations (from step {start_step})")
99
+ for step in range(start_step, self.num_iterations):
100
  x_batch, v_batch, t_batch = self.pool.sample(self.train_batch_size, self.device)
101
  t_normalized = t_batch / max(self.gradient_flow.num_steps, 1.0)
102
  v_pred = self.model(x_batch, t_normalized)
 
104
  self.optimizer.zero_grad()
105
  loss.backward()
106
  self.optimizer.step()
107
+ if (step + 1) % 500 == 0 or step == start_step:
108
  loss_val = loss.item()
109
  history["loss"].append(loss_val)
110
  history["step"].append(step + 1)
111
  logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
112
+ if (step + 1) % self.checkpoint_every == 0:
113
+ _save_checkpoint(
114
+ os.path.join(self.checkpoint_dir, "nsgf_checkpoint.pt"),
115
+ model_state=self.model.state_dict(),
116
+ optimizer_state=self.optimizer.state_dict(),
117
+ step=step + 1,
118
+ history=history,
119
+ )
120
  logger.info("NSGF training complete.")
121
  return history
122
 
 
128
  """
129
  def __init__(self, model: nn.Module, nsgf_model: nn.Module,
130
  data_loader: DatasetLoader, config: dict,
131
+ nsgf_num_steps: int = 5, device: str = "cpu",
132
+ checkpoint_dir: str = "checkpoints"):
133
  self.model = model.to(device)
134
  self.nsgf_model = nsgf_model.to(device)
135
  self.nsgf_model.eval()
 
137
  self.config = config
138
  self.device = device
139
  self.nsgf_num_steps = nsgf_num_steps
140
+ self.checkpoint_dir = checkpoint_dir
141
 
142
  train_cfg = config.get("nsf_training", config.get("training", {}))
143
  self.num_iterations = train_cfg.get("num_iterations", 100000)
 
148
  betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)),
149
  weight_decay=train_cfg.get("weight_decay", 0.0),
150
  )
151
+ self.checkpoint_every = config.get("checkpoint_every", 5000)
152
 
153
  @torch.no_grad()
154
  def _generate_nsgf_samples(self, n: int) -> torch.Tensor:
 
160
  X = X + dt * v
161
  return X
162
 
163
+ def train(self, start_step: int = 0) -> Dict[str, list]:
164
  self.model.train()
165
  history = {"loss": [], "step": []}
166
+ logger.info(f"Starting NSF training: {self.num_iterations} iterations (from step {start_step})")
167
+ for step in range(start_step, self.num_iterations):
168
  P0 = self._generate_nsgf_samples(self.train_batch_size)
169
  P1 = self.data_loader.sample_target(self.train_batch_size, self.device)
170
  t = torch.rand(self.train_batch_size, device=self.device)
 
179
  self.optimizer.zero_grad()
180
  loss.backward()
181
  self.optimizer.step()
182
+ if (step + 1) % 500 == 0 or step == start_step:
183
  loss_val = loss.item()
184
  history["loss"].append(loss_val)
185
  history["step"].append(step + 1)
186
  logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
187
+ if (step + 1) % self.checkpoint_every == 0:
188
+ _save_checkpoint(
189
+ os.path.join(self.checkpoint_dir, "nsf_checkpoint.pt"),
190
+ model_state=self.model.state_dict(),
191
+ optimizer_state=self.optimizer.state_dict(),
192
+ step=step + 1,
193
+ history=history,
194
+ )
195
  logger.info("NSF training complete.")
196
  return history
197
 
 
202
  """
203
  def __init__(self, predictor: PhaseTransitionPredictor, nsgf_model: nn.Module,
204
  data_loader: DatasetLoader, config: dict,
205
+ nsgf_num_steps: int = 5, device: str = "cpu",
206
+ checkpoint_dir: str = "checkpoints"):
207
  self.predictor = predictor.to(device)
208
  self.nsgf_model = nsgf_model.to(device)
209
  self.nsgf_model.eval()
 
211
  self.config = config
212
  self.device = device
213
  self.nsgf_num_steps = nsgf_num_steps
214
+ self.checkpoint_dir = checkpoint_dir
215
  tp_cfg = config.get("time_predictor", {})
216
  self.num_iterations = tp_cfg.get("num_iterations", 40000)
217
  self.batch_size = tp_cfg.get("batch_size", 128)
218
  self.lr = tp_cfg.get("learning_rate", 1e-4)
219
  self.optimizer = optim.Adam(self.predictor.parameters(), lr=self.lr, betas=(0.9, 0.999))
220
+ self.checkpoint_every = config.get("checkpoint_every", 5000)
221
 
222
  @torch.no_grad()
223
  def _generate_nsgf_samples(self, n: int) -> torch.Tensor:
 
229
  X = X + dt * v
230
  return X
231
 
232
+ def train(self, start_step: int = 0) -> Dict[str, list]:
233
  self.predictor.train()
234
  history = {"loss": [], "step": []}
235
+ logger.info(f"Starting phase predictor training: {self.num_iterations} iterations (from step {start_step})")
236
+ for step in range(start_step, self.num_iterations):
237
  P0 = self._generate_nsgf_samples(self.batch_size)
238
  P1 = self.data_loader.sample_target(self.batch_size, self.device)
239
  t = torch.rand(self.batch_size, device=self.device)
 
247
  self.optimizer.zero_grad()
248
  loss.backward()
249
  self.optimizer.step()
250
+ if (step + 1) % 1000 == 0 or step == start_step:
251
  loss_val = loss.item()
252
  history["loss"].append(loss_val)
253
  history["step"].append(step + 1)
254
  logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
255
+ if (step + 1) % self.checkpoint_every == 0:
256
+ _save_checkpoint(
257
+ os.path.join(self.checkpoint_dir, "predictor_checkpoint.pt"),
258
+ model_state=self.predictor.state_dict(),
259
+ optimizer_state=self.optimizer.state_dict(),
260
+ step=step + 1,
261
+ history=history,
262
+ )
263
  logger.info("Phase predictor training complete.")
264
  return history
265
 
266
 
267
  class NSGFPlusPlusTrainer:
268
+ """End-to-end NSGF++ trainer (Algorithm 3 / Appendix D).
269
+
270
+ Saves checkpoints after each phase so training can be resumed.
271
+ """
272
  def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module,
273
  phase_predictor: PhaseTransitionPredictor,
274
+ data_loader: DatasetLoader, config: dict, device: str = "cpu",
275
+ checkpoint_dir: str = "checkpoints"):
276
  self.nsgf_model = nsgf_model
277
  self.nsf_model = nsf_model
278
  self.phase_predictor = phase_predictor
279
  self.data_loader = data_loader
280
  self.config = config
281
  self.device = device
282
+ self.checkpoint_dir = checkpoint_dir
283
 
284
+ def train_all(self, resume_phase: int = 1) -> Dict[str, Any]:
285
+ """Train all phases. resume_phase: 1=start from NSGF, 2=skip to NSF, 3=skip to predictor."""
286
  results = {}
287
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
 
 
 
 
 
 
 
 
288
 
289
+ if resume_phase <= 1:
290
+ logger.info("=" * 60)
291
+ logger.info("Phase 1: Training NSGF model")
292
+ logger.info("=" * 60)
293
+ nsgf_trainer = NSGFTrainer(
294
+ model=self.nsgf_model, data_loader=self.data_loader,
295
+ config=self.config, device=self.device,
296
+ checkpoint_dir=self.checkpoint_dir,
297
+ )
298
+ nsgf_trainer.build_trajectory_pool()
299
+ results["nsgf"] = nsgf_trainer.train()
300
+ _save_checkpoint(
301
+ os.path.join(self.checkpoint_dir, "phase1_complete.pt"),
302
+ nsgf_model_state=self.nsgf_model.state_dict(),
303
+ phase=1,
304
+ )
305
+ del nsgf_trainer.pool
306
+ if self.device != "cpu":
307
+ torch.cuda.empty_cache()
308
+ else:
309
+ logger.info(f"Skipping Phase 1 (resuming from phase {resume_phase})")
310
 
311
+ if resume_phase <= 2:
312
+ logger.info("=" * 60)
313
+ logger.info("Phase 2: Training NSF (Neural Straight Flow) model")
314
+ logger.info("=" * 60)
315
+ nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5)
316
+ nsf_trainer = NSFTrainer(
317
+ model=self.nsf_model, nsgf_model=self.nsgf_model,
318
+ data_loader=self.data_loader, config=self.config,
319
+ nsgf_num_steps=nsgf_steps, device=self.device,
320
+ checkpoint_dir=self.checkpoint_dir,
321
+ )
322
+ results["nsf"] = nsf_trainer.train()
323
+ _save_checkpoint(
324
+ os.path.join(self.checkpoint_dir, "phase2_complete.pt"),
325
+ nsgf_model_state=self.nsgf_model.state_dict(),
326
+ nsf_model_state=self.nsf_model.state_dict(),
327
+ phase=2,
328
+ )
329
+ else:
330
+ logger.info(f"Skipping Phase 2 (resuming from phase {resume_phase})")
331
+
332
+ if resume_phase <= 3:
333
+ logger.info("=" * 60)
334
+ logger.info("Phase 3: Training phase-transition time predictor")
335
+ logger.info("=" * 60)
336
+ nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5)
337
+ pt_trainer = PhaseTransitionTrainer(
338
+ predictor=self.phase_predictor, nsgf_model=self.nsgf_model,
339
+ data_loader=self.data_loader, config=self.config,
340
+ nsgf_num_steps=nsgf_steps, device=self.device,
341
+ checkpoint_dir=self.checkpoint_dir,
342
+ )
343
+ results["phase_predictor"] = pt_trainer.train()
344
+ _save_checkpoint(
345
+ os.path.join(self.checkpoint_dir, "phase3_complete.pt"),
346
+ nsgf_model_state=self.nsgf_model.state_dict(),
347
+ nsf_model_state=self.nsf_model.state_dict(),
348
+ predictor_state=self.phase_predictor.state_dict(),
349
+ phase=3,
350
+ )
351
 
352
  logger.info("=" * 60)
353
  logger.info("NSGF++ training complete!")