Fix DataLoader batch size mismatch across training phases + --train-iters now overrides all phases
Browse files- DatasetLoader.sample_target() recreates DataLoader when batch size changes
(Phase 1 pool building uses 256, Phase 2 NSF uses 128 — caused RuntimeError)
- --train-iters CLI flag now also overrides time_predictor iterations"
- dataset_loader.py +4 -1
dataset_loader.py
CHANGED
|
@@ -140,7 +140,10 @@ class DatasetLoader:
|
|
| 140 |
|
| 141 |
def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor:
|
| 142 |
if self.is_image:
|
| 143 |
-
if
|
|
|
|
|
|
|
|
|
|
| 144 |
self._image_loader = get_image_dataloader(
|
| 145 |
self.dataset_name, batch_size=n, train=True
|
| 146 |
)
|
|
|
|
| 140 |
|
| 141 |
def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor:
|
| 142 |
if self.is_image:
|
| 143 |
+
# Recreate DataLoader if batch size changed (different training phases
|
| 144 |
+
# use different batch sizes, e.g. 256 for pool building, 128 for NSF)
|
| 145 |
+
if not hasattr(self, "_image_loader") or self._image_batch_size != n:
|
| 146 |
+
self._image_batch_size = n
|
| 147 |
self._image_loader = get_image_dataloader(
|
| 148 |
self.dataset_name, batch_size=n, train=True
|
| 149 |
)
|