Pratyush-01 commited on
Commit
7886d89
·
verified ·
1 Parent(s): d190a3b

Add blog.md: accurate systems table, SFT+GRPO plots

Browse files
Files changed (1) hide show
  1. docs/blog.md +139 -0
docs/blog.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PhysiX — Equation Discovery from Noisy Trajectories via RLVR
2
+
3
+ **OpenEnv India Hackathon 2026 submission** · [Live Space](https://huggingface.co/spaces/Pratyush-01/physix-live) · [Trained Model](https://huggingface.co/Pratyush-01/physix-3b-rl) · [W&B Runs](https://wandb.ai/pratyush01/physix-live)
4
+
5
+ ---
6
+
7
+ ## The Problem
8
+
9
+ Given a short, noisy trajectory of a physical system — positions and velocities over time — can a language model discover the underlying equation of motion?
10
+
11
+ This is symbolic regression meets agentic RL. The task is hard because:
12
+ - The equation space is combinatorially large
13
+ - Noise means no trajectory perfectly matches any ODE
14
+ - The agent has to learn to iterate — propose, simulate, compare residuals, refine
15
+
16
+ Classical symbolic regression tools (GP, sparse regression) can do this, but they aren't language models and they can't use natural language hints. We train a 3B LLM to do it iteratively using RLVR.
17
+
18
+ ---
19
+
20
+ ## The Environment
21
+
22
+ **PhysiXEnvironment** presents the agent with a noisy trajectory from a physical system and asks it to output an ODE that reproduces the motion. All reward comes from `scipy.odeint` — no LLM-as-judge.
23
+
24
+ ### Systems trained on (GRPO)
25
+
26
+ Three systems were used for GRPO training:
27
+
28
+ | System | Ground-truth ODE | Difficulty |
29
+ |--------|-----------------|------------|
30
+ | Free Fall | `d2y/dt2 = -g` | 1 param |
31
+ | Simple Pendulum | `d2theta/dt2 = -(g/L)*sin(theta)` | transcendental |
32
+ | Damped Spring | `d2x/dt2 = -(k/m)*x - (b/m)*dx` | damped oscillation, 3 params |
33
+
34
+ Parameters and initial conditions are randomised per episode.
35
+
36
+ The SFT warm-start used all 6 Tier 1–2 systems (above three plus Free Fall with Drag, Damped Pendulum, Spring-Mass) to give the model broad format coverage before RL focused it on the three above.
37
+
38
+ ### Episode flow
39
+
40
+ Each episode:
41
+ 1. `reset()` → agent receives a noisy trajectory + a one-sentence hint
42
+ 2. Agent proposes an ODE in JSON: `{"equation": "...", "params": {...}}`
43
+ 3. Environment simulates the hypothesis via `scipy.odeint` and computes R²
44
+ 4. Agent receives residual summary in English + numeric reward breakdown
45
+ 5. Repeat up to 8 turns, or until R² ≥ 0.93
46
+
47
+ ---
48
+
49
+ ## Reward Design
50
+
51
+ All reward is computed from `scipy.odeint` — no model-in-the-loop scoring. Five components:
52
+
53
+ **R²** (coefficient of determination): R² = 1 is a perfect match, R² = 0 means no better than predicting the mean, R² < 0 is actively wrong.
54
+
55
+ | Component | Formula | What it rewards | Why it's needed |
56
+ |-----------|---------|-----------------|-----------------|
57
+ | `match` | R² | Continuous fit quality | Primary learning signal |
58
+ | `match_dense` | √R² | Same, stretched | R² ≈ 0 early on; √R² gives non-zero gradient (√0.05 ≈ 0.22) so GRPO isn't blind in early steps |
59
+ | `correctness` | 1 if R² ≥ 0.70 else 0 | Binary "good enough" | Creates a cliff the policy climbs; helps escape plateaus where the dense signal flattens |
60
+ | `simplicity` | 1 − operators/12, gated on R² ≥ 0.10 | Shorter equations | Without the gate, `d2y/dt2 = 0` scores simplicity = 1 for free despite being completely wrong |
61
+ | `format` | 1 if parses **and** `odeint` succeeds without NaN | Valid, simulatable output | Without the NaN check, explosive equations like `exp(vy**10)` claim format reward |
62
+
63
+ ---
64
+
65
+ ## Training: SFT → GRPO
66
+
67
+ ### Why SFT first
68
+
69
+ Qwen2.5-3B is a small, cold base model. Out of the box ~80% of its completions are LaTeX, prose, or malformed JSON that the verifier can't parse. GRPO needs *variance in reward* across rollouts to estimate advantages; if every rollout scores ~0 because nothing parses, the gradient is zero and nothing learns.
70
+
71
+ SFT on 384 synthetic `(prompt, ground_truth)` pairs from the environment (64 per system × 6 systems) teaches the model the output format before RL begins. 4 epochs, ~5 min on L40S.
72
+
73
+ After SFT: >90% of completions parse and simulate successfully — GRPO now has a signal to work with.
74
+
75
+ ### SFT results
76
+
77
+ | SFT Loss (↓) |
78
+ |:---:|
79
+ | ![SFT loss](plots/sft_loss.png) |
80
+
81
+ Loss drops from 2.66 → 1.68 over 4 epochs. The smooth monotonic descent confirms the format is being learned reliably with no overfitting signs.
82
+
83
+ ### GRPO
84
+
85
+ - **Model:** Qwen/Qwen2.5-3B-Instruct + LoRA-32
86
+ - **Systems:** free_fall, simple_pendulum, damped_spring
87
+ - **Steps:** 200 (with early stopping on reward convergence)
88
+ - **LR:** 1e-5
89
+ - **Generations:** 4 per prompt
90
+ - **Framework:** Unsloth + TRL GRPOTrainer
91
+
92
+ LR selection:
93
+ - `3e-6` → near-zero gradient for 67+ steps
94
+ - `3e-5` → hit ceiling too fast (~250 steps)
95
+ - `1e-5` → smooth, steadily rising curve ✓
96
+
97
+ ---
98
+
99
+ ## Results
100
+
101
+ | SFT Loss (↓) | GRPO Loss (↓) |
102
+ |:---:|:---:|
103
+ | ![SFT loss](plots/sft_loss.png) | ![GRPO loss](plots/loss.png) |
104
+
105
+ | GRPO Total Reward (↑) |
106
+ |:---:|
107
+ | ![reward](plots/reward.png) |
108
+
109
+ | Per-component reward breakdown |
110
+ |:---:|
111
+ | ![reward components](plots/reward_components.png) |
112
+
113
+ Key observations:
114
+ - `reward_format` jumps to ~0.9 in the first 10 steps — the SFT warm-start establishes output format immediately
115
+ - `reward_match_dense` (√R²) and `reward_correctness` climb from ~0.6 → ~0.95 over 200 steps
116
+ - `reward_match` (raw R²) converges to ~0.95+ by step 150
117
+ - Total mean reward rises from ~3.3 → ~4.8 (+45%) with ±1σ variance shrinking
118
+ - GRPO loss near zero is **expected** — it's the KL regularisation term; the real signal is the reward curves
119
+
120
+ Full runs: [wandb.ai/pratyush01/physix-live](https://wandb.ai/pratyush01/physix-live)
121
+
122
+ ---
123
+
124
+ ## What's Novel
125
+
126
+ 1. **Verifiable reward without a judge** — R² from `scipy.odeint` is ground truth, not a proxy
127
+ 2. **Iterative refinement loop** — the environment returns residual summaries in English so the agent can reason about what went wrong
128
+ 3. **Reward hacking case study** — three exploits found and patched during development (parse-but-crash, trivial-equation simplicity, progress signal duplication)
129
+ 4. **SFT → GRPO pipeline** — shows how a cold 3B model can be made RL-trainable in under 10 minutes of SFT
130
+
131
+ ---
132
+
133
+ ## Links
134
+
135
+ - **Live demo:** https://huggingface.co/spaces/Pratyush-01/physix-live
136
+ - **Trained model:** https://huggingface.co/Pratyush-01/physix-3b-rl
137
+ - **Checkpoint repo:** https://huggingface.co/Pratyush-01/physix-3b-rl-ckpt
138
+ - **Training notebook:** https://huggingface.co/spaces/Pratyush-01/physix-live/blob/main/train/physix_train_colab.ipynb
139
+ - **W&B project:** https://wandb.ai/pratyush01/physix-live