YashashMathur commited on
Commit
7a19964
Β·
verified Β·
1 Parent(s): 9a90a52

Fix: Use Level 1 scenarios only for training

Browse files
Files changed (1) hide show
  1. train.py +8 -3
train.py CHANGED
@@ -48,7 +48,7 @@ SFT_STEPS = 40 # More warmup for JSON format
48
  GRPO_STEPS = 250
49
  GRPO_K = 2
50
  GRPO_LR = 1e-5
51
- CURRICULUM_SWITCH = 150
52
  GRAD_CLIP = 1.0
53
  SAVE_EVERY = 50
54
 
@@ -100,9 +100,14 @@ def sig(s):
100
 
101
 
102
  dataset = list({sig(s): s for s in raw}.values())
 
 
 
 
 
103
  rng = random.Random(42)
104
  families = defaultdict(list)
105
- for s in dataset:
106
  families[(s.get("decision"), s.get("violation_type"))].append(s)
107
 
108
  train_set, eval_set = [], []
@@ -111,7 +116,7 @@ for items in families.values():
111
  k = int(len(items) * 0.8)
112
  train_set.extend(items[:k])
113
  eval_set.extend(items[k:])
114
- print(f"Dataset: {len(train_set)} train / {len(eval_set)} eval")
115
 
116
  # ─── Policy Rules + Reward ────────────────────────────────────────────────────
117
  POLICY_RULES = {
 
48
  GRPO_STEPS = 250
49
  GRPO_K = 2
50
  GRPO_LR = 1e-5
51
+ CURRICULUM_SWITCH = 0 # Start with Level 1, advance early
52
  GRAD_CLIP = 1.0
53
  SAVE_EVERY = 50
54
 
 
100
 
101
 
102
  dataset = list({sig(s): s for s in raw}.values())
103
+
104
+ # Filter for Level 1 scenarios only (for early training)
105
+ level1_data = [s for s in dataset if s.get("level", 1) == 1]
106
+ print(f"Level 1 scenarios: {len(level1_data)} / {len(dataset)}")
107
+
108
  rng = random.Random(42)
109
  families = defaultdict(list)
110
+ for s in level1_data:
111
  families[(s.get("decision"), s.get("violation_type"))].append(s)
112
 
113
  train_set, eval_set = [], []
 
116
  k = int(len(items) * 0.8)
117
  train_set.extend(items[:k])
118
  eval_set.extend(items[k:])
119
+ print(f"Dataset: {len(train_set)} train / {len(eval_set)} eval (Level 1 only)")
120
 
121
  # ─── Policy Rules + Reward ────────────────────────────────────────────────────
122
  POLICY_RULES = {