rogermt commited on
Commit
9e3fccc
·
verified ·
1 Parent(s): 376238e

Fix DataLoader batch size mismatch across training phases + --train-iters now overrides all phases

Browse files

- DatasetLoader.sample_target() recreates DataLoader when batch size changes
(Phase 1 pool building uses 256, Phase 2 NSF uses 128 — caused RuntimeError)
- --train-iters CLI flag now also overrides time_predictor iterations"

Files changed (1) hide show
  1. dataset_loader.py +4 -1
dataset_loader.py CHANGED
@@ -140,7 +140,10 @@ class DatasetLoader:
140
 
141
  def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor:
142
  if self.is_image:
143
- if not hasattr(self, "_image_loader"):
 
 
 
144
  self._image_loader = get_image_dataloader(
145
  self.dataset_name, batch_size=n, train=True
146
  )
 
140
 
141
  def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor:
142
  if self.is_image:
143
+ # Recreate DataLoader if batch size changed (different training phases
144
+ # use different batch sizes, e.g. 256 for pool building, 128 for NSF)
145
+ if not hasattr(self, "_image_loader") or self._image_batch_size != n:
146
+ self._image_batch_size = n
147
  self._image_loader = get_image_dataloader(
148
  self.dataset_name, batch_size=n, train=True
149
  )