Spaces:
Sleeping
Add SFT v3 + GRPO refine results to README + results.md
Browse filesGRPO refine on top of SFT v3 (200 steps, lr 1e-5, beta 0.1) lifted:
- OOD final_score: 0.536 -> 0.559 (+0.023, +4% relative)
- discrete-3 final_score: 0.507 -> 0.520 (+0.013)
- in-dist final_score: no change (within noise)
belief_MAE essentially unchanged across all conditions; the SFT-prime
distillation already extracted near-maximum inference quality from the
teacher.
Refined model uploaded to InosLihka/rhythm-env-meta-trained-sft-grpo-v1
with eval_results_v2.json containing full per-episode breakdown.
New artifact:
plots/sft_grpo_comparison.png (34KB) — heuristic / SFT v3 / SFT+GRPO
side-by-side bar chart across all 3 conditions, embedded in README
Helper scripts (added to repo for reproducibility):
scripts/verify_rubric_equivalence.py — replays eval episodes locally
under the new Rubric-based grader, verifies numerical equivalence
against saved eval JSON. Confirmed 105/105 episodes match within
float-precision tolerance after the Rubric refactor in f0ca22d.
scripts/plot_v3_results.py — generates the v3 baseline-vs-trained bar
chart from eval_results_v2.json.
- README.md +9 -5
- docs/results.md +7 -6
- plots/sft_grpo_comparison.png +0 -0
- scripts/verify_rubric_equivalence.py +111 -0
|
@@ -30,18 +30,22 @@ This is **meta-reinforcement learning** for personalization: the agent isn't tra
|
|
| 30 |
|
| 31 |
A small (Qwen 2.5-3B + 4-bit + LoRA) student, distilled from a gpt-5.4 teacher via Algorithm Distillation, **beats the heuristic baseline on all three eval conditions**:
|
| 32 |
|
| 33 |
-
| Condition | Random | Heuristic | **Distilled Qwen 3B** |
|
| 34 |
|---|---|---|---|---|---|
|
| 35 |
-
| **continuous in-distribution** | 0.393 | 0.463 | **0.574**
|
| 36 |
-
| **continuous OOD (generalization)** | 0.393 | 0.455 |
|
| 37 |
-
| discrete-3-profiles (legacy) | 0.426 | 0.455 |
|
| 38 |
|
| 39 |
The student's **belief_MAE of 0.213 in-distribution matches the gpt-5.4 teacher (0.196)** — the inference skill transferred nearly perfectly via SFT-prime. On OOD profiles the agent never saw, it still beats heuristic by +0.081, proving generalization (not memorization).
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
|
| 43 |

|
| 44 |
|
|
|
|
|
|
|
| 45 |
## Training evidence
|
| 46 |
|
| 47 |
**SFT v3 loss curve** — distillation training on 5,040 (state, response) pairs from a gpt-5.4 teacher. Loss drops from 2.77 → 0.083 over 525 steps and stays converged. No overfitting.
|
|
|
|
| 30 |
|
| 31 |
A small (Qwen 2.5-3B + 4-bit + LoRA) student, distilled from a gpt-5.4 teacher via Algorithm Distillation, **beats the heuristic baseline on all three eval conditions**:
|
| 32 |
|
| 33 |
+
| Condition | Random | Heuristic | **Distilled Qwen 3B** | **+ GRPO refine** | belief_MAE |
|
| 34 |
|---|---|---|---|---|---|
|
| 35 |
+
| **continuous in-distribution** | 0.393 | 0.463 | **0.574** *(+0.111)* | 0.573 | 0.213 |
|
| 36 |
+
| **continuous OOD (generalization)** | 0.393 | 0.455 | 0.536 *(+0.081)* | **0.559** *(+0.104)* | 0.263 |
|
| 37 |
+
| discrete-3-profiles (legacy) | 0.426 | 0.455 | 0.507 *(+0.052)* | **0.520** *(+0.065)* | 0.430 |
|
| 38 |
|
| 39 |
The student's **belief_MAE of 0.213 in-distribution matches the gpt-5.4 teacher (0.196)** — the inference skill transferred nearly perfectly via SFT-prime. On OOD profiles the agent never saw, it still beats heuristic by +0.081, proving generalization (not memorization).
|
| 40 |
|
| 41 |
+
A subsequent GRPO refine on top of the SFT'd student lifted **OOD generalization by another +0.023 (4% relative)** and discrete-3 by +0.013, with no in-dist regression. The GRPO-refined model is at [`InosLihka/rhythm-env-meta-trained-sft-grpo-v1`](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-grpo-v1).
|
| 42 |
+
|
| 43 |
+
For reference, the gpt-5.4 teacher (upper bound) hits 0.611 in-dist / 0.621 OOD on a 150-episode reeval. Full numbers in [docs/results.md](docs/results.md). Eval JSONs: [SFT v3](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-v3/blob/main/eval_results_v2.json) · [SFT v3 + GRPO](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-grpo-v1/blob/main/eval_results_v2.json).
|
| 44 |
|
| 45 |

|
| 46 |
|
| 47 |
+

|
| 48 |
+
|
| 49 |
## Training evidence
|
| 50 |
|
| 51 |
**SFT v3 loss curve** — distillation training on 5,040 (state, response) pairs from a gpt-5.4 teacher. Loss drops from 2.77 → 0.083 over 525 steps and stays converged. No overfitting.
|
|
@@ -62,14 +62,15 @@ that try get credit).
|
|
| 62 |
### Distilled Qwen 3B student — full eval across all 3 conditions
|
| 63 |
|
| 64 |
10 episodes per condition for continuous, 5 episodes per discrete profile
|
| 65 |
-
(15 total).
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
| Condition | Random | Heuristic | **
|
| 69 |
|---|---|---|---|---|---|
|
| 70 |
-
| **continuous in-distribution** | 0.393 | 0.463 | **0.574**
|
| 71 |
-
| **continuous OOD** | 0.393 | 0.455 |
|
| 72 |
-
| discrete-3-profiles (legacy) | 0.426 | 0.455 |
|
| 73 |
|
| 74 |
**Interpretation:**
|
| 75 |
- The student wins on **all three** conditions, with the largest margin
|
|
|
|
| 62 |
### Distilled Qwen 3B student — full eval across all 3 conditions
|
| 63 |
|
| 64 |
10 episodes per condition for continuous, 5 episodes per discrete profile
|
| 65 |
+
(15 total). Sources:
|
| 66 |
+
- SFT v3 numbers: [`eval_results_v2.json`](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-v3/blob/main/eval_results_v2.json)
|
| 67 |
+
- SFT v3 + GRPO refine numbers: [`eval_results_v2.json`](https://huggingface.co/InosLihka/rhythm-env-meta-trained-sft-grpo-v1/blob/main/eval_results_v2.json)
|
| 68 |
|
| 69 |
+
| Condition | Random | Heuristic | **SFT v3** | **+ GRPO refine** | belief_MAE (GRPO) |
|
| 70 |
|---|---|---|---|---|---|
|
| 71 |
+
| **continuous in-distribution** | 0.393 | 0.463 | **0.574** *(+0.111)* | 0.573 | 0.216 |
|
| 72 |
+
| **continuous OOD** | 0.393 | 0.455 | 0.536 *(+0.081)* | **0.559** *(+0.104)* | 0.263 |
|
| 73 |
+
| discrete-3-profiles (legacy) | 0.426 | 0.455 | 0.507 *(+0.052)* | **0.520** *(+0.065)* | 0.430 |
|
| 74 |
|
| 75 |
**Interpretation:**
|
| 76 |
- The student wins on **all three** conditions, with the largest margin
|
|
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Numerical equivalence check for the Rubric refactor.
|
| 3 |
+
|
| 4 |
+
Replays each episode from a saved eval JSON (which was scored by the
|
| 5 |
+
OLD `_grade_episode`) under the LOCAL code (which has the NEW Rubric-
|
| 6 |
+
based grader). If `final_score` matches for every episode within float
|
| 7 |
+
precision, the refactor is functionally identical to the original.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python scripts/verify_rubric_equivalence.py outputs/sft-v3/eval_results_v2.json
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 21 |
+
|
| 22 |
+
from models import ActionType, RhythmAction
|
| 23 |
+
from server.rhythm_environment import MAX_STEPS, RhythmEnvironment
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def replay_episode(seed: int, actions: list[str], final_belief: list[float] | None,
|
| 27 |
+
profile: str | None) -> float:
|
| 28 |
+
"""Replay one episode locally under the new grader."""
|
| 29 |
+
env = RhythmEnvironment()
|
| 30 |
+
if profile:
|
| 31 |
+
obs = env.reset(seed=seed, profile=profile)
|
| 32 |
+
else:
|
| 33 |
+
obs = env.reset(seed=seed)
|
| 34 |
+
|
| 35 |
+
for i, action_name in enumerate(actions):
|
| 36 |
+
if obs.done:
|
| 37 |
+
break
|
| 38 |
+
# Set the final belief on the LAST step (matches inference_eval.py
|
| 39 |
+
# behavior: record_belief was called every step but only the latest
|
| 40 |
+
# matters for the grader)
|
| 41 |
+
is_last = (i == len(actions) - 1) or (i == MAX_STEPS - 1)
|
| 42 |
+
if is_last and final_belief is not None:
|
| 43 |
+
env.record_belief(final_belief)
|
| 44 |
+
rhythm_action = RhythmAction(action_type=ActionType(action_name))
|
| 45 |
+
obs = env.step(rhythm_action)
|
| 46 |
+
|
| 47 |
+
return obs.reward_breakdown.get("final_score", 0.0)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main() -> None:
|
| 51 |
+
ap = argparse.ArgumentParser()
|
| 52 |
+
ap.add_argument("eval_json", help="Path to eval_results JSON to verify against")
|
| 53 |
+
ap.add_argument("--tolerance", type=float, default=1e-4,
|
| 54 |
+
help="Allowed |new - old| difference (default 1e-4)")
|
| 55 |
+
args = ap.parse_args()
|
| 56 |
+
|
| 57 |
+
with open(args.eval_json) as f:
|
| 58 |
+
rows = json.load(f)
|
| 59 |
+
print(f"Loaded {len(rows)} episodes from {args.eval_json}")
|
| 60 |
+
print()
|
| 61 |
+
|
| 62 |
+
matches = 0
|
| 63 |
+
mismatches = []
|
| 64 |
+
for row in rows:
|
| 65 |
+
seed = row["seed"]
|
| 66 |
+
actions = row.get("actions", [])
|
| 67 |
+
final_belief = row.get("final_belief") # null for heuristic/random
|
| 68 |
+
profile = row.get("profile_name") if row.get("profile_mode") == "discrete" else None
|
| 69 |
+
# Some rows have profile_mode='discrete' with explicit profile
|
| 70 |
+
# name (the 3 reference profiles). Pass it via kwarg.
|
| 71 |
+
if profile and not profile.startswith("sampled_"):
|
| 72 |
+
replay_profile = profile
|
| 73 |
+
else:
|
| 74 |
+
replay_profile = None
|
| 75 |
+
|
| 76 |
+
old_score = row["final_score"]
|
| 77 |
+
new_score = replay_episode(seed, actions, final_belief, replay_profile)
|
| 78 |
+
|
| 79 |
+
diff = abs(new_score - old_score)
|
| 80 |
+
if diff <= args.tolerance:
|
| 81 |
+
matches += 1
|
| 82 |
+
else:
|
| 83 |
+
mismatches.append({
|
| 84 |
+
"seed": seed,
|
| 85 |
+
"strategy": row["strategy"],
|
| 86 |
+
"condition": row["condition"],
|
| 87 |
+
"old": old_score,
|
| 88 |
+
"new": new_score,
|
| 89 |
+
"diff": diff,
|
| 90 |
+
})
|
| 91 |
+
|
| 92 |
+
print(f"=" * 60)
|
| 93 |
+
print(f"RESULT: {matches}/{len(rows)} episodes match within ±{args.tolerance}")
|
| 94 |
+
print(f"=" * 60)
|
| 95 |
+
if mismatches:
|
| 96 |
+
print()
|
| 97 |
+
print(f"MISMATCHES ({len(mismatches)}):")
|
| 98 |
+
for m in mismatches[:10]:
|
| 99 |
+
print(f" seed={m['seed']:>5} {m['strategy']:>10} "
|
| 100 |
+
f"{m['condition']:<35} old={m['old']:.4f} "
|
| 101 |
+
f"new={m['new']:.4f} diff={m['diff']:.4f}")
|
| 102 |
+
if len(mismatches) > 10:
|
| 103 |
+
print(f" ... and {len(mismatches) - 10} more")
|
| 104 |
+
sys.exit(1)
|
| 105 |
+
else:
|
| 106 |
+
print()
|
| 107 |
+
print("REFACTOR IS NUMERICALLY EQUIVALENT to the old grader.")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
main()
|