Add pool.finalize() call after building trajectory pool for O(1) sampling
Browse files- trainer.py +3 -0
trainer.py
CHANGED
|
@@ -74,6 +74,9 @@ class NSGFTrainer:
|
|
| 74 |
if (batch_idx + 1) % max(1, num_batches // 10) == 0:
|
| 75 |
logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}")
|
| 76 |
logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}")
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
def train(self) -> Dict[str, list]:
|
| 79 |
self.model.train()
|
|
|
|
| 74 |
if (batch_idx + 1) % max(1, num_batches // 10) == 0:
|
| 75 |
logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}")
|
| 76 |
logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}")
|
| 77 |
+
# Pre-concatenate for O(1) sampling during training
|
| 78 |
+
self.pool.finalize()
|
| 79 |
+
logger.info("Trajectory pool finalized (pre-concatenated for fast sampling).")
|
| 80 |
|
| 81 |
def train(self) -> Dict[str, list]:
|
| 82 |
self.model.train()
|