InosLihka commited on
Commit
666b4ce
·
1 Parent(s): f0ca22d

Add SFT v3 + GRPO refine results to README + results.md

Browse files

GRPO refine on top of SFT v3 (200 steps, lr 1e-5, beta 0.1) lifted:
- OOD final_score: 0.536 -> 0.559 (+0.023, +4% relative)
- discrete-3 final_score: 0.507 -> 0.520 (+0.013)
- in-dist final_score: no change (within noise)

belief_MAE essentially unchanged across all conditions; the SFT-prime
distillation already extracted near-maximum inference quality from the
teacher.

Refined model uploaded to InosLihka/rhythm-env-meta-trained-sft-grpo-v1
with eval_results_v2.json containing full per-episode breakdown.

New artifact:
plots/sft_grpo_comparison.png (34KB) — heuristic / SFT v3 / SFT+GRPO
side-by-side bar chart across all 3 conditions, embedded in README

Helper scripts (added to repo for reproducibility):
scripts/verify_rubric_equivalence.py — replays eval episodes locally
under the new Rubric-based grader, verifies numerical equivalence
against saved eval JSON. Confirmed 105/105 episodes match within
float-precision tolerance after the Rubric refactor in f0ca22d.
scripts/plot_v3_results.py — generates the v3 baseline-vs-trained bar
chart from eval_results_v2.json.

README.md CHANGED
@@ -30,18 +30,22 @@ This is **meta-reinforcement learning** for personalization: the agent isn't tra
30
 
31
  A small (Qwen 2.5-3B + 4-bit + LoRA) student, distilled from a gpt-5.4 teacher via Algorithm Distillation, **beats the heuristic baseline on all three eval conditions**:
32
 
33
- | Condition | Random | Heuristic | **Distilled Qwen 3B** | Δ vs Heuristic | belief_MAE |
34
  |---|---|---|---|---|---|
35
- | **continuous in-distribution** | 0.393 | 0.463 | **0.574** | **+0.111** | 0.213 |
36
- | **continuous OOD (generalization)** | 0.393 | 0.455 | **0.536** | **+0.081** | 0.265 |
37
- | discrete-3-profiles (legacy) | 0.426 | 0.455 | **0.507** | +0.052 | 0.415 |
38
 
39
  The student's **belief_MAE of 0.213 in-distribution matches the gpt-5.4 teacher (0.196)** — the inference skill transferred nearly perfectly via SFT-prime. On OOD profiles the agent never saw, it still beats heuristic by +0.081, proving generalization (not memorization).
40
 
41
- For reference, the gpt-5.4 teacher (upper bound) hits 0.611 in-dist / 0.621 OOD on a 150-episode reeval. Full numbers in [docs/results.md](docs/results.md). Eval JSON: [eval_results_v2.json](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-v3/blob/main/eval_results_v2.json).
 
 
42
 
43
  ![v3 baseline vs trained across conditions](plots/sft_v3_baseline_vs_trained.png)
44
 
 
 
45
  ## Training evidence
46
 
47
  **SFT v3 loss curve** — distillation training on 5,040 (state, response) pairs from a gpt-5.4 teacher. Loss drops from 2.77 → 0.083 over 525 steps and stays converged. No overfitting.
 
30
 
31
  A small (Qwen 2.5-3B + 4-bit + LoRA) student, distilled from a gpt-5.4 teacher via Algorithm Distillation, **beats the heuristic baseline on all three eval conditions**:
32
 
33
+ | Condition | Random | Heuristic | **Distilled Qwen 3B** | **+ GRPO refine** | belief_MAE |
34
  |---|---|---|---|---|---|
35
+ | **continuous in-distribution** | 0.393 | 0.463 | **0.574** *(+0.111)* | 0.573 | 0.213 |
36
+ | **continuous OOD (generalization)** | 0.393 | 0.455 | 0.536 *(+0.081)* | **0.559** *(+0.104)* | 0.263 |
37
+ | discrete-3-profiles (legacy) | 0.426 | 0.455 | 0.507 *(+0.052)* | **0.520** *(+0.065)* | 0.430 |
38
 
39
  The student's **belief_MAE of 0.213 in-distribution matches the gpt-5.4 teacher (0.196)** — the inference skill transferred nearly perfectly via SFT-prime. On OOD profiles the agent never saw, it still beats heuristic by +0.081, proving generalization (not memorization).
40
 
41
+ A subsequent GRPO refine on top of the SFT'd student lifted **OOD generalization by another +0.023 (4% relative)** and discrete-3 by +0.013, with no in-dist regression. The GRPO-refined model is at [`InosLihka/rhythm-env-meta-trained-sft-grpo-v1`](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-grpo-v1).
42
+
43
+ For reference, the gpt-5.4 teacher (upper bound) hits 0.611 in-dist / 0.621 OOD on a 150-episode reeval. Full numbers in [docs/results.md](docs/results.md). Eval JSONs: [SFT v3](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-v3/blob/main/eval_results_v2.json) · [SFT v3 + GRPO](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-grpo-v1/blob/main/eval_results_v2.json).
44
 
45
  ![v3 baseline vs trained across conditions](plots/sft_v3_baseline_vs_trained.png)
46
 
47
+ ![SFT v3 vs SFT+GRPO comparison](plots/sft_grpo_comparison.png)
48
+
49
  ## Training evidence
50
 
51
  **SFT v3 loss curve** — distillation training on 5,040 (state, response) pairs from a gpt-5.4 teacher. Loss drops from 2.77 → 0.083 over 525 steps and stays converged. No overfitting.
docs/results.md CHANGED
@@ -62,14 +62,15 @@ that try get credit).
62
  ### Distilled Qwen 3B student — full eval across all 3 conditions
63
 
64
  10 episodes per condition for continuous, 5 episodes per discrete profile
65
- (15 total). Source: `eval_results_v2.json` on the
66
- [trained-model repo](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-v3).
 
67
 
68
- | Condition | Random | Heuristic | **Distilled Qwen 3B** | Δ vs Heuristic | belief_MAE |
69
  |---|---|---|---|---|---|
70
- | **continuous in-distribution** | 0.393 | 0.463 | **0.574** | **+0.111** | **0.213** |
71
- | **continuous OOD** | 0.393 | 0.455 | **0.536** | **+0.081** | 0.265 |
72
- | discrete-3-profiles (legacy) | 0.426 | 0.455 | **0.507** | +0.052 | 0.415 |
73
 
74
  **Interpretation:**
75
  - The student wins on **all three** conditions, with the largest margin
 
62
  ### Distilled Qwen 3B student — full eval across all 3 conditions
63
 
64
  10 episodes per condition for continuous, 5 episodes per discrete profile
65
+ (15 total). Sources:
66
+ - SFT v3 numbers: [`eval_results_v2.json`](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-v3/blob/main/eval_results_v2.json)
67
+ - SFT v3 + GRPO refine numbers: [`eval_results_v2.json`](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-grpo-v1/blob/main/eval_results_v2.json)
68
 
69
+ | Condition | Random | Heuristic | **SFT v3** | **+ GRPO refine** | belief_MAE (GRPO) |
70
  |---|---|---|---|---|---|
71
+ | **continuous in-distribution** | 0.393 | 0.463 | **0.574** *(+0.111)* | 0.573 | 0.216 |
72
+ | **continuous OOD** | 0.393 | 0.455 | 0.536 *(+0.081)* | **0.559** *(+0.104)* | 0.263 |
73
+ | discrete-3-profiles (legacy) | 0.426 | 0.455 | 0.507 *(+0.052)* | **0.520** *(+0.065)* | 0.430 |
74
 
75
  **Interpretation:**
76
  - The student wins on **all three** conditions, with the largest margin
plots/sft_grpo_comparison.png ADDED
scripts/verify_rubric_equivalence.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Numerical equivalence check for the Rubric refactor.
3
+
4
+ Replays each episode from a saved eval JSON (which was scored by the
5
+ OLD `_grade_episode`) under the LOCAL code (which has the NEW Rubric-
6
+ based grader). If `final_score` matches for every episode within float
7
+ precision, the refactor is functionally identical to the original.
8
+
9
+ Usage:
10
+ python scripts/verify_rubric_equivalence.py outputs/sft-v3/eval_results_v2.json
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import os
16
+ import random
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
+
22
+ from models import ActionType, RhythmAction
23
+ from server.rhythm_environment import MAX_STEPS, RhythmEnvironment
24
+
25
+
26
+ def replay_episode(seed: int, actions: list[str], final_belief: list[float] | None,
27
+ profile: str | None) -> float:
28
+ """Replay one episode locally under the new grader."""
29
+ env = RhythmEnvironment()
30
+ if profile:
31
+ obs = env.reset(seed=seed, profile=profile)
32
+ else:
33
+ obs = env.reset(seed=seed)
34
+
35
+ for i, action_name in enumerate(actions):
36
+ if obs.done:
37
+ break
38
+ # Set the final belief on the LAST step (matches inference_eval.py
39
+ # behavior: record_belief was called every step but only the latest
40
+ # matters for the grader)
41
+ is_last = (i == len(actions) - 1) or (i == MAX_STEPS - 1)
42
+ if is_last and final_belief is not None:
43
+ env.record_belief(final_belief)
44
+ rhythm_action = RhythmAction(action_type=ActionType(action_name))
45
+ obs = env.step(rhythm_action)
46
+
47
+ return obs.reward_breakdown.get("final_score", 0.0)
48
+
49
+
50
+ def main() -> None:
51
+ ap = argparse.ArgumentParser()
52
+ ap.add_argument("eval_json", help="Path to eval_results JSON to verify against")
53
+ ap.add_argument("--tolerance", type=float, default=1e-4,
54
+ help="Allowed |new - old| difference (default 1e-4)")
55
+ args = ap.parse_args()
56
+
57
+ with open(args.eval_json) as f:
58
+ rows = json.load(f)
59
+ print(f"Loaded {len(rows)} episodes from {args.eval_json}")
60
+ print()
61
+
62
+ matches = 0
63
+ mismatches = []
64
+ for row in rows:
65
+ seed = row["seed"]
66
+ actions = row.get("actions", [])
67
+ final_belief = row.get("final_belief") # null for heuristic/random
68
+ profile = row.get("profile_name") if row.get("profile_mode") == "discrete" else None
69
+ # Some rows have profile_mode='discrete' with explicit profile
70
+ # name (the 3 reference profiles). Pass it via kwarg.
71
+ if profile and not profile.startswith("sampled_"):
72
+ replay_profile = profile
73
+ else:
74
+ replay_profile = None
75
+
76
+ old_score = row["final_score"]
77
+ new_score = replay_episode(seed, actions, final_belief, replay_profile)
78
+
79
+ diff = abs(new_score - old_score)
80
+ if diff <= args.tolerance:
81
+ matches += 1
82
+ else:
83
+ mismatches.append({
84
+ "seed": seed,
85
+ "strategy": row["strategy"],
86
+ "condition": row["condition"],
87
+ "old": old_score,
88
+ "new": new_score,
89
+ "diff": diff,
90
+ })
91
+
92
+ print(f"=" * 60)
93
+ print(f"RESULT: {matches}/{len(rows)} episodes match within ±{args.tolerance}")
94
+ print(f"=" * 60)
95
+ if mismatches:
96
+ print()
97
+ print(f"MISMATCHES ({len(mismatches)}):")
98
+ for m in mismatches[:10]:
99
+ print(f" seed={m['seed']:>5} {m['strategy']:>10} "
100
+ f"{m['condition']:<35} old={m['old']:.4f} "
101
+ f"new={m['new']:.4f} diff={m['diff']:.4f}")
102
+ if len(mismatches) > 10:
103
+ print(f" ... and {len(mismatches) - 10} more")
104
+ sys.exit(1)
105
+ else:
106
+ print()
107
+ print("REFACTOR IS NUMERICALLY EQUIVALENT to the old grader.")
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()