OGrohit commited on
Commit
922e5a7
·
verified ·
1 Parent(s): 2fee0ff

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +20 -1
train.py CHANGED
@@ -22,6 +22,7 @@ import json
22
  import re
23
  import time
24
  import os
 
25
  from dataclasses import dataclass, field
26
  from typing import Optional, List
27
 
@@ -672,10 +673,21 @@ def main():
672
  CHECKPOINT_DIR = "./phase2_checkpoints"
673
  os.makedirs(CHECKPOINT_DIR, exist_ok=True)
674
 
 
 
 
 
675
  for task_id in tasks:
676
  print(f"\n{'='*60}")
677
  print(f"[TRAIN] Training on task: {task_id}")
678
- print(f"{'='*60}")
 
 
 
 
 
 
 
679
 
680
  task_rewards = []
681
 
@@ -717,11 +729,18 @@ def main():
717
  f"Rolling avg (10): {rolling_avg:.3f}"
718
  )
719
 
 
 
 
 
720
  # Small delay to avoid hammering the env
721
  time.sleep(0.1)
722
 
723
  reward_history[task_id] = task_rewards
724
 
 
 
 
725
  # Summary for this task
726
  if task_rewards:
727
  first_10 = sum(task_rewards[:10]) / min(10, len(task_rewards))
 
22
  import re
23
  import time
24
  import os
25
+ import csv
26
  from dataclasses import dataclass, field
27
  from typing import Optional, List
28
 
 
673
  CHECKPOINT_DIR = "./phase2_checkpoints"
674
  os.makedirs(CHECKPOINT_DIR, exist_ok=True)
675
 
676
+ # CSV logging dir
677
+ CSV_LOG_DIR = "./logs"
678
+ os.makedirs(CSV_LOG_DIR, exist_ok=True)
679
+
680
  for task_id in tasks:
681
  print(f"\n{'='*60}")
682
  print(f"[TRAIN] Training on task: {task_id}")
683
+ print(f"{'='*60}\n")
684
+
685
+ # Initialize CSV file for this task
686
+ csv_path = os.path.join(CSV_LOG_DIR, f"{task_id}_results.csv")
687
+ csv_file = open(csv_path, "w", newline="")
688
+ csv_writer = csv.writer(csv_file)
689
+ csv_writer.writerow(["episode", "reward", "steps"]) # Header
690
+ print(f"[LOG] Tracking results -> {csv_path}\n")
691
 
692
  task_rewards = []
693
 
 
729
  f"Rolling avg (10): {rolling_avg:.3f}"
730
  )
731
 
732
+ # Log to CSV
733
+ csv_writer.writerow([ep, f"{total_reward:.4f}", steps])
734
+ csv_file.flush()
735
+
736
  # Small delay to avoid hammering the env
737
  time.sleep(0.1)
738
 
739
  reward_history[task_id] = task_rewards
740
 
741
+ # Close CSV file for this task
742
+ csv_file.close()
743
+
744
  # Summary for this task
745
  if task_rewards:
746
  first_10 = sum(task_rewards[:10]) / min(10, len(task_rewards))