Spaces:
Runtime error
Runtime error
Fix: Use Level 1 scenarios only for training
Browse files
train.py
CHANGED
|
@@ -48,7 +48,7 @@ SFT_STEPS = 40 # More warmup for JSON format
|
|
| 48 |
GRPO_STEPS = 250
|
| 49 |
GRPO_K = 2
|
| 50 |
GRPO_LR = 1e-5
|
| 51 |
-
CURRICULUM_SWITCH =
|
| 52 |
GRAD_CLIP = 1.0
|
| 53 |
SAVE_EVERY = 50
|
| 54 |
|
|
@@ -100,9 +100,14 @@ def sig(s):
|
|
| 100 |
|
| 101 |
|
| 102 |
dataset = list({sig(s): s for s in raw}.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
rng = random.Random(42)
|
| 104 |
families = defaultdict(list)
|
| 105 |
-
for s in
|
| 106 |
families[(s.get("decision"), s.get("violation_type"))].append(s)
|
| 107 |
|
| 108 |
train_set, eval_set = [], []
|
|
@@ -111,7 +116,7 @@ for items in families.values():
|
|
| 111 |
k = int(len(items) * 0.8)
|
| 112 |
train_set.extend(items[:k])
|
| 113 |
eval_set.extend(items[k:])
|
| 114 |
-
print(f"Dataset: {len(train_set)} train / {len(eval_set)} eval")
|
| 115 |
|
| 116 |
# βββ Policy Rules + Reward ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
POLICY_RULES = {
|
|
|
|
| 48 |
GRPO_STEPS = 250
|
| 49 |
GRPO_K = 2
|
| 50 |
GRPO_LR = 1e-5
|
| 51 |
+
CURRICULUM_SWITCH = 0 # Start with Level 1, advance early
|
| 52 |
GRAD_CLIP = 1.0
|
| 53 |
SAVE_EVERY = 50
|
| 54 |
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
dataset = list({sig(s): s for s in raw}.values())
|
| 103 |
+
|
| 104 |
+
# Filter for Level 1 scenarios only (for early training)
|
| 105 |
+
level1_data = [s for s in dataset if s.get("level", 1) == 1]
|
| 106 |
+
print(f"Level 1 scenarios: {len(level1_data)} / {len(dataset)}")
|
| 107 |
+
|
| 108 |
rng = random.Random(42)
|
| 109 |
families = defaultdict(list)
|
| 110 |
+
for s in level1_data:
|
| 111 |
families[(s.get("decision"), s.get("violation_type"))].append(s)
|
| 112 |
|
| 113 |
train_set, eval_set = [], []
|
|
|
|
| 116 |
k = int(len(items) * 0.8)
|
| 117 |
train_set.extend(items[:k])
|
| 118 |
eval_set.extend(items[k:])
|
| 119 |
+
print(f"Dataset: {len(train_set)} train / {len(eval_set)} eval (Level 1 only)")
|
| 120 |
|
| 121 |
# βββ Policy Rules + Reward ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
POLICY_RULES = {
|