rogermt commited on
Commit
a365009
·
verified ·
1 Parent(s): 8b62ba9

main.py: --resume-phase, --checkpoint-dir, --sinkhorn-batch flags

Browse files
Files changed (1) hide show
  1. main.py +43 -17
main.py CHANGED
@@ -8,11 +8,12 @@ Orchestrates the full experiment pipeline:
8
  5. Evaluate (W2 for 2D, FID/IS for images)
9
  6. Visualize results
10
 
 
 
11
  Usage:
12
  python main.py --experiment 2d --dataset 8gaussians --steps 10
13
- python main.py --experiment 2d --dataset 8gaussians --steps 10 --device cuda
14
- python main.py --experiment 2d --dataset 8gaussians --steps 10 --device cpu
15
  python main.py --experiment mnist --device cuda
 
16
  python main.py --experiment cifar10 --device cuda
17
 
18
  Reference: arXiv:2401.14069 (Neural Sinkhorn Gradient Flow)
@@ -93,12 +94,10 @@ def run_2d_experiment(config: dict, args):
93
  data_loader=data_loader,
94
  config=config,
95
  device=device,
 
96
  )
97
 
98
- # Build trajectory pool
99
  trainer.build_trajectory_pool()
100
-
101
- # Train velocity field
102
  history = trainer.train()
103
 
104
  train_time = time.time() - start_time
@@ -116,8 +115,6 @@ def run_2d_experiment(config: dict, args):
116
  )
117
 
118
  samples = sampler.sample(num_eval)
119
-
120
- # Also get trajectory for visualization
121
  trajectory = sampler.sample_trajectory(min(200, num_eval))
122
 
123
  # ---- Evaluation ----
@@ -139,17 +136,14 @@ def run_2d_experiment(config: dict, args):
139
  title=f"NSGF — {config['dataset']} ({num_steps} steps), W2={metrics.get('w2', 0):.4f}",
140
  save_path=f"results/nsgf_2d_{config['dataset']}_{num_steps}steps.png",
141
  )
142
-
143
  plot_2d_trajectory(
144
  trajectory, test_samples,
145
  title=f"NSGF Trajectory — {config['dataset']}",
146
  save_path=f"results/nsgf_trajectory_{config['dataset']}_{num_steps}steps.png",
147
  )
148
 
149
- # Save model
150
  torch.save(model.state_dict(), f"results/nsgf_2d_{config['dataset']}.pt")
151
  logger.info("Model saved.")
152
-
153
  return metrics
154
 
155
 
@@ -159,6 +153,8 @@ def run_image_experiment(config: dict, args, dataset_name: str):
159
  Reference: Section 5.2, Appendix E.2
160
  """
161
  device = get_device(args)
 
 
162
  logger.info(f"Running {dataset_name.upper()} experiment on {device}")
163
 
164
  # Override from args
@@ -168,6 +164,11 @@ def run_image_experiment(config: dict, args, dataset_name: str):
168
  config["nsgf_training"]["num_iterations"] = args.train_iters
169
  config["nsf_training"]["num_iterations"] = args.train_iters
170
  config["time_predictor"]["num_iterations"] = args.train_iters
 
 
 
 
 
171
 
172
  # Setup
173
  data_loader = DatasetLoader(config)
@@ -177,6 +178,25 @@ def run_image_experiment(config: dict, args, dataset_name: str):
177
  nsf_model = create_velocity_unet(config)
178
  phase_predictor = create_phase_predictor(config)
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  logger.info(f"NSGF UNet params: {sum(p.numel() for p in nsgf_model.parameters()):,}")
181
  logger.info(f"NSF UNet params: {sum(p.numel() for p in nsf_model.parameters()):,}")
182
  logger.info(f"Phase predictor params: {sum(p.numel() for p in phase_predictor.parameters()):,}")
@@ -191,9 +211,10 @@ def run_image_experiment(config: dict, args, dataset_name: str):
191
  data_loader=data_loader,
192
  config=config,
193
  device=device,
 
194
  )
195
 
196
- results = pp_trainer.train_all()
197
 
198
  train_time = time.time() - start_time
199
  logger.info(f"Training completed in {train_time:.1f}s")
@@ -215,7 +236,6 @@ def run_image_experiment(config: dict, args, dataset_name: str):
215
  )
216
 
217
  logger.info(f"Generating {num_gen} samples...")
218
- # Generate in batches to avoid OOM
219
  batch_size = 128
220
  all_samples = []
221
  for i in range(0, num_gen, batch_size):
@@ -225,10 +245,7 @@ def run_image_experiment(config: dict, args, dataset_name: str):
225
  generated = torch.cat(all_samples, dim=0)
226
 
227
  # ---- Evaluation ----
228
- # Get test set
229
- eval_loader = data_loader
230
- test_samples = eval_loader.get_test_samples(num_gen, device="cpu")
231
-
232
  evaluator = Evaluation(config, device)
233
  metrics = evaluator.evaluate(generated, test_samples)
234
 
@@ -248,7 +265,7 @@ def run_image_experiment(config: dict, args, dataset_name: str):
248
  save_path=f"results/nsgf_pp_{dataset_name}_samples.png",
249
  )
250
 
251
- # Save models
252
  torch.save(nsgf_model.state_dict(), f"results/nsgf_{dataset_name}_nsgf.pt")
253
  torch.save(nsf_model.state_dict(), f"results/nsgf_{dataset_name}_nsf.pt")
254
  torch.save(phase_predictor.state_dict(), f"results/nsgf_{dataset_name}_predictor.pt")
@@ -268,11 +285,20 @@ def main():
268
  parser.add_argument("--steps", type=int, default=None, help="Number of flow steps")
269
  parser.add_argument("--pool-batches", type=int, default=None, help="Pool building batches")
270
  parser.add_argument("--train-iters", type=int, default=None, help="Training iterations")
 
 
271
  parser.add_argument("--config", type=str, default="config.yaml", help="Config file path")
272
  parser.add_argument("--seed", type=int, default=42, help="Random seed")
273
  parser.add_argument("--device", type=str, default=None,
274
  choices=["cpu", "cuda"],
275
  help="Force device (default: auto-detect)")
 
 
 
 
 
 
 
276
 
277
  args = parser.parse_args()
278
 
 
8
  5. Evaluate (W2 for 2D, FID/IS for images)
9
  6. Visualize results
10
 
11
+ Supports --resume-phase to continue from a checkpoint after interruption.
12
+
13
  Usage:
14
  python main.py --experiment 2d --dataset 8gaussians --steps 10
 
 
15
  python main.py --experiment mnist --device cuda
16
+ python main.py --experiment mnist --resume-phase 2 # skip Phase 1, load checkpoint
17
  python main.py --experiment cifar10 --device cuda
18
 
19
  Reference: arXiv:2401.14069 (Neural Sinkhorn Gradient Flow)
 
94
  data_loader=data_loader,
95
  config=config,
96
  device=device,
97
+ checkpoint_dir=args.checkpoint_dir,
98
  )
99
 
 
100
  trainer.build_trajectory_pool()
 
 
101
  history = trainer.train()
102
 
103
  train_time = time.time() - start_time
 
115
  )
116
 
117
  samples = sampler.sample(num_eval)
 
 
118
  trajectory = sampler.sample_trajectory(min(200, num_eval))
119
 
120
  # ---- Evaluation ----
 
136
  title=f"NSGF — {config['dataset']} ({num_steps} steps), W2={metrics.get('w2', 0):.4f}",
137
  save_path=f"results/nsgf_2d_{config['dataset']}_{num_steps}steps.png",
138
  )
 
139
  plot_2d_trajectory(
140
  trajectory, test_samples,
141
  title=f"NSGF Trajectory — {config['dataset']}",
142
  save_path=f"results/nsgf_trajectory_{config['dataset']}_{num_steps}steps.png",
143
  )
144
 
 
145
  torch.save(model.state_dict(), f"results/nsgf_2d_{config['dataset']}.pt")
146
  logger.info("Model saved.")
 
147
  return metrics
148
 
149
 
 
153
  Reference: Section 5.2, Appendix E.2
154
  """
155
  device = get_device(args)
156
+ checkpoint_dir = args.checkpoint_dir
157
+ resume_phase = args.resume_phase
158
  logger.info(f"Running {dataset_name.upper()} experiment on {device}")
159
 
160
  # Override from args
 
164
  config["nsgf_training"]["num_iterations"] = args.train_iters
165
  config["nsf_training"]["num_iterations"] = args.train_iters
166
  config["time_predictor"]["num_iterations"] = args.train_iters
167
+ if args.sinkhorn_batch:
168
+ config["sinkhorn"]["batch_size"] = args.sinkhorn_batch
169
+
170
+ # Inject checkpoint_every into config for trainers to read
171
+ config["checkpoint_every"] = args.checkpoint_every
172
 
173
  # Setup
174
  data_loader = DatasetLoader(config)
 
178
  nsf_model = create_velocity_unet(config)
179
  phase_predictor = create_phase_predictor(config)
180
 
181
+ # Load checkpoints if resuming
182
+ if resume_phase > 1:
183
+ ckpt_path = os.path.join(checkpoint_dir, f"phase{resume_phase - 1}_complete.pt")
184
+ if os.path.exists(ckpt_path):
185
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
186
+ if "nsgf_model_state" in ckpt:
187
+ nsgf_model.load_state_dict(ckpt["nsgf_model_state"])
188
+ logger.info(f"Loaded NSGF model from {ckpt_path}")
189
+ if "nsf_model_state" in ckpt:
190
+ nsf_model.load_state_dict(ckpt["nsf_model_state"])
191
+ logger.info(f"Loaded NSF model from {ckpt_path}")
192
+ if "predictor_state" in ckpt:
193
+ phase_predictor.load_state_dict(ckpt["predictor_state"])
194
+ logger.info(f"Loaded phase predictor from {ckpt_path}")
195
+ else:
196
+ logger.error(f"Checkpoint not found: {ckpt_path}")
197
+ logger.error(f"Cannot resume from phase {resume_phase} without phase {resume_phase - 1} checkpoint.")
198
+ sys.exit(1)
199
+
200
  logger.info(f"NSGF UNet params: {sum(p.numel() for p in nsgf_model.parameters()):,}")
201
  logger.info(f"NSF UNet params: {sum(p.numel() for p in nsf_model.parameters()):,}")
202
  logger.info(f"Phase predictor params: {sum(p.numel() for p in phase_predictor.parameters()):,}")
 
211
  data_loader=data_loader,
212
  config=config,
213
  device=device,
214
+ checkpoint_dir=checkpoint_dir,
215
  )
216
 
217
+ results = pp_trainer.train_all(resume_phase=resume_phase)
218
 
219
  train_time = time.time() - start_time
220
  logger.info(f"Training completed in {train_time:.1f}s")
 
236
  )
237
 
238
  logger.info(f"Generating {num_gen} samples...")
 
239
  batch_size = 128
240
  all_samples = []
241
  for i in range(0, num_gen, batch_size):
 
245
  generated = torch.cat(all_samples, dim=0)
246
 
247
  # ---- Evaluation ----
248
+ test_samples = data_loader.get_test_samples(num_gen, device="cpu")
 
 
 
249
  evaluator = Evaluation(config, device)
250
  metrics = evaluator.evaluate(generated, test_samples)
251
 
 
265
  save_path=f"results/nsgf_pp_{dataset_name}_samples.png",
266
  )
267
 
268
+ # Save final models
269
  torch.save(nsgf_model.state_dict(), f"results/nsgf_{dataset_name}_nsgf.pt")
270
  torch.save(nsf_model.state_dict(), f"results/nsgf_{dataset_name}_nsf.pt")
271
  torch.save(phase_predictor.state_dict(), f"results/nsgf_{dataset_name}_predictor.pt")
 
285
  parser.add_argument("--steps", type=int, default=None, help="Number of flow steps")
286
  parser.add_argument("--pool-batches", type=int, default=None, help="Pool building batches")
287
  parser.add_argument("--train-iters", type=int, default=None, help="Training iterations")
288
+ parser.add_argument("--sinkhorn-batch", type=int, default=None,
289
+ help="Sinkhorn batch size for pool building (reduce for OOM)")
290
  parser.add_argument("--config", type=str, default="config.yaml", help="Config file path")
291
  parser.add_argument("--seed", type=int, default=42, help="Random seed")
292
  parser.add_argument("--device", type=str, default=None,
293
  choices=["cpu", "cuda"],
294
  help="Force device (default: auto-detect)")
295
+ parser.add_argument("--checkpoint-dir", type=str, default="checkpoints",
296
+ help="Directory for saving/loading checkpoints")
297
+ parser.add_argument("--checkpoint-every", type=int, default=5000,
298
+ help="Save checkpoint every N training steps")
299
+ parser.add_argument("--resume-phase", type=int, default=1,
300
+ choices=[1, 2, 3],
301
+ help="Resume from phase N (loads phase N-1 checkpoint)")
302
 
303
  args = parser.parse_args()
304