File size: 13,289 Bytes
005eb1e
08f8699
 
 
 
005eb1e
08f8699
005eb1e
08f8699
 
 
 
 
 
 
 
005eb1e
 
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
ecdfaaf
08f8699
d190a3b
ecdfaaf
08f8699
 
 
 
 
d6cb922
08f8699
d6cb922
08f8699
d6cb922
 
 
 
 
08f8699
d6cb922
08f8699
 
 
 
d6cb922
 
 
 
 
 
 
 
 
 
 
 
 
08f8699
 
 
 
 
 
d6cb922
 
 
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f50272
 
08f8699
 
 
 
 
 
 
 
 
 
d6cb922
08f8699
d6cb922
 
 
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6cb922
08f8699
d6cb922
 
 
 
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
---
title: PhysiX
emoji: ⚛️
colorFrom: blue
colorTo: purple
sdk: docker
app_port: 7860
pinned: false
license: mit
short_description: Equation discovery from noisy trajectories (RLVR)
tags:
  - openenv
  - rlvr
  - physics
  - equation-discovery
  - ode
---

# PhysiX — Equation Discovery via RLVR

An [OpenEnv](https://github.com/openenv-hackathon/openenv) hackathon submission (Apr 2026).

Given a noisy trajectory and a one-sentence hint, a language model iteratively proposes and refines an ODE that reproduces the observed motion. Reward comes entirely from `scipy.integrate.odeint` + per-step R² — no LLM-as-judge.

---

## Links

| | |
|---|---|
| **Live demo (HF Space)** | https://huggingface.co/spaces/Pratyush-01/physix-live |
| **Trained model** | https://huggingface.co/Pratyush-01/physix-3b-rl |
| **Colab training notebook** | https://huggingface.co/spaces/Pratyush-01/physix-live/blob/main/train/physix_train_colab.ipynb |
| **W&B training runs** | https://wandb.ai/pratyush01/physix-live |
| **Blog post / writeup** | https://huggingface.co/spaces/Pratyush-01/physix-live/blob/main/docs/blog.md |
| **Checkpoint repo** | https://huggingface.co/Pratyush-01/physix-3b-rl-ckpt |

---

## Environment

### Physical Systems

Three physical systems, each with randomised parameters and initial conditions per episode. These are the systems the published model is trained and evaluated on.

| System | Ground-truth equation | Notes |
|--------|-----------------------|-------|
| Free Fall | `d2y/dt2 = -g` | 1 parameter |
| Simple Pendulum | `d2theta/dt2 = -(g/L)*sin(theta)` | transcendental |
| Damped Spring | `d2x/dt2 = -(k/m)*x - (b/m)*dx` | damped oscillation |

The verifier and DSL are designed to extend cleanly to richer dynamics (coupled oscillators, Lorenz, reaction–diffusion, N-body). Adding a system is a matter of writing a `PhysicalSystem` subclass and registering it — no changes to the reward, parser, or training loop. See [docs/blog.md → Future Scope](docs/blog.md) for the extensibility plan.

### Example Task

```
HINT: Mass on a spring, displaced 0.40 m and released.
      Visible amplitude decay over a few seconds.

TRAJECTORY (t, x, dx):
  t=0.00  x= 0.40  dx= 0.00
  t=0.50  x= 0.27  dx=-0.69
  t=1.00  x=-0.06  dx=-0.84
  t=1.50  x=-0.30  dx=-0.32
  t=2.00  x=-0.30  dx= 0.34
  t=3.00  x= 0.10  dx= 0.41
  t=5.00  x=-0.04  dx=-0.10   ← amplitude visibly shrinking

STATS: x_range=[-0.31, 0.40]  decay_envelope≈e^(-0.18 t)
```

Target output:

```json
{
  "equation": "d2x/dt2 = -(k/m) * x - (b/m) * dx",
  "params": {"k": 4.0, "m": 1.0, "b": 0.36},
  "rationale": "Linear restoring force plus velocity-proportional damping; envelope decay matches b/(2m)."
}
```

**Grammar:** operators `+ - * / **`, functions `sin cos tan exp log sqrt abs`, declared state variables and parameter names. Anything outside this scores `format = 0`.

### Episode Flow

```mermaid
sequenceDiagram
    participant Agent
    participant Env as PhysiXEnvironment
    participant Sim as scipy.odeint
    participant Verifier

    Env->>Agent: reset() → trajectory + hint
    loop up to 8 turns
        Agent->>Env: step(equation + params + rationale)
        Env->>Sim: simulate hypothesis from t=0
        Sim-->>Verifier: predicted trajectory
        Verifier-->>Env: r_match + r_progress + r_simplicity + r_format
        Env->>Agent: mismatch summary + reward breakdown + history
        alt r_match > 0.93 or budget exhausted
            Env-->>Agent: done=True
        end
    end
```

After each step the agent receives an English mismatch summary (e.g. *"predicted vy diverges after t=2 s; residual consistently negative"*) alongside the numeric reward breakdown, so it has something to act on in the next turn.

---

## Reward

All reward is computed from `scipy.odeint` output — no model-in-the-loop scoring.

### Step-wise (live env + GRPO)

| Component | Weight | Formula | Purpose |
|-----------|:------:|---------|---------|
| `match` | 0.50 | R² (observed vs. predicted) | primary accuracy signal |
| `progress` | 0.20 | `max(0, r_match − r_match_prev)` | per-turn improvement shaping |
| `simplicity` | 0.20 | `1 − (operator_count / 12)` | prefer shorter equations |
| `format` | 0.10 | 1 if parsed **and** simulated successfully | syntactic + numerical validity |

### GRPO-only additions

Two extra signals are added during training but not used in the live env:

- **`match_dense = sqrt(R²)`** — gives a non-trivial gradient when raw R² is near zero (e.g. `sqrt(0.05) ≈ 0.22`).
- **`correctness` = 1 if R² ≥ 0.70, else 0** — a binary bonus that helps push past R² plateaus where the dense signal flattens.

### Reward-hacking mitigations

Three failure modes found during development and how they were closed:

**1. Parse-but-crash exploit.**
A valid-but-explosive equation (e.g. `d2y/dt2 = exp(vy**10)`) parses but makes `odeint` produce NaN. Without a fix, it earns `format = 1`.  
→ `format = 1` only if integration completes without NaN/inf.

**2. Trivial-equation exploit.**
`d2y/dt2 = 0` has zero operators, so `simplicity = 1`, earning 20% reward for a completely wrong trajectory.  
→ `simplicity = 0` unless `r_match ≥ 0.10`.

**3. Progress signal in single-turn GRPO.**
Every GRPO training row starts with `previous_r_match = 0`, so `progress = r_match` — a redundant copy of the match signal that dilutes advantage estimates.  
→ `progress` is excluded from the GRPO reward function set; it is only used in multi-turn live episodes.

---

## Training: SFT → GRPO

### Why SFT first

GRPO relies on reward variance across rollouts to estimate advantages. With a cold base model, ~80% of completions are unparseable (LaTeX, prose, malformed JSON) and most parseable ones crash the integrator, leaving near-zero variance and no useful gradient. The model needs to produce the right output format before RL can do anything meaningful with the physics signal.

SFT runs for 3 epochs on synthetic `(prompt, ground_truth_equation)` pairs generated from the environment. After SFT:
- >90% of completions parse and simulate successfully (up from ~20%).
- Equations are in the ASCII ODE grammar the verifier expects.
- The model has seen the right equation family for each system at least once.

SFT only establishes format. Parameter values are still wrong — that is what GRPO refines.

### Step 1 — SFT warm-start

```bash
python -m physix.training.sft \
  --model Qwen/Qwen2.5-3B-Instruct \
  --output-dir runs/physix-3b-sft \
  --epochs 3 \
  --lora-r 32 \
  --instances-per-system 32 \
  --system-ids damped_spring
```

Runtime: ~5 min on L40S.

### Step 2 — GRPO

```bash
python -m physix.training.loop \
  --model Qwen/Qwen2.5-3B-Instruct \
  --output-dir runs/physix-3b-rl \
  --num-steps 200 \
  --num-generations 4 \
  --lora-r 32 \
  --sft-checkpoint runs/physix-3b-sft/merged \
  --system-ids damped_spring \
  --push-to-hub \
  --hub-repo-id Pratyush-01/physix-3b-rl
```

Runtime: ~45 min on L40S.

### Full cloud job

```bash
hf jobs uv run train/job_train_single.py \
    --image unsloth/unsloth:2026.3.8-pt2.9.0-vllm-0.16.0-cu12.8-studio-release \
    --flavor l40sx1 \
    --secrets HF_TOKEN \
    --secrets WANDB_API_KEY \
    -v hf://datasets/Pratyush-01/physix-live-src:/physix-live \
    --timeout 2h
```

---

## Training Results

| GRPO Loss (↓) | Total Reward (↑) |
|:---:|:---:|
| ![loss](docs/plots/loss.png) | ![reward](docs/plots/reward.png) |

W&B runs: [pratyush01/physix-live](https://wandb.ai/pratyush01/physix-live)

Key observations from the run:
- Total mean reward rises from ~3.3 to ~4.8 (+45%) over 200 steps with ±1σ variance shrinking — the policy is both improving and becoming more consistent.
- The SFT warm-start gets format compliance high from step 1, so GRPO spends its budget improving R² rather than relearning JSON syntax.

---

## Repository Layout

```
physix-live/
├── physix/
│   ├── models.py                 # Pydantic Action / Observation / State
│   ├── client.py                 # OpenEnv WebSocket client
│   ├── systems/                  # physical systems (3 trained, exposed via SUPPORTED_SYSTEMS)
│   │   ├── base.py               # PhysicalSystem ABC
│   │   ├── tier1.py              # FreeFall, SimplePendulum (+ extras for future work)
│   │   ├── tier2.py              # DampedSpring (+ extras for future work)
│   │   ├── tier3.py              # placeholders for future extensions, not exposed
│   │   └── registry.py
│   ├── verifier/
│   │   ├── parser.py             # SymPy whitelisted grammar
│   │   ├── simulator.py          # scipy.odeint forward simulation
│   │   ├── metrics.py            # per-step R²
│   │   ├── mismatch.py           # English residual summary
│   │   └── reward.py             # reward composition + hacking mitigations
│   ├── server/
│   │   ├── environment.py        # PhysiXEnvironment (OpenEnv subclass)
│   │   ├── interactive.py        # session-based REST router for the UI
│   │   └── app.py
│   └── training/
│       ├── prompt.py             # observation → prompt
│       ├── scorer.py             # cached single-completion scorer
│       ├── reward_fns.py         # TRL-compatible reward callables
│       ├── dataset.py            # GRPO dataset builder
│       ├── sft.py                # SFT warm-start
│       └── loop.py               # Unsloth + TRL GRPO loop
├── frontend/                     # React + TS + Tailwind demo UI
├── train/                        # HF Jobs launcher + Colab notebook
│   ├── submit.py                 # submit job via HfApi.run_uv_job
│   ├── job_train.py              # multi-system driver (in-container)
│   ├── job_train_single.py       # single-system driver (in-container)
│   ├── physix_train_colab.ipynb  # SFT → GRPO end-to-end notebook
│   └── sync-plots.sh             # mirror plots from model repo
├── tests/                        # ~30 tests
├── docs/
│   ├── plots/                    # committed loss / reward / per-component PNGs
│   └── writeup.md
├── Dockerfile                    # env Space build (FastAPI + built React UI)
├── openenv.yaml                  # OpenEnv manifest (name, runtime, app entrypoint)
└── pyproject.toml
```

---

## Quick Start

One command from the repo root:

```bash
make dev
```

This starts the FastAPI backend on `:8000` (deps auto-resolved by `uv`) and the Vite frontend on `:5173`. Open [http://localhost:5173](http://localhost:5173).

### Connecting an LLM

The demo speaks to **any OpenAI-compatible `/v1/chat/completions` endpoint** — local Ollama, Hugging Face Inference Providers, OpenAI, vLLM, OpenRouter, etc. The "Connect an LLM" panel exposes:

| Field | Purpose |
|-------|---------|
| **Endpoint** | Preset dropdown. Picks `base_url` + a default model id. |
| **Model** | Provider-native id (HF repo, Ollama tag, OpenAI name). Free-form. |
| **Custom base URL** | Shown when `Custom` is selected. Anything ending in `/v1`. |
| **API key** | Bearer token. Persisted per `base_url` in `localStorage`, never sent unless an episode runs. |

Server-side env-var fallback (lets a deployed Space ship a sensible default without leaking secrets in the bundle):

| URL family | Env var |
|---|---|
| `*huggingface*` | `HF_TOKEN`, then `HUGGINGFACE_API_KEY` |
| `*openai.com*` | `OPENAI_API_KEY` |
| `*openrouter*` | `OPENROUTER_API_KEY` |
| `localhost` / `127.0.0.1` | none (Ollama needs no key) |

### Side-by-side comparison

The default page is a **two-column comparison**: same trajectory, same hint, same seed, same verifier — two different models. The presets are wired to make the headline story self-evident:

- **A** = `Pratyush-01/physix-3b-rl` via HF Inference Providers (the GRPO-trained model)
- **B** = `Qwen/Qwen2.5-3B-Instruct` via HF Inference Providers (untrained baseline)

Drop in `gpt-4o-mini` on either side as a frontier reference, or swap to local Ollama for offline dev. The reward delta between the two columns is exactly what GRPO bought — no benchmark-prose necessary.

> **For the trained model on HF Inference Providers**: weights are public, but the repo card needs `inference: true` and a serving provider (Featherless/Together/etc.) to have it loaded. If a visitor sees a 404 from the trained side, they can either bring up `ollama serve` locally and pull a quantised version, or fall back to `Qwen/Qwen2.5-3B-Instruct` on both sides.

### Programmatic use

```python
import asyncio
from physix import PhysiXEnv, PhysiXAction

async def main():
    async with PhysiXEnv(base_url="http://127.0.0.1:8000") as env:
        obs = await env.reset(system_id="damped_spring", seed=42)
        result = await env.step(
            PhysiXAction(
                equation="d2x/dt2 = -(k/m) * x - (b/m) * dx",
                params={"k": 4.0, "m": 1.0, "b": 0.36},
            )
        )
        print(result.observation.reward_breakdown)

asyncio.run(main())
```

```bash
pytest tests/
```

---

## License

MIT.