Spaces:
Running
Running
Upload train.py
Browse files
train.py
CHANGED
|
@@ -22,6 +22,7 @@ import json
|
|
| 22 |
import re
|
| 23 |
import time
|
| 24 |
import os
|
|
|
|
| 25 |
from dataclasses import dataclass, field
|
| 26 |
from typing import Optional, List
|
| 27 |
|
|
@@ -672,10 +673,21 @@ def main():
|
|
| 672 |
CHECKPOINT_DIR = "./phase2_checkpoints"
|
| 673 |
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
for task_id in tasks:
|
| 676 |
print(f"\n{'='*60}")
|
| 677 |
print(f"[TRAIN] Training on task: {task_id}")
|
| 678 |
-
print(f"{'='*60}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
|
| 680 |
task_rewards = []
|
| 681 |
|
|
@@ -717,11 +729,18 @@ def main():
|
|
| 717 |
f"Rolling avg (10): {rolling_avg:.3f}"
|
| 718 |
)
|
| 719 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
# Small delay to avoid hammering the env
|
| 721 |
time.sleep(0.1)
|
| 722 |
|
| 723 |
reward_history[task_id] = task_rewards
|
| 724 |
|
|
|
|
|
|
|
|
|
|
| 725 |
# Summary for this task
|
| 726 |
if task_rewards:
|
| 727 |
first_10 = sum(task_rewards[:10]) / min(10, len(task_rewards))
|
|
|
|
| 22 |
import re
|
| 23 |
import time
|
| 24 |
import os
|
| 25 |
+
import csv
|
| 26 |
from dataclasses import dataclass, field
|
| 27 |
from typing import Optional, List
|
| 28 |
|
|
|
|
| 673 |
CHECKPOINT_DIR = "./phase2_checkpoints"
|
| 674 |
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 675 |
|
| 676 |
+
# CSV logging dir
|
| 677 |
+
CSV_LOG_DIR = "./logs"
|
| 678 |
+
os.makedirs(CSV_LOG_DIR, exist_ok=True)
|
| 679 |
+
|
| 680 |
for task_id in tasks:
|
| 681 |
print(f"\n{'='*60}")
|
| 682 |
print(f"[TRAIN] Training on task: {task_id}")
|
| 683 |
+
print(f"{'='*60}\n")
|
| 684 |
+
|
| 685 |
+
# Initialize CSV file for this task
|
| 686 |
+
csv_path = os.path.join(CSV_LOG_DIR, f"{task_id}_results.csv")
|
| 687 |
+
csv_file = open(csv_path, "w", newline="")
|
| 688 |
+
csv_writer = csv.writer(csv_file)
|
| 689 |
+
csv_writer.writerow(["episode", "reward", "steps"]) # Header
|
| 690 |
+
print(f"[LOG] Tracking results -> {csv_path}\n")
|
| 691 |
|
| 692 |
task_rewards = []
|
| 693 |
|
|
|
|
| 729 |
f"Rolling avg (10): {rolling_avg:.3f}"
|
| 730 |
)
|
| 731 |
|
| 732 |
+
# Log to CSV
|
| 733 |
+
csv_writer.writerow([ep, f"{total_reward:.4f}", steps])
|
| 734 |
+
csv_file.flush()
|
| 735 |
+
|
| 736 |
# Small delay to avoid hammering the env
|
| 737 |
time.sleep(0.1)
|
| 738 |
|
| 739 |
reward_history[task_id] = task_rewards
|
| 740 |
|
| 741 |
+
# Close CSV file for this task
|
| 742 |
+
csv_file.close()
|
| 743 |
+
|
| 744 |
# Summary for this task
|
| 745 |
if task_rewards:
|
| 746 |
first_10 = sum(task_rewards[:10]) / min(10, len(task_rewards))
|