Spaces:
Running
Algorithm Distillation: grader v2 with belief_accuracy + SFT pipeline
Browse filesGrader changes:
- Add belief_accuracy term (weight 0.20) to _grade_episode
- env.record_belief() for callers to register the agent's emitted belief
- Redistribute weights: crash 0.20β0.15, progress 0.25β0.20, conn 0.15β0.10, adapt 0.30β0.25
- Heuristic + random baselines still emit no belief, score 0 on this component
- Validates: gpt-5.4 teacher beats heuristic 0.617 vs 0.449 on 30/30 episodes
Distillation pipeline:
- scripts/generate_teacher_trajectories.py: Azure OpenAI teacher rollouts
- scripts/reeval_teacher_trajectories.py: offline re-score under v2 grader
- scripts/upload_teacher_data.py: push trajectories to HF Hub dataset
- training/sft_prime.py: TRL SFTTrainer fine-tune of Qwen 2.5-3B on teacher data
- scripts/sft_on_hf.py: HF Jobs orchestrator for SFT prime stage
Parser + prompt:
- extract_action_and_belief now handles both single-line and CoT-prefixed formats
(takes LAST <digit> <digit> <digit> <ACTION> match)
- SYSTEM_PROMPT in dataset.py unified with teacher's CoT-asking prompt
- inference_eval.py records belief during model rollouts
Cleanup (separate task, bundled here):
- Removed env_reward_simple, --use_simple_reward flag, --profile_mode flag
- Removed 'discrete' profile_mode random branch (was unreachable)
- Deduped heuristic_action (now lives in dataset.py only)
- Promoted pipeline_dryrun.py β tests/test_pipeline_smoke.py
- Deleted: docs/logdump.txt, scripts/analyze_logdump.py, diagnostic_replay.py,
legacy blog_post.md, eval_results_v1.json, references/ subdir
- 50 tests pass (45 existing + 5 new belief grader tests)
- .env.example +11 -0
- .gitignore +7 -0
- blog_post.md +0 -90
- docs/architecture.md +26 -21
- docs/entity_definitions.md +23 -4
- docs/iterations.md +113 -5
- docs/logdump.txt +0 -0
- docs/references/React Orchestrator Linkedin/V1_ReACT_based_Orchestrator.ipynb +0 -0
- docs/references/React Orchestrator Linkedin/V2_ReACT_Based_Orchestrator.ipynb +0 -0
- docs/references/React Orchestrator Linkedin/o3_mini_V1_ReACT_Based_Orchestrator.ipynb +0 -0
- docs/references/React Orchestrator Linkedin/o3_mini_v2_improved_react_orchestrator.ipynb +0 -0
- eval_baselines_v2.json +12 -0
- eval_results_v1.json +0 -758
- models.py +2 -1
- scripts/analyze_logdump.py +0 -124
- scripts/diagnostic_replay.py +0 -90
- scripts/generate_teacher_trajectories.py +506 -0
- scripts/pipeline_dryrun.py +0 -121
- scripts/reeval_teacher_trajectories.py +154 -0
- scripts/sft_on_hf.py +168 -0
- scripts/upload_teacher_data.py +105 -0
- server/rhythm_environment.py +91 -57
- tests/test_pipeline_smoke.py +235 -0
- tests/test_rhythm_env.py +98 -6
- training/dataset.py +66 -49
- training/inference_eval.py +15 -35
- training/reward_functions.py +86 -196
- training/sft_prime.py +230 -0
- training/train.py +10 -22
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Azure OpenAI credentials for teacher trajectory generation.
|
| 2 |
+
# Copy this file to .env and fill in your real values:
|
| 3 |
+
# cp .env.example .env
|
| 4 |
+
# then edit .env to put in your actual key
|
| 5 |
+
#
|
| 6 |
+
# .env is in .gitignore so it never gets committed.
|
| 7 |
+
|
| 8 |
+
AZURE_OPENAI_ENDPOINT=https://metahackathon-resource.cognitiveservices.azure.com/
|
| 9 |
+
AZURE_OPENAI_API_KEY=PASTE_YOUR_KEY_HERE
|
| 10 |
+
AZURE_OPENAI_DEPLOYMENT=gpt-5.4
|
| 11 |
+
AZURE_OPENAI_API_VERSION=2024-12-01-preview
|
|
@@ -7,3 +7,10 @@ __pycache__/
|
|
| 7 |
*.egg-info/
|
| 8 |
dist/
|
| 9 |
build/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
*.egg-info/
|
| 8 |
dist/
|
| 9 |
build/
|
| 10 |
+
|
| 11 |
+
# Local-only artifacts (not committed; uploaded to HF Hub when needed)
|
| 12 |
+
data/
|
| 13 |
+
iter1_results/
|
| 14 |
+
iter2_results/
|
| 15 |
+
iter5_results/
|
| 16 |
+
outputs/
|
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
# Teaching an AI to Know You (Without Asking)
|
| 2 |
-
|
| 3 |
-
Ask someone how they'd build a personal AI assistant, and they'll say: give it a personality quiz. A preferences form. Maybe a settings page where you pick "introvert" or "morning person" from a dropdown.
|
| 4 |
-
|
| 5 |
-
Sounds reasonable. It's the wrong approach entirely.
|
| 6 |
-
|
| 7 |
-
Think about the people who actually know you well β a close friend, a partner, a sibling. None of them sat you down with a questionnaire. They figured you out by *watching*. They noticed that you get irritable after too many social events. That you do your best thinking before noon. That skipping exercise makes you anxious by Wednesday.
|
| 8 |
-
|
| 9 |
-
They learned your hidden patterns through trial, error, and feedback. RhythmEnv is an experiment in doing that with an RL agent.
|
| 10 |
-
|
| 11 |
-
## Why personality can't be captured in a settings page
|
| 12 |
-
|
| 13 |
-
I work on AI at Microsoft. One thing I kept running into building assistant features was the gap between what users *say* they want and what actually helps them. People are bad at introspecting their own patterns. The introvert who says "I don't mind meetings" because they've normalized the drain. The workaholic who checks "I value work-life balance" because they know they should.
|
| 14 |
-
|
| 15 |
-
Preference forms capture what people believe about themselves. Behavior reveals what's actually true.
|
| 16 |
-
|
| 17 |
-
So the right question isn't "can we ask better questions?" It's "can we learn without asking at all?"
|
| 18 |
-
|
| 19 |
-
## What "knowing a person" actually means
|
| 20 |
-
|
| 21 |
-
Here's how I decomposed it. Every person has a set of hidden traits β atomic behavioral properties that describe how they *respond* to activities, not just what they like:
|
| 22 |
-
|
| 23 |
-
How much does socializing physically drain you? When does your brain work best β morning or evening? Does leisure make you feel guilty, or does it recharge you? Does progress at work give you inner peace, or just tire you out?
|
| 24 |
-
|
| 25 |
-
No single trait defines a person. It's the combination. An introvert who peaks in the morning has high social drain, early cognitive peak, and solo time as their recharge mechanism. An extrovert night owl has the opposite: socializing barely costs vitality, evening is when they come alive, and being alone doesn't restore them. Same list of traits. Completely different values. Completely different person.
|
| 26 |
-
|
| 27 |
-
But traits are only half of it.
|
| 28 |
-
|
| 29 |
-
## The thing I couldn't solve with traits alone
|
| 30 |
-
|
| 31 |
-
Two people can do the exact same activities and have completely different days. Not because the activities are different β because they *define a good day differently*.
|
| 32 |
-
|
| 33 |
-
This is the second layer: hidden reward weights. A definition of what a good week means to each person.
|
| 34 |
-
|
| 35 |
-
The introvert values serenity above everything else (60% of their score). A week where they maintained inner peace and made some progress is a great week. Connection barely registers. The extrovert values connection above all (75%). A week full of meaningful social interactions is a great week, even if they didn't make much career progress. The workaholic values progress above all (70%). Deep productive work is the whole point. Everything else is secondary.
|
| 36 |
-
|
| 37 |
-
The agent sees the same five meters. Takes the same ten actions. Gets a scalar reward. But the reward is secretly computed using these hidden weights. Same action, same meter changes, completely different reward signal depending on who you're helping.
|
| 38 |
-
|
| 39 |
-
## The environment itself
|
| 40 |
-
|
| 41 |
-
RhythmEnv simulates one week in a person's life β seven days, four time slots each, 28 decisions. Each decision is an activity: deep work, exercise, sleep, meditation, family time, socializing. Ten options total.
|
| 42 |
-
|
| 43 |
-
Five meters track the person's state. Picture them like fuel gauges on a dashboard. Vitality is physical energy β sleep fills it, sustained work drains it. Cognition is mental sharpness, highest in the morning for some people, evening for others. Progress is career momentum, the only meter that only goes up through work. Serenity is inner calm β meditation and rest help, overwork kills it. And Connection, the most interesting one: it decays passively every single time slot. If you don't actively socialise, it drops on its own. The agent can't ignore it and come back to it later.
|
| 44 |
-
|
| 45 |
-
The hidden profile changes what these meters *mean*. Tell the introvert to socialise: their vitality drops three times faster than the base rate. Their body physically rejects it. Tell the extrovert the same: barely any drain. They could socialise all day.
|
| 46 |
-
|
| 47 |
-
Tell the introvert to meditate: they get a +0.10 serenity bonus on top of the base effect. Alone time is their recharge. Tell the workaholic the same thing: their serenity *drops* by 0.10, because idle activities make them anxious.
|
| 48 |
-
|
| 49 |
-
Tell the workaholic to do deep work: they recover +0.06 vitality β productive work energises them. Tell the introvert to do deep work in the morning: their progress and cognition gains are doubled. Same action, completely different physiological response.
|
| 50 |
-
|
| 51 |
-
## What the agent must figure out, without being told
|
| 52 |
-
|
| 53 |
-
The agent sees meters, time of day, and reward. It doesn't see which person type it's helping, or the trait values, or how the reward is being computed.
|
| 54 |
-
|
| 55 |
-
After a few actions, the patterns start showing. "I socialised and my vitality crashed β this person drains from socialising." "I meditated and got a huge reward β serenity must be heavily weighted for them." "Deep work in the morning gave double progress β this person peaks early."
|
| 56 |
-
|
| 57 |
-
A good agent should probe in the first few steps, infer the person type from the unexpected meter changes, then adapt its strategy for the rest of the week. An agent that discovers it's helping an introvert should meditate more and socialise less. One that discovers it's helping a workaholic should maximise productive hours and cut idle time.
|
| 58 |
-
|
| 59 |
-
## The training signal
|
| 60 |
-
|
| 61 |
-
Here's what makes this tractable for RL. At the same starting state β Monday morning, all meters at 0.7 β the best action is completely different per profile:
|
| 62 |
-
|
| 63 |
-
| Profile | Best action | Reward | Worst action | Reward |
|
| 64 |
-
|---|---|---|---|---|
|
| 65 |
-
| Introvert | MEDITATE | +1.76 | SOCIALIZE | +0.03 |
|
| 66 |
-
| Extrovert | FAMILY_TIME | +2.63 | ME_TIME | β0.42 |
|
| 67 |
-
| Workaholic | DEEP_WORK | +1.57 | ME_TIME | β0.27 |
|
| 68 |
-
|
| 69 |
-
GRPO β Group Relative Policy Optimization β generates multiple candidate actions for each state, scores them all against the real environment, then updates the model to prefer the higher-scoring ones. Think of it as the model getting to observe "if I had done X instead of Y here, the outcome would have been this" β and slowly building intuition for which choices work for which person.
|
| 70 |
-
|
| 71 |
-
The model is Qwen 2.5-3B with 4-bit quantization and LoRA. Small enough to train on a free Colab T4.
|
| 72 |
-
|
| 73 |
-
## What I'm watching for
|
| 74 |
-
|
| 75 |
-
The rule-based heuristic baseline β fixed logic, no profile adaptation, treats everyone the same β scores around 0.76β0.82 depending on the profile. It works *despite* the hidden dynamics, not because it understands them. Sleep when vitality is low. Meditate when serenity is low. Socialise when connection drops. Reasonable advice for anyone.
|
| 76 |
-
|
| 77 |
-
The goal for the trained agent isn't just higher scores. It's qualitatively different action sequences per person. The introvert's week should look nothing like the extrovert's week. The workaholic's Monday should look nothing like the introvert's Monday. If the agent is just scoring higher by exploiting a pattern that works across all profiles, that's not discovery β that's luck.
|
| 78 |
-
|
| 79 |
-
No questionnaire. No settings page. Just attention, inference, and adjustment.
|
| 80 |
-
|
| 81 |
-
That's what I think personal AI should actually feel like.
|
| 82 |
-
|
| 83 |
-
---
|
| 84 |
-
|
| 85 |
-
**Links:**
|
| 86 |
-
- [Live Environment (HF Space)](https://huggingface.co/spaces/InosLihka/rhythm_env)
|
| 87 |
-
- [Training Notebook (Colab)](training/RhythmEnv_GRPO_Training.ipynb)
|
| 88 |
-
- [Source Code](https://huggingface.co/spaces/InosLihka/rhythm_env)
|
| 89 |
-
|
| 90 |
-
*Built for the Meta PyTorch OpenEnv Hackathon Grand Finale, Bangalore, April 2026.*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -10,7 +10,7 @@ profile, real numbers from the reward calculation).
|
|
| 10 |
|
| 11 |
```
|
| 12 |
ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
-
β AGENT (Qwen 2.5-3B + LoRA r=
|
| 14 |
β β
|
| 15 |
β Input: prompt (state + history) β
|
| 16 |
β Output: "3 7 5 DEEP_WORK" β
|
|
@@ -400,25 +400,30 @@ DATASET (~3000 rows, generated ONCE before training)
|
|
| 400 |
β final_score β
|
| 401 |
β β [0, 1] β
|
| 402 |
ββββββββ¬ββββββββ
|
| 403 |
-
βββββββββββββββββββ
|
| 404 |
-
β
|
| 405 |
-
βΌ
|
| 406 |
-
ββββββββββββββ
|
| 407 |
-
β crash_free β
|
| 408 |
-
β Γ 0.
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
β
|
| 412 |
-
β
|
| 413 |
-
β
|
| 414 |
-
β
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
This means: at step 27, agent gets last per-step reward + bonus from grader.
|
| 424 |
This is the only direct gradient signal pointing at the actual episode quality.
|
|
@@ -433,7 +438,7 @@ This is the only direct gradient signal pointing at the actual episode quality.
|
|
| 433 |
|
| 434 |
ββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββ
|
| 435 |
β discrete-3-profiles β continuous-in-distributionβ continuous-OOD β
|
| 436 |
-
β (
|
| 437 |
β β learn the meta-policy?) β generalize?) β
|
| 438 |
ββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββ€
|
| 439 |
β env.reset(seed=N, β env.reset(seed=N) β env.reset(seed=10000+N) β
|
|
|
|
| 10 |
|
| 11 |
```
|
| 12 |
ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
+
β AGENT (Qwen 2.5-3B + LoRA r=8, 4-bit) β
|
| 14 |
β β
|
| 15 |
β Input: prompt (state + history) β
|
| 16 |
β Output: "3 7 5 DEEP_WORK" β
|
|
|
|
| 400 |
β final_score β
|
| 401 |
β β [0, 1] β
|
| 402 |
ββββββββ¬ββββββββ
|
| 403 |
+
ββββββββββββββββ¬ββββββββββββββΌβββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ
|
| 404 |
+
β β β β β β
|
| 405 |
+
βΌ βΌ βΌ βΌ βΌ βΌ
|
| 406 |
+
ββββββββββββββ ββββββββββββββ ββββββββββββ ββββββββββββββ ββββββββββββββ ββββββββββββββ
|
| 407 |
+
β crash_free β β progress β β connectionβ β adaptation β β efficiency β β belief β
|
| 408 |
+
β Γ 0.15 β β Γ 0.20 β β Γ 0.10 β β Γ 0.25 β β Γ 0.10 β β accuracy β
|
| 409 |
+
β β β β β β β β β β β Γ 0.20 β
|
| 410 |
+
ββββββββββββββ€ ββββββββββββββ€ ββββββββββββ€ ββββββββββββββ€ ββββββββββββββ€ ββββββββββββββ€
|
| 411 |
+
β 1 - crashesβ β final P β β final Cn β β late-half β β avg_reward β β 1 - MAE β
|
| 412 |
+
β /total_ck β β value β β value β β mean rewardβ β normalized β β vs true β
|
| 413 |
+
β β β β β β β - early β β to [0,1] β β profile β
|
| 414 |
+
β e.g. 0.95 β β e.g. 0.42 β β e.g. 0.51β β e.g. +0.18 β β e.g. 0.55 β β e.g. 0.80 β
|
| 415 |
+
β Γ0.15=0.14 β β Γ0.20=0.084β β Γ0.10=0.05β β Γ0.25=0.045β β Γ0.10=0.055β β Γ0.20=0.16 β
|
| 416 |
+
ββββββββββββββ ββββββββββββββ ββββββββββββ ββββββββββββββ ββββββββββββββ ββββββββββββββ
|
| 417 |
+
|
| 418 |
+
Ξ£ = 0.14 + 0.084 + 0.05 + 0.045 + 0.055 + 0.16
|
| 419 |
+
= 0.534 β final_score (with inference)
|
| 420 |
+
|
| 421 |
+
Heuristic / random baselines never call env.record_belief(), so the belief
|
| 422 |
+
component scores 0 for them β by design: the meta-RL skill is INFERENCE,
|
| 423 |
+
and only agents that actually try get credit on this axis.
|
| 424 |
+
|
| 425 |
+
Plus a sparse terminal reward (added to step 27's per-step reward):
|
| 426 |
+
terminal_bonus = (final_score - 0.5) Γ 5 β e.g. (0.534 - 0.5) Γ 5 = +0.17
|
| 427 |
|
| 428 |
This means: at step 27, agent gets last per-step reward + bonus from grader.
|
| 429 |
This is the only direct gradient signal pointing at the actual episode quality.
|
|
|
|
| 438 |
|
| 439 |
ββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββ
|
| 440 |
β discrete-3-profiles β continuous-in-distributionβ continuous-OOD β
|
| 441 |
+
β (3 reference profiles) β (was the agent able to β (does meta-policy β
|
| 442 |
β β learn the meta-policy?) β generalize?) β
|
| 443 |
ββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββ€
|
| 444 |
β env.reset(seed=N, β env.reset(seed=N) β env.reset(seed=10000+N) β
|
|
@@ -197,13 +197,31 @@ extrovert_night_owl: β0.39 (connection weight = 75%; deep work gives 0 connec
|
|
| 197 |
Score in [0.0, 1.0]:
|
| 198 |
|
| 199 |
```
|
| 200 |
-
score = 0.
|
| 201 |
-
+ 0.25 Γ crash_free_ratio (1 β crash_count / total_possible_crashes)
|
| 202 |
+ 0.20 Γ progress (final progress meter value)
|
| 203 |
-
+ 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
+ 0.10 Γ efficiency_score (avg step reward normalised to [0, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
```
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
---
|
| 208 |
|
| 209 |
## Internal Tracking Variables
|
|
@@ -216,5 +234,6 @@ Not in the observation. Used by the environment to compute rewards and grade.
|
|
| 216 |
| `_rng` | Seeded random instance for event rolls and profile selection |
|
| 217 |
| `_crash_count` | Steps where any meter fell below 0.10 |
|
| 218 |
| `_total_reward` | Running sum of step rewards for efficiency score |
|
| 219 |
-
| `
|
|
|
|
| 220 |
| `_timestep` | Current step index (0β27) |
|
|
|
|
| 197 |
Score in [0.0, 1.0]:
|
| 198 |
|
| 199 |
```
|
| 200 |
+
score = 0.15 Γ crash_free_ratio (1 β crash_count / total_possible_crashes)
|
|
|
|
| 201 |
+ 0.20 Γ progress (final progress meter value)
|
| 202 |
+
+ 0.10 Γ connection (final connection meter value)
|
| 203 |
+
+ 0.25 Γ adaptation_score (late-half mean per-step reward minus
|
| 204 |
+
early-half mean β gated by absolute
|
| 205 |
+
late-half quality so a "terrible-then-
|
| 206 |
+
mediocre" exploit cannot win)
|
| 207 |
+ 0.10 Γ efficiency_score (avg step reward normalised to [0, 1])
|
| 208 |
+
+ 0.20 Γ belief_accuracy (1 β MAE between agent's last-emitted
|
| 209 |
+
belief vector and the true profile
|
| 210 |
+
vector; 0 if the agent never emitted a
|
| 211 |
+
belief β heuristic / random baselines)
|
| 212 |
```
|
| 213 |
|
| 214 |
+
Two meta-RL signals: `adaptation_score` is implicit (rewards getting better
|
| 215 |
+
over time, since per-step rewards are profile-weighted), and `belief_accuracy`
|
| 216 |
+
is explicit (rewards INFERRING the profile correctly). Without the explicit
|
| 217 |
+
term, agents that play heuristic-style "keep meters healthy" score the same
|
| 218 |
+
as agents that actually do inference, since the other components don't
|
| 219 |
+
differentiate inference from reflex.
|
| 220 |
+
|
| 221 |
+
To emit a belief, the agent calls `env.record_belief([s, m, w])` once per
|
| 222 |
+
step (typically right after parsing its own completion). The grader uses the
|
| 223 |
+
LAST recorded belief.
|
| 224 |
+
|
| 225 |
---
|
| 226 |
|
| 227 |
## Internal Tracking Variables
|
|
|
|
| 234 |
| `_rng` | Seeded random instance for event rolls and profile selection |
|
| 235 |
| `_crash_count` | Steps where any meter fell below 0.10 |
|
| 236 |
| `_total_reward` | Running sum of step rewards for efficiency score |
|
| 237 |
+
| `_step_history` | Rolling window of completed steps (action, reward, deltas, anomalies). Used both as the agent-visible history and to compute repetition dampening. |
|
| 238 |
+
| `_step_rewards` | Per-step reward list for adaptation_score in the grader |
|
| 239 |
| `_timestep` | Current step index (0β27) |
|
|
@@ -197,7 +197,7 @@ just not what we wanted it to do.
|
|
| 197 |
|
| 198 |
---
|
| 199 |
|
| 200 |
-
## Iter 3: Align reward + restructure format (
|
| 201 |
|
| 202 |
**5 architectural fixes**:
|
| 203 |
|
|
@@ -237,7 +237,112 @@ penalty + exploration bonus + terminal supervision, the agent should:
|
|
| 237 |
- Beat random in 2/3 conditions on final_score
|
| 238 |
- Show positive (or less-negative) adaptation than baselines
|
| 239 |
|
| 240 |
-
**Result**:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
---
|
| 243 |
|
|
@@ -247,9 +352,12 @@ penalty + exploration bonus + terminal supervision, the agent should:
|
|
| 247 |
|---|---|---|---|
|
| 248 |
| 1 | ~$0.50 | 200 | Mode collapse to single action |
|
| 249 |
| 2 | ~$1.50 | 400 | Mode collapse to 2-cycle |
|
| 250 |
-
| 3 |
|
| 251 |
-
|
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
---
|
| 255 |
|
|
|
|
| 197 |
|
| 198 |
---
|
| 199 |
|
| 200 |
+
## Iter 3: Align reward + restructure format (CANCELLED before run β stale code, $0)
|
| 201 |
|
| 202 |
**5 architectural fixes**:
|
| 203 |
|
|
|
|
| 237 |
- Beat random in 2/3 conditions on final_score
|
| 238 |
- Show positive (or less-negative) adaptation than baselines
|
| 239 |
|
| 240 |
+
**Result**: Iter 3 was never actually launched. Pre-flight inspection of the
|
| 241 |
+
HF Space confirmed the cloned snapshot still had stale code, and a re-launched
|
| 242 |
+
external review surfaced 7 deeper bugs (see Round 2 below) that needed to
|
| 243 |
+
land before any further GPU spend was justified.
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## Round 2 fixes (applied for iter 4+, after external bug review)
|
| 248 |
+
|
| 249 |
+
External agent surfaced 7 issues that survived all prior reviews. All landed
|
| 250 |
+
on `round2` branch and on the HF Space `main` before iter 4 launched:
|
| 251 |
+
|
| 252 |
+
1. **Anomalies surfaced in prompt** (`StepRecord` + `format_observation_prompt`
|
| 253 |
+
+ `inference.py`): per-meter anomaly signals were computed each step but
|
| 254 |
+
never made visible to the agent. Agent was supposed to learn from them.
|
| 255 |
+
2. **Belief baseline subtraction** in `belief_accuracy`: reward is now
|
| 256 |
+
`similarity β constant_baseline_similarity`. The constant `5 5 5` belief
|
| 257 |
+
no longer earns a free +1/step floor.
|
| 258 |
+
3. **Profile weight cap 0.80 β 0.45** in `sample_profile`. Forces every
|
| 259 |
+
sampled profile to weight 3+ meters meaningfully (originally to kill the
|
| 260 |
+
"single-meter dominant β SLEEP-spam optimal" exploit).
|
| 261 |
+
4. **Scaled-down shaping** in `_compute_reward`: -0.10 / -0.15 / +0.07
|
| 262 |
+
(was -0.30 / -0.40 / +0.20). Reduces noise-floor of shaping vs. the
|
| 263 |
+
real signal layers.
|
| 264 |
+
5. **Step-0 belief reward = 0**: agent has no information at step 0, so
|
| 265 |
+
penalizing belief-vs-target there just punishes initialization.
|
| 266 |
+
6. **Belief-action coupling reward** (Β±0.15): rewards if the chosen action
|
| 267 |
+
matches the agent's emitted belief, penalizes if it contradicts. Forces
|
| 268 |
+
the belief to be *causally useful*, not decorative.
|
| 269 |
+
7. **`grader_bias` moved out of `_compute_reward` into `env_reward`**:
|
| 270 |
+
keeps per-step env reward pure for inference-signal analysis. The
|
| 271 |
+
progress/connection bias still lands in the GRPO advantage, just via
|
| 272 |
+
the env-reward layer.
|
| 273 |
+
|
| 274 |
+
---
|
| 275 |
+
|
| 276 |
+
## Iter 4: Round 2 fixes β partial run, mistakenly cancelled (2026-04-26, ~$2.10, 235/800 steps)
|
| 277 |
+
|
| 278 |
+
**Config**: a10g-large, LoRA rank 16, num_generations 8, 800 steps, all
|
| 279 |
+
Round 1 + Iter 3 architectural fixes + Round 2 (above).
|
| 280 |
+
|
| 281 |
+
**Hypothesis**: With anomalies in the prompt, baseline subtraction killing
|
| 282 |
+
the belief-spam floor, belief-action coupling forcing causal use of belief,
|
| 283 |
+
and grader_bias keeping env-reward pure, the agent should show monotonic
|
| 284 |
+
belief_accuracy growth without hitting a 2-cycle hack.
|
| 285 |
+
|
| 286 |
+
**What we got** (from 235-step partial β see `docs/iter4_partial_analysis.txt`):
|
| 287 |
+
|
| 288 |
+
Working:
|
| 289 |
+
- Total reward: -3.4 β +0.39 (climbing)
|
| 290 |
+
- format_valid: -1.20 β +0.44 (slow but climbing)
|
| 291 |
+
- env_reward: -2.01 β +0.44 (climbing)
|
| 292 |
+
- grad_norm normalized to ~10 by step 60 from initial 36+
|
| 293 |
+
- No catastrophic mode collapse
|
| 294 |
+
|
| 295 |
+
Broken β the unsolved core:
|
| 296 |
+
- **`belief_accuracy/mean` flat at -0.10 throughout 235 steps**
|
| 297 |
+
- Linear slope: +0.0007 per 100 steps (essentially zero, well under noise)
|
| 298 |
+
- Agent emits beliefs SLIGHTLY WORSE than constant baseline
|
| 299 |
+
|
| 300 |
+
**Why the run ended at 235**: I cancelled the job based on stale HF API
|
| 301 |
+
log output that suggested the run was stuck. The HF UI showed it was
|
| 302 |
+
healthy. ~$2.10 wasted. Lesson banked: **trust the live UI over the
|
| 303 |
+
`/logs` API endpoint**, which lags severely.
|
| 304 |
+
|
| 305 |
+
**Root-cause hypothesis** (post-mortem analysis):
|
| 306 |
+
|
| 307 |
+
The profile cap (0.80 β 0.45) and the baseline subtraction interact
|
| 308 |
+
negatively. With weights clamped to β€0.45, sampled profiles cluster
|
| 309 |
+
toward balanced; `profile_to_belief_vector` (whose `work_pref` axis is
|
| 310 |
+
30%-weighted on the progress reward weight) consequently lands closer to
|
| 311 |
+
[0.5, 0.5, 0.5]. The constant `5 5 5` belief already has high cosine
|
| 312 |
+
similarity with that target, so after baseline subtraction there is
|
| 313 |
+
almost no headroom for the agent to "win" against it.
|
| 314 |
+
|
| 315 |
+
**Why we missed it**:
|
| 316 |
+
- The Round 2 fixes were treated as independent, but #2 (baseline
|
| 317 |
+
subtraction) and #3 (profile cap) share the same denominator β the
|
| 318 |
+
spread of the belief target distribution. An analytical check on
|
| 319 |
+
belief-target stddev under the new cap would have caught it before
|
| 320 |
+
spending compute.
|
| 321 |
+
- The `grader_bias` term (#7) was the original justification for
|
| 322 |
+
needing a tighter profile cap (kill the SLEEP-spam exploit). Once
|
| 323 |
+
grader_bias was in env_reward, the cap could have been reverted.
|
| 324 |
+
We applied both fixes simultaneously.
|
| 325 |
+
|
| 326 |
+
---
|
| 327 |
+
|
| 328 |
+
## Iter 5: Identical fixes, smaller config (2026-04-26 05:18 UTC, RUNNING)
|
| 329 |
+
|
| 330 |
+
**Config**: a10g-large, **LoRA rank 8**, **num_generations 4**, **500 steps**.
|
| 331 |
+
Same fix set as iter 4 β Round 1 + Iter 3 architectural + Round 2.
|
| 332 |
+
|
| 333 |
+
**Hypothesis**: With a smaller config, validate that iter 4's partial-run
|
| 334 |
+
trajectory was real (climbing total reward, flat belief_accuracy) rather
|
| 335 |
+
than a fluke of the cancelled-mid-run snapshot.
|
| 336 |
+
|
| 337 |
+
**Expected outcome** (informed by iter 4 partial): same flat belief_accuracy
|
| 338 |
+
because the underlying cap Γ baseline interaction is unchanged. This run
|
| 339 |
+
exists to confirm the hypothesis cheaply before spending on the iter 6
|
| 340 |
+
profile-cap revert.
|
| 341 |
+
|
| 342 |
+
**Job**: `69eda027d70108f37acdf9a7` β
|
| 343 |
+
`https://huggingface.co/jobs/InosLihka/69eda027d70108f37acdf9a7`
|
| 344 |
+
|
| 345 |
+
**Result**: TBD β currently running.
|
| 346 |
|
| 347 |
---
|
| 348 |
|
|
|
|
| 352 |
|---|---|---|---|
|
| 353 |
| 1 | ~$0.50 | 200 | Mode collapse to single action |
|
| 354 |
| 2 | ~$1.50 | 400 | Mode collapse to 2-cycle |
|
| 355 |
+
| 3 | $0 | β | Cancelled pre-run (stale code) |
|
| 356 |
+
| 4 (a100/l40s/h200 attempts) | ~$1.50 | β | Capacity-cancelled or hardware-incompat |
|
| 357 |
+
| 4 (a10g) | ~$2.10 | 235/800 | Cancelled by mistake; partial data shows flat belief_accuracy |
|
| 358 |
+
| 5 (a10g) | TBD | 500 (running) | TBD |
|
| 359 |
+
| **Subtotal** | **~$5.60** | | |
|
| 360 |
+
| Budget | $30 | | ~$24.40 remaining |
|
| 361 |
|
| 362 |
---
|
| 363 |
|
|
The diff for this file is too large to render.
See raw diff
|
|
|
|
The diff for this file is too large to render.
See raw diff
|
|
|
|
The diff for this file is too large to render.
See raw diff
|
|
|
|
The diff for this file is too large to render.
See raw diff
|
|
|
|
The diff for this file is too large to render.
See raw diff
|
|
|
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"grader_version": "v2_with_belief_accuracy",
|
| 3 |
+
"note": "baselines emit no belief; belief_accuracy component is 0 for them by design",
|
| 4 |
+
"indist": {
|
| 5 |
+
"heuristic_mean": 0.4488,
|
| 6 |
+
"random_mean": 0.402
|
| 7 |
+
},
|
| 8 |
+
"ood": {
|
| 9 |
+
"heuristic_mean": 0.4539,
|
| 10 |
+
"random_mean": 0.3974
|
| 11 |
+
}
|
| 12 |
+
}
|
|
@@ -1,758 +0,0 @@
|
|
| 1 |
-
[
|
| 2 |
-
{
|
| 3 |
-
"profile": "introvert_morning",
|
| 4 |
-
"strategy": "heuristic",
|
| 5 |
-
"seed": 0,
|
| 6 |
-
"final_score": 0.7696,
|
| 7 |
-
"total_reward": 5.84,
|
| 8 |
-
"vitality": 0.76,
|
| 9 |
-
"cognition": 0.41,
|
| 10 |
-
"progress": 1.0,
|
| 11 |
-
"serenity": 1.0,
|
| 12 |
-
"connection": 0.43,
|
| 13 |
-
"actions": [
|
| 14 |
-
"deep_work",
|
| 15 |
-
"admin_work",
|
| 16 |
-
"meditate",
|
| 17 |
-
"sleep",
|
| 18 |
-
"deep_work",
|
| 19 |
-
"admin_work",
|
| 20 |
-
"meditate",
|
| 21 |
-
"sleep",
|
| 22 |
-
"deep_work",
|
| 23 |
-
"admin_work",
|
| 24 |
-
"meditate",
|
| 25 |
-
"sleep",
|
| 26 |
-
"deep_work",
|
| 27 |
-
"admin_work",
|
| 28 |
-
"socialize",
|
| 29 |
-
"sleep",
|
| 30 |
-
"deep_work",
|
| 31 |
-
"admin_work",
|
| 32 |
-
"meditate",
|
| 33 |
-
"sleep",
|
| 34 |
-
"deep_work",
|
| 35 |
-
"admin_work",
|
| 36 |
-
"socialize",
|
| 37 |
-
"sleep",
|
| 38 |
-
"deep_work",
|
| 39 |
-
"admin_work",
|
| 40 |
-
"meditate",
|
| 41 |
-
"sleep"
|
| 42 |
-
]
|
| 43 |
-
},
|
| 44 |
-
{
|
| 45 |
-
"profile": "introvert_morning",
|
| 46 |
-
"strategy": "heuristic",
|
| 47 |
-
"seed": 1,
|
| 48 |
-
"final_score": 0.7526,
|
| 49 |
-
"total_reward": 6.88,
|
| 50 |
-
"vitality": 0.64,
|
| 51 |
-
"cognition": 0.33,
|
| 52 |
-
"progress": 1.0,
|
| 53 |
-
"serenity": 0.87,
|
| 54 |
-
"connection": 0.43,
|
| 55 |
-
"actions": [
|
| 56 |
-
"deep_work",
|
| 57 |
-
"admin_work",
|
| 58 |
-
"meditate",
|
| 59 |
-
"sleep",
|
| 60 |
-
"deep_work",
|
| 61 |
-
"admin_work",
|
| 62 |
-
"meditate",
|
| 63 |
-
"sleep",
|
| 64 |
-
"deep_work",
|
| 65 |
-
"admin_work",
|
| 66 |
-
"meditate",
|
| 67 |
-
"sleep",
|
| 68 |
-
"deep_work",
|
| 69 |
-
"admin_work",
|
| 70 |
-
"meditate",
|
| 71 |
-
"sleep",
|
| 72 |
-
"deep_work",
|
| 73 |
-
"admin_work",
|
| 74 |
-
"meditate",
|
| 75 |
-
"sleep",
|
| 76 |
-
"deep_work",
|
| 77 |
-
"admin_work",
|
| 78 |
-
"socialize",
|
| 79 |
-
"sleep",
|
| 80 |
-
"deep_work",
|
| 81 |
-
"admin_work",
|
| 82 |
-
"socialize",
|
| 83 |
-
"sleep"
|
| 84 |
-
]
|
| 85 |
-
},
|
| 86 |
-
{
|
| 87 |
-
"profile": "introvert_morning",
|
| 88 |
-
"strategy": "heuristic",
|
| 89 |
-
"seed": 2,
|
| 90 |
-
"final_score": 0.7723,
|
| 91 |
-
"total_reward": 8.97,
|
| 92 |
-
"vitality": 0.84,
|
| 93 |
-
"cognition": 0.44,
|
| 94 |
-
"progress": 1.0,
|
| 95 |
-
"serenity": 1.0,
|
| 96 |
-
"connection": 0.39,
|
| 97 |
-
"actions": [
|
| 98 |
-
"deep_work",
|
| 99 |
-
"admin_work",
|
| 100 |
-
"meditate",
|
| 101 |
-
"sleep",
|
| 102 |
-
"deep_work",
|
| 103 |
-
"admin_work",
|
| 104 |
-
"meditate",
|
| 105 |
-
"sleep",
|
| 106 |
-
"deep_work",
|
| 107 |
-
"admin_work",
|
| 108 |
-
"socialize",
|
| 109 |
-
"sleep",
|
| 110 |
-
"deep_work",
|
| 111 |
-
"admin_work",
|
| 112 |
-
"socialize",
|
| 113 |
-
"sleep",
|
| 114 |
-
"exercise",
|
| 115 |
-
"admin_work",
|
| 116 |
-
"socialize",
|
| 117 |
-
"sleep",
|
| 118 |
-
"deep_work",
|
| 119 |
-
"admin_work",
|
| 120 |
-
"meditate",
|
| 121 |
-
"sleep",
|
| 122 |
-
"exercise",
|
| 123 |
-
"admin_work",
|
| 124 |
-
"meditate",
|
| 125 |
-
"sleep"
|
| 126 |
-
]
|
| 127 |
-
},
|
| 128 |
-
{
|
| 129 |
-
"profile": "extrovert_night_owl",
|
| 130 |
-
"strategy": "heuristic",
|
| 131 |
-
"seed": 0,
|
| 132 |
-
"final_score": 0.8197,
|
| 133 |
-
"total_reward": 4.21,
|
| 134 |
-
"vitality": 1.0,
|
| 135 |
-
"cognition": 1.0,
|
| 136 |
-
"progress": 1.0,
|
| 137 |
-
"serenity": 1.0,
|
| 138 |
-
"connection": 0.43,
|
| 139 |
-
"actions": [
|
| 140 |
-
"deep_work",
|
| 141 |
-
"learn",
|
| 142 |
-
"meditate",
|
| 143 |
-
"sleep",
|
| 144 |
-
"deep_work",
|
| 145 |
-
"learn",
|
| 146 |
-
"meditate",
|
| 147 |
-
"sleep",
|
| 148 |
-
"deep_work",
|
| 149 |
-
"admin_work",
|
| 150 |
-
"meditate",
|
| 151 |
-
"sleep",
|
| 152 |
-
"deep_work",
|
| 153 |
-
"admin_work",
|
| 154 |
-
"socialize",
|
| 155 |
-
"sleep",
|
| 156 |
-
"deep_work",
|
| 157 |
-
"admin_work",
|
| 158 |
-
"meditate",
|
| 159 |
-
"sleep",
|
| 160 |
-
"deep_work",
|
| 161 |
-
"admin_work",
|
| 162 |
-
"meditate",
|
| 163 |
-
"sleep",
|
| 164 |
-
"deep_work",
|
| 165 |
-
"admin_work",
|
| 166 |
-
"meditate",
|
| 167 |
-
"sleep"
|
| 168 |
-
]
|
| 169 |
-
},
|
| 170 |
-
{
|
| 171 |
-
"profile": "extrovert_night_owl",
|
| 172 |
-
"strategy": "heuristic",
|
| 173 |
-
"seed": 1,
|
| 174 |
-
"final_score": 0.8209,
|
| 175 |
-
"total_reward": 4.59,
|
| 176 |
-
"vitality": 1.0,
|
| 177 |
-
"cognition": 1.0,
|
| 178 |
-
"progress": 0.97,
|
| 179 |
-
"serenity": 1.0,
|
| 180 |
-
"connection": 0.46,
|
| 181 |
-
"actions": [
|
| 182 |
-
"deep_work",
|
| 183 |
-
"learn",
|
| 184 |
-
"meditate",
|
| 185 |
-
"sleep",
|
| 186 |
-
"deep_work",
|
| 187 |
-
"learn",
|
| 188 |
-
"meditate",
|
| 189 |
-
"sleep",
|
| 190 |
-
"deep_work",
|
| 191 |
-
"admin_work",
|
| 192 |
-
"meditate",
|
| 193 |
-
"sleep",
|
| 194 |
-
"deep_work",
|
| 195 |
-
"admin_work",
|
| 196 |
-
"meditate",
|
| 197 |
-
"sleep",
|
| 198 |
-
"deep_work",
|
| 199 |
-
"admin_work",
|
| 200 |
-
"meditate",
|
| 201 |
-
"sleep",
|
| 202 |
-
"deep_work",
|
| 203 |
-
"admin_work",
|
| 204 |
-
"socialize",
|
| 205 |
-
"sleep",
|
| 206 |
-
"deep_work",
|
| 207 |
-
"admin_work",
|
| 208 |
-
"meditate",
|
| 209 |
-
"sleep"
|
| 210 |
-
]
|
| 211 |
-
},
|
| 212 |
-
{
|
| 213 |
-
"profile": "extrovert_night_owl",
|
| 214 |
-
"strategy": "heuristic",
|
| 215 |
-
"seed": 2,
|
| 216 |
-
"final_score": 0.8164,
|
| 217 |
-
"total_reward": 5.85,
|
| 218 |
-
"vitality": 0.97,
|
| 219 |
-
"cognition": 0.82,
|
| 220 |
-
"progress": 0.93,
|
| 221 |
-
"serenity": 1.0,
|
| 222 |
-
"connection": 0.53,
|
| 223 |
-
"actions": [
|
| 224 |
-
"deep_work",
|
| 225 |
-
"learn",
|
| 226 |
-
"meditate",
|
| 227 |
-
"sleep",
|
| 228 |
-
"deep_work",
|
| 229 |
-
"learn",
|
| 230 |
-
"meditate",
|
| 231 |
-
"sleep",
|
| 232 |
-
"deep_work",
|
| 233 |
-
"learn",
|
| 234 |
-
"socialize",
|
| 235 |
-
"sleep",
|
| 236 |
-
"deep_work",
|
| 237 |
-
"admin_work",
|
| 238 |
-
"meditate",
|
| 239 |
-
"sleep",
|
| 240 |
-
"deep_work",
|
| 241 |
-
"admin_work",
|
| 242 |
-
"meditate",
|
| 243 |
-
"sleep",
|
| 244 |
-
"deep_work",
|
| 245 |
-
"admin_work",
|
| 246 |
-
"socialize",
|
| 247 |
-
"sleep",
|
| 248 |
-
"deep_work",
|
| 249 |
-
"admin_work",
|
| 250 |
-
"meditate",
|
| 251 |
-
"sleep"
|
| 252 |
-
]
|
| 253 |
-
},
|
| 254 |
-
{
|
| 255 |
-
"profile": "workaholic_stoic",
|
| 256 |
-
"strategy": "heuristic",
|
| 257 |
-
"seed": 0,
|
| 258 |
-
"final_score": 0.7461,
|
| 259 |
-
"total_reward": 11.98,
|
| 260 |
-
"vitality": 0.59,
|
| 261 |
-
"cognition": 0.23,
|
| 262 |
-
"progress": 1.0,
|
| 263 |
-
"serenity": 0.95,
|
| 264 |
-
"connection": 0.41,
|
| 265 |
-
"actions": [
|
| 266 |
-
"deep_work",
|
| 267 |
-
"learn",
|
| 268 |
-
"meditate",
|
| 269 |
-
"sleep",
|
| 270 |
-
"deep_work",
|
| 271 |
-
"admin_work",
|
| 272 |
-
"socialize",
|
| 273 |
-
"sleep",
|
| 274 |
-
"deep_work",
|
| 275 |
-
"admin_work",
|
| 276 |
-
"socialize",
|
| 277 |
-
"sleep",
|
| 278 |
-
"deep_work",
|
| 279 |
-
"admin_work",
|
| 280 |
-
"meditate",
|
| 281 |
-
"sleep",
|
| 282 |
-
"deep_work",
|
| 283 |
-
"admin_work",
|
| 284 |
-
"socialize",
|
| 285 |
-
"sleep",
|
| 286 |
-
"exercise",
|
| 287 |
-
"admin_work",
|
| 288 |
-
"socialize",
|
| 289 |
-
"sleep",
|
| 290 |
-
"deep_work",
|
| 291 |
-
"admin_work",
|
| 292 |
-
"socialize",
|
| 293 |
-
"sleep"
|
| 294 |
-
]
|
| 295 |
-
},
|
| 296 |
-
{
|
| 297 |
-
"profile": "workaholic_stoic",
|
| 298 |
-
"strategy": "heuristic",
|
| 299 |
-
"seed": 1,
|
| 300 |
-
"final_score": 0.7585,
|
| 301 |
-
"total_reward": 12.67,
|
| 302 |
-
"vitality": 0.72,
|
| 303 |
-
"cognition": 0.32,
|
| 304 |
-
"progress": 1.0,
|
| 305 |
-
"serenity": 0.95,
|
| 306 |
-
"connection": 0.38,
|
| 307 |
-
"actions": [
|
| 308 |
-
"deep_work",
|
| 309 |
-
"learn",
|
| 310 |
-
"meditate",
|
| 311 |
-
"sleep",
|
| 312 |
-
"deep_work",
|
| 313 |
-
"admin_work",
|
| 314 |
-
"socialize",
|
| 315 |
-
"sleep",
|
| 316 |
-
"deep_work",
|
| 317 |
-
"admin_work",
|
| 318 |
-
"meditate",
|
| 319 |
-
"sleep",
|
| 320 |
-
"deep_work",
|
| 321 |
-
"admin_work",
|
| 322 |
-
"meditate",
|
| 323 |
-
"sleep",
|
| 324 |
-
"deep_work",
|
| 325 |
-
"admin_work",
|
| 326 |
-
"socialize",
|
| 327 |
-
"sleep",
|
| 328 |
-
"deep_work",
|
| 329 |
-
"admin_work",
|
| 330 |
-
"socialize",
|
| 331 |
-
"sleep",
|
| 332 |
-
"exercise",
|
| 333 |
-
"admin_work",
|
| 334 |
-
"socialize",
|
| 335 |
-
"sleep"
|
| 336 |
-
]
|
| 337 |
-
},
|
| 338 |
-
{
|
| 339 |
-
"profile": "workaholic_stoic",
|
| 340 |
-
"strategy": "heuristic",
|
| 341 |
-
"seed": 2,
|
| 342 |
-
"final_score": 0.7782,
|
| 343 |
-
"total_reward": 13.16,
|
| 344 |
-
"vitality": 0.74,
|
| 345 |
-
"cognition": 0.34,
|
| 346 |
-
"progress": 1.0,
|
| 347 |
-
"serenity": 0.95,
|
| 348 |
-
"connection": 0.44,
|
| 349 |
-
"actions": [
|
| 350 |
-
"deep_work",
|
| 351 |
-
"learn",
|
| 352 |
-
"meditate",
|
| 353 |
-
"sleep",
|
| 354 |
-
"deep_work",
|
| 355 |
-
"admin_work",
|
| 356 |
-
"socialize",
|
| 357 |
-
"sleep",
|
| 358 |
-
"deep_work",
|
| 359 |
-
"admin_work",
|
| 360 |
-
"socialize",
|
| 361 |
-
"sleep",
|
| 362 |
-
"deep_work",
|
| 363 |
-
"admin_work",
|
| 364 |
-
"socialize",
|
| 365 |
-
"sleep",
|
| 366 |
-
"exercise",
|
| 367 |
-
"admin_work",
|
| 368 |
-
"socialize",
|
| 369 |
-
"sleep",
|
| 370 |
-
"exercise",
|
| 371 |
-
"admin_work",
|
| 372 |
-
"socialize",
|
| 373 |
-
"sleep",
|
| 374 |
-
"exercise",
|
| 375 |
-
"admin_work",
|
| 376 |
-
"socialize",
|
| 377 |
-
"sleep"
|
| 378 |
-
]
|
| 379 |
-
},
|
| 380 |
-
{
|
| 381 |
-
"profile": "introvert_morning",
|
| 382 |
-
"strategy": "random",
|
| 383 |
-
"seed": 0,
|
| 384 |
-
"final_score": 0.7141,
|
| 385 |
-
"total_reward": 1.82,
|
| 386 |
-
"vitality": 0.45,
|
| 387 |
-
"cognition": 0.56,
|
| 388 |
-
"progress": 0.66,
|
| 389 |
-
"serenity": 1.0,
|
| 390 |
-
"connection": 0.71,
|
| 391 |
-
"actions": [
|
| 392 |
-
"socialize",
|
| 393 |
-
"me_time",
|
| 394 |
-
"binge_watch",
|
| 395 |
-
"socialize",
|
| 396 |
-
"exercise",
|
| 397 |
-
"family_time",
|
| 398 |
-
"sleep",
|
| 399 |
-
"admin_work",
|
| 400 |
-
"meditate",
|
| 401 |
-
"binge_watch",
|
| 402 |
-
"admin_work",
|
| 403 |
-
"deep_work",
|
| 404 |
-
"sleep",
|
| 405 |
-
"meditate",
|
| 406 |
-
"sleep",
|
| 407 |
-
"family_time",
|
| 408 |
-
"exercise",
|
| 409 |
-
"deep_work",
|
| 410 |
-
"admin_work",
|
| 411 |
-
"meditate",
|
| 412 |
-
"socialize",
|
| 413 |
-
"binge_watch",
|
| 414 |
-
"exercise",
|
| 415 |
-
"meditate",
|
| 416 |
-
"learn",
|
| 417 |
-
"socialize",
|
| 418 |
-
"admin_work",
|
| 419 |
-
"sleep"
|
| 420 |
-
]
|
| 421 |
-
},
|
| 422 |
-
{
|
| 423 |
-
"profile": "introvert_morning",
|
| 424 |
-
"strategy": "random",
|
| 425 |
-
"seed": 1,
|
| 426 |
-
"final_score": 0.6924,
|
| 427 |
-
"total_reward": 4.02,
|
| 428 |
-
"vitality": 0.08,
|
| 429 |
-
"cognition": 0.5,
|
| 430 |
-
"progress": 0.76,
|
| 431 |
-
"serenity": 0.97,
|
| 432 |
-
"connection": 0.68,
|
| 433 |
-
"actions": [
|
| 434 |
-
"exercise",
|
| 435 |
-
"meditate",
|
| 436 |
-
"sleep",
|
| 437 |
-
"meditate",
|
| 438 |
-
"meditate",
|
| 439 |
-
"me_time",
|
| 440 |
-
"meditate",
|
| 441 |
-
"learn",
|
| 442 |
-
"meditate",
|
| 443 |
-
"meditate",
|
| 444 |
-
"socialize",
|
| 445 |
-
"socialize",
|
| 446 |
-
"deep_work",
|
| 447 |
-
"meditate",
|
| 448 |
-
"socialize",
|
| 449 |
-
"deep_work",
|
| 450 |
-
"meditate",
|
| 451 |
-
"sleep",
|
| 452 |
-
"learn",
|
| 453 |
-
"socialize",
|
| 454 |
-
"deep_work",
|
| 455 |
-
"socialize",
|
| 456 |
-
"learn",
|
| 457 |
-
"sleep",
|
| 458 |
-
"family_time",
|
| 459 |
-
"meditate",
|
| 460 |
-
"meditate",
|
| 461 |
-
"admin_work"
|
| 462 |
-
]
|
| 463 |
-
},
|
| 464 |
-
{
|
| 465 |
-
"profile": "introvert_morning",
|
| 466 |
-
"strategy": "random",
|
| 467 |
-
"seed": 2,
|
| 468 |
-
"final_score": 0.6715,
|
| 469 |
-
"total_reward": 6.12,
|
| 470 |
-
"vitality": 0.61,
|
| 471 |
-
"cognition": 0.22,
|
| 472 |
-
"progress": 1.0,
|
| 473 |
-
"serenity": 0.86,
|
| 474 |
-
"connection": 0.17,
|
| 475 |
-
"actions": [
|
| 476 |
-
"me_time",
|
| 477 |
-
"meditate",
|
| 478 |
-
"learn",
|
| 479 |
-
"meditate",
|
| 480 |
-
"learn",
|
| 481 |
-
"family_time",
|
| 482 |
-
"deep_work",
|
| 483 |
-
"family_time",
|
| 484 |
-
"me_time",
|
| 485 |
-
"admin_work",
|
| 486 |
-
"sleep",
|
| 487 |
-
"meditate",
|
| 488 |
-
"sleep",
|
| 489 |
-
"admin_work",
|
| 490 |
-
"meditate",
|
| 491 |
-
"me_time",
|
| 492 |
-
"sleep",
|
| 493 |
-
"sleep",
|
| 494 |
-
"binge_watch",
|
| 495 |
-
"admin_work",
|
| 496 |
-
"deep_work",
|
| 497 |
-
"admin_work",
|
| 498 |
-
"admin_work",
|
| 499 |
-
"binge_watch",
|
| 500 |
-
"learn",
|
| 501 |
-
"sleep",
|
| 502 |
-
"me_time",
|
| 503 |
-
"deep_work"
|
| 504 |
-
]
|
| 505 |
-
},
|
| 506 |
-
{
|
| 507 |
-
"profile": "extrovert_night_owl",
|
| 508 |
-
"strategy": "random",
|
| 509 |
-
"seed": 0,
|
| 510 |
-
"final_score": 0.9368,
|
| 511 |
-
"total_reward": 8.5,
|
| 512 |
-
"vitality": 1.0,
|
| 513 |
-
"cognition": 0.83,
|
| 514 |
-
"progress": 1.0,
|
| 515 |
-
"serenity": 1.0,
|
| 516 |
-
"connection": 0.98,
|
| 517 |
-
"actions": [
|
| 518 |
-
"socialize",
|
| 519 |
-
"me_time",
|
| 520 |
-
"binge_watch",
|
| 521 |
-
"socialize",
|
| 522 |
-
"exercise",
|
| 523 |
-
"family_time",
|
| 524 |
-
"sleep",
|
| 525 |
-
"admin_work",
|
| 526 |
-
"meditate",
|
| 527 |
-
"binge_watch",
|
| 528 |
-
"admin_work",
|
| 529 |
-
"deep_work",
|
| 530 |
-
"sleep",
|
| 531 |
-
"meditate",
|
| 532 |
-
"sleep",
|
| 533 |
-
"family_time",
|
| 534 |
-
"exercise",
|
| 535 |
-
"deep_work",
|
| 536 |
-
"admin_work",
|
| 537 |
-
"meditate",
|
| 538 |
-
"socialize",
|
| 539 |
-
"binge_watch",
|
| 540 |
-
"exercise",
|
| 541 |
-
"meditate",
|
| 542 |
-
"learn",
|
| 543 |
-
"socialize",
|
| 544 |
-
"admin_work",
|
| 545 |
-
"sleep"
|
| 546 |
-
]
|
| 547 |
-
},
|
| 548 |
-
{
|
| 549 |
-
"profile": "extrovert_night_owl",
|
| 550 |
-
"strategy": "random",
|
| 551 |
-
"seed": 1,
|
| 552 |
-
"final_score": 0.9054,
|
| 553 |
-
"total_reward": 8.2,
|
| 554 |
-
"vitality": 0.75,
|
| 555 |
-
"cognition": 0.69,
|
| 556 |
-
"progress": 1.0,
|
| 557 |
-
"serenity": 0.97,
|
| 558 |
-
"connection": 0.97,
|
| 559 |
-
"actions": [
|
| 560 |
-
"exercise",
|
| 561 |
-
"meditate",
|
| 562 |
-
"sleep",
|
| 563 |
-
"meditate",
|
| 564 |
-
"meditate",
|
| 565 |
-
"me_time",
|
| 566 |
-
"meditate",
|
| 567 |
-
"learn",
|
| 568 |
-
"meditate",
|
| 569 |
-
"meditate",
|
| 570 |
-
"socialize",
|
| 571 |
-
"socialize",
|
| 572 |
-
"deep_work",
|
| 573 |
-
"meditate",
|
| 574 |
-
"socialize",
|
| 575 |
-
"deep_work",
|
| 576 |
-
"meditate",
|
| 577 |
-
"sleep",
|
| 578 |
-
"learn",
|
| 579 |
-
"socialize",
|
| 580 |
-
"deep_work",
|
| 581 |
-
"socialize",
|
| 582 |
-
"learn",
|
| 583 |
-
"sleep",
|
| 584 |
-
"family_time",
|
| 585 |
-
"meditate",
|
| 586 |
-
"meditate",
|
| 587 |
-
"admin_work"
|
| 588 |
-
]
|
| 589 |
-
},
|
| 590 |
-
{
|
| 591 |
-
"profile": "extrovert_night_owl",
|
| 592 |
-
"strategy": "random",
|
| 593 |
-
"seed": 2,
|
| 594 |
-
"final_score": 0.7462,
|
| 595 |
-
"total_reward": 4.12,
|
| 596 |
-
"vitality": 0.75,
|
| 597 |
-
"cognition": 0.32,
|
| 598 |
-
"progress": 1.0,
|
| 599 |
-
"serenity": 0.95,
|
| 600 |
-
"connection": 0.4,
|
| 601 |
-
"actions": [
|
| 602 |
-
"me_time",
|
| 603 |
-
"meditate",
|
| 604 |
-
"learn",
|
| 605 |
-
"meditate",
|
| 606 |
-
"learn",
|
| 607 |
-
"family_time",
|
| 608 |
-
"deep_work",
|
| 609 |
-
"family_time",
|
| 610 |
-
"me_time",
|
| 611 |
-
"admin_work",
|
| 612 |
-
"sleep",
|
| 613 |
-
"meditate",
|
| 614 |
-
"sleep",
|
| 615 |
-
"admin_work",
|
| 616 |
-
"meditate",
|
| 617 |
-
"me_time",
|
| 618 |
-
"sleep",
|
| 619 |
-
"sleep",
|
| 620 |
-
"binge_watch",
|
| 621 |
-
"admin_work",
|
| 622 |
-
"deep_work",
|
| 623 |
-
"admin_work",
|
| 624 |
-
"admin_work",
|
| 625 |
-
"binge_watch",
|
| 626 |
-
"learn",
|
| 627 |
-
"sleep",
|
| 628 |
-
"me_time",
|
| 629 |
-
"deep_work"
|
| 630 |
-
]
|
| 631 |
-
},
|
| 632 |
-
{
|
| 633 |
-
"profile": "workaholic_stoic",
|
| 634 |
-
"strategy": "random",
|
| 635 |
-
"seed": 0,
|
| 636 |
-
"final_score": 0.6185,
|
| 637 |
-
"total_reward": 4.07,
|
| 638 |
-
"vitality": 0.4,
|
| 639 |
-
"cognition": 0.52,
|
| 640 |
-
"progress": 0.55,
|
| 641 |
-
"serenity": 0.95,
|
| 642 |
-
"connection": 0.41,
|
| 643 |
-
"actions": [
|
| 644 |
-
"socialize",
|
| 645 |
-
"me_time",
|
| 646 |
-
"binge_watch",
|
| 647 |
-
"socialize",
|
| 648 |
-
"exercise",
|
| 649 |
-
"family_time",
|
| 650 |
-
"sleep",
|
| 651 |
-
"admin_work",
|
| 652 |
-
"meditate",
|
| 653 |
-
"binge_watch",
|
| 654 |
-
"admin_work",
|
| 655 |
-
"deep_work",
|
| 656 |
-
"sleep",
|
| 657 |
-
"meditate",
|
| 658 |
-
"sleep",
|
| 659 |
-
"family_time",
|
| 660 |
-
"exercise",
|
| 661 |
-
"deep_work",
|
| 662 |
-
"admin_work",
|
| 663 |
-
"meditate",
|
| 664 |
-
"socialize",
|
| 665 |
-
"binge_watch",
|
| 666 |
-
"exercise",
|
| 667 |
-
"meditate",
|
| 668 |
-
"learn",
|
| 669 |
-
"socialize",
|
| 670 |
-
"admin_work",
|
| 671 |
-
"sleep"
|
| 672 |
-
]
|
| 673 |
-
},
|
| 674 |
-
{
|
| 675 |
-
"profile": "workaholic_stoic",
|
| 676 |
-
"strategy": "random",
|
| 677 |
-
"seed": 1,
|
| 678 |
-
"final_score": 0.6094,
|
| 679 |
-
"total_reward": 5.39,
|
| 680 |
-
"vitality": 0.04,
|
| 681 |
-
"cognition": 0.55,
|
| 682 |
-
"progress": 0.6,
|
| 683 |
-
"serenity": 1.0,
|
| 684 |
-
"connection": 0.44,
|
| 685 |
-
"actions": [
|
| 686 |
-
"exercise",
|
| 687 |
-
"meditate",
|
| 688 |
-
"sleep",
|
| 689 |
-
"meditate",
|
| 690 |
-
"meditate",
|
| 691 |
-
"me_time",
|
| 692 |
-
"meditate",
|
| 693 |
-
"learn",
|
| 694 |
-
"meditate",
|
| 695 |
-
"meditate",
|
| 696 |
-
"socialize",
|
| 697 |
-
"socialize",
|
| 698 |
-
"deep_work",
|
| 699 |
-
"meditate",
|
| 700 |
-
"socialize",
|
| 701 |
-
"deep_work",
|
| 702 |
-
"meditate",
|
| 703 |
-
"sleep",
|
| 704 |
-
"learn",
|
| 705 |
-
"socialize",
|
| 706 |
-
"deep_work",
|
| 707 |
-
"socialize",
|
| 708 |
-
"learn",
|
| 709 |
-
"sleep",
|
| 710 |
-
"family_time",
|
| 711 |
-
"meditate",
|
| 712 |
-
"meditate",
|
| 713 |
-
"admin_work"
|
| 714 |
-
]
|
| 715 |
-
},
|
| 716 |
-
{
|
| 717 |
-
"profile": "workaholic_stoic",
|
| 718 |
-
"strategy": "random",
|
| 719 |
-
"seed": 2,
|
| 720 |
-
"final_score": 0.5782,
|
| 721 |
-
"total_reward": 7.33,
|
| 722 |
-
"vitality": 0.4,
|
| 723 |
-
"cognition": 0.23,
|
| 724 |
-
"progress": 0.88,
|
| 725 |
-
"serenity": 0.99,
|
| 726 |
-
"connection": 0.0,
|
| 727 |
-
"actions": [
|
| 728 |
-
"me_time",
|
| 729 |
-
"meditate",
|
| 730 |
-
"learn",
|
| 731 |
-
"meditate",
|
| 732 |
-
"learn",
|
| 733 |
-
"family_time",
|
| 734 |
-
"deep_work",
|
| 735 |
-
"family_time",
|
| 736 |
-
"me_time",
|
| 737 |
-
"admin_work",
|
| 738 |
-
"sleep",
|
| 739 |
-
"meditate",
|
| 740 |
-
"sleep",
|
| 741 |
-
"admin_work",
|
| 742 |
-
"meditate",
|
| 743 |
-
"me_time",
|
| 744 |
-
"sleep",
|
| 745 |
-
"sleep",
|
| 746 |
-
"binge_watch",
|
| 747 |
-
"admin_work",
|
| 748 |
-
"deep_work",
|
| 749 |
-
"admin_work",
|
| 750 |
-
"admin_work",
|
| 751 |
-
"binge_watch",
|
| 752 |
-
"learn",
|
| 753 |
-
"sleep",
|
| 754 |
-
"me_time",
|
| 755 |
-
"deep_work"
|
| 756 |
-
]
|
| 757 |
-
}
|
| 758 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -66,7 +66,8 @@ class StepRecord(BaseModel):
|
|
| 66 |
progress_delta: float = 0.0
|
| 67 |
serenity_delta: float = 0.0
|
| 68 |
connection_delta: float = 0.0
|
| 69 |
-
#
|
|
|
|
| 70 |
vitality_anomaly: float = 0.0
|
| 71 |
cognition_anomaly: float = 0.0
|
| 72 |
progress_anomaly: float = 0.0
|
|
|
|
| 66 |
progress_delta: float = 0.0
|
| 67 |
serenity_delta: float = 0.0
|
| 68 |
connection_delta: float = 0.0
|
| 69 |
+
# Per-meter anomalies: actual_delta minus expected_delta_under_neutral_profile.
|
| 70 |
+
# Surfaced to the agent in the prompt β the cleanest profile-inference signal.
|
| 71 |
vitality_anomaly: float = 0.0
|
| 72 |
cognition_anomaly: float = 0.0
|
| 73 |
progress_anomaly: float = 0.0
|
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
"""Parse logdump.txt (HF Jobs UI export) and produce a training trajectory analysis."""
|
| 2 |
-
|
| 3 |
-
import json
|
| 4 |
-
import re
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
LOG_PATH = Path("docs/logdump.txt")
|
| 8 |
-
|
| 9 |
-
# Each metric line looks like a Python dict literal β parse with eval-ish via JSON
|
| 10 |
-
# (HF prints them as Python repr, so single quotes β need to handle)
|
| 11 |
-
DICT_RE = re.compile(r"^\{'loss':.*\}$")
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def parse_dict_line(line: str) -> dict | None:
|
| 15 |
-
line = line.strip()
|
| 16 |
-
if not line.startswith("{") or "'loss'" not in line:
|
| 17 |
-
return None
|
| 18 |
-
# Convert Python single-quoted dict to JSON: replace ' with " (naive but works for our shape)
|
| 19 |
-
try:
|
| 20 |
-
py_dict = eval(line, {"__builtins__": {}}, {})
|
| 21 |
-
if isinstance(py_dict, dict):
|
| 22 |
-
return py_dict
|
| 23 |
-
except Exception:
|
| 24 |
-
return None
|
| 25 |
-
return None
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def main():
|
| 29 |
-
with open(LOG_PATH, encoding="utf-8") as f:
|
| 30 |
-
lines = f.readlines()
|
| 31 |
-
|
| 32 |
-
rows = []
|
| 33 |
-
for ln in lines:
|
| 34 |
-
d = parse_dict_line(ln)
|
| 35 |
-
if d is not None:
|
| 36 |
-
rows.append(d)
|
| 37 |
-
|
| 38 |
-
print(f"Parsed {len(rows)} metric rows")
|
| 39 |
-
if not rows:
|
| 40 |
-
return
|
| 41 |
-
|
| 42 |
-
# Snapshots at percentiles
|
| 43 |
-
n = len(rows)
|
| 44 |
-
snaps = [0, n // 8, n // 4, n // 2, 3 * n // 4, n - 1]
|
| 45 |
-
snaps = sorted(set(snaps))
|
| 46 |
-
|
| 47 |
-
metrics = [
|
| 48 |
-
("loss", lambda r: r.get("loss")),
|
| 49 |
-
("reward", lambda r: r.get("reward")),
|
| 50 |
-
("reward_std", lambda r: r.get("reward_std")),
|
| 51 |
-
("frac_zero_std", lambda r: r.get("frac_reward_zero_std")),
|
| 52 |
-
("format_valid", lambda r: r.get("rewards/format_valid/mean")),
|
| 53 |
-
("action_legal", lambda r: r.get("rewards/action_legal/mean")),
|
| 54 |
-
("env_reward", lambda r: r.get("rewards/env_reward/mean")),
|
| 55 |
-
("belief_accuracy", lambda r: r.get("rewards/belief_accuracy/mean")),
|
| 56 |
-
("kl", lambda r: r.get("kl")),
|
| 57 |
-
("compl_length", lambda r: r.get("completion_length")),
|
| 58 |
-
("grad_norm", lambda r: r.get("grad_norm")),
|
| 59 |
-
]
|
| 60 |
-
|
| 61 |
-
print()
|
| 62 |
-
header = f"{'metric':<18} " + " ".join(f"step~{s:>5}" for s in snaps)
|
| 63 |
-
print(header)
|
| 64 |
-
print("-" * len(header))
|
| 65 |
-
for label, getter in metrics:
|
| 66 |
-
vals = []
|
| 67 |
-
for s in snaps:
|
| 68 |
-
v = getter(rows[s])
|
| 69 |
-
vals.append(f"{v:+.3f}" if isinstance(v, (int, float)) else "-")
|
| 70 |
-
print(f"{label:<18} " + " ".join(f"{v:>10}" for v in vals))
|
| 71 |
-
|
| 72 |
-
# Compute trends: linear-fit slope per metric (eyeball trend direction)
|
| 73 |
-
import statistics
|
| 74 |
-
|
| 75 |
-
print()
|
| 76 |
-
print("=== Linear trend (slope) over the run β units per 100 steps ===")
|
| 77 |
-
n = len(rows)
|
| 78 |
-
xs = list(range(n))
|
| 79 |
-
for label, getter in metrics:
|
| 80 |
-
ys = [getter(r) for r in rows]
|
| 81 |
-
ys_clean = [y for y in ys if isinstance(y, (int, float))]
|
| 82 |
-
xs_clean = [x for x, y in zip(xs, ys) if isinstance(y, (int, float))]
|
| 83 |
-
if len(ys_clean) < 5:
|
| 84 |
-
continue
|
| 85 |
-
# Simple least-squares
|
| 86 |
-
x_mean = statistics.mean(xs_clean)
|
| 87 |
-
y_mean = statistics.mean(ys_clean)
|
| 88 |
-
num = sum((x - x_mean) * (y - y_mean) for x, y in zip(xs_clean, ys_clean))
|
| 89 |
-
den = sum((x - x_mean) ** 2 for x in xs_clean)
|
| 90 |
-
slope = (num / den) * 100 if den > 0 else 0
|
| 91 |
-
# Mean of last 20 vs mean of first 20 (more robust signal)
|
| 92 |
-
first_mean = statistics.mean(ys_clean[:20]) if len(ys_clean) >= 20 else float("nan")
|
| 93 |
-
last_mean = statistics.mean(ys_clean[-20:]) if len(ys_clean) >= 20 else float("nan")
|
| 94 |
-
delta = last_mean - first_mean
|
| 95 |
-
direction = "UP" if delta > 0.01 else ("DOWN" if delta < -0.01 else "FLAT")
|
| 96 |
-
print(f" {label:<18} slope/100steps={slope:+.4f} first20-mean={first_mean:+.3f} last20-mean={last_mean:+.3f} delta={delta:+.3f} [{direction}]")
|
| 97 |
-
|
| 98 |
-
# Look at action-distribution implicit signal: completion_length stability
|
| 99 |
-
print()
|
| 100 |
-
print("=== Mode-collapse warning signs ===")
|
| 101 |
-
last_50 = rows[-50:] if len(rows) >= 50 else rows
|
| 102 |
-
avg_zero_std = statistics.mean(r.get("frac_reward_zero_std", 0) for r in last_50 if isinstance(r.get("frac_reward_zero_std"), (int, float)))
|
| 103 |
-
avg_reward_std = statistics.mean(r.get("reward_std", 0) for r in last_50 if isinstance(r.get("reward_std"), (int, float)))
|
| 104 |
-
print(f"Last-50 mean frac_reward_zero_std: {avg_zero_std:.2f} (1.0 = full collapse)")
|
| 105 |
-
print(f"Last-50 mean reward_std: {avg_reward_std:.3f} (β₯0.3 = healthy variance)")
|
| 106 |
-
|
| 107 |
-
print()
|
| 108 |
-
print(f"Final step in log: {n} (iter 4 was canceled at ~step 235)")
|
| 109 |
-
if n >= 1:
|
| 110 |
-
last = rows[-1]
|
| 111 |
-
print()
|
| 112 |
-
print("=== Final-step metrics ===")
|
| 113 |
-
for k in [
|
| 114 |
-
"loss", "reward", "reward_std", "frac_reward_zero_std",
|
| 115 |
-
"rewards/format_valid/mean", "rewards/action_legal/mean",
|
| 116 |
-
"rewards/env_reward/mean", "rewards/belief_accuracy/mean",
|
| 117 |
-
"kl", "completion_length", "grad_norm",
|
| 118 |
-
]:
|
| 119 |
-
v = last.get(k)
|
| 120 |
-
print(f" {k:<35} {v}")
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
if __name__ == "__main__":
|
| 124 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Diagnostic: verify env_reward replay matches live env.
|
| 3 |
-
|
| 4 |
-
For 10 seeds, plays a random episode while recording each (action, reward).
|
| 5 |
-
Then for each step independently, replays the prefix and asserts the reward
|
| 6 |
-
from the replay matches the recorded reward within 1e-6.
|
| 7 |
-
|
| 8 |
-
Run from rhythm_env root:
|
| 9 |
-
python scripts/diagnostic_replay.py
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import os
|
| 13 |
-
import random
|
| 14 |
-
import sys
|
| 15 |
-
|
| 16 |
-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
-
|
| 18 |
-
from models import ActionType, RhythmAction
|
| 19 |
-
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS
|
| 20 |
-
from training.reward_functions import env_reward
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def play_and_record(seed: int):
|
| 24 |
-
"""Play one random episode, return list of (action_history_at_step, action_taken, reward_received)."""
|
| 25 |
-
env = RhythmEnvironment()
|
| 26 |
-
env.reset(seed=seed)
|
| 27 |
-
rng = random.Random(seed + 7777)
|
| 28 |
-
actions_so_far = []
|
| 29 |
-
records = []
|
| 30 |
-
for _ in range(MAX_STEPS):
|
| 31 |
-
action_type = rng.choice(list(ActionType))
|
| 32 |
-
history_snapshot = list(actions_so_far)
|
| 33 |
-
obs = env.step(RhythmAction(action_type=action_type))
|
| 34 |
-
records.append((history_snapshot, action_type, obs.reward))
|
| 35 |
-
actions_so_far.append(action_type.value)
|
| 36 |
-
if obs.done:
|
| 37 |
-
break
|
| 38 |
-
return records
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def main():
|
| 42 |
-
print("=" * 70)
|
| 43 |
-
print("env_reward replay diagnostic")
|
| 44 |
-
print("=" * 70)
|
| 45 |
-
|
| 46 |
-
total_steps = 0
|
| 47 |
-
mismatches = 0
|
| 48 |
-
max_diff = 0.0
|
| 49 |
-
|
| 50 |
-
for seed in range(10):
|
| 51 |
-
records = play_and_record(seed)
|
| 52 |
-
for history, action_type, recorded_reward in records:
|
| 53 |
-
# Build a fake completion that yields this exact action.
|
| 54 |
-
completions = [[{"content": action_type.value.upper()}]]
|
| 55 |
-
|
| 56 |
-
scores = env_reward(
|
| 57 |
-
completions,
|
| 58 |
-
seed=[seed],
|
| 59 |
-
step_index=[len(history)],
|
| 60 |
-
action_history=[history],
|
| 61 |
-
)
|
| 62 |
-
replay_reward = scores[0]
|
| 63 |
-
diff = abs(replay_reward - recorded_reward)
|
| 64 |
-
max_diff = max(max_diff, diff)
|
| 65 |
-
if diff > 1e-6:
|
| 66 |
-
mismatches += 1
|
| 67 |
-
if mismatches <= 3:
|
| 68 |
-
print(
|
| 69 |
-
f" MISMATCH seed={seed} step={len(history)} "
|
| 70 |
-
f"action={action_type.value} recorded={recorded_reward:.6f} "
|
| 71 |
-
f"replay={replay_reward:.6f} diff={diff:.6f}"
|
| 72 |
-
)
|
| 73 |
-
total_steps += 1
|
| 74 |
-
|
| 75 |
-
print()
|
| 76 |
-
print(f"Total steps checked: {total_steps}")
|
| 77 |
-
print(f"Mismatches (>1e-6): {mismatches}")
|
| 78 |
-
print(f"Max diff: {max_diff:.6e}")
|
| 79 |
-
print()
|
| 80 |
-
if mismatches == 0:
|
| 81 |
-
print("PASS: env_reward replay is deterministic and matches live env.")
|
| 82 |
-
return 0
|
| 83 |
-
else:
|
| 84 |
-
print("FAIL: env_reward replay diverges from live env.")
|
| 85 |
-
print("Likely cause: non-determinism in env (RNG state, profile selection, etc.)")
|
| 86 |
-
return 1
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
if __name__ == "__main__":
|
| 90 |
-
sys.exit(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate teacher trajectories for Algorithm Distillation.
|
| 3 |
+
|
| 4 |
+
For each seed, plays one full RhythmEnv episode where the action at each step
|
| 5 |
+
is chosen by a teacher LLM (gpt-5.4 via Azure OpenAI). The teacher is prompted
|
| 6 |
+
to emit `<reasoning>...</reasoning>` followed by `S M W ACTION_NAME` on a final
|
| 7 |
+
line. We parse the answer line, step the env, save the full (prompt, response,
|
| 8 |
+
action, reward) tuple to JSONL, and aggregate per-episode metrics for gating.
|
| 9 |
+
|
| 10 |
+
Required env vars (no secrets in code):
|
| 11 |
+
AZURE_OPENAI_ENDPOINT e.g. https://metahackathon-resource.cognitiveservices.azure.com/
|
| 12 |
+
AZURE_OPENAI_API_KEY your Azure OpenAI key (do NOT paste in chat)
|
| 13 |
+
AZURE_OPENAI_DEPLOYMENT the deployment name you chose, e.g. gpt-5.4
|
| 14 |
+
AZURE_OPENAI_API_VERSION e.g. 2024-12-01-preview (default if unset)
|
| 15 |
+
|
| 16 |
+
Usage from rhythm_env root:
|
| 17 |
+
|
| 18 |
+
# Stage 1a: 30-episode validation (~$3-5)
|
| 19 |
+
python scripts/generate_teacher_trajectories.py \
|
| 20 |
+
--seeds 0-29 \
|
| 21 |
+
--output data/teacher_30ep_validation.jsonl \
|
| 22 |
+
--concurrency 3
|
| 23 |
+
|
| 24 |
+
# Stage 1b: scale to 150 episodes (~$15-20)
|
| 25 |
+
python scripts/generate_teacher_trajectories.py \
|
| 26 |
+
--seeds 0-99 \
|
| 27 |
+
--output data/teacher_150ep_indist.jsonl \
|
| 28 |
+
--concurrency 5
|
| 29 |
+
python scripts/generate_teacher_trajectories.py \
|
| 30 |
+
--seeds 10000-10049 \
|
| 31 |
+
--output data/teacher_150ep_ood.jsonl \
|
| 32 |
+
--concurrency 5
|
| 33 |
+
|
| 34 |
+
The script prints PASS/FAIL gate verdicts at the end so you can decide whether
|
| 35 |
+
to scale or fix the teacher prompt before spending more.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import asyncio
|
| 40 |
+
import json
|
| 41 |
+
import os
|
| 42 |
+
import re
|
| 43 |
+
import sys
|
| 44 |
+
import time
|
| 45 |
+
from collections import Counter
|
| 46 |
+
from pathlib import Path
|
| 47 |
+
|
| 48 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 49 |
+
|
| 50 |
+
# Load .env (repo root) before reading os.environ so credentials don't have
|
| 51 |
+
# to be exported in the shell. The .env file is in .gitignore.
|
| 52 |
+
try:
|
| 53 |
+
from dotenv import load_dotenv
|
| 54 |
+
_ENV_PATH = Path(__file__).resolve().parent.parent / ".env"
|
| 55 |
+
if _ENV_PATH.exists():
|
| 56 |
+
load_dotenv(_ENV_PATH)
|
| 57 |
+
except ImportError:
|
| 58 |
+
pass # dotenv not installed β fall back to whatever's in the shell
|
| 59 |
+
|
| 60 |
+
from openai import AsyncAzureOpenAI
|
| 61 |
+
from openai import APIError, RateLimitError, APIConnectionError, APITimeoutError
|
| 62 |
+
|
| 63 |
+
from models import ActionType, RhythmAction
|
| 64 |
+
from server.rhythm_environment import (
|
| 65 |
+
MAX_STEPS,
|
| 66 |
+
RhythmEnvironment,
|
| 67 |
+
)
|
| 68 |
+
from training.dataset import format_observation_prompt
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Teacher system prompt
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
# The student will eventually be SFT'd to match this contract: emit a
|
| 75 |
+
# <reasoning>...</reasoning> block then a final answer line `S M W ACTION_NAME`.
|
| 76 |
+
# Keep this in sync with whatever SYSTEM_PROMPT the SFT'd student will use.
|
| 77 |
+
TEACHER_SYSTEM_PROMPT = """You are a life-management agent helping a person whose preferences are HIDDEN.
|
| 78 |
+
You see 5 life meters and a rolling history of recent steps. The same action
|
| 79 |
+
affects different people differently β you must INFER who you're helping from
|
| 80 |
+
rewards, meter changes, and per-meter ANOMALY signals.
|
| 81 |
+
|
| 82 |
+
Each step, do TWO things:
|
| 83 |
+
|
| 84 |
+
1. Reason briefly about what the observations imply about the person.
|
| 85 |
+
Focus on:
|
| 86 |
+
- Anomalies (actual delta vs neutral-profile expectation): big positive
|
| 87 |
+
social_serenity / connection responses β high S; big morning cognition
|
| 88 |
+
gains β high M; productive work giving vitality back β high W
|
| 89 |
+
- Current meter state: any meter under 0.15 needs urgent recovery
|
| 90 |
+
- What action best fits BOTH the inferred profile and the current state
|
| 91 |
+
|
| 92 |
+
2. Output your final answer on the LAST line in this exact format:
|
| 93 |
+
S M W ACTION_NAME
|
| 94 |
+
where S, M, W are belief digits 0-9 (0=low, 9=high) representing your best
|
| 95 |
+
estimate of social_pref, morning_pref, work_pref. ACTION_NAME is one of:
|
| 96 |
+
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE, FAMILY_TIME,
|
| 97 |
+
SOCIALIZE, ME_TIME, BINGE_WATCH
|
| 98 |
+
|
| 99 |
+
Wrap your reasoning in <reasoning>...</reasoning> tags. Keep reasoning under
|
| 100 |
+
120 tokens. The final answer line MUST be the last line of your response.
|
| 101 |
+
|
| 102 |
+
Beliefβaction quick reference:
|
| 103 |
+
- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply
|
| 104 |
+
- High M (morning person): DEEP_WORK / LEARN in early slots gets bonus cognition
|
| 105 |
+
- High W (workaholic): DEEP_WORK, LEARN drive progress and may energize
|
| 106 |
+
- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE
|
| 107 |
+
- Low M (night owl): DEEP_WORK / LEARN in evening/night slots
|
| 108 |
+
- Watch crashes: any meter under 0.10 = -0.30 penalty per crashed meter
|
| 109 |
+
- Connection decays passively β actively maintain via SOCIALIZE/FAMILY_TIME
|
| 110 |
+
- Don't repeat the same action 3+ times in a row β repetition penalty applies
|
| 111 |
+
|
| 112 |
+
Strategy: probe varied actions in the first ~5 steps to gather profile evidence,
|
| 113 |
+
then exploit your sharpened belief by picking actions that match the inferred
|
| 114 |
+
profile + current meter state.
|
| 115 |
+
|
| 116 |
+
Example output:
|
| 117 |
+
<reasoning>
|
| 118 |
+
Last step's socialize gave V-0.12 (anom -0.06, much worse than neutral) β high
|
| 119 |
+
social drain, suggests low S. Morning DEEP_WORK earlier gave bonus cognition
|
| 120 |
+
(anom +0.04) β high M. Vitality at 0.6 still ok, serenity dropping. With low S +
|
| 121 |
+
high M, MEDITATE is the recovery play that fits.
|
| 122 |
+
</reasoning>
|
| 123 |
+
2 8 5 MEDITATE"""
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
# Answer parsing β find the LAST `S M W ACTION_NAME` pattern in the response
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
VALID_ACTIONS = [at.value.upper() for at in ActionType]
|
| 130 |
+
ANSWER_PATTERN = re.compile(
|
| 131 |
+
r'(\d)\s+(\d)\s+(\d)\s+(' + '|'.join(VALID_ACTIONS) + r')\b',
|
| 132 |
+
re.IGNORECASE,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def parse_teacher_response(text: str):
|
| 137 |
+
"""Extract (action_type, belief_vector, raw_match) from teacher output.
|
| 138 |
+
|
| 139 |
+
Returns (None, None, None) if no answer line is parseable.
|
| 140 |
+
"""
|
| 141 |
+
if not text:
|
| 142 |
+
return None, None, None
|
| 143 |
+
matches = list(ANSWER_PATTERN.finditer(text))
|
| 144 |
+
if not matches:
|
| 145 |
+
return None, None, None
|
| 146 |
+
last = matches[-1]
|
| 147 |
+
s, m, w, action_name = last.groups()
|
| 148 |
+
try:
|
| 149 |
+
belief = [int(s) / 9.0, int(m) / 9.0, int(w) / 9.0]
|
| 150 |
+
action = ActionType(action_name.lower())
|
| 151 |
+
return action, belief, last.group(0)
|
| 152 |
+
except (ValueError, KeyError):
|
| 153 |
+
return None, None, None
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ---------------------------------------------------------------------------
|
| 157 |
+
# Async API calls with retry
|
| 158 |
+
# ---------------------------------------------------------------------------
|
| 159 |
+
async def call_teacher(
|
| 160 |
+
client: AsyncAzureOpenAI,
|
| 161 |
+
deployment: str,
|
| 162 |
+
user_prompt: str,
|
| 163 |
+
temperature: float = 0.5,
|
| 164 |
+
max_completion_tokens: int = 400,
|
| 165 |
+
max_retries: int = 4,
|
| 166 |
+
) -> str:
|
| 167 |
+
"""Call the teacher with retries on transient errors. Returns response text."""
|
| 168 |
+
last_err: Exception | None = None
|
| 169 |
+
for attempt in range(max_retries):
|
| 170 |
+
try:
|
| 171 |
+
resp = await client.chat.completions.create(
|
| 172 |
+
model=deployment,
|
| 173 |
+
messages=[
|
| 174 |
+
{"role": "system", "content": TEACHER_SYSTEM_PROMPT},
|
| 175 |
+
{"role": "user", "content": user_prompt},
|
| 176 |
+
],
|
| 177 |
+
temperature=temperature,
|
| 178 |
+
max_completion_tokens=max_completion_tokens,
|
| 179 |
+
)
|
| 180 |
+
return resp.choices[0].message.content or ""
|
| 181 |
+
except (RateLimitError, APIConnectionError, APITimeoutError) as e:
|
| 182 |
+
last_err = e
|
| 183 |
+
wait = min(60, 2 ** attempt)
|
| 184 |
+
await asyncio.sleep(wait)
|
| 185 |
+
except APIError as e:
|
| 186 |
+
# Non-transient API error β log and bail (don't waste retries)
|
| 187 |
+
last_err = e
|
| 188 |
+
break
|
| 189 |
+
raise RuntimeError(f"Teacher call failed after {max_retries} retries: {last_err}")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
# Episode rollout
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
+
async def play_episode(
|
| 196 |
+
client: AsyncAzureOpenAI,
|
| 197 |
+
deployment: str,
|
| 198 |
+
seed: int,
|
| 199 |
+
) -> tuple[list[dict], dict]:
|
| 200 |
+
"""Run a full episode with the teacher. Returns (per-step rows, summary)."""
|
| 201 |
+
env = RhythmEnvironment()
|
| 202 |
+
obs = env.reset(seed=seed)
|
| 203 |
+
true_belief = env.get_belief_target()
|
| 204 |
+
profile_name = env.state.profile_name
|
| 205 |
+
|
| 206 |
+
step_rows: list[dict] = []
|
| 207 |
+
actions_taken: list[str] = []
|
| 208 |
+
rewards: list[float] = []
|
| 209 |
+
final_belief: list[float] | None = None
|
| 210 |
+
|
| 211 |
+
for step_idx in range(MAX_STEPS):
|
| 212 |
+
if obs.done:
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
user_prompt = format_observation_prompt(obs)
|
| 216 |
+
try:
|
| 217 |
+
teacher_resp = await call_teacher(client, deployment, user_prompt)
|
| 218 |
+
except RuntimeError as e:
|
| 219 |
+
# Hard failure β abort this episode rather than corrupt the dataset
|
| 220 |
+
return step_rows, {
|
| 221 |
+
"seed": seed,
|
| 222 |
+
"profile_name": profile_name,
|
| 223 |
+
"true_belief": [round(x, 3) for x in true_belief],
|
| 224 |
+
"final_belief": [round(x, 3) for x in final_belief] if final_belief else None,
|
| 225 |
+
"belief_mae": None,
|
| 226 |
+
"final_score": 0.0,
|
| 227 |
+
"total_reward": round(sum(rewards), 2),
|
| 228 |
+
"n_steps": len(step_rows),
|
| 229 |
+
"actions": actions_taken,
|
| 230 |
+
"action_distribution": dict(Counter(actions_taken)),
|
| 231 |
+
"n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]),
|
| 232 |
+
"aborted": True,
|
| 233 |
+
"error": str(e),
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
action, belief, raw_match = parse_teacher_response(teacher_resp)
|
| 237 |
+
parse_failed = action is None
|
| 238 |
+
if parse_failed:
|
| 239 |
+
# Fallback: SLEEP keeps the episode alive without skewing exploration
|
| 240 |
+
action = ActionType.SLEEP
|
| 241 |
+
belief = [0.5, 0.5, 0.5]
|
| 242 |
+
else:
|
| 243 |
+
final_belief = belief
|
| 244 |
+
|
| 245 |
+
# Tell the env about the emitted belief so the grader's belief_accuracy
|
| 246 |
+
# component scores it. Without this call, final_score logged below is
|
| 247 |
+
# artificially low (belief component scores 0 even when the teacher
|
| 248 |
+
# actually emitted a belief).
|
| 249 |
+
env.record_belief(belief)
|
| 250 |
+
|
| 251 |
+
rhythm_action = RhythmAction(action_type=action)
|
| 252 |
+
actions_taken.append(action.value)
|
| 253 |
+
next_obs = env.step(rhythm_action)
|
| 254 |
+
rewards.append(next_obs.reward)
|
| 255 |
+
|
| 256 |
+
step_rows.append({
|
| 257 |
+
"seed": seed,
|
| 258 |
+
"step": step_idx,
|
| 259 |
+
"profile_name": profile_name,
|
| 260 |
+
"user_prompt": user_prompt,
|
| 261 |
+
"teacher_response": teacher_resp,
|
| 262 |
+
"parsed_action": action.value,
|
| 263 |
+
"parsed_belief": belief,
|
| 264 |
+
"answer_match": raw_match,
|
| 265 |
+
"env_reward": round(next_obs.reward, 4),
|
| 266 |
+
"parse_failed": parse_failed,
|
| 267 |
+
"true_belief": [round(x, 3) for x in true_belief],
|
| 268 |
+
})
|
| 269 |
+
|
| 270 |
+
obs = next_obs
|
| 271 |
+
|
| 272 |
+
final_score = obs.reward_breakdown.get("final_score", 0.0)
|
| 273 |
+
belief_mae = (
|
| 274 |
+
sum(abs(b - t) for b, t in zip(final_belief, true_belief)) / 3.0
|
| 275 |
+
if final_belief is not None else None
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
return step_rows, {
|
| 279 |
+
"seed": seed,
|
| 280 |
+
"profile_name": profile_name,
|
| 281 |
+
"true_belief": [round(x, 3) for x in true_belief],
|
| 282 |
+
"final_belief": [round(x, 3) for x in final_belief] if final_belief else None,
|
| 283 |
+
"belief_mae": round(belief_mae, 4) if belief_mae is not None else None,
|
| 284 |
+
"final_score": round(final_score, 4),
|
| 285 |
+
"total_reward": round(sum(rewards), 2),
|
| 286 |
+
"n_steps": len(step_rows),
|
| 287 |
+
"actions": actions_taken,
|
| 288 |
+
"action_distribution": dict(Counter(actions_taken)),
|
| 289 |
+
"n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]),
|
| 290 |
+
"aborted": False,
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ---------------------------------------------------------------------------
|
| 295 |
+
# Resume helpers
|
| 296 |
+
# ---------------------------------------------------------------------------
|
| 297 |
+
def already_completed_seeds(jsonl_path: Path) -> set[int]:
|
| 298 |
+
"""Seeds whose final step (MAX_STEPS - 1 = 27) is already in the file."""
|
| 299 |
+
if not jsonl_path.exists():
|
| 300 |
+
return set()
|
| 301 |
+
seed_max_step: dict[int, int] = {}
|
| 302 |
+
with open(jsonl_path) as f:
|
| 303 |
+
for line in f:
|
| 304 |
+
try:
|
| 305 |
+
row = json.loads(line)
|
| 306 |
+
except json.JSONDecodeError:
|
| 307 |
+
continue
|
| 308 |
+
sd = row.get("seed")
|
| 309 |
+
st = row.get("step", -1)
|
| 310 |
+
if sd is None:
|
| 311 |
+
continue
|
| 312 |
+
seed_max_step[sd] = max(seed_max_step.get(sd, -1), st)
|
| 313 |
+
return {s for s, mx in seed_max_step.items() if mx >= MAX_STEPS - 1}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# ---------------------------------------------------------------------------
|
| 317 |
+
# Main
|
| 318 |
+
# ---------------------------------------------------------------------------
|
| 319 |
+
def parse_seed_arg(seed_str: str) -> list[int]:
|
| 320 |
+
if "-" in seed_str and "," not in seed_str:
|
| 321 |
+
lo, hi = seed_str.split("-")
|
| 322 |
+
return list(range(int(lo), int(hi) + 1))
|
| 323 |
+
return [int(s.strip()) for s in seed_str.split(",") if s.strip()]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
async def main() -> None:
|
| 327 |
+
parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
|
| 328 |
+
parser.add_argument("--seeds", type=str, required=True,
|
| 329 |
+
help="Seed range '0-29' or comma list '0,1,5'")
|
| 330 |
+
parser.add_argument("--output", type=str, required=True,
|
| 331 |
+
help="Output JSONL path for per-step trajectories")
|
| 332 |
+
parser.add_argument("--summary", type=str, default=None,
|
| 333 |
+
help="Output JSON path for episode summaries (default: <output>.summary.json)")
|
| 334 |
+
parser.add_argument("--concurrency", type=int, default=3,
|
| 335 |
+
help="Episodes to run concurrently (default 3; 1500 RPM allows up to ~5)")
|
| 336 |
+
parser.add_argument("--temperature", type=float, default=0.5,
|
| 337 |
+
help="Teacher sampling temperature (default 0.5; lower = more consistent)")
|
| 338 |
+
parser.add_argument("--no-resume", action="store_true",
|
| 339 |
+
help="Do not skip seeds already in the output file")
|
| 340 |
+
args = parser.parse_args()
|
| 341 |
+
|
| 342 |
+
seeds = parse_seed_arg(args.seeds)
|
| 343 |
+
output_path = Path(args.output)
|
| 344 |
+
summary_path = Path(args.summary) if args.summary else output_path.with_suffix(".summary.json")
|
| 345 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 346 |
+
|
| 347 |
+
if not args.no_resume:
|
| 348 |
+
completed = already_completed_seeds(output_path)
|
| 349 |
+
if completed:
|
| 350 |
+
print(f"Resume: {len(completed)} seeds already complete; "
|
| 351 |
+
f"{len(seeds) - len(completed & set(seeds))} remaining of {len(seeds)}")
|
| 352 |
+
seeds = [s for s in seeds if s not in completed]
|
| 353 |
+
|
| 354 |
+
# Azure config (read from env so secrets never touch the repo)
|
| 355 |
+
try:
|
| 356 |
+
endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
|
| 357 |
+
api_key = os.environ["AZURE_OPENAI_API_KEY"]
|
| 358 |
+
deployment = os.environ["AZURE_OPENAI_DEPLOYMENT"]
|
| 359 |
+
except KeyError as e:
|
| 360 |
+
sys.exit(f"ERROR: missing env var {e}. Set AZURE_OPENAI_ENDPOINT, "
|
| 361 |
+
f"AZURE_OPENAI_API_KEY, AZURE_OPENAI_DEPLOYMENT.")
|
| 362 |
+
api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview")
|
| 363 |
+
|
| 364 |
+
print(f"Endpoint: {endpoint}")
|
| 365 |
+
print(f"Deployment: {deployment}")
|
| 366 |
+
print(f"API version: {api_version}")
|
| 367 |
+
print(f"Seeds: {len(seeds)} (concurrency={args.concurrency}, temp={args.temperature})")
|
| 368 |
+
print(f"Output: {output_path}")
|
| 369 |
+
print(f"Summary: {summary_path}")
|
| 370 |
+
print()
|
| 371 |
+
|
| 372 |
+
if not seeds:
|
| 373 |
+
print("No seeds to process. Exiting.")
|
| 374 |
+
return
|
| 375 |
+
|
| 376 |
+
client = AsyncAzureOpenAI(
|
| 377 |
+
azure_endpoint=endpoint,
|
| 378 |
+
api_key=api_key,
|
| 379 |
+
api_version=api_version,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
sem = asyncio.Semaphore(args.concurrency)
|
| 383 |
+
file_lock = asyncio.Lock()
|
| 384 |
+
summaries: list[dict] = []
|
| 385 |
+
|
| 386 |
+
async def run_one(seed: int) -> dict | None:
|
| 387 |
+
async with sem:
|
| 388 |
+
t0 = time.time()
|
| 389 |
+
print(f" [seed {seed}] starting", flush=True)
|
| 390 |
+
try:
|
| 391 |
+
step_rows, summary = await play_episode(client, deployment, seed)
|
| 392 |
+
except Exception as e:
|
| 393 |
+
print(f" [seed {seed}] CRASHED: {e}", flush=True)
|
| 394 |
+
return None
|
| 395 |
+
# Append per-step rows atomically (prevents interleaved writes)
|
| 396 |
+
async with file_lock:
|
| 397 |
+
with open(output_path, "a") as f:
|
| 398 |
+
for row in step_rows:
|
| 399 |
+
f.write(json.dumps(row) + "\n")
|
| 400 |
+
dt = time.time() - t0
|
| 401 |
+
mae_str = f"{summary['belief_mae']:.3f}" if summary['belief_mae'] is not None else "n/a"
|
| 402 |
+
print(f" [seed {seed}] done in {dt:.1f}s: "
|
| 403 |
+
f"final={summary['final_score']:.3f} mae={mae_str} "
|
| 404 |
+
f"unique_actions={len(summary['action_distribution'])} "
|
| 405 |
+
f"parse_fails={summary['n_parse_failures']}", flush=True)
|
| 406 |
+
return summary
|
| 407 |
+
|
| 408 |
+
tasks = [run_one(s) for s in seeds]
|
| 409 |
+
results = await asyncio.gather(*tasks)
|
| 410 |
+
summaries = [r for r in results if r is not None]
|
| 411 |
+
|
| 412 |
+
# Merge with any prior summaries (for resume)
|
| 413 |
+
prior_summaries: list[dict] = []
|
| 414 |
+
if summary_path.exists() and not args.no_resume:
|
| 415 |
+
try:
|
| 416 |
+
with open(summary_path) as f:
|
| 417 |
+
prior_summaries = json.load(f).get("episodes", [])
|
| 418 |
+
except (json.JSONDecodeError, KeyError):
|
| 419 |
+
prior_summaries = []
|
| 420 |
+
seen = {s["seed"] for s in summaries}
|
| 421 |
+
summaries = summaries + [s for s in prior_summaries if s["seed"] not in seen]
|
| 422 |
+
|
| 423 |
+
# Aggregate
|
| 424 |
+
n = len(summaries)
|
| 425 |
+
if n == 0:
|
| 426 |
+
print("No episodes completed.")
|
| 427 |
+
return
|
| 428 |
+
|
| 429 |
+
valid = [s for s in summaries if not s.get("aborted")]
|
| 430 |
+
avg_score = sum(s["final_score"] for s in valid) / max(len(valid), 1)
|
| 431 |
+
valid_mae = [s["belief_mae"] for s in valid if s["belief_mae"] is not None]
|
| 432 |
+
avg_mae = sum(valid_mae) / len(valid_mae) if valid_mae else None
|
| 433 |
+
all_actions: Counter = Counter()
|
| 434 |
+
for s in valid:
|
| 435 |
+
all_actions.update(s["action_distribution"])
|
| 436 |
+
n_unique = len(all_actions)
|
| 437 |
+
n_parse_fails = sum(s["n_parse_failures"] for s in valid)
|
| 438 |
+
n_aborted = sum(1 for s in summaries if s.get("aborted"))
|
| 439 |
+
|
| 440 |
+
summary_blob = {
|
| 441 |
+
"n_episodes": n,
|
| 442 |
+
"n_aborted": n_aborted,
|
| 443 |
+
"avg_final_score": round(avg_score, 4),
|
| 444 |
+
"avg_belief_mae": round(avg_mae, 4) if avg_mae is not None else None,
|
| 445 |
+
"n_unique_actions_overall": n_unique,
|
| 446 |
+
"action_distribution_overall": dict(all_actions),
|
| 447 |
+
"n_parse_failures_total": n_parse_fails,
|
| 448 |
+
"deployment": deployment,
|
| 449 |
+
"api_version": api_version,
|
| 450 |
+
"episodes": summaries,
|
| 451 |
+
}
|
| 452 |
+
with open(summary_path, "w") as f:
|
| 453 |
+
json.dump(summary_blob, f, indent=2)
|
| 454 |
+
|
| 455 |
+
# Gates
|
| 456 |
+
BAR_HEURISTIC = 0.587
|
| 457 |
+
BAR_GATE_SCORE = 0.65
|
| 458 |
+
BAR_GATE_MAE = 0.20
|
| 459 |
+
BAR_GATE_ACTIONS = 6
|
| 460 |
+
|
| 461 |
+
print()
|
| 462 |
+
print("=" * 72)
|
| 463 |
+
print("BATCH SUMMARY")
|
| 464 |
+
print("=" * 72)
|
| 465 |
+
print(f"Episodes completed: {n} (aborted: {n_aborted})")
|
| 466 |
+
print(f"Avg final_score: {avg_score:.4f} "
|
| 467 |
+
f"(heuristic baseline: {BAR_HEURISTIC}, random: 0.516)")
|
| 468 |
+
if avg_mae is not None:
|
| 469 |
+
print(f"Avg belief MAE: {avg_mae:.4f} (lower is better)")
|
| 470 |
+
print(f"Unique actions: {n_unique} of 10")
|
| 471 |
+
print(f"Parse failures: {n_parse_fails} (across all step calls)")
|
| 472 |
+
print()
|
| 473 |
+
print("VALIDATION GATES:")
|
| 474 |
+
g_score = avg_score >= BAR_GATE_SCORE
|
| 475 |
+
g_mae = avg_mae is not None and avg_mae < BAR_GATE_MAE
|
| 476 |
+
g_actions = n_unique >= BAR_GATE_ACTIONS
|
| 477 |
+
g_parse = n_parse_fails < 0.05 * n * MAX_STEPS # < 5% parse failure rate
|
| 478 |
+
print(f" [{'PASS' if g_score else 'FAIL'}] avg_final_score >= {BAR_GATE_SCORE}: "
|
| 479 |
+
f"{avg_score:.3f}")
|
| 480 |
+
mae_disp = f"{avg_mae:.3f}" if avg_mae is not None else "n/a"
|
| 481 |
+
print(f" [{'PASS' if g_mae else 'FAIL'}] avg_belief_mae < {BAR_GATE_MAE}: {mae_disp}")
|
| 482 |
+
print(f" [{'PASS' if g_actions else 'FAIL'}] unique_actions >= {BAR_GATE_ACTIONS}: "
|
| 483 |
+
f"{n_unique}")
|
| 484 |
+
print(f" [{'PASS' if g_parse else 'FAIL'}] parse_failures < 5% of calls: "
|
| 485 |
+
f"{n_parse_fails}/{n * MAX_STEPS}")
|
| 486 |
+
print()
|
| 487 |
+
if g_score and g_mae and g_actions and g_parse:
|
| 488 |
+
print("ALL GATES PASS β safe to scale to production batch.")
|
| 489 |
+
else:
|
| 490 |
+
print("ONE OR MORE GATES FAILED β investigate before scaling.")
|
| 491 |
+
if not g_score:
|
| 492 |
+
print(" -> Teacher quality too low. Consider escalating model "
|
| 493 |
+
"(e.g. gpt-5-pro) or refining the prompt.")
|
| 494 |
+
if not g_mae:
|
| 495 |
+
print(" -> Teacher's beliefs aren't tracking the true profile. "
|
| 496 |
+
"Check anomaly visibility in observation prompt.")
|
| 497 |
+
if not g_actions:
|
| 498 |
+
print(" -> Teacher converged on a narrow action set. Encourage "
|
| 499 |
+
"exploration in the prompt.")
|
| 500 |
+
if not g_parse:
|
| 501 |
+
print(" -> Many responses didn't end with the answer pattern. "
|
| 502 |
+
"Strengthen format instruction in the system prompt.")
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
asyncio.run(main())
|
|
@@ -1,121 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pipeline dry-run: validate the full meta-RL reward stack end-to-end without GPU.
|
| 3 |
-
|
| 4 |
-
Generates a small dataset, synthesizes completions of varying quality
|
| 5 |
-
(random valid, perfect, garbage, action-only, action+belief, etc.), and
|
| 6 |
-
runs all 4 reward functions. Reports score distributions and prompt sizes.
|
| 7 |
-
|
| 8 |
-
This is the local Gate 2 smoke check β proves the dataset, parser, and
|
| 9 |
-
reward stack are internally consistent before kicking off real training.
|
| 10 |
-
|
| 11 |
-
Run from rhythm_env root:
|
| 12 |
-
python scripts/pipeline_dryrun.py
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
import os
|
| 16 |
-
import random
|
| 17 |
-
import sys
|
| 18 |
-
|
| 19 |
-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 20 |
-
|
| 21 |
-
from training.dataset import generate_dataset
|
| 22 |
-
from training.reward_functions import (
|
| 23 |
-
extract_action_and_belief,
|
| 24 |
-
format_valid,
|
| 25 |
-
action_legal,
|
| 26 |
-
env_reward,
|
| 27 |
-
belief_accuracy,
|
| 28 |
-
)
|
| 29 |
-
from models import ActionType
|
| 30 |
-
from server.rhythm_environment import sample_profile, profile_to_belief_vector
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def synth_completion(prompt_seed: int, kind: str) -> str:
|
| 34 |
-
"""Synthesize a completion of a given quality."""
|
| 35 |
-
rng = random.Random(prompt_seed)
|
| 36 |
-
actions = list(ActionType)
|
| 37 |
-
action_str = rng.choice(actions).value.upper()
|
| 38 |
-
s, m, w = rng.randint(0, 9), rng.randint(0, 9), rng.randint(0, 9)
|
| 39 |
-
if kind == "perfect":
|
| 40 |
-
# Perfect belief means matching the profile
|
| 41 |
-
true = profile_to_belief_vector(sample_profile(prompt_seed))
|
| 42 |
-
s = round(true[0] * 9)
|
| 43 |
-
m = round(true[1] * 9)
|
| 44 |
-
w = round(true[2] * 9)
|
| 45 |
-
return f"{action_str} {s} {m} {w}"
|
| 46 |
-
if kind == "good":
|
| 47 |
-
return f"{action_str} {s} {m} {w}"
|
| 48 |
-
if kind == "action_only":
|
| 49 |
-
return action_str
|
| 50 |
-
if kind == "garbage":
|
| 51 |
-
return "I don't know what to do here"
|
| 52 |
-
if kind == "verbose":
|
| 53 |
-
return f"My choice is {action_str} with belief {s} {m} {w} based on the rewards I see."
|
| 54 |
-
if kind == "wrong_belief":
|
| 55 |
-
# Output opposite of true belief
|
| 56 |
-
true = profile_to_belief_vector(sample_profile(prompt_seed))
|
| 57 |
-
s = round((1 - true[0]) * 9)
|
| 58 |
-
m = round((1 - true[1]) * 9)
|
| 59 |
-
w = round((1 - true[2]) * 9)
|
| 60 |
-
return f"{action_str} {s} {m} {w}"
|
| 61 |
-
return action_str
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def main():
|
| 65 |
-
print("=" * 70)
|
| 66 |
-
print("Pipeline Dry-Run")
|
| 67 |
-
print("=" * 70)
|
| 68 |
-
|
| 69 |
-
# Generate a small dataset (continuous profiles, 10% hint)
|
| 70 |
-
samples = generate_dataset(
|
| 71 |
-
num_episodes=20,
|
| 72 |
-
strategy="mixed",
|
| 73 |
-
max_samples=80,
|
| 74 |
-
profile_mode="continuous",
|
| 75 |
-
hint_fraction=0.1,
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
# Prompt size analysis
|
| 79 |
-
sizes = [len(s["prompt"][0]["content"]) + len(s["prompt"][1]["content"]) for s in samples]
|
| 80 |
-
print(f"\nPrompt sizes (chars): min={min(sizes)}, max={max(sizes)}, mean={sum(sizes)/len(sizes):.0f}")
|
| 81 |
-
# Rough token estimate: ~4 chars per token
|
| 82 |
-
print(f"Estimated tokens: min={min(sizes)//4}, max={max(sizes)//4}, mean={sum(sizes)//len(sizes)//4}")
|
| 83 |
-
|
| 84 |
-
print("\n" + "=" * 70)
|
| 85 |
-
print("Reward distributions across completion kinds")
|
| 86 |
-
print("=" * 70)
|
| 87 |
-
|
| 88 |
-
# For each completion kind, generate completions for first N samples
|
| 89 |
-
kinds = ["perfect", "good", "action_only", "garbage", "verbose", "wrong_belief"]
|
| 90 |
-
n = 30
|
| 91 |
-
|
| 92 |
-
sub = samples[:n]
|
| 93 |
-
seeds_col = [s["seed"] for s in sub]
|
| 94 |
-
history_col = [s["action_history"] for s in sub]
|
| 95 |
-
mode_col = [s["profile_mode"] for s in sub]
|
| 96 |
-
|
| 97 |
-
print(f"\n{'kind':<14} | {'fmt':>6} {'leg':>6} {'env':>6} {'bel':>6} | {'TOTAL':>6}")
|
| 98 |
-
print("-" * 60)
|
| 99 |
-
for kind in kinds:
|
| 100 |
-
completions = [[{"content": synth_completion(s["seed"], kind)}] for s in sub]
|
| 101 |
-
f_scores = format_valid(completions)
|
| 102 |
-
l_scores = action_legal(completions)
|
| 103 |
-
e_scores = env_reward(completions, seed=seeds_col, action_history=history_col, profile_mode=mode_col)
|
| 104 |
-
b_scores = belief_accuracy(completions, seed=seeds_col, action_history=history_col, profile_mode=mode_col)
|
| 105 |
-
f_avg = sum(f_scores) / len(f_scores)
|
| 106 |
-
l_avg = sum(l_scores) / len(l_scores)
|
| 107 |
-
e_avg = sum(e_scores) / len(e_scores)
|
| 108 |
-
b_avg = sum(b_scores) / len(b_scores)
|
| 109 |
-
total = f_avg + l_avg + e_avg + b_avg
|
| 110 |
-
print(f"{kind:<14} | {f_avg:+6.2f} {l_avg:+6.2f} {e_avg:+6.2f} {b_avg:+6.2f} | {total:+6.2f}")
|
| 111 |
-
|
| 112 |
-
print()
|
| 113 |
-
print("Expected ordering (best -> worst total):")
|
| 114 |
-
print(" perfect > good > wrong_belief, action_only > verbose > garbage")
|
| 115 |
-
print("If `perfect > wrong_belief`, the belief signal is gradient-providing.")
|
| 116 |
-
print("If `good > action_only`, format_valid pushes toward emitting belief.")
|
| 117 |
-
print("If `garbage` is most negative, format penalty is doing its job.")
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if __name__ == "__main__":
|
| 121 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Re-evaluate existing teacher trajectories under the NEW grader (with belief_accuracy term).
|
| 3 |
+
|
| 4 |
+
For each episode in the JSONL:
|
| 5 |
+
1. Replay the env from the recorded seed
|
| 6 |
+
2. Step with the recorded action sequence
|
| 7 |
+
3. Call env.record_belief(parsed_belief) at each step (using the LAST step's
|
| 8 |
+
belief for the grader)
|
| 9 |
+
4. Read final_score (now under new grader)
|
| 10 |
+
|
| 11 |
+
Also runs heuristic + random baselines on the same seeds for comparison, so
|
| 12 |
+
we can directly answer: does the teacher beat heuristic+random under the new
|
| 13 |
+
grader, and by how much?
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python scripts/reeval_teacher_trajectories.py --jsonl data/teacher_30ep_validation.jsonl
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import random
|
| 23 |
+
import sys
|
| 24 |
+
from collections import defaultdict
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from statistics import mean, median, stdev
|
| 27 |
+
|
| 28 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 29 |
+
|
| 30 |
+
from models import ActionType, RhythmAction
|
| 31 |
+
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS
|
| 32 |
+
from training.dataset import heuristic_action
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def replay_with_beliefs(seed: int, action_seq: list[str], belief_seq: list[list[float]]) -> dict:
|
| 36 |
+
"""Replay an episode with given actions+beliefs, return final_score breakdown."""
|
| 37 |
+
env = RhythmEnvironment()
|
| 38 |
+
obs = env.reset(seed=seed)
|
| 39 |
+
for action_name, belief in zip(action_seq, belief_seq):
|
| 40 |
+
if obs.done:
|
| 41 |
+
break
|
| 42 |
+
env.record_belief(belief)
|
| 43 |
+
obs = env.step(RhythmAction(action_type=ActionType(action_name)))
|
| 44 |
+
final_score = obs.reward_breakdown.get("final_score", 0.0)
|
| 45 |
+
return {
|
| 46 |
+
"final_score": final_score,
|
| 47 |
+
"vitality": obs.vitality,
|
| 48 |
+
"cognition": obs.cognition,
|
| 49 |
+
"progress": obs.progress,
|
| 50 |
+
"serenity": obs.serenity,
|
| 51 |
+
"connection": obs.connection,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def play_heuristic(seed: int) -> float:
|
| 56 |
+
env = RhythmEnvironment()
|
| 57 |
+
obs = env.reset(seed=seed)
|
| 58 |
+
for _ in range(MAX_STEPS):
|
| 59 |
+
if obs.done:
|
| 60 |
+
break
|
| 61 |
+
action = heuristic_action(obs)
|
| 62 |
+
obs = env.step(RhythmAction(action_type=action))
|
| 63 |
+
return obs.reward_breakdown.get("final_score", 0.0)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def play_random(seed: int) -> float:
|
| 67 |
+
env = RhythmEnvironment()
|
| 68 |
+
obs = env.reset(seed=seed)
|
| 69 |
+
rng = random.Random(seed + 12345)
|
| 70 |
+
actions = list(ActionType)
|
| 71 |
+
for _ in range(MAX_STEPS):
|
| 72 |
+
if obs.done:
|
| 73 |
+
break
|
| 74 |
+
action = rng.choice(actions)
|
| 75 |
+
obs = env.step(RhythmAction(action_type=action))
|
| 76 |
+
return obs.reward_breakdown.get("final_score", 0.0)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
parser = argparse.ArgumentParser()
|
| 81 |
+
parser.add_argument("--jsonl", type=str, required=True,
|
| 82 |
+
help="Teacher trajectories JSONL")
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
|
| 85 |
+
# Group rows by seed
|
| 86 |
+
by_seed: dict[int, list[dict]] = defaultdict(list)
|
| 87 |
+
with open(args.jsonl) as f:
|
| 88 |
+
for line in f:
|
| 89 |
+
row = json.loads(line)
|
| 90 |
+
by_seed[row["seed"]].append(row)
|
| 91 |
+
for seed in by_seed:
|
| 92 |
+
by_seed[seed].sort(key=lambda r: r["step"])
|
| 93 |
+
|
| 94 |
+
seeds = sorted(by_seed.keys())
|
| 95 |
+
print(f"Loaded {len(seeds)} episodes from {args.jsonl}\n")
|
| 96 |
+
|
| 97 |
+
teacher_scores = []
|
| 98 |
+
heuristic_scores = []
|
| 99 |
+
random_scores = []
|
| 100 |
+
|
| 101 |
+
for seed in seeds:
|
| 102 |
+
rows = by_seed[seed]
|
| 103 |
+
action_seq = [r["parsed_action"] for r in rows]
|
| 104 |
+
belief_seq = [r["parsed_belief"] for r in rows]
|
| 105 |
+
|
| 106 |
+
result = replay_with_beliefs(seed, action_seq, belief_seq)
|
| 107 |
+
teacher_scores.append(result["final_score"])
|
| 108 |
+
heuristic_scores.append(play_heuristic(seed))
|
| 109 |
+
random_scores.append(play_random(seed))
|
| 110 |
+
|
| 111 |
+
def stats(vs, label):
|
| 112 |
+
return f"{label:<10} mean={mean(vs):.4f} median={median(vs):.4f} std={stdev(vs):.4f} min={min(vs):.4f} max={max(vs):.4f}"
|
| 113 |
+
|
| 114 |
+
print("=" * 78)
|
| 115 |
+
print("FINAL_SCORE UNDER NEW GRADER (belief_accuracy weighted 0.20)")
|
| 116 |
+
print("=" * 78)
|
| 117 |
+
print(stats(teacher_scores, "teacher:"))
|
| 118 |
+
print(stats(heuristic_scores, "heuristic:"))
|
| 119 |
+
print(stats(random_scores, "random:"))
|
| 120 |
+
|
| 121 |
+
teacher_avg = mean(teacher_scores)
|
| 122 |
+
heur_avg = mean(heuristic_scores)
|
| 123 |
+
rand_avg = mean(random_scores)
|
| 124 |
+
margin_h = teacher_avg - heur_avg
|
| 125 |
+
margin_r = teacher_avg - rand_avg
|
| 126 |
+
|
| 127 |
+
print()
|
| 128 |
+
print(f"teacher - heuristic: {margin_h:+.4f}")
|
| 129 |
+
print(f"teacher - random: {margin_r:+.4f}")
|
| 130 |
+
print()
|
| 131 |
+
|
| 132 |
+
# Per-episode comparison
|
| 133 |
+
teacher_beats_heur = sum(1 for t, h in zip(teacher_scores, heuristic_scores) if t > h)
|
| 134 |
+
print(f"Episodes where teacher > heuristic: {teacher_beats_heur} / {len(seeds)}")
|
| 135 |
+
|
| 136 |
+
# Verdict
|
| 137 |
+
print()
|
| 138 |
+
print("=" * 78)
|
| 139 |
+
print("VERDICT")
|
| 140 |
+
print("=" * 78)
|
| 141 |
+
GATE = 0.05 # teacher must beat heuristic by at least 5pts on average
|
| 142 |
+
if margin_h >= GATE:
|
| 143 |
+
print(f"PASS β teacher beats heuristic by {margin_h:.3f} (>= {GATE} required)")
|
| 144 |
+
print(" New grader differentiates inference from reflex. Proceed to scale + SFT.")
|
| 145 |
+
elif margin_h > 0:
|
| 146 |
+
print(f"WEAK β teacher beats heuristic by only {margin_h:.3f} (need {GATE}+)")
|
| 147 |
+
print(" Consider raising belief_accuracy weight or improving teacher prompt.")
|
| 148 |
+
else:
|
| 149 |
+
print(f"FAIL β teacher does NOT beat heuristic ({margin_h:.3f})")
|
| 150 |
+
print(" Either grader weights are wrong or teacher is genuinely weak.")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
main()
|
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.10"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "torch",
|
| 5 |
+
# "transformers==4.56.2",
|
| 6 |
+
# "trl==0.22.2",
|
| 7 |
+
# "datasets",
|
| 8 |
+
# "peft",
|
| 9 |
+
# "accelerate",
|
| 10 |
+
# "bitsandbytes",
|
| 11 |
+
# "unsloth",
|
| 12 |
+
# "openenv-core",
|
| 13 |
+
# "fastapi",
|
| 14 |
+
# "uvicorn",
|
| 15 |
+
# "pydantic",
|
| 16 |
+
# "huggingface_hub",
|
| 17 |
+
# ]
|
| 18 |
+
# ///
|
| 19 |
+
"""
|
| 20 |
+
HF Jobs orchestrator for SFT prime stage.
|
| 21 |
+
|
| 22 |
+
Submits the SFT prime training as an HF Jobs run. Clones the rhythm_env
|
| 23 |
+
HF Space, downloads the teacher trajectory JSONL files from a HF dataset
|
| 24 |
+
or model repo, runs training/sft_prime.py, and uploads the SFT'd model.
|
| 25 |
+
|
| 26 |
+
Submit from local with:
|
| 27 |
+
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN \\
|
| 28 |
+
-e TEACHER_DATA_REPO=InosLihka/rhythm-env-teacher-trajectories \\
|
| 29 |
+
-e MODEL_REPO_SUFFIX=sft-primed \\
|
| 30 |
+
-e EPOCHS=2 \\
|
| 31 |
+
-d scripts/sft_on_hf.py
|
| 32 |
+
|
| 33 |
+
Cost on a10g-large at $1.50/hr: ~$2-3 for ~30-45 min training.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import json
|
| 37 |
+
import os
|
| 38 |
+
import shutil
|
| 39 |
+
import subprocess
|
| 40 |
+
import sys
|
| 41 |
+
from pathlib import Path
|
| 42 |
+
|
| 43 |
+
REPO_URL = os.environ.get("REPO_URL", "https://huggingface.co/spaces/InosLihka/rhythm_env")
|
| 44 |
+
WORK_DIR = "/tmp/rhythm_env"
|
| 45 |
+
OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythm-env-sft-primed"
|
| 46 |
+
|
| 47 |
+
# Teacher trajectory data must be uploaded to a HF dataset/model repo before
|
| 48 |
+
# this job runs (HF Jobs containers don't have access to local files). The
|
| 49 |
+
# repo should contain the teacher_*.jsonl files at its root.
|
| 50 |
+
TEACHER_DATA_REPO = os.environ.get(
|
| 51 |
+
"TEACHER_DATA_REPO",
|
| 52 |
+
"InosLihka/rhythm-env-teacher-trajectories",
|
| 53 |
+
)
|
| 54 |
+
TEACHER_FILES = os.environ.get(
|
| 55 |
+
"TEACHER_FILES",
|
| 56 |
+
"teacher_30ep_validation.jsonl,teacher_indist_30_99.jsonl,teacher_ood_10000_10049.jsonl",
|
| 57 |
+
).split(",")
|
| 58 |
+
|
| 59 |
+
EPOCHS = int(os.environ.get("EPOCHS", "2"))
|
| 60 |
+
MAX_STEPS = int(os.environ.get("MAX_STEPS", "-1")) # -1 = use epochs
|
| 61 |
+
LORA_RANK = int(os.environ.get("LORA_RANK", "16"))
|
| 62 |
+
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4"))
|
| 63 |
+
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "2048"))
|
| 64 |
+
|
| 65 |
+
SUFFIX = os.environ.get("MODEL_REPO_SUFFIX", "sft-primed")
|
| 66 |
+
DEFAULT_REPO = f"InosLihka/rhythm-env-meta-trained-{SUFFIX}"
|
| 67 |
+
MODEL_REPO = os.environ.get("MODEL_REPO", DEFAULT_REPO)
|
| 68 |
+
|
| 69 |
+
print("=== SFT prime config ===")
|
| 70 |
+
print(f" TEACHER_DATA_REPO: {TEACHER_DATA_REPO}")
|
| 71 |
+
print(f" TEACHER_FILES: {TEACHER_FILES}")
|
| 72 |
+
print(f" EPOCHS={EPOCHS}, MAX_STEPS={MAX_STEPS}, LORA_RANK={LORA_RANK}")
|
| 73 |
+
print(f" LR={LEARNING_RATE}, MAX_SEQ_LENGTH={MAX_SEQ_LENGTH}")
|
| 74 |
+
print(f" MODEL_REPO={MODEL_REPO}")
|
| 75 |
+
print()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def run(cmd):
|
| 79 |
+
print(f"\n>>> {' '.join(cmd) if isinstance(cmd, list) else cmd}", flush=True)
|
| 80 |
+
subprocess.run(cmd, check=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def main():
|
| 84 |
+
# 1. Clone repo
|
| 85 |
+
if Path(WORK_DIR).exists():
|
| 86 |
+
shutil.rmtree(WORK_DIR)
|
| 87 |
+
run(["git", "clone", REPO_URL, WORK_DIR])
|
| 88 |
+
os.chdir(WORK_DIR)
|
| 89 |
+
sys.path.insert(0, WORK_DIR)
|
| 90 |
+
sys.path.insert(0, os.path.join(WORK_DIR, "training"))
|
| 91 |
+
|
| 92 |
+
# 2. Download teacher trajectories from HF Hub
|
| 93 |
+
from huggingface_hub import hf_hub_download
|
| 94 |
+
|
| 95 |
+
Path("data").mkdir(exist_ok=True)
|
| 96 |
+
local_paths = []
|
| 97 |
+
for fn in TEACHER_FILES:
|
| 98 |
+
fn = fn.strip()
|
| 99 |
+
if not fn:
|
| 100 |
+
continue
|
| 101 |
+
print(f"Downloading {fn} from {TEACHER_DATA_REPO}...")
|
| 102 |
+
local = hf_hub_download(
|
| 103 |
+
repo_id=TEACHER_DATA_REPO,
|
| 104 |
+
filename=fn,
|
| 105 |
+
repo_type="dataset",
|
| 106 |
+
local_dir="data",
|
| 107 |
+
)
|
| 108 |
+
local_paths.append(local)
|
| 109 |
+
print(f"Downloaded {len(local_paths)} JSONL files")
|
| 110 |
+
|
| 111 |
+
# 3. Run SFT
|
| 112 |
+
sft_args = [
|
| 113 |
+
"python", "training/sft_prime.py",
|
| 114 |
+
"--teacher_jsonls", *local_paths,
|
| 115 |
+
"--output_dir", OUTPUT_DIR,
|
| 116 |
+
"--lora_rank", str(LORA_RANK),
|
| 117 |
+
"--learning_rate", str(LEARNING_RATE),
|
| 118 |
+
"--max_seq_length", str(MAX_SEQ_LENGTH),
|
| 119 |
+
"--epochs", str(EPOCHS),
|
| 120 |
+
]
|
| 121 |
+
if MAX_STEPS > 0:
|
| 122 |
+
sft_args.extend(["--max_steps", str(MAX_STEPS)])
|
| 123 |
+
run(sft_args)
|
| 124 |
+
|
| 125 |
+
# 4. Eval (3 conditions: discrete-3 / in-dist / OOD)
|
| 126 |
+
eval_args = [
|
| 127 |
+
"python", "training/inference_eval.py",
|
| 128 |
+
"--model_path", OUTPUT_DIR,
|
| 129 |
+
"--num_episodes", "5",
|
| 130 |
+
"--output_file", "eval_results.json",
|
| 131 |
+
]
|
| 132 |
+
run(eval_args)
|
| 133 |
+
|
| 134 |
+
# 5. Upload to HF Hub
|
| 135 |
+
token = os.environ.get("HF_TOKEN")
|
| 136 |
+
if not token:
|
| 137 |
+
print("WARNING: HF_TOKEN not set, skipping upload")
|
| 138 |
+
print(f"Outputs at: {OUTPUT_DIR}")
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
from huggingface_hub import HfApi, login
|
| 142 |
+
login(token=token)
|
| 143 |
+
api = HfApi()
|
| 144 |
+
api.create_repo(MODEL_REPO, exist_ok=True, repo_type="model")
|
| 145 |
+
|
| 146 |
+
api.upload_folder(
|
| 147 |
+
folder_path=OUTPUT_DIR,
|
| 148 |
+
repo_id=MODEL_REPO,
|
| 149 |
+
repo_type="model",
|
| 150 |
+
commit_message=f"SFT prime ({EPOCHS} epochs, lora r={LORA_RANK}) on teacher trajectories",
|
| 151 |
+
)
|
| 152 |
+
api.upload_file(
|
| 153 |
+
path_or_fileobj="eval_results.json",
|
| 154 |
+
path_in_repo="eval_results.json",
|
| 155 |
+
repo_id=MODEL_REPO,
|
| 156 |
+
repo_type="model",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
print()
|
| 160 |
+
print("=" * 60)
|
| 161 |
+
print("DONE")
|
| 162 |
+
print(f" SFT'd model: https://huggingface.co/{MODEL_REPO}")
|
| 163 |
+
print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
|
| 164 |
+
print("=" * 60)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
main()
|
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Upload teacher trajectory JSONL files to a HF Hub dataset repo so the
|
| 3 |
+
SFT-on-HF-Jobs orchestrator can download them.
|
| 4 |
+
|
| 5 |
+
HF Jobs containers don't have access to local files β the teacher data
|
| 6 |
+
has to live on HF Hub first.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/upload_teacher_data.py \\
|
| 10 |
+
--files data/teacher_30ep_validation.jsonl \\
|
| 11 |
+
data/teacher_indist_30_99.jsonl \\
|
| 12 |
+
data/teacher_ood_10000_10049.jsonl \\
|
| 13 |
+
--repo InosLihka/rhythm-env-teacher-trajectories
|
| 14 |
+
|
| 15 |
+
Requires HF_TOKEN env var (or `hf auth login` already done).
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
from huggingface_hub import HfApi, login
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> None:
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
parser.add_argument("--files", nargs="+", required=True,
|
| 29 |
+
help="Local JSONL files to upload")
|
| 30 |
+
parser.add_argument("--repo", type=str, required=True,
|
| 31 |
+
help="HF Hub dataset repo (e.g. InosLihka/rhythm-env-teacher-trajectories)")
|
| 32 |
+
parser.add_argument("--commit_message", type=str,
|
| 33 |
+
default="Add teacher trajectories from gpt-5.4 + grader v2")
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
token = os.environ.get("HF_TOKEN")
|
| 37 |
+
if token:
|
| 38 |
+
login(token=token)
|
| 39 |
+
|
| 40 |
+
api = HfApi()
|
| 41 |
+
api.create_repo(args.repo, exist_ok=True, repo_type="dataset", private=False)
|
| 42 |
+
print(f"Repo: https://huggingface.co/datasets/{args.repo}")
|
| 43 |
+
|
| 44 |
+
for path in args.files:
|
| 45 |
+
p = Path(path)
|
| 46 |
+
if not p.exists():
|
| 47 |
+
print(f"SKIP missing: {path}")
|
| 48 |
+
continue
|
| 49 |
+
print(f"Uploading {p.name} ({p.stat().st_size / 1024:.1f} KB)...")
|
| 50 |
+
api.upload_file(
|
| 51 |
+
path_or_fileobj=str(p),
|
| 52 |
+
path_in_repo=p.name,
|
| 53 |
+
repo_id=args.repo,
|
| 54 |
+
repo_type="dataset",
|
| 55 |
+
commit_message=args.commit_message,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Add a small README documenting the dataset format
|
| 59 |
+
readme = """# RhythmEnv teacher trajectories
|
| 60 |
+
|
| 61 |
+
Per-step (state, prompt, teacher_response, action, belief, reward) tuples
|
| 62 |
+
collected by replaying RhythmEnv with gpt-5.4 (Azure AI Foundry) as the
|
| 63 |
+
acting agent. Used as the SFT corpus for Algorithm Distillation.
|
| 64 |
+
|
| 65 |
+
## Files
|
| 66 |
+
|
| 67 |
+
Each JSONL row is one step. Schema:
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
{
|
| 71 |
+
"seed": int, # episode seed (also determines hidden profile)
|
| 72 |
+
"step": int, # step index 0..27
|
| 73 |
+
"profile_name": str, # 'sampled_<seed>' for continuous-mode profiles
|
| 74 |
+
"user_prompt": str, # observation prompt the student will see at inference
|
| 75 |
+
"teacher_response": str, # full teacher output: "<reasoning>...</reasoning>\\nS M W ACTION_NAME"
|
| 76 |
+
"parsed_action": str, # action name (e.g. "deep_work")
|
| 77 |
+
"parsed_belief": [s, m, w], # 3-dim belief in [0, 1]
|
| 78 |
+
"answer_match": str, # raw matched substring of the answer line
|
| 79 |
+
"env_reward": float, # per-step env reward
|
| 80 |
+
"parse_failed": bool, # True if response couldn't be parsed into action+belief
|
| 81 |
+
"true_belief": [s, m, w] # ground-truth belief vector for the active profile
|
| 82 |
+
}
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Generation
|
| 86 |
+
|
| 87 |
+
Generated using `scripts/generate_teacher_trajectories.py` from the
|
| 88 |
+
[InosLihka/rhythm_env Space](https://huggingface.co/spaces/InosLihka/rhythm_env).
|
| 89 |
+
Teacher: `gpt-5.4` (Azure AI Foundry, version 2026-03-05). Sampling
|
| 90 |
+
temperature 0.5. ~840 (state, response) pairs per 30-episode batch.
|
| 91 |
+
"""
|
| 92 |
+
api.upload_file(
|
| 93 |
+
path_or_fileobj=readme.encode("utf-8"),
|
| 94 |
+
path_in_repo="README.md",
|
| 95 |
+
repo_id=args.repo,
|
| 96 |
+
repo_type="dataset",
|
| 97 |
+
commit_message="Add dataset README",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
print()
|
| 101 |
+
print(f"Done. Dataset: https://huggingface.co/datasets/{args.repo}")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
main()
|
|
@@ -19,8 +19,9 @@ Key design principles for learnability:
|
|
| 19 |
in every observation so the agent can detect personality anomalies
|
| 20 |
- *_anomaly fields: per-meter deviation from neutral-profile expectation,
|
| 21 |
giving a direct fingerprint of the hidden profile each step
|
| 22 |
-
-
|
| 23 |
-
|
|
|
|
| 24 |
- Profile assignment uses a scrambled seed to prevent memorization
|
| 25 |
of seed β profile mappings during training
|
| 26 |
"""
|
|
@@ -215,9 +216,11 @@ def sample_profile(seed: int) -> Dict[str, Any]:
|
|
| 215 |
raw = [rng.gammavariate(a, 1.0) for a in alphas]
|
| 216 |
total = sum(raw)
|
| 217 |
weights = [w / total for w in raw]
|
| 218 |
-
#
|
| 219 |
-
#
|
| 220 |
-
#
|
|
|
|
|
|
|
| 221 |
weights = [max(0.05, min(0.45, w)) for w in weights]
|
| 222 |
total = sum(weights)
|
| 223 |
weights = [w / total for w in weights]
|
|
@@ -303,8 +306,8 @@ class RhythmEnvironment(Environment):
|
|
| 303 |
- Anomaly signals: actual delta minus neutral-profile expectation
|
| 304 |
- Rolling step_history (last 7 steps) with actions, rewards, deltas
|
| 305 |
|
| 306 |
-
The final grade rewards profile-appropriate strategy
|
| 307 |
-
of
|
| 308 |
"""
|
| 309 |
|
| 310 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
|
@@ -324,9 +327,12 @@ class RhythmEnvironment(Environment):
|
|
| 324 |
self._timestep: int = 0
|
| 325 |
self._crash_count: int = 0
|
| 326 |
self._total_reward: float = 0.0
|
| 327 |
-
self._recent_actions: list = []
|
| 328 |
self._step_history: list = []
|
| 329 |
self._step_rewards: list = [] # per-step rewards (for adaptation_score in grader)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
def get_metadata(self) -> EnvironmentMetadata:
|
| 332 |
return EnvironmentMetadata(
|
|
@@ -357,18 +363,13 @@ class RhythmEnvironment(Environment):
|
|
| 357 |
|
| 358 |
self._rng = random.Random(effective_seed)
|
| 359 |
|
| 360 |
-
# Profile selection β
|
| 361 |
-
# 1. Explicit hardcoded profile name β
|
| 362 |
-
#
|
| 363 |
-
#
|
| 364 |
profile_name = kwargs.get("profile")
|
| 365 |
-
profile_mode = kwargs.get("profile_mode", "continuous")
|
| 366 |
if profile_name and profile_name in PROFILE_MAP:
|
| 367 |
self._profile = deepcopy(PROFILE_MAP[profile_name])
|
| 368 |
-
elif profile_mode == "discrete":
|
| 369 |
-
profile_rng = random.Random(effective_seed ^ 0xA3C5F729)
|
| 370 |
-
profile_index = profile_rng.randint(0, len(PROFILES) - 1)
|
| 371 |
-
self._profile = deepcopy(PROFILES[profile_index])
|
| 372 |
else:
|
| 373 |
self._profile = sample_profile(effective_seed)
|
| 374 |
|
|
@@ -384,9 +385,9 @@ class RhythmEnvironment(Environment):
|
|
| 384 |
self._timestep = 0
|
| 385 |
self._crash_count = 0
|
| 386 |
self._total_reward = 0.0
|
| 387 |
-
self._recent_actions = []
|
| 388 |
self._step_history = []
|
| 389 |
self._step_rewards = []
|
|
|
|
| 390 |
|
| 391 |
self._state = RhythmState(
|
| 392 |
episode_id=episode_id or str(uuid4()),
|
|
@@ -430,7 +431,8 @@ class RhythmEnvironment(Environment):
|
|
| 430 |
effects = dict(ACTION_EFFECTS[action_name])
|
| 431 |
|
| 432 |
# --- 2b. Repetition dampening ---
|
| 433 |
-
|
|
|
|
| 434 |
if repeat_count > 0:
|
| 435 |
dampening = 1.0 - 0.25 * repeat_count # 0.75, 0.50, 0.25
|
| 436 |
for meter in METERS:
|
|
@@ -486,7 +488,6 @@ class RhythmEnvironment(Environment):
|
|
| 486 |
|
| 487 |
# --- 10. Advance timestep ---
|
| 488 |
self._timestep += 1
|
| 489 |
-
self._recent_actions.append(action_name)
|
| 490 |
new_day = self._timestep // SLOTS_PER_DAY
|
| 491 |
new_slot = self._timestep % SLOTS_PER_DAY
|
| 492 |
|
|
@@ -509,9 +510,10 @@ class RhythmEnvironment(Environment):
|
|
| 509 |
if done:
|
| 510 |
final_score = self._grade_episode()
|
| 511 |
reward_breakdown["final_score"] = round(final_score, 4)
|
| 512 |
-
#
|
| 513 |
-
#
|
| 514 |
-
#
|
|
|
|
| 515 |
terminal_bonus = (final_score - 0.5) * 5.0
|
| 516 |
reward = max(-3.0, min(3.0, reward + terminal_bonus))
|
| 517 |
self._total_reward += terminal_bonus # update tracking too
|
|
@@ -530,8 +532,9 @@ class RhythmEnvironment(Environment):
|
|
| 530 |
self._state.active_event = active_event
|
| 531 |
|
| 532 |
# --- 15. Append completed step to rolling history ---
|
| 533 |
-
#
|
| 534 |
-
#
|
|
|
|
| 535 |
self._step_history.append({
|
| 536 |
"step": current_step,
|
| 537 |
"action": action_name,
|
|
@@ -575,12 +578,26 @@ class RhythmEnvironment(Environment):
|
|
| 575 |
"""
|
| 576 |
return profile_to_belief_vector(self._profile)
|
| 577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
def get_profile_hint(self) -> Dict[str, float]:
|
| 579 |
"""Return a coarse profile hint usable in observation during curriculum.
|
| 580 |
|
| 581 |
-
Returns the 3-dim belief vector with descriptive keys. The
|
| 582 |
-
|
| 583 |
-
|
| 584 |
"""
|
| 585 |
b = profile_to_belief_vector(self._profile)
|
| 586 |
return {"social_pref": round(b[0], 3), "morning_pref": round(b[1], 3), "work_pref": round(b[2], 3)}
|
|
@@ -711,15 +728,15 @@ class RhythmEnvironment(Environment):
|
|
| 711 |
self._vitality = max(0.0, self._vitality - vd)
|
| 712 |
|
| 713 |
def _compute_reward(self, deltas: Dict[str, float]) -> float:
|
| 714 |
-
"""
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
per-step reward
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
"""
|
| 724 |
weights = self._profile["reward_weights"]
|
| 725 |
return sum(deltas[m] * weights[m] for m in METERS) * REWARD_SCALE
|
|
@@ -729,32 +746,37 @@ class RhythmEnvironment(Environment):
|
|
| 729 |
Compute final episode score in [0, 1].
|
| 730 |
|
| 731 |
Components (meta-learning aligned):
|
| 732 |
-
0.
|
| 733 |
-
0.
|
| 734 |
-
0.
|
| 735 |
-
0.
|
| 736 |
0.10 β efficiency: bounded normalized average reward
|
|
|
|
| 737 |
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
a
|
|
|
|
|
|
|
|
|
|
| 741 |
|
|
|
|
|
|
|
| 742 |
Per-step reward is already profile-weighted via _compute_reward(), so
|
| 743 |
-
a high late-half mean
|
| 744 |
-
optimized for THIS profile's preferences.
|
| 745 |
"""
|
| 746 |
steps = max(self._timestep, 1)
|
| 747 |
|
| 748 |
-
# 1. Crash-free ratio (0.
|
| 749 |
crash_free_ratio = 1.0 - (self._crash_count / (steps * len(METERS)))
|
| 750 |
|
| 751 |
-
# 2. Progress (0.
|
| 752 |
progress_score = self._progress
|
| 753 |
|
| 754 |
-
# 3. Connection (0.
|
| 755 |
connection_score = self._connection
|
| 756 |
|
| 757 |
-
# 4. Adaptation score (0.
|
| 758 |
# Split rewards in halves; positive only if late half is non-negative
|
| 759 |
# AND late > early. Normalized to [0, 1].
|
| 760 |
half = max(steps // 2, 1)
|
|
@@ -763,9 +785,10 @@ class RhythmEnvironment(Environment):
|
|
| 763 |
if early and late:
|
| 764 |
mean_early = sum(early) / len(early)
|
| 765 |
mean_late = sum(late) / len(late)
|
| 766 |
-
#
|
| 767 |
-
#
|
| 768 |
-
#
|
|
|
|
| 769 |
late_quality = max(0.0, min(1.0, (mean_late + 3.0) / 6.0))
|
| 770 |
gain = mean_late - mean_early
|
| 771 |
# gain in [-6, +6]; normalize to [0, 1] (only positive gain counts)
|
|
@@ -778,12 +801,23 @@ class RhythmEnvironment(Environment):
|
|
| 778 |
avg_reward = self._total_reward / steps
|
| 779 |
efficiency_score = max(0.0, min(1.0, (avg_reward + 1.0) / 2.0))
|
| 780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
score = (
|
| 782 |
-
0.
|
| 783 |
-
+ 0.
|
| 784 |
-
+ 0.
|
| 785 |
-
+ 0.
|
| 786 |
+ 0.10 * efficiency_score
|
|
|
|
| 787 |
)
|
| 788 |
return max(0.0, min(1.0, score))
|
| 789 |
|
|
|
|
| 19 |
in every observation so the agent can detect personality anomalies
|
| 20 |
- *_anomaly fields: per-meter deviation from neutral-profile expectation,
|
| 21 |
giving a direct fingerprint of the hidden profile each step
|
| 22 |
+
- adaptation_score: 30% of final grade β late-half mean per-step reward
|
| 23 |
+
minus early-half mean (gated by absolute late-half quality). Rewards
|
| 24 |
+
the agent for getting better as it learns the user.
|
| 25 |
- Profile assignment uses a scrambled seed to prevent memorization
|
| 26 |
of seed β profile mappings during training
|
| 27 |
"""
|
|
|
|
| 216 |
raw = [rng.gammavariate(a, 1.0) for a in alphas]
|
| 217 |
total = sum(raw)
|
| 218 |
weights = [w / total for w in raw]
|
| 219 |
+
# Cap each weight at 0.45 so every sampled profile weights 3+ meters
|
| 220 |
+
# meaningfully. With an 0.80 cap, single-meter-dominant profiles let
|
| 221 |
+
# SLEEP-spam (or any single recovery action) be optimal β the env wasn't
|
| 222 |
+
# lying, the agent was right to spam. Forcing balance makes belief
|
| 223 |
+
# inference matter for action selection.
|
| 224 |
weights = [max(0.05, min(0.45, w)) for w in weights]
|
| 225 |
total = sum(weights)
|
| 226 |
weights = [w / total for w in weights]
|
|
|
|
| 306 |
- Anomaly signals: actual delta minus neutral-profile expectation
|
| 307 |
- Rolling step_history (last 7 steps) with actions, rewards, deltas
|
| 308 |
|
| 309 |
+
The final grade rewards profile-appropriate strategy via adaptation_score
|
| 310 |
+
(30% of grade): late-half mean per-step reward minus early-half mean.
|
| 311 |
"""
|
| 312 |
|
| 313 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
|
|
|
| 327 |
self._timestep: int = 0
|
| 328 |
self._crash_count: int = 0
|
| 329 |
self._total_reward: float = 0.0
|
|
|
|
| 330 |
self._step_history: list = []
|
| 331 |
self._step_rewards: list = [] # per-step rewards (for adaptation_score in grader)
|
| 332 |
+
# Latest emitted belief vector β set by callers via record_belief() and
|
| 333 |
+
# consumed by _grade_episode. Stays None if the agent never emits a belief
|
| 334 |
+
# (e.g. heuristic baseline) β that case scores 0 on the belief component.
|
| 335 |
+
self._final_belief: Optional[List[float]] = None
|
| 336 |
|
| 337 |
def get_metadata(self) -> EnvironmentMetadata:
|
| 338 |
return EnvironmentMetadata(
|
|
|
|
| 363 |
|
| 364 |
self._rng = random.Random(effective_seed)
|
| 365 |
|
| 366 |
+
# Profile selection β two modes:
|
| 367 |
+
# 1. Explicit hardcoded profile name β one of the 3 reference profiles
|
| 368 |
+
# (used by tests + the legacy 3-profile eval condition)
|
| 369 |
+
# 2. Default β sampled continuous profile (meta-RL training distribution)
|
| 370 |
profile_name = kwargs.get("profile")
|
|
|
|
| 371 |
if profile_name and profile_name in PROFILE_MAP:
|
| 372 |
self._profile = deepcopy(PROFILE_MAP[profile_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
else:
|
| 374 |
self._profile = sample_profile(effective_seed)
|
| 375 |
|
|
|
|
| 385 |
self._timestep = 0
|
| 386 |
self._crash_count = 0
|
| 387 |
self._total_reward = 0.0
|
|
|
|
| 388 |
self._step_history = []
|
| 389 |
self._step_rewards = []
|
| 390 |
+
self._final_belief = None
|
| 391 |
|
| 392 |
self._state = RhythmState(
|
| 393 |
episode_id=episode_id or str(uuid4()),
|
|
|
|
| 431 |
effects = dict(ACTION_EFFECTS[action_name])
|
| 432 |
|
| 433 |
# --- 2b. Repetition dampening ---
|
| 434 |
+
recent3 = [h["action"] for h in self._step_history[-3:]]
|
| 435 |
+
repeat_count = recent3.count(action_name)
|
| 436 |
if repeat_count > 0:
|
| 437 |
dampening = 1.0 - 0.25 * repeat_count # 0.75, 0.50, 0.25
|
| 438 |
for meter in METERS:
|
|
|
|
| 488 |
|
| 489 |
# --- 10. Advance timestep ---
|
| 490 |
self._timestep += 1
|
|
|
|
| 491 |
new_day = self._timestep // SLOTS_PER_DAY
|
| 492 |
new_slot = self._timestep % SLOTS_PER_DAY
|
| 493 |
|
|
|
|
| 510 |
if done:
|
| 511 |
final_score = self._grade_episode()
|
| 512 |
reward_breakdown["final_score"] = round(final_score, 4)
|
| 513 |
+
# Sparse terminal reward: directly supervise on grader final_score.
|
| 514 |
+
# Centered on 0.5 (the "average" episode), scaled by 5x to give a
|
| 515 |
+
# range of [-2.5, +2.5] β strong enough to dominate any local
|
| 516 |
+
# reward-hack the agent might find on per-step shaping alone.
|
| 517 |
terminal_bonus = (final_score - 0.5) * 5.0
|
| 518 |
reward = max(-3.0, min(3.0, reward + terminal_bonus))
|
| 519 |
self._total_reward += terminal_bonus # update tracking too
|
|
|
|
| 532 |
self._state.active_event = active_event
|
| 533 |
|
| 534 |
# --- 15. Append completed step to rolling history ---
|
| 535 |
+
# History entries carry per-meter anomalies (actual β expected_under_neutral).
|
| 536 |
+
# The prompt builder reads these directly to surface the agent's clearest
|
| 537 |
+
# profile-inference signal.
|
| 538 |
self._step_history.append({
|
| 539 |
"step": current_step,
|
| 540 |
"action": action_name,
|
|
|
|
| 578 |
"""
|
| 579 |
return profile_to_belief_vector(self._profile)
|
| 580 |
|
| 581 |
+
def record_belief(self, belief: List[float]) -> None:
|
| 582 |
+
"""Record the agent's emitted belief for the current step.
|
| 583 |
+
|
| 584 |
+
The grader (`_grade_episode`) uses the LAST recorded belief to compute
|
| 585 |
+
the belief_accuracy component of final_score. Callers should invoke
|
| 586 |
+
this once per step after parsing the agent's completion. Heuristic /
|
| 587 |
+
random baselines that don't emit a belief never call this, and the
|
| 588 |
+
belief component scores 0 for them β that's intentional: the meta-RL
|
| 589 |
+
skill is INFERENCE, and only agents that actually try get credit.
|
| 590 |
+
"""
|
| 591 |
+
if len(belief) != 3:
|
| 592 |
+
raise ValueError(f"belief must have 3 elements, got {len(belief)}")
|
| 593 |
+
self._final_belief = [max(0.0, min(1.0, float(b))) for b in belief]
|
| 594 |
+
|
| 595 |
def get_profile_hint(self) -> Dict[str, float]:
|
| 596 |
"""Return a coarse profile hint usable in observation during curriculum.
|
| 597 |
|
| 598 |
+
Returns the 3-dim belief vector with descriptive keys. The dataset
|
| 599 |
+
generator passes this into the prompt for the fraction of samples
|
| 600 |
+
with show_profile_hint=True (the curriculum's "visible" warmup phase).
|
| 601 |
"""
|
| 602 |
b = profile_to_belief_vector(self._profile)
|
| 603 |
return {"social_pref": round(b[0], 3), "morning_pref": round(b[1], 3), "work_pref": round(b[2], 3)}
|
|
|
|
| 728 |
self._vitality = max(0.0, self._vitality - vd)
|
| 729 |
|
| 730 |
def _compute_reward(self, deltas: Dict[str, float]) -> float:
|
| 731 |
+
"""Pure profile-weighted per-step reward.
|
| 732 |
+
|
| 733 |
+
Deliberately uncontaminated: the grader-aligned bias (progress +
|
| 734 |
+
connection deltas) lives in the TRAINING reward function in
|
| 735 |
+
reward_functions.py, not here. Keeping the env's per-step reward
|
| 736 |
+
pure means (1) the agent's inference signal stays a clean function
|
| 737 |
+
of the hidden profile_weights, (2) the grader's adaptation_score
|
| 738 |
+
isn't computed on biased rewards, and (3) the env's reward matches
|
| 739 |
+
what an honest deployment would surface to the agent.
|
| 740 |
"""
|
| 741 |
weights = self._profile["reward_weights"]
|
| 742 |
return sum(deltas[m] * weights[m] for m in METERS) * REWARD_SCALE
|
|
|
|
| 746 |
Compute final episode score in [0, 1].
|
| 747 |
|
| 748 |
Components (meta-learning aligned):
|
| 749 |
+
0.15 β crash_free_ratio: no critical meter drops
|
| 750 |
+
0.20 β progress: career/skill growth
|
| 751 |
+
0.10 β connection: relationship maintained
|
| 752 |
+
0.25 β adaptation_score: agent got better as it learned the user
|
| 753 |
0.10 β efficiency: bounded normalized average reward
|
| 754 |
+
0.20 β belief_accuracy: how close last-emitted belief is to true profile
|
| 755 |
|
| 756 |
+
belief_accuracy is the explicit meta-RL inference signal: an agent
|
| 757 |
+
that doesn't emit a belief scores 0 here, and an agent that emits
|
| 758 |
+
a belief close to the hidden profile vector scores up to 1. Without
|
| 759 |
+
this term, agents that play heuristic-style "keep meters healthy"
|
| 760 |
+
score the same as agents that actually infer the profile, since the
|
| 761 |
+
other components don't differentiate inference from reflex.
|
| 762 |
|
| 763 |
+
adaptation_score remains the implicit signal: late-half mean per-step
|
| 764 |
+
reward minus early-half mean, gated by absolute late-half quality.
|
| 765 |
Per-step reward is already profile-weighted via _compute_reward(), so
|
| 766 |
+
a high late-half mean still means the agent figured out the profile.
|
|
|
|
| 767 |
"""
|
| 768 |
steps = max(self._timestep, 1)
|
| 769 |
|
| 770 |
+
# 1. Crash-free ratio (0.15)
|
| 771 |
crash_free_ratio = 1.0 - (self._crash_count / (steps * len(METERS)))
|
| 772 |
|
| 773 |
+
# 2. Progress (0.20)
|
| 774 |
progress_score = self._progress
|
| 775 |
|
| 776 |
+
# 3. Connection (0.10)
|
| 777 |
connection_score = self._connection
|
| 778 |
|
| 779 |
+
# 4. Adaptation score (0.25) β implicit inference signal.
|
| 780 |
# Split rewards in halves; positive only if late half is non-negative
|
| 781 |
# AND late > early. Normalized to [0, 1].
|
| 782 |
half = max(steps // 2, 1)
|
|
|
|
| 785 |
if early and late:
|
| 786 |
mean_early = sum(early) / len(early)
|
| 787 |
mean_late = sum(late) / len(late)
|
| 788 |
+
# Per-step rewards are clamped to [-3, +3] in step(), so normalize
|
| 789 |
+
# late_quality with the [-3, +3] range (NOT [-1, +1]) β otherwise
|
| 790 |
+
# the gate saturates at 1.0 for any mean_late β₯ 1 and the grader
|
| 791 |
+
# cannot distinguish good from excellent late-half quality.
|
| 792 |
late_quality = max(0.0, min(1.0, (mean_late + 3.0) / 6.0))
|
| 793 |
gain = mean_late - mean_early
|
| 794 |
# gain in [-6, +6]; normalize to [0, 1] (only positive gain counts)
|
|
|
|
| 801 |
avg_reward = self._total_reward / steps
|
| 802 |
efficiency_score = max(0.0, min(1.0, (avg_reward + 1.0) / 2.0))
|
| 803 |
|
| 804 |
+
# 6. Belief accuracy (0.20) β explicit inference signal.
|
| 805 |
+
# Score = 1 - mean_absolute_error against the true belief vector.
|
| 806 |
+
# If no belief was recorded (heuristic / random baselines), score = 0.
|
| 807 |
+
if self._final_belief is not None:
|
| 808 |
+
true_belief = profile_to_belief_vector(self._profile)
|
| 809 |
+
mae = sum(abs(b - t) for b, t in zip(self._final_belief, true_belief)) / 3.0
|
| 810 |
+
belief_accuracy_score = max(0.0, 1.0 - mae)
|
| 811 |
+
else:
|
| 812 |
+
belief_accuracy_score = 0.0
|
| 813 |
+
|
| 814 |
score = (
|
| 815 |
+
0.15 * crash_free_ratio
|
| 816 |
+
+ 0.20 * progress_score
|
| 817 |
+
+ 0.10 * connection_score
|
| 818 |
+
+ 0.25 * adaptation_score
|
| 819 |
+ 0.10 * efficiency_score
|
| 820 |
+
+ 0.20 * belief_accuracy_score
|
| 821 |
)
|
| 822 |
return max(0.0, min(1.0, score))
|
| 823 |
|
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""End-to-end smoke tests for the meta-RL training pipeline (no GPU).
|
| 8 |
+
|
| 9 |
+
Validates that dataset generation, the LLM-output parser, and all four
|
| 10 |
+
reward functions agree with each other before any GPU training spend.
|
| 11 |
+
The most important check is reward_variance_across_completion_kinds β
|
| 12 |
+
that's what catches the iter-1 mode-collapse class of bug (a reward
|
| 13 |
+
layer returns the same value for every completion in a GRPO group,
|
| 14 |
+
contributing zero to advantage).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import random
|
| 18 |
+
|
| 19 |
+
import pytest
|
| 20 |
+
|
| 21 |
+
from models import ActionType
|
| 22 |
+
from server.rhythm_environment import sample_profile, profile_to_belief_vector
|
| 23 |
+
from training.dataset import generate_dataset
|
| 24 |
+
from training.reward_functions import (
|
| 25 |
+
action_legal,
|
| 26 |
+
belief_accuracy,
|
| 27 |
+
env_reward,
|
| 28 |
+
extract_action_and_belief,
|
| 29 |
+
format_valid,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.fixture(scope="module")
|
| 34 |
+
def small_dataset():
|
| 35 |
+
"""20 episodes, ~80 samples max β enough variety for reward checks."""
|
| 36 |
+
return generate_dataset(
|
| 37 |
+
num_episodes=20,
|
| 38 |
+
strategy="mixed",
|
| 39 |
+
max_samples=80,
|
| 40 |
+
hint_fraction=0.1,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _synth_completion(seed: int, kind: str) -> str:
|
| 45 |
+
"""Synthesize a completion of a given quality."""
|
| 46 |
+
rng = random.Random(seed)
|
| 47 |
+
action_str = rng.choice(list(ActionType)).value.upper()
|
| 48 |
+
s, m, w = rng.randint(0, 9), rng.randint(0, 9), rng.randint(0, 9)
|
| 49 |
+
|
| 50 |
+
if kind == "perfect":
|
| 51 |
+
true = profile_to_belief_vector(sample_profile(seed))
|
| 52 |
+
s, m, w = round(true[0] * 9), round(true[1] * 9), round(true[2] * 9)
|
| 53 |
+
return f"{s} {m} {w} {action_str}"
|
| 54 |
+
if kind == "good":
|
| 55 |
+
return f"{s} {m} {w} {action_str}"
|
| 56 |
+
if kind == "action_only":
|
| 57 |
+
return action_str
|
| 58 |
+
if kind == "wrong_belief":
|
| 59 |
+
true = profile_to_belief_vector(sample_profile(seed))
|
| 60 |
+
s = round((1 - true[0]) * 9)
|
| 61 |
+
m = round((1 - true[1]) * 9)
|
| 62 |
+
w = round((1 - true[2]) * 9)
|
| 63 |
+
return f"{s} {m} {w} {action_str}"
|
| 64 |
+
if kind == "garbage":
|
| 65 |
+
return "I don't know what to do here"
|
| 66 |
+
return action_str
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# Dataset shape
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_dataset_is_non_empty(small_dataset):
|
| 75 |
+
assert len(small_dataset) > 0
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def test_dataset_rows_have_required_replay_columns(small_dataset):
|
| 79 |
+
expected = {"prompt", "seed", "step_index", "action_history", "profile_mode"}
|
| 80 |
+
for row in small_dataset:
|
| 81 |
+
missing = expected - row.keys()
|
| 82 |
+
assert not missing, f"row missing columns: {missing}"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_dataset_prompts_are_chat_messages(small_dataset):
|
| 86 |
+
for row in small_dataset[:5]:
|
| 87 |
+
msgs = row["prompt"]
|
| 88 |
+
assert isinstance(msgs, list) and len(msgs) == 2
|
| 89 |
+
assert msgs[0]["role"] == "system"
|
| 90 |
+
assert msgs[1]["role"] == "user"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Parser
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_parser_belief_first_format():
|
| 99 |
+
action, belief, provided = extract_action_and_belief("3 8 7 DEEP_WORK")
|
| 100 |
+
assert action == ActionType.DEEP_WORK
|
| 101 |
+
assert belief == pytest.approx([3 / 9, 8 / 9, 7 / 9], abs=1e-3)
|
| 102 |
+
assert provided is True
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def test_parser_action_only_returns_default_belief():
|
| 106 |
+
action, belief, provided = extract_action_and_belief("DEEP_WORK")
|
| 107 |
+
assert action == ActionType.DEEP_WORK
|
| 108 |
+
assert belief == [0.5, 0.5, 0.5]
|
| 109 |
+
assert provided is False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def test_parser_garbage_returns_none():
|
| 113 |
+
action, _, _ = extract_action_and_belief("I don't know what to do here")
|
| 114 |
+
assert action is None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
# Reward layers run end-to-end on synth completions
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@pytest.fixture(scope="module")
|
| 123 |
+
def replay_columns(small_dataset):
|
| 124 |
+
# Skip step_index=0 samples β belief_accuracy intentionally returns 0
|
| 125 |
+
# at step 0 (no info yet), which would mask gradient checks.
|
| 126 |
+
sub = [s for s in small_dataset if s["step_index"] > 0][:30]
|
| 127 |
+
return {
|
| 128 |
+
"samples": sub,
|
| 129 |
+
"seed": [s["seed"] for s in sub],
|
| 130 |
+
"history": [s["action_history"] for s in sub],
|
| 131 |
+
"mode": [s["profile_mode"] for s in sub],
|
| 132 |
+
"step_index": [s["step_index"] for s in sub],
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@pytest.mark.parametrize("kind", ["perfect", "good", "action_only", "wrong_belief", "garbage"])
|
| 137 |
+
def test_all_reward_layers_return_floats(kind, replay_columns):
|
| 138 |
+
sub = replay_columns["samples"]
|
| 139 |
+
completions = [[{"content": _synth_completion(s["seed"], kind)}] for s in sub]
|
| 140 |
+
|
| 141 |
+
f = format_valid(completions)
|
| 142 |
+
l = action_legal(completions)
|
| 143 |
+
e = env_reward(
|
| 144 |
+
completions,
|
| 145 |
+
seed=replay_columns["seed"],
|
| 146 |
+
action_history=replay_columns["history"],
|
| 147 |
+
profile_mode=replay_columns["mode"],
|
| 148 |
+
step_index=replay_columns["step_index"],
|
| 149 |
+
)
|
| 150 |
+
b = belief_accuracy(
|
| 151 |
+
completions,
|
| 152 |
+
seed=replay_columns["seed"],
|
| 153 |
+
action_history=replay_columns["history"],
|
| 154 |
+
profile_mode=replay_columns["mode"],
|
| 155 |
+
step_index=replay_columns["step_index"],
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
for layer_name, scores in [("format_valid", f), ("action_legal", l), ("env_reward", e), ("belief_accuracy", b)]:
|
| 159 |
+
assert len(scores) == len(completions), f"{layer_name} length mismatch"
|
| 160 |
+
for s in scores:
|
| 161 |
+
assert isinstance(s, float), f"{layer_name} returned non-float: {type(s)}"
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
# Reward layers DISCRIMINATE between completion qualities (anti-mode-collapse)
|
| 166 |
+
# ---------------------------------------------------------------------------
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def test_format_valid_discriminates_belief_vs_action_only(replay_columns):
|
| 170 |
+
"""format_valid must reward belief+action higher than action-only."""
|
| 171 |
+
sub = replay_columns["samples"]
|
| 172 |
+
good = [[{"content": _synth_completion(s["seed"], "good")}] for s in sub]
|
| 173 |
+
action_only = [[{"content": _synth_completion(s["seed"], "action_only")}] for s in sub]
|
| 174 |
+
|
| 175 |
+
good_avg = sum(format_valid(good)) / len(good)
|
| 176 |
+
action_only_avg = sum(format_valid(action_only)) / len(action_only)
|
| 177 |
+
assert good_avg > action_only_avg, (
|
| 178 |
+
f"format_valid did not push toward belief output: "
|
| 179 |
+
f"good={good_avg:.3f} action_only={action_only_avg:.3f}"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def test_belief_accuracy_discriminates_perfect_vs_wrong(replay_columns):
|
| 184 |
+
"""The whole point of the meta-RL signal: better belief β higher reward."""
|
| 185 |
+
sub = replay_columns["samples"]
|
| 186 |
+
perfect = [[{"content": _synth_completion(s["seed"], "perfect")}] for s in sub]
|
| 187 |
+
wrong = [[{"content": _synth_completion(s["seed"], "wrong_belief")}] for s in sub]
|
| 188 |
+
|
| 189 |
+
perfect_avg = sum(belief_accuracy(
|
| 190 |
+
perfect,
|
| 191 |
+
seed=replay_columns["seed"],
|
| 192 |
+
action_history=replay_columns["history"],
|
| 193 |
+
profile_mode=replay_columns["mode"],
|
| 194 |
+
step_index=replay_columns["step_index"],
|
| 195 |
+
)) / len(perfect)
|
| 196 |
+
wrong_avg = sum(belief_accuracy(
|
| 197 |
+
wrong,
|
| 198 |
+
seed=replay_columns["seed"],
|
| 199 |
+
action_history=replay_columns["history"],
|
| 200 |
+
profile_mode=replay_columns["mode"],
|
| 201 |
+
step_index=replay_columns["step_index"],
|
| 202 |
+
)) / len(wrong)
|
| 203 |
+
|
| 204 |
+
assert perfect_avg > wrong_avg, (
|
| 205 |
+
f"belief_accuracy gave no gradient: perfect={perfect_avg:.3f} wrong={wrong_avg:.3f}"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def test_no_reward_layer_is_constant_across_kinds(replay_columns):
|
| 210 |
+
"""The iter-1 collapse trap: a reward layer returning the same value for every
|
| 211 |
+
completion in a GRPO group contributes zero to advantage. At least one of
|
| 212 |
+
{good, action_only, garbage, wrong_belief} must produce a different mean
|
| 213 |
+
score from the others for each layer that's supposed to be a learning signal.
|
| 214 |
+
"""
|
| 215 |
+
sub = replay_columns["samples"]
|
| 216 |
+
kinds = ["good", "action_only", "garbage", "wrong_belief"]
|
| 217 |
+
completions_by_kind = {
|
| 218 |
+
kind: [[{"content": _synth_completion(s["seed"], kind)}] for s in sub] for kind in kinds
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# format_valid and belief_accuracy MUST discriminate.
|
| 222 |
+
f_means = [sum(format_valid(completions_by_kind[k])) / len(sub) for k in kinds]
|
| 223 |
+
assert max(f_means) - min(f_means) > 0.1, f"format_valid is near-constant: {f_means}"
|
| 224 |
+
|
| 225 |
+
b_means = [
|
| 226 |
+
sum(belief_accuracy(
|
| 227 |
+
completions_by_kind[k],
|
| 228 |
+
seed=replay_columns["seed"],
|
| 229 |
+
action_history=replay_columns["history"],
|
| 230 |
+
profile_mode=replay_columns["mode"],
|
| 231 |
+
step_index=replay_columns["step_index"],
|
| 232 |
+
)) / len(sub)
|
| 233 |
+
for k in kinds
|
| 234 |
+
]
|
| 235 |
+
assert max(b_means) - min(b_means) > 0.05, f"belief_accuracy is near-constant: {b_means}"
|
|
@@ -356,18 +356,15 @@ class TestEdgeCases:
|
|
| 356 |
|
| 357 |
def test_state_exposes_profile(self, env):
|
| 358 |
"""State should include profile_name for debugging."""
|
| 359 |
-
# Default
|
| 360 |
env.reset(seed=0)
|
| 361 |
assert env.state.profile_name != ""
|
| 362 |
assert env.state.profile_name.startswith("sampled_")
|
| 363 |
|
| 364 |
-
#
|
| 365 |
-
env.reset(seed=0, profile_mode="discrete")
|
| 366 |
-
assert env.state.profile_name in [p["name"] for p in PROFILES]
|
| 367 |
-
|
| 368 |
-
# Explicit profile: name matches the requested profile
|
| 369 |
env.reset(seed=0, profile="workaholic_stoic")
|
| 370 |
assert env.state.profile_name == "workaholic_stoic"
|
|
|
|
| 371 |
|
| 372 |
def test_all_action_types_valid(self, env):
|
| 373 |
"""Every ActionType should be processable without error."""
|
|
@@ -377,3 +374,98 @@ class TestEdgeCases:
|
|
| 377 |
e.reset(seed=0)
|
| 378 |
obs = e.step(make_action(action_type))
|
| 379 |
assert isinstance(obs, RhythmObservation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
def test_state_exposes_profile(self, env):
|
| 358 |
"""State should include profile_name for debugging."""
|
| 359 |
+
# Default: continuous profile (name like 'sampled_0')
|
| 360 |
env.reset(seed=0)
|
| 361 |
assert env.state.profile_name != ""
|
| 362 |
assert env.state.profile_name.startswith("sampled_")
|
| 363 |
|
| 364 |
+
# Explicit profile: name matches the requested reference profile
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
env.reset(seed=0, profile="workaholic_stoic")
|
| 366 |
assert env.state.profile_name == "workaholic_stoic"
|
| 367 |
+
assert env.state.profile_name in [p["name"] for p in PROFILES]
|
| 368 |
|
| 369 |
def test_all_action_types_valid(self, env):
|
| 370 |
"""Every ActionType should be processable without error."""
|
|
|
|
| 374 |
e.reset(seed=0)
|
| 375 |
obs = e.step(make_action(action_type))
|
| 376 |
assert isinstance(obs, RhythmObservation)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# ---------------------------------------------------------------------------
|
| 380 |
+
# Belief-accuracy grader component
|
| 381 |
+
# ---------------------------------------------------------------------------
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class TestBeliefAccuracyGrader:
|
| 385 |
+
"""The grader awards 0.20 weight to belief_accuracy. Agents that don't
|
| 386 |
+
emit beliefs get 0 on this component; agents whose final belief matches
|
| 387 |
+
the true profile vector get up to 0.20 added to final_score.
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
def _run_episode_with_belief(self, seed, belief, profile=None):
|
| 391 |
+
env = RhythmEnvironment()
|
| 392 |
+
if profile:
|
| 393 |
+
obs = env.reset(seed=seed, profile=profile)
|
| 394 |
+
else:
|
| 395 |
+
obs = env.reset(seed=seed)
|
| 396 |
+
for _ in range(MAX_STEPS):
|
| 397 |
+
if obs.done:
|
| 398 |
+
break
|
| 399 |
+
env.record_belief(belief)
|
| 400 |
+
obs = env.step(make_action(ActionType.SLEEP))
|
| 401 |
+
return obs.reward_breakdown.get("final_score", 0.0)
|
| 402 |
+
|
| 403 |
+
def _run_episode_no_belief(self, seed, profile=None):
|
| 404 |
+
env = RhythmEnvironment()
|
| 405 |
+
if profile:
|
| 406 |
+
obs = env.reset(seed=seed, profile=profile)
|
| 407 |
+
else:
|
| 408 |
+
obs = env.reset(seed=seed)
|
| 409 |
+
for _ in range(MAX_STEPS):
|
| 410 |
+
if obs.done:
|
| 411 |
+
break
|
| 412 |
+
obs = env.step(make_action(ActionType.SLEEP))
|
| 413 |
+
return obs.reward_breakdown.get("final_score", 0.0)
|
| 414 |
+
|
| 415 |
+
def test_no_belief_means_zero_belief_component(self, env):
|
| 416 |
+
"""Agent that never calls record_belief gets 0 on the belief component."""
|
| 417 |
+
score = self._run_episode_no_belief(seed=42)
|
| 418 |
+
# Without belief, max possible score is 0.80 (all weights ex belief).
|
| 419 |
+
# Realistic ceiling is much lower since SLEEP-only doesn't max meters.
|
| 420 |
+
assert score <= 0.80
|
| 421 |
+
|
| 422 |
+
def test_perfect_belief_lifts_score(self, env):
|
| 423 |
+
"""An agent that emits the TRUE belief vector should score higher
|
| 424 |
+
than the same actions with no belief β by up to +0.20."""
|
| 425 |
+
# Use a known reference profile so we can hand-pick the perfect belief.
|
| 426 |
+
from server.rhythm_environment import (
|
| 427 |
+
PROFILE_MAP,
|
| 428 |
+
profile_to_belief_vector,
|
| 429 |
+
)
|
| 430 |
+
profile_name = "workaholic_stoic"
|
| 431 |
+
true_belief = profile_to_belief_vector(PROFILE_MAP[profile_name])
|
| 432 |
+
|
| 433 |
+
no_belief_score = self._run_episode_no_belief(seed=7, profile=profile_name)
|
| 434 |
+
perfect_score = self._run_episode_with_belief(
|
| 435 |
+
seed=7, belief=true_belief, profile=profile_name
|
| 436 |
+
)
|
| 437 |
+
# Perfect belief contributes 0.20 to final_score
|
| 438 |
+
assert perfect_score > no_belief_score
|
| 439 |
+
assert (perfect_score - no_belief_score) == pytest.approx(0.20, abs=0.01)
|
| 440 |
+
|
| 441 |
+
def test_wrong_belief_scores_less_than_perfect(self, env):
|
| 442 |
+
"""Wrong belief still counts (0 β€ score β€ 1) but less than perfect."""
|
| 443 |
+
from server.rhythm_environment import (
|
| 444 |
+
PROFILE_MAP,
|
| 445 |
+
profile_to_belief_vector,
|
| 446 |
+
)
|
| 447 |
+
profile_name = "introvert_morning"
|
| 448 |
+
true_belief = profile_to_belief_vector(PROFILE_MAP[profile_name])
|
| 449 |
+
wrong_belief = [1.0 - b for b in true_belief] # opposite
|
| 450 |
+
|
| 451 |
+
perfect_score = self._run_episode_with_belief(
|
| 452 |
+
seed=7, belief=true_belief, profile=profile_name
|
| 453 |
+
)
|
| 454 |
+
wrong_score = self._run_episode_with_belief(
|
| 455 |
+
seed=7, belief=wrong_belief, profile=profile_name
|
| 456 |
+
)
|
| 457 |
+
assert perfect_score > wrong_score
|
| 458 |
+
|
| 459 |
+
def test_record_belief_validates_length(self, env):
|
| 460 |
+
env.reset(seed=0)
|
| 461 |
+
with pytest.raises(ValueError):
|
| 462 |
+
env.record_belief([0.5, 0.5]) # wrong length
|
| 463 |
+
with pytest.raises(ValueError):
|
| 464 |
+
env.record_belief([0.5, 0.5, 0.5, 0.5]) # too long
|
| 465 |
+
|
| 466 |
+
def test_record_belief_clamps_to_unit_interval(self, env):
|
| 467 |
+
"""Beliefs outside [0, 1] should be clamped, not rejected."""
|
| 468 |
+
env.reset(seed=0)
|
| 469 |
+
env.record_belief([-0.5, 1.5, 0.5])
|
| 470 |
+
# Internal state should be clamped
|
| 471 |
+
assert env._final_belief == [0.0, 1.0, 0.5]
|
|
@@ -1,17 +1,14 @@
|
|
| 1 |
"""
|
| 2 |
Dataset generator for RhythmEnv GRPO training (meta-RL version).
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
of using profile signals before forcing inference)
|
| 13 |
-
- Dataset rows include seed, step_index, action_history, profile_mode
|
| 14 |
-
so env_reward and belief_accuracy can replay deterministically
|
| 15 |
"""
|
| 16 |
|
| 17 |
import sys
|
|
@@ -26,37 +23,53 @@ from server.rhythm_environment import RhythmEnvironment, MAX_STEPS, METERS
|
|
| 26 |
SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"]
|
| 27 |
DAY_NAMES = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
|
| 28 |
|
| 29 |
-
SYSTEM_PROMPT =
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def format_observation_prompt(obs, profile_hint: dict | None = None) -> str:
|
|
@@ -72,10 +85,10 @@ def format_observation_prompt(obs, profile_hint: dict | None = None) -> str:
|
|
| 72 |
|
| 73 |
history_lines = []
|
| 74 |
for h in (obs.step_history or [])[-5:]: # last 5 only to fit prompt budget
|
| 75 |
-
#
|
| 76 |
-
#
|
| 77 |
-
#
|
| 78 |
-
#
|
| 79 |
anom_str = (
|
| 80 |
f" [anom V{h.vitality_anomaly:+.2f} C{h.cognition_anomaly:+.2f} "
|
| 81 |
f"P{h.progress_anomaly:+.2f} S{h.serenity_anomaly:+.2f} "
|
|
@@ -170,7 +183,7 @@ def generate_episode_samples(
|
|
| 170 |
if strategy == "random":
|
| 171 |
action_type = rng.choice(all_actions)
|
| 172 |
elif strategy == "heuristic":
|
| 173 |
-
action_type =
|
| 174 |
else:
|
| 175 |
action_type = rng.choice(all_actions)
|
| 176 |
|
|
@@ -181,8 +194,12 @@ def generate_episode_samples(
|
|
| 181 |
return samples
|
| 182 |
|
| 183 |
|
| 184 |
-
def
|
| 185 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
slot = obs.slot
|
| 187 |
v, c, p, s, cn = obs.vitality, obs.cognition, obs.progress, obs.serenity, obs.connection
|
| 188 |
|
|
|
|
| 1 |
"""
|
| 2 |
Dataset generator for RhythmEnv GRPO training (meta-RL version).
|
| 3 |
|
| 4 |
+
Plays episodes under a continuously-sampled profile per seed and emits
|
| 5 |
+
observation prompts at each step, paired with the replay metadata
|
| 6 |
+
(seed, step_index, action_history) the reward functions need to
|
| 7 |
+
reconstruct env state deterministically.
|
| 8 |
+
|
| 9 |
+
The system prompt asks for "S M W ACTION_NAME" β three belief digits then
|
| 10 |
+
the action. A `hint_fraction` slice of episodes carries a true-belief hint
|
| 11 |
+
in the prompt as a curriculum warmup; the rest force pure inference.
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import sys
|
|
|
|
| 23 |
SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"]
|
| 24 |
DAY_NAMES = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
|
| 25 |
|
| 26 |
+
SYSTEM_PROMPT = """You are a life-management agent helping a person whose preferences are HIDDEN.
|
| 27 |
+
You see 5 life meters and a rolling history of recent steps. The same action
|
| 28 |
+
affects different people differently β you must INFER who you're helping from
|
| 29 |
+
rewards, meter changes, and per-meter ANOMALY signals.
|
| 30 |
+
|
| 31 |
+
Each step, do TWO things:
|
| 32 |
+
|
| 33 |
+
1. Reason briefly about what the observations imply about the person.
|
| 34 |
+
Focus on:
|
| 35 |
+
- Anomalies (actual delta vs neutral-profile expectation): big positive
|
| 36 |
+
social_serenity / connection responses β high S; big morning cognition
|
| 37 |
+
gains β high M; productive work giving vitality back β high W
|
| 38 |
+
- Current meter state: any meter under 0.15 needs urgent recovery
|
| 39 |
+
- What action best fits BOTH the inferred profile and the current state
|
| 40 |
+
|
| 41 |
+
2. Output your final answer on the LAST line in this exact format:
|
| 42 |
+
S M W ACTION_NAME
|
| 43 |
+
where S, M, W are belief digits 0-9 (0=low, 9=high) representing your best
|
| 44 |
+
estimate of social_pref, morning_pref, work_pref. ACTION_NAME is one of:
|
| 45 |
+
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE, FAMILY_TIME,
|
| 46 |
+
SOCIALIZE, ME_TIME, BINGE_WATCH
|
| 47 |
+
|
| 48 |
+
Wrap your reasoning in <reasoning>...</reasoning> tags. Keep reasoning under
|
| 49 |
+
120 tokens. The final answer line MUST be the last line of your response.
|
| 50 |
+
|
| 51 |
+
Beliefβaction quick reference:
|
| 52 |
+
- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply
|
| 53 |
+
- High M (morning person): DEEP_WORK / LEARN in early slots gets bonus cognition
|
| 54 |
+
- High W (workaholic): DEEP_WORK, LEARN drive progress and may energize
|
| 55 |
+
- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE
|
| 56 |
+
- Low M (night owl): DEEP_WORK / LEARN in evening/night slots
|
| 57 |
+
- Watch crashes: any meter under 0.10 = -0.30 penalty per crashed meter
|
| 58 |
+
- Connection decays passively β actively maintain via SOCIALIZE/FAMILY_TIME
|
| 59 |
+
- Don't repeat the same action 3+ times in a row β repetition penalty applies
|
| 60 |
+
|
| 61 |
+
Strategy: probe varied actions in the first ~5 steps to gather profile evidence,
|
| 62 |
+
then exploit your sharpened belief by picking actions that match the inferred
|
| 63 |
+
profile + current meter state.
|
| 64 |
+
|
| 65 |
+
Example output:
|
| 66 |
+
<reasoning>
|
| 67 |
+
Last step's socialize gave V-0.12 (anom -0.06, much worse than neutral) β high
|
| 68 |
+
social drain, suggests low S. Morning DEEP_WORK earlier gave bonus cognition
|
| 69 |
+
(anom +0.04) β high M. Vitality at 0.6 still ok, serenity dropping. With low S +
|
| 70 |
+
high M, MEDITATE is the recovery play that fits.
|
| 71 |
+
</reasoning>
|
| 72 |
+
2 8 5 MEDITATE"""
|
| 73 |
|
| 74 |
|
| 75 |
def format_observation_prompt(obs, profile_hint: dict | None = None) -> str:
|
|
|
|
| 85 |
|
| 86 |
history_lines = []
|
| 87 |
for h in (obs.step_history or [])[-5:]: # last 5 only to fit prompt budget
|
| 88 |
+
# Per-meter anomalies (actual_delta β expected_under_neutral_profile)
|
| 89 |
+
# are the cleanest profile-inference signal β they show how this person's
|
| 90 |
+
# response DEVIATES from the average person. Surfacing them here in the
|
| 91 |
+
# prompt is what gives the agent a fingerprint to learn from.
|
| 92 |
anom_str = (
|
| 93 |
f" [anom V{h.vitality_anomaly:+.2f} C{h.cognition_anomaly:+.2f} "
|
| 94 |
f"P{h.progress_anomaly:+.2f} S{h.serenity_anomaly:+.2f} "
|
|
|
|
| 183 |
if strategy == "random":
|
| 184 |
action_type = rng.choice(all_actions)
|
| 185 |
elif strategy == "heuristic":
|
| 186 |
+
action_type = heuristic_action(obs)
|
| 187 |
else:
|
| 188 |
action_type = rng.choice(all_actions)
|
| 189 |
|
|
|
|
| 194 |
return samples
|
| 195 |
|
| 196 |
|
| 197 |
+
def heuristic_action(obs) -> ActionType:
|
| 198 |
+
"""Priority-based heuristic baseline (profile-blind).
|
| 199 |
+
|
| 200 |
+
Used both during dataset generation (to roll out diverse states) and
|
| 201 |
+
by inference_eval as the heuristic baseline strategy.
|
| 202 |
+
"""
|
| 203 |
slot = obs.slot
|
| 204 |
v, c, p, s, cn = obs.vitality, obs.cognition, obs.progress, obs.serenity, obs.connection
|
| 205 |
|
|
@@ -2,8 +2,8 @@
|
|
| 2 |
RhythmEnv Inference Evaluation β Baseline vs Trained, with meta-RL eval suite.
|
| 3 |
|
| 4 |
Three evaluation conditions:
|
| 5 |
-
1. discrete-3-profiles:
|
| 6 |
-
|
| 7 |
2. continuous-in-distribution: Sampled profiles from the training distribution
|
| 8 |
(was the agent able to learn the meta-policy?)
|
| 9 |
3. continuous-OOD: Profiles from a held-out region of the parameter space
|
|
@@ -28,6 +28,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
|
| 28 |
|
| 29 |
from models import ActionType, RhythmAction
|
| 30 |
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS, sample_profile, profile_to_belief_vector
|
|
|
|
| 31 |
|
| 32 |
DISCRETE_PROFILES = ["introvert_morning", "extrovert_night_owl", "workaholic_stoic"]
|
| 33 |
SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"]
|
|
@@ -39,40 +40,14 @@ IN_DIST_SEEDS_DEFAULT = list(range(100, 110)) # 10 unseen-by-training in-distr
|
|
| 39 |
OOD_SEEDS_DEFAULT = list(range(10000, 10010)) # 10 OOD seeds
|
| 40 |
|
| 41 |
|
| 42 |
-
def heuristic_action(obs) -> ActionType:
|
| 43 |
-
"""Priority-based heuristic baseline (profile-blind)."""
|
| 44 |
-
slot = obs.slot
|
| 45 |
-
v, c, p, s, cn = obs.vitality, obs.cognition, obs.progress, obs.serenity, obs.connection
|
| 46 |
-
|
| 47 |
-
if v < 0.15:
|
| 48 |
-
return ActionType.SLEEP
|
| 49 |
-
if s < 0.15:
|
| 50 |
-
return ActionType.MEDITATE
|
| 51 |
-
if cn < 0.15:
|
| 52 |
-
return ActionType.FAMILY_TIME
|
| 53 |
-
if slot == 3:
|
| 54 |
-
return ActionType.SLEEP
|
| 55 |
-
if slot == 0:
|
| 56 |
-
return ActionType.DEEP_WORK if (v > 0.4 and c > 0.3) else ActionType.EXERCISE
|
| 57 |
-
if slot == 1:
|
| 58 |
-
if cn < 0.3:
|
| 59 |
-
return ActionType.FAMILY_TIME
|
| 60 |
-
if p < 0.3 and v > 0.3:
|
| 61 |
-
return ActionType.LEARN
|
| 62 |
-
return ActionType.ADMIN_WORK
|
| 63 |
-
if cn < 0.4:
|
| 64 |
-
return ActionType.SOCIALIZE
|
| 65 |
-
if s < 0.5:
|
| 66 |
-
return ActionType.ME_TIME
|
| 67 |
-
return ActionType.MEDITATE
|
| 68 |
-
|
| 69 |
-
|
| 70 |
def random_action(rng) -> ActionType:
|
| 71 |
return rng.choice(list(ActionType))
|
| 72 |
|
| 73 |
|
| 74 |
def model_action(obs, model, tokenizer, return_belief: bool = False):
|
| 75 |
"""Get action (and optionally belief) from trained model."""
|
|
|
|
|
|
|
| 76 |
from training.dataset import format_observation_prompt, SYSTEM_PROMPT
|
| 77 |
from training.reward_functions import extract_action_and_belief
|
| 78 |
|
|
@@ -129,6 +104,10 @@ def run_episode(
|
|
| 129 |
elif strategy == "model" and model is not None:
|
| 130 |
action_type, belief = model_action(obs, model, tokenizer, return_belief=True)
|
| 131 |
beliefs_seen.append(belief)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
else:
|
| 133 |
action_type = random_action(rng)
|
| 134 |
|
|
@@ -249,13 +228,14 @@ def main():
|
|
| 249 |
|
| 250 |
all_results = []
|
| 251 |
|
| 252 |
-
# Condition 1:
|
|
|
|
| 253 |
discrete_runs = [
|
| 254 |
-
{"seed": ep, "profile": p
|
| 255 |
for p in DISCRETE_PROFILES for ep in range(args.num_episodes)
|
| 256 |
]
|
| 257 |
all_results += eval_condition(
|
| 258 |
-
"discrete-3-profiles
|
| 259 |
strategies, discrete_runs,
|
| 260 |
model=model, tokenizer=tokenizer,
|
| 261 |
)
|
|
@@ -276,7 +256,7 @@ def main():
|
|
| 276 |
model=model, tokenizer=tokenizer,
|
| 277 |
)
|
| 278 |
|
| 279 |
-
# Per-profile breakdown for
|
| 280 |
print(f"\n{'=' * 70}")
|
| 281 |
print("DISCRETE-3-PROFILE BREAKDOWN")
|
| 282 |
print(f"{'=' * 70}")
|
|
@@ -285,7 +265,7 @@ def main():
|
|
| 285 |
print(f"{s:>10}", end="")
|
| 286 |
print()
|
| 287 |
print("-" * 70)
|
| 288 |
-
discrete = [r for r in all_results if r["condition"] == "discrete-3-profiles
|
| 289 |
for profile in DISCRETE_PROFILES:
|
| 290 |
row = f"{profile:<25} "
|
| 291 |
for s in strategies:
|
|
|
|
| 2 |
RhythmEnv Inference Evaluation β Baseline vs Trained, with meta-RL eval suite.
|
| 3 |
|
| 4 |
Three evaluation conditions:
|
| 5 |
+
1. discrete-3-profiles: 3 hardcoded reference profiles. A sanity check
|
| 6 |
+
that the meta-trained agent still handles the original named profiles.
|
| 7 |
2. continuous-in-distribution: Sampled profiles from the training distribution
|
| 8 |
(was the agent able to learn the meta-policy?)
|
| 9 |
3. continuous-OOD: Profiles from a held-out region of the parameter space
|
|
|
|
| 28 |
|
| 29 |
from models import ActionType, RhythmAction
|
| 30 |
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS, sample_profile, profile_to_belief_vector
|
| 31 |
+
from training.dataset import heuristic_action
|
| 32 |
|
| 33 |
DISCRETE_PROFILES = ["introvert_morning", "extrovert_night_owl", "workaholic_stoic"]
|
| 34 |
SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"]
|
|
|
|
| 40 |
OOD_SEEDS_DEFAULT = list(range(10000, 10010)) # 10 OOD seeds
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def random_action(rng) -> ActionType:
|
| 44 |
return rng.choice(list(ActionType))
|
| 45 |
|
| 46 |
|
| 47 |
def model_action(obs, model, tokenizer, return_belief: bool = False):
|
| 48 |
"""Get action (and optionally belief) from trained model."""
|
| 49 |
+
# Lazy imports: keep the heavy training-stack imports out of module load
|
| 50 |
+
# so this script can run in baseline-only mode without unsloth/transformers.
|
| 51 |
from training.dataset import format_observation_prompt, SYSTEM_PROMPT
|
| 52 |
from training.reward_functions import extract_action_and_belief
|
| 53 |
|
|
|
|
| 104 |
elif strategy == "model" and model is not None:
|
| 105 |
action_type, belief = model_action(obs, model, tokenizer, return_belief=True)
|
| 106 |
beliefs_seen.append(belief)
|
| 107 |
+
# Tell the env about the emitted belief so the grader can score
|
| 108 |
+
# belief_accuracy. Heuristic / random skip this β they get 0 on
|
| 109 |
+
# the belief component, by design.
|
| 110 |
+
env.record_belief(belief)
|
| 111 |
else:
|
| 112 |
action_type = random_action(rng)
|
| 113 |
|
|
|
|
| 228 |
|
| 229 |
all_results = []
|
| 230 |
|
| 231 |
+
# Condition 1: 3 hardcoded reference profiles (sanity-check the agent
|
| 232 |
+
# still handles the named profiles β no longer the primary eval signal).
|
| 233 |
discrete_runs = [
|
| 234 |
+
{"seed": ep, "profile": p}
|
| 235 |
for p in DISCRETE_PROFILES for ep in range(args.num_episodes)
|
| 236 |
]
|
| 237 |
all_results += eval_condition(
|
| 238 |
+
"discrete-3-profiles",
|
| 239 |
strategies, discrete_runs,
|
| 240 |
model=model, tokenizer=tokenizer,
|
| 241 |
)
|
|
|
|
| 256 |
model=model, tokenizer=tokenizer,
|
| 257 |
)
|
| 258 |
|
| 259 |
+
# Per-profile breakdown for the 3 reference profiles
|
| 260 |
print(f"\n{'=' * 70}")
|
| 261 |
print("DISCRETE-3-PROFILE BREAKDOWN")
|
| 262 |
print(f"{'=' * 70}")
|
|
|
|
| 265 |
print(f"{s:>10}", end="")
|
| 266 |
print()
|
| 267 |
print("-" * 70)
|
| 268 |
+
discrete = [r for r in all_results if r["condition"] == "discrete-3-profiles"]
|
| 269 |
for profile in DISCRETE_PROFILES:
|
| 270 |
row = f"{profile:<25} "
|
| 271 |
for s in strategies:
|
|
@@ -4,32 +4,27 @@ Reward functions for RhythmEnv GRPO training (meta-RL version).
|
|
| 4 |
Four-layer reward stack:
|
| 5 |
1. format_valid β does the LLM output have a parseable belief + action format?
|
| 6 |
2. action_legal β is the action one of the 10 valid actions?
|
| 7 |
-
3. env_reward β actual environment reward (seed-replay) for the chosen action
|
| 8 |
-
plus diversity/exploration shaping
|
| 9 |
4. belief_accuracy β how close is the belief vector to the hidden profile's true vector?
|
| 10 |
|
| 11 |
-
|
| 12 |
- S, M, W: single digits 0-9 representing the agent's belief about the user
|
| 13 |
S = social preference (0=hates social, 9=loves social)
|
| 14 |
M = morning preference (0=night owl, 9=morning person)
|
| 15 |
W = work preference (0=avoids work, 9=workaholic)
|
| 16 |
- ACTION_NAME: one of 10 valid actions
|
| 17 |
|
| 18 |
-
Example: "3 8 7 DEEP_WORK"
|
| 19 |
-
β belief=[0.33, 0.89, 0.78], action=DEEP_WORK
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
the
|
| 24 |
-
|
| 25 |
-
post-hoc afterthought that didn't influence behavior.
|
| 26 |
-
|
| 27 |
-
The parser ALSO accepts the old "ACTION S M W" format for backward compatibility.
|
| 28 |
|
| 29 |
Each function returns a list of floats (one per completion).
|
| 30 |
"""
|
| 31 |
|
| 32 |
-
import math
|
| 33 |
import os
|
| 34 |
import re
|
| 35 |
import sys
|
|
@@ -46,95 +41,57 @@ VALID_ACTIONS = {at.value.upper(): at for at in ActionType}
|
|
| 46 |
DEFAULT_BELIEF = [0.5, 0.5, 0.5]
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def extract_action_and_belief(text: str) -> tuple[ActionType | None, list[float], bool]:
|
| 50 |
-
"""Parse '
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
Returns:
|
| 53 |
(action, belief, belief_provided):
|
| 54 |
-
action: parsed ActionType or None
|
| 55 |
-
belief: 3-dim vector in [0, 1], DEFAULT_BELIEF if
|
| 56 |
-
belief_provided: True iff
|
| 57 |
"""
|
| 58 |
if not text:
|
| 59 |
return None, list(DEFAULT_BELIEF), False
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
parts = line.upper().replace(",", " ").split()
|
| 67 |
-
if not parts:
|
| 68 |
-
return None, list(DEFAULT_BELIEF), False
|
| 69 |
-
|
| 70 |
-
# Find action and its index in parts
|
| 71 |
-
action: ActionType | None = None
|
| 72 |
-
action_idx = -1
|
| 73 |
-
for idx, p in enumerate(parts):
|
| 74 |
-
if p in VALID_ACTIONS:
|
| 75 |
-
action = VALID_ACTIONS[p]
|
| 76 |
-
action_idx = idx
|
| 77 |
-
break
|
| 78 |
-
if action is None:
|
| 79 |
-
for idx, p in enumerate(parts):
|
| 80 |
-
for name, at in VALID_ACTIONS.items():
|
| 81 |
-
if name in p:
|
| 82 |
-
action = at
|
| 83 |
-
action_idx = idx
|
| 84 |
-
break
|
| 85 |
-
if action is not None:
|
| 86 |
-
break
|
| 87 |
-
|
| 88 |
-
# Iter 3: parse belief from BEFORE the action (belief-first format).
|
| 89 |
-
# Falls back to AFTER the action (legacy format) if no digits found before.
|
| 90 |
-
def _parse_digit(token: str) -> float | None:
|
| 91 |
-
token = token.strip().rstrip(".")
|
| 92 |
-
if not token:
|
| 93 |
-
return None
|
| 94 |
try:
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
belief_provided = True
|
| 114 |
-
|
| 115 |
-
# If belief-first didn't work, try legacy after-action format
|
| 116 |
-
if not belief_provided:
|
| 117 |
-
after_belief: list[float] = []
|
| 118 |
-
after_provided = False
|
| 119 |
-
for i in range(3):
|
| 120 |
-
j = action_idx + 1 + i
|
| 121 |
-
if j < len(parts):
|
| 122 |
-
d = _parse_digit(parts[j])
|
| 123 |
-
if d is not None:
|
| 124 |
-
after_belief.append(d)
|
| 125 |
-
after_provided = True
|
| 126 |
-
else:
|
| 127 |
-
after_belief.append(0.5)
|
| 128 |
-
else:
|
| 129 |
-
after_belief.append(0.5)
|
| 130 |
-
if after_provided:
|
| 131 |
-
belief = after_belief
|
| 132 |
-
belief_provided = True
|
| 133 |
-
|
| 134 |
-
if not belief or len(belief) != 3:
|
| 135 |
-
belief = list(DEFAULT_BELIEF)
|
| 136 |
-
|
| 137 |
-
return action, belief, belief_provided
|
| 138 |
|
| 139 |
|
| 140 |
def extract_action(text: str) -> ActionType | None:
|
|
@@ -168,10 +125,11 @@ def action_legal(completions, **kwargs) -> list[float]:
|
|
| 168 |
"""
|
| 169 |
Layer 2: Is the parsed action one of the 10 valid actions?
|
| 170 |
|
| 171 |
-
All 10 actions are always legal in this env (no state-dependent validity)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
| 175 |
"""
|
| 176 |
scores = []
|
| 177 |
for completion in completions:
|
|
@@ -230,8 +188,9 @@ def env_reward(
|
|
| 230 |
ep_history = prompt_data.get("action_history", [])
|
| 231 |
ep_mode = prompt_data.get("profile_mode", "continuous")
|
| 232 |
else:
|
| 233 |
-
#
|
| 234 |
-
#
|
|
|
|
| 235 |
ep_seed = (i * 17) ^ 0xBEEF
|
| 236 |
ep_history = []
|
| 237 |
ep_mode = "continuous"
|
|
@@ -245,36 +204,38 @@ def env_reward(
|
|
| 245 |
reward = obs.reward
|
| 246 |
chosen = action_type.value
|
| 247 |
|
| 248 |
-
#
|
| 249 |
-
#
|
| 250 |
-
#
|
|
|
|
| 251 |
progress_delta = env._progress - pre_progress
|
| 252 |
connection_delta = env._connection - pre_connection
|
| 253 |
reward += 0.5 * progress_delta + 0.4 * connection_delta
|
| 254 |
|
| 255 |
-
#
|
| 256 |
-
#
|
| 257 |
-
#
|
|
|
|
|
|
|
| 258 |
if ep_history and len(ep_history) >= 2:
|
| 259 |
recent3 = ep_history[-3:]
|
| 260 |
if recent3.count(chosen) >= 2:
|
| 261 |
-
reward -= 0.10
|
| 262 |
|
| 263 |
if ep_history and len(ep_history) >= 5:
|
| 264 |
last6 = ep_history[-5:] + [chosen]
|
| 265 |
if len(set(last6)) <= 2:
|
| 266 |
-
reward -= 0.15
|
| 267 |
|
| 268 |
if ep_history is not None:
|
| 269 |
seen = set(ep_history)
|
| 270 |
if chosen not in seen and len(seen) < 6:
|
| 271 |
-
reward += 0.07
|
| 272 |
|
| 273 |
-
#
|
| 274 |
-
#
|
| 275 |
-
# belief
|
| 276 |
-
#
|
| 277 |
-
# an explicit gradient signal.
|
| 278 |
_, b, b_provided = extract_action_and_belief(response)
|
| 279 |
if b_provided:
|
| 280 |
s_pref, m_pref, w_pref = b
|
|
@@ -308,7 +269,6 @@ def env_reward(
|
|
| 308 |
|
| 309 |
def belief_accuracy(
|
| 310 |
completions,
|
| 311 |
-
prompts=None,
|
| 312 |
seed=None,
|
| 313 |
step_index=None,
|
| 314 |
action_history=None,
|
|
@@ -316,21 +276,17 @@ def belief_accuracy(
|
|
| 316 |
**kwargs,
|
| 317 |
) -> list[float]:
|
| 318 |
"""
|
| 319 |
-
Layer 4: Belief-vector accuracy reward (
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
| 325 |
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
- Constant emission β reward β 0 (no free reward)
|
| 329 |
-
- Better-than-baseline belief β positive
|
| 330 |
-
- Worse-than-baseline belief β negative
|
| 331 |
-
|
| 332 |
-
Plus iter 4 (Issue 9): no belief reward at step 0 (no information available
|
| 333 |
-
to commit a belief) β prevents pulling the policy toward a constant prior.
|
| 334 |
"""
|
| 335 |
scores = []
|
| 336 |
for i, completion in enumerate(completions):
|
|
@@ -351,8 +307,9 @@ def belief_accuracy(
|
|
| 351 |
scores.append(0.0)
|
| 352 |
continue
|
| 353 |
|
| 354 |
-
#
|
| 355 |
-
#
|
|
|
|
| 356 |
if ep_step == 0:
|
| 357 |
scores.append(0.0)
|
| 358 |
continue
|
|
@@ -360,7 +317,6 @@ def belief_accuracy(
|
|
| 360 |
try:
|
| 361 |
env = _replay_env(ep_seed, ep_history, ep_mode)
|
| 362 |
true_belief = env.get_belief_target()
|
| 363 |
-
# Iter 4 fix (Issue 4): subtract the constant-baseline reward
|
| 364 |
mae = sum(abs(b - t) for b, t in zip(belief, true_belief)) / 3.0
|
| 365 |
similarity = 1.0 - mae
|
| 366 |
baseline_mae = sum(abs(0.5 - t) for t in true_belief) / 3.0
|
|
@@ -371,69 +327,3 @@ def belief_accuracy(
|
|
| 371 |
scores.append(0.0)
|
| 372 |
|
| 373 |
return scores
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
def env_reward_simple(completions, prompts=None, **kwargs) -> list[float]:
|
| 377 |
-
"""
|
| 378 |
-
State-aware heuristic reward (no env replay), used for fast smoke-training.
|
| 379 |
-
Identical to original simple reward β the format change is handled by extract_action.
|
| 380 |
-
"""
|
| 381 |
-
scores = []
|
| 382 |
-
for i, completion in enumerate(completions):
|
| 383 |
-
response = completion[0]["content"] if isinstance(completion, list) else completion
|
| 384 |
-
action_type = extract_action(response)
|
| 385 |
-
|
| 386 |
-
if action_type is None:
|
| 387 |
-
scores.append(-2.0)
|
| 388 |
-
continue
|
| 389 |
-
|
| 390 |
-
prompt_text = ""
|
| 391 |
-
if prompts and i < len(prompts):
|
| 392 |
-
p = prompts[i]
|
| 393 |
-
if isinstance(p, list):
|
| 394 |
-
prompt_text = p[-1].get("content", "") if p else ""
|
| 395 |
-
elif isinstance(p, dict):
|
| 396 |
-
prompt_text = p.get("content", "")
|
| 397 |
-
else:
|
| 398 |
-
prompt_text = str(p)
|
| 399 |
-
|
| 400 |
-
meters = {}
|
| 401 |
-
for meter_name in ["Vitality", "Cognition", "Progress", "Serenity", "Connection"]:
|
| 402 |
-
match = re.search(rf"{meter_name}:\s*([\d.]+)", prompt_text)
|
| 403 |
-
if match:
|
| 404 |
-
meters[meter_name.lower()] = float(match.group(1))
|
| 405 |
-
|
| 406 |
-
is_morning = "Morning" in prompt_text
|
| 407 |
-
is_night = "Night" in prompt_text
|
| 408 |
-
|
| 409 |
-
score = 0.0
|
| 410 |
-
v = meters.get("vitality", 0.5)
|
| 411 |
-
s = meters.get("serenity", 0.5)
|
| 412 |
-
cn = meters.get("connection", 0.5)
|
| 413 |
-
|
| 414 |
-
if v < 0.2 and action_type in (ActionType.SLEEP, ActionType.EXERCISE):
|
| 415 |
-
score += 0.5
|
| 416 |
-
elif v < 0.2:
|
| 417 |
-
score -= 0.3
|
| 418 |
-
|
| 419 |
-
if s < 0.2 and action_type in (ActionType.MEDITATE, ActionType.ME_TIME):
|
| 420 |
-
score += 0.5
|
| 421 |
-
elif s < 0.2 and action_type != ActionType.SLEEP:
|
| 422 |
-
score -= 0.2
|
| 423 |
-
|
| 424 |
-
if cn < 0.2 and action_type in (ActionType.FAMILY_TIME, ActionType.SOCIALIZE):
|
| 425 |
-
score += 0.5
|
| 426 |
-
elif cn < 0.2:
|
| 427 |
-
score -= 0.1
|
| 428 |
-
|
| 429 |
-
if is_morning and action_type in (ActionType.DEEP_WORK, ActionType.LEARN):
|
| 430 |
-
score += 0.2
|
| 431 |
-
if is_night and action_type == ActionType.SLEEP:
|
| 432 |
-
score += 0.3
|
| 433 |
-
|
| 434 |
-
if action_type == ActionType.BINGE_WATCH:
|
| 435 |
-
score -= 0.3
|
| 436 |
-
|
| 437 |
-
scores.append(score)
|
| 438 |
-
|
| 439 |
-
return scores
|
|
|
|
| 4 |
Four-layer reward stack:
|
| 5 |
1. format_valid β does the LLM output have a parseable belief + action format?
|
| 6 |
2. action_legal β is the action one of the 10 valid actions?
|
| 7 |
+
3. env_reward β actual environment reward (seed-replay) for the chosen action,
|
| 8 |
+
plus grader-aligned bias and diversity/exploration shaping
|
| 9 |
4. belief_accuracy β how close is the belief vector to the hidden profile's true vector?
|
| 10 |
|
| 11 |
+
Output format: "S M W ACTION_NAME" (belief first)
|
| 12 |
- S, M, W: single digits 0-9 representing the agent's belief about the user
|
| 13 |
S = social preference (0=hates social, 9=loves social)
|
| 14 |
M = morning preference (0=night owl, 9=morning person)
|
| 15 |
W = work preference (0=avoids work, 9=workaholic)
|
| 16 |
- ACTION_NAME: one of 10 valid actions
|
| 17 |
|
| 18 |
+
Example: "3 8 7 DEEP_WORK" β belief=[0.33, 0.89, 0.78], action=DEEP_WORK
|
|
|
|
| 19 |
|
| 20 |
+
Belief-first matters because tokens generated earlier condition tokens generated
|
| 21 |
+
later in causal LMs β the action ends up causally conditioned on the belief, so
|
| 22 |
+
the belief is functionally useful for action selection rather than a post-hoc
|
| 23 |
+
afterthought. The parser also accepts a legacy "ACTION S M W" ordering as fallback.
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
Each function returns a list of floats (one per completion).
|
| 26 |
"""
|
| 27 |
|
|
|
|
| 28 |
import os
|
| 29 |
import re
|
| 30 |
import sys
|
|
|
|
| 41 |
DEFAULT_BELIEF = [0.5, 0.5, 0.5]
|
| 42 |
|
| 43 |
|
| 44 |
+
_ANSWER_PATTERN = re.compile(
|
| 45 |
+
r"(\d)\s+(\d)\s+(\d)\s+(" + "|".join(at.value.upper() for at in ActionType) + r")\b",
|
| 46 |
+
re.IGNORECASE,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
def extract_action_and_belief(text: str) -> tuple[ActionType | None, list[float], bool]:
|
| 51 |
+
"""Parse the agent's output and extract (action, belief, belief_provided).
|
| 52 |
+
|
| 53 |
+
Supports both formats:
|
| 54 |
+
- Plain answer line: "S M W ACTION_NAME"
|
| 55 |
+
- With CoT prefix: "<reasoning>...</reasoning>\nS M W ACTION_NAME"
|
| 56 |
+
|
| 57 |
+
Strategy: search the entire response for the LAST occurrence of the
|
| 58 |
+
"<digit> <digit> <digit> <ACTION>" pattern. Taking the last match handles
|
| 59 |
+
cases where the model mentions an example mid-reasoning then commits to
|
| 60 |
+
a different answer at the end.
|
| 61 |
|
| 62 |
Returns:
|
| 63 |
(action, belief, belief_provided):
|
| 64 |
+
action: parsed ActionType or None if no valid action anywhere
|
| 65 |
+
belief: 3-dim vector in [0, 1], DEFAULT_BELIEF if no belief parsed
|
| 66 |
+
belief_provided: True iff a belief was extracted (full S M W ACTION pattern)
|
| 67 |
"""
|
| 68 |
if not text:
|
| 69 |
return None, list(DEFAULT_BELIEF), False
|
| 70 |
|
| 71 |
+
# Primary path: full belief+action pattern, take the LAST occurrence.
|
| 72 |
+
matches = list(_ANSWER_PATTERN.finditer(text))
|
| 73 |
+
if matches:
|
| 74 |
+
last = matches[-1]
|
| 75 |
+
s, m, w, action_name = last.groups()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
try:
|
| 77 |
+
belief = [int(s) / 9.0, int(m) / 9.0, int(w) / 9.0]
|
| 78 |
+
action = ActionType(action_name.lower())
|
| 79 |
+
return action, belief, True
|
| 80 |
+
except (ValueError, KeyError):
|
| 81 |
+
pass # fall through to action-only fallback
|
| 82 |
+
|
| 83 |
+
# Fallback: action-only output (no belief digits). Search for any valid
|
| 84 |
+
# ACTION_NAME token in the response and return that with default belief.
|
| 85 |
+
for line in text.strip().split("\n"):
|
| 86 |
+
for token in line.upper().replace(",", " ").split():
|
| 87 |
+
token = token.strip(".")
|
| 88 |
+
if token in VALID_ACTIONS:
|
| 89 |
+
return VALID_ACTIONS[token], list(DEFAULT_BELIEF), False
|
| 90 |
+
for name, at in VALID_ACTIONS.items():
|
| 91 |
+
if name in token:
|
| 92 |
+
return at, list(DEFAULT_BELIEF), False
|
| 93 |
+
|
| 94 |
+
return None, list(DEFAULT_BELIEF), False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
def extract_action(text: str) -> ActionType | None:
|
|
|
|
| 125 |
"""
|
| 126 |
Layer 2: Is the parsed action one of the 10 valid actions?
|
| 127 |
|
| 128 |
+
All 10 actions are always legal in this env (no state-dependent validity),
|
| 129 |
+
so this layer is a pure penalty: 0 for any valid action, -1 for unparseable
|
| 130 |
+
output. Returning 0 (rather than a positive constant) is intentional β
|
| 131 |
+
a constant reward across all completions in a GRPO group contributes
|
| 132 |
+
exactly zero to the advantage and was a major contributor to mode collapse.
|
| 133 |
"""
|
| 134 |
scores = []
|
| 135 |
for completion in completions:
|
|
|
|
| 188 |
ep_history = prompt_data.get("action_history", [])
|
| 189 |
ep_mode = prompt_data.get("profile_mode", "continuous")
|
| 190 |
else:
|
| 191 |
+
# Fallback seed mixes index with a prime to break deterministic
|
| 192 |
+
# seed clusters β without it, all completions in a position-class
|
| 193 |
+
# land on the same env_reward and contribute zero to GRPO advantage.
|
| 194 |
ep_seed = (i * 17) ^ 0xBEEF
|
| 195 |
ep_history = []
|
| 196 |
ep_mode = "continuous"
|
|
|
|
| 204 |
reward = obs.reward
|
| 205 |
chosen = action_type.value
|
| 206 |
|
| 207 |
+
# Grader-aligned bias (progress + connection deltas): shapes only
|
| 208 |
+
# the GRPO-visible training reward, not the env's per-step reward.
|
| 209 |
+
# Lives here rather than in env.step() so that the env's per-step
|
| 210 |
+
# reward (used by adaptation_score in the grader) stays pure.
|
| 211 |
progress_delta = env._progress - pre_progress
|
| 212 |
connection_delta = env._connection - pre_connection
|
| 213 |
reward += 0.5 * progress_delta + 0.4 * connection_delta
|
| 214 |
|
| 215 |
+
# Diversity shaping: small nudges (~1/3 the magnitude of the env signal)
|
| 216 |
+
# so they don't dominate it. Three terms:
|
| 217 |
+
# - repetition penalty (action β₯2Γ in last 3)
|
| 218 |
+
# - low-entropy window penalty (last 6 actions β€2 unique)
|
| 219 |
+
# - new-action exploration bonus (until 6 distinct actions tried)
|
| 220 |
if ep_history and len(ep_history) >= 2:
|
| 221 |
recent3 = ep_history[-3:]
|
| 222 |
if recent3.count(chosen) >= 2:
|
| 223 |
+
reward -= 0.10
|
| 224 |
|
| 225 |
if ep_history and len(ep_history) >= 5:
|
| 226 |
last6 = ep_history[-5:] + [chosen]
|
| 227 |
if len(set(last6)) <= 2:
|
| 228 |
+
reward -= 0.15
|
| 229 |
|
| 230 |
if ep_history is not None:
|
| 231 |
seen = set(ep_history)
|
| 232 |
if chosen not in seen and len(seen) < 6:
|
| 233 |
+
reward += 0.07
|
| 234 |
|
| 235 |
+
# Belief-action coupling: rewards consistency between the agent's
|
| 236 |
+
# emitted belief and its chosen action. Without this term, the
|
| 237 |
+
# belief-first format only enforces consistency via causal attention
|
| 238 |
+
# (weak signal); this provides an explicit gradient.
|
|
|
|
| 239 |
_, b, b_provided = extract_action_and_belief(response)
|
| 240 |
if b_provided:
|
| 241 |
s_pref, m_pref, w_pref = b
|
|
|
|
| 269 |
|
| 270 |
def belief_accuracy(
|
| 271 |
completions,
|
|
|
|
| 272 |
seed=None,
|
| 273 |
step_index=None,
|
| 274 |
action_history=None,
|
|
|
|
| 276 |
**kwargs,
|
| 277 |
) -> list[float]:
|
| 278 |
"""
|
| 279 |
+
Layer 4: Belief-vector accuracy reward (the meta-learning signal).
|
| 280 |
|
| 281 |
+
Reward = (1 β belief_mae) β constant_baseline_similarity, where the
|
| 282 |
+
constant baseline is what a "5 5 5" emission would score for THIS profile.
|
| 283 |
+
Subtracting the baseline matters: without it, a constant neutral emission
|
| 284 |
+
earns positive reward on every step (~+0.34 Γ the layer weight) for zero
|
| 285 |
+
actual learning, which silently re-creates the iter-1 mode-collapse pull.
|
| 286 |
+
With it: constant emission β 0, better-than-baseline > 0, worse < 0.
|
| 287 |
|
| 288 |
+
Belief reward is also skipped at step 0 β the agent has no information yet,
|
| 289 |
+
so penalizing belief-vs-target there just biases toward a constant prior.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
"""
|
| 291 |
scores = []
|
| 292 |
for i, completion in enumerate(completions):
|
|
|
|
| 307 |
scores.append(0.0)
|
| 308 |
continue
|
| 309 |
|
| 310 |
+
# Step 0: agent has no information yet β committing a belief here
|
| 311 |
+
# would just pull the policy toward whatever constant prior the
|
| 312 |
+
# base model emits. Skip the reward for that step.
|
| 313 |
if ep_step == 0:
|
| 314 |
scores.append(0.0)
|
| 315 |
continue
|
|
|
|
| 317 |
try:
|
| 318 |
env = _replay_env(ep_seed, ep_history, ep_mode)
|
| 319 |
true_belief = env.get_belief_target()
|
|
|
|
| 320 |
mae = sum(abs(b - t) for b, t in zip(belief, true_belief)) / 3.0
|
| 321 |
similarity = 1.0 - mae
|
| 322 |
baseline_mae = sum(abs(0.5 - t) for t in true_belief) / 3.0
|
|
|
|
| 327 |
scores.append(0.0)
|
| 328 |
|
| 329 |
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SFT prime: teach Qwen 2.5-3B the teacher's CoT-then-answer format.
|
| 3 |
+
|
| 4 |
+
This is Stage 2 of Algorithm Distillation. We've already collected
|
| 5 |
+
teacher trajectories (Stage 1). Here we fine-tune the student on the
|
| 6 |
+
teacher's full responses β `<reasoning>...</reasoning>\nS M W ACTION_NAME` β
|
| 7 |
+
so the student learns BOTH the format and the reasoning pattern that
|
| 8 |
+
produced each answer.
|
| 9 |
+
|
| 10 |
+
After this stage, the student should beat heuristic baselines on the
|
| 11 |
+
v2 grader (which awards 0.20 for belief_accuracy). GRPO refinement is
|
| 12 |
+
optional β only if the SFT'd model regresses on something.
|
| 13 |
+
|
| 14 |
+
Usage (from rhythm_env root):
|
| 15 |
+
python training/sft_prime.py \
|
| 16 |
+
--teacher_jsonls data/teacher_30ep_validation.jsonl \
|
| 17 |
+
data/teacher_indist_30_99.jsonl \
|
| 18 |
+
data/teacher_ood_10000_10049.jsonl \
|
| 19 |
+
--output_dir outputs/rhythm-env-sft-primed \
|
| 20 |
+
--max_steps 600 \
|
| 21 |
+
--epochs 2
|
| 22 |
+
|
| 23 |
+
Designed to run on HF Jobs with a10g-large flavor.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import os
|
| 29 |
+
import sys
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 33 |
+
|
| 34 |
+
# The teacher's system prompt is the canonical contract β student must learn
|
| 35 |
+
# to respond to this exact prompt. Imported from the teacher script for SSOT.
|
| 36 |
+
from scripts.generate_teacher_trajectories import TEACHER_SYSTEM_PROMPT
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_teacher_dataset(jsonl_paths: list[str], drop_parse_fails: bool = True) -> list[dict]:
|
| 40 |
+
"""Read teacher JSONL files and return list of {prompt, response} pairs.
|
| 41 |
+
|
| 42 |
+
Each input row is one step from one teacher episode. We turn it into a
|
| 43 |
+
chat-format SFT example: messages=[system, user] β completion=response.
|
| 44 |
+
Steps where the teacher's response failed to parse are dropped (we
|
| 45 |
+
don't want to teach the student bad outputs).
|
| 46 |
+
"""
|
| 47 |
+
pairs: list[dict] = []
|
| 48 |
+
n_total = 0
|
| 49 |
+
n_dropped = 0
|
| 50 |
+
for path in jsonl_paths:
|
| 51 |
+
with open(path) as f:
|
| 52 |
+
for line in f:
|
| 53 |
+
row = json.loads(line)
|
| 54 |
+
n_total += 1
|
| 55 |
+
if drop_parse_fails and row.get("parse_failed"):
|
| 56 |
+
n_dropped += 1
|
| 57 |
+
continue
|
| 58 |
+
resp = row.get("teacher_response", "")
|
| 59 |
+
if not resp or not resp.strip():
|
| 60 |
+
n_dropped += 1
|
| 61 |
+
continue
|
| 62 |
+
pairs.append({
|
| 63 |
+
"messages": [
|
| 64 |
+
{"role": "system", "content": TEACHER_SYSTEM_PROMPT},
|
| 65 |
+
{"role": "user", "content": row["user_prompt"]},
|
| 66 |
+
{"role": "assistant", "content": resp},
|
| 67 |
+
],
|
| 68 |
+
})
|
| 69 |
+
print(f"Loaded {len(pairs)} SFT examples ({n_dropped}/{n_total} dropped: "
|
| 70 |
+
f"parse-failed or empty)")
|
| 71 |
+
return pairs
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main() -> None:
|
| 75 |
+
parser = argparse.ArgumentParser()
|
| 76 |
+
parser.add_argument("--teacher_jsonls", nargs="+", required=True,
|
| 77 |
+
help="One or more teacher trajectory JSONL files")
|
| 78 |
+
parser.add_argument("--output_dir", type=str, default="outputs/rhythm-env-sft-primed")
|
| 79 |
+
parser.add_argument("--model_name", type=str, default="unsloth/Qwen2.5-3B-Instruct")
|
| 80 |
+
parser.add_argument("--epochs", type=int, default=2,
|
| 81 |
+
help="SFT epochs over the dataset (2 is plenty for ~3000 examples)")
|
| 82 |
+
parser.add_argument("--max_steps", type=int, default=-1,
|
| 83 |
+
help="Override epochs with a step count (-1 = use epochs)")
|
| 84 |
+
parser.add_argument("--lora_rank", type=int, default=16)
|
| 85 |
+
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
| 86 |
+
parser.add_argument("--max_seq_length", type=int, default=2048,
|
| 87 |
+
help="Must fit system + user + CoT response. ~600 user + ~120 CoT + ~10 ans + slack")
|
| 88 |
+
parser.add_argument("--per_device_batch_size", type=int, default=1)
|
| 89 |
+
parser.add_argument("--grad_accum", type=int, default=8,
|
| 90 |
+
help="Effective batch size = per_device * grad_accum")
|
| 91 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.1)
|
| 92 |
+
parser.add_argument("--save_method", type=str, default="merged_16bit",
|
| 93 |
+
choices=["lora", "merged_16bit", "merged_4bit"])
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
|
| 96 |
+
# ---- 1. Load + format the dataset ----
|
| 97 |
+
print("=" * 60)
|
| 98 |
+
print("Step 1: Loading teacher dataset")
|
| 99 |
+
print("=" * 60)
|
| 100 |
+
pairs = load_teacher_dataset(args.teacher_jsonls)
|
| 101 |
+
if not pairs:
|
| 102 |
+
sys.exit("ERROR: no SFT examples loaded β check JSONL paths")
|
| 103 |
+
|
| 104 |
+
from datasets import Dataset
|
| 105 |
+
raw_ds = Dataset.from_list(pairs)
|
| 106 |
+
print(f"Dataset size: {len(raw_ds)} examples")
|
| 107 |
+
|
| 108 |
+
# ---- 2. Load Qwen base via Unsloth ----
|
| 109 |
+
print("\n" + "=" * 60)
|
| 110 |
+
print(f"Step 2: Loading base model {args.model_name}")
|
| 111 |
+
print("=" * 60)
|
| 112 |
+
from unsloth import FastLanguageModel
|
| 113 |
+
|
| 114 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 115 |
+
model_name=args.model_name,
|
| 116 |
+
load_in_4bit=True,
|
| 117 |
+
max_seq_length=args.max_seq_length,
|
| 118 |
+
)
|
| 119 |
+
model = FastLanguageModel.get_peft_model(
|
| 120 |
+
model,
|
| 121 |
+
r=args.lora_rank,
|
| 122 |
+
target_modules=[
|
| 123 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 124 |
+
"gate_proj", "up_proj", "down_proj",
|
| 125 |
+
],
|
| 126 |
+
lora_alpha=args.lora_rank * 2,
|
| 127 |
+
use_gradient_checkpointing="unsloth",
|
| 128 |
+
random_state=3407,
|
| 129 |
+
)
|
| 130 |
+
print(f"LoRA rank {args.lora_rank}, alpha {args.lora_rank * 2}")
|
| 131 |
+
|
| 132 |
+
# ---- 3. Map to chat-template strings + tokenize ----
|
| 133 |
+
print("\n" + "=" * 60)
|
| 134 |
+
print("Step 3: Preparing dataset")
|
| 135 |
+
print("=" * 60)
|
| 136 |
+
|
| 137 |
+
def format_example(ex):
|
| 138 |
+
text = tokenizer.apply_chat_template(
|
| 139 |
+
ex["messages"],
|
| 140 |
+
tokenize=False,
|
| 141 |
+
add_generation_prompt=False,
|
| 142 |
+
)
|
| 143 |
+
return {"text": text}
|
| 144 |
+
|
| 145 |
+
ds = raw_ds.map(format_example, remove_columns=raw_ds.column_names)
|
| 146 |
+
print("Sample formatted text (first 800 chars):")
|
| 147 |
+
print(ds[0]["text"][:800])
|
| 148 |
+
print("...")
|
| 149 |
+
|
| 150 |
+
# ---- 4. SFTTrainer ----
|
| 151 |
+
print("\n" + "=" * 60)
|
| 152 |
+
print("Step 4: Configuring SFTTrainer")
|
| 153 |
+
print("=" * 60)
|
| 154 |
+
from trl import SFTConfig, SFTTrainer
|
| 155 |
+
|
| 156 |
+
sft_kwargs = dict(
|
| 157 |
+
per_device_train_batch_size=args.per_device_batch_size,
|
| 158 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 159 |
+
learning_rate=args.learning_rate,
|
| 160 |
+
warmup_ratio=args.warmup_ratio,
|
| 161 |
+
lr_scheduler_type="cosine",
|
| 162 |
+
optim="adamw_8bit",
|
| 163 |
+
weight_decay=0.001,
|
| 164 |
+
logging_steps=5,
|
| 165 |
+
save_strategy="no",
|
| 166 |
+
report_to="none",
|
| 167 |
+
output_dir=args.output_dir,
|
| 168 |
+
max_seq_length=args.max_seq_length,
|
| 169 |
+
dataset_text_field="text",
|
| 170 |
+
packing=False,
|
| 171 |
+
)
|
| 172 |
+
if args.max_steps > 0:
|
| 173 |
+
sft_kwargs["max_steps"] = args.max_steps
|
| 174 |
+
else:
|
| 175 |
+
sft_kwargs["num_train_epochs"] = args.epochs
|
| 176 |
+
|
| 177 |
+
sft_config = SFTConfig(**sft_kwargs)
|
| 178 |
+
|
| 179 |
+
trainer = SFTTrainer(
|
| 180 |
+
model=model,
|
| 181 |
+
tokenizer=tokenizer,
|
| 182 |
+
train_dataset=ds,
|
| 183 |
+
args=sft_config,
|
| 184 |
+
)
|
| 185 |
+
print(f"Effective batch size: {args.per_device_batch_size * args.grad_accum}")
|
| 186 |
+
if args.max_steps > 0:
|
| 187 |
+
print(f"max_steps: {args.max_steps}")
|
| 188 |
+
else:
|
| 189 |
+
print(f"epochs: {args.epochs} β ~{len(ds) * args.epochs // (args.per_device_batch_size * args.grad_accum)} steps")
|
| 190 |
+
|
| 191 |
+
# ---- 5. Train ----
|
| 192 |
+
print("\n" + "=" * 60)
|
| 193 |
+
print("Step 5: Training")
|
| 194 |
+
print("=" * 60)
|
| 195 |
+
trainer.train()
|
| 196 |
+
|
| 197 |
+
# ---- 6. Save ----
|
| 198 |
+
print("\n" + "=" * 60)
|
| 199 |
+
print("Step 6: Saving model")
|
| 200 |
+
print("=" * 60)
|
| 201 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 202 |
+
if args.save_method == "lora":
|
| 203 |
+
model.save_pretrained(args.output_dir)
|
| 204 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 205 |
+
else:
|
| 206 |
+
model.save_pretrained_merged(
|
| 207 |
+
args.output_dir,
|
| 208 |
+
tokenizer,
|
| 209 |
+
save_method=args.save_method,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Save log_history for plot_from_log.py
|
| 213 |
+
log_path = os.path.join(args.output_dir, "log_history.json")
|
| 214 |
+
with open(log_path, "w") as f:
|
| 215 |
+
json.dump(trainer.state.log_history, f, indent=2)
|
| 216 |
+
|
| 217 |
+
# Save training config
|
| 218 |
+
config_path = os.path.join(args.output_dir, "training_config.json")
|
| 219 |
+
with open(config_path, "w") as f:
|
| 220 |
+
json.dump(vars(args), f, indent=2)
|
| 221 |
+
|
| 222 |
+
print(f"\nSaved SFT-primed model to: {args.output_dir}")
|
| 223 |
+
print(f"Log history: {log_path}")
|
| 224 |
+
print(f"Training config: {config_path}")
|
| 225 |
+
print()
|
| 226 |
+
print("Next: python training/inference_eval.py --model_path " + args.output_dir)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
main()
|
|
@@ -13,8 +13,8 @@ Usage (Colab T4):
|
|
| 13 |
!pip install unsloth transformers trl datasets
|
| 14 |
!python training/train.py --max_steps 1500
|
| 15 |
|
| 16 |
-
|
| 17 |
-
python
|
| 18 |
"""
|
| 19 |
|
| 20 |
import argparse
|
|
@@ -45,12 +45,7 @@ def main():
|
|
| 45 |
help="Fraction of dataset with profile hint visible. Default 0.0 (no hints) "
|
| 46 |
"to eliminate train-eval distribution mismatch. Set >0 only if you ALSO "
|
| 47 |
"show hints during eval.")
|
| 48 |
-
parser.add_argument("--profile_mode", type=str, default="continuous",
|
| 49 |
-
choices=["continuous", "discrete"],
|
| 50 |
-
help="continuous = sampled per-episode (meta-RL); discrete = 3 hardcoded profiles")
|
| 51 |
parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_meta_trained")
|
| 52 |
-
parser.add_argument("--use_simple_reward", action="store_true",
|
| 53 |
-
help="Use heuristic reward instead of env-replay (smoke testing)")
|
| 54 |
parser.add_argument("--report_to", type=str, default="none")
|
| 55 |
args = parser.parse_args()
|
| 56 |
|
|
@@ -68,7 +63,6 @@ def main():
|
|
| 68 |
num_episodes=args.num_episodes,
|
| 69 |
strategy="mixed",
|
| 70 |
max_samples=args.max_samples,
|
| 71 |
-
profile_mode=args.profile_mode,
|
| 72 |
hint_fraction=args.hint_fraction,
|
| 73 |
)
|
| 74 |
|
|
@@ -123,16 +117,10 @@ def main():
|
|
| 123 |
print("Step 3: Setting up reward functions")
|
| 124 |
print("=" * 60)
|
| 125 |
|
| 126 |
-
from reward_functions import
|
| 127 |
-
format_valid, action_legal, env_reward, env_reward_simple, belief_accuracy
|
| 128 |
-
)
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
print("Using: format_valid + action_legal + env_reward_simple + belief_accuracy")
|
| 133 |
-
else:
|
| 134 |
-
reward_funcs = [format_valid, action_legal, env_reward, belief_accuracy]
|
| 135 |
-
print("Using: format_valid + action_legal + env_reward + belief_accuracy")
|
| 136 |
|
| 137 |
# ---------------------------------------------------------------
|
| 138 |
# 4. GRPO trainer config
|
|
@@ -146,11 +134,11 @@ def main():
|
|
| 146 |
max_prompt_length = 600 # history + hint room
|
| 147 |
max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
|
| 148 |
|
| 149 |
-
# reward_weights:
|
| 150 |
-
#
|
| 151 |
-
#
|
| 152 |
-
# learning signal.
|
| 153 |
-
# Order MUST match reward_funcs
|
| 154 |
reward_weights = [0.05, 0.05, 1.5, 3.0]
|
| 155 |
|
| 156 |
training_args_kwargs = dict(
|
|
|
|
| 13 |
!pip install unsloth transformers trl datasets
|
| 14 |
!python training/train.py --max_steps 1500
|
| 15 |
|
| 16 |
+
Setup-check (no GPU): run the smoke tests instead of starting a real run:
|
| 17 |
+
python -m pytest tests/test_pipeline_smoke.py -q
|
| 18 |
"""
|
| 19 |
|
| 20 |
import argparse
|
|
|
|
| 45 |
help="Fraction of dataset with profile hint visible. Default 0.0 (no hints) "
|
| 46 |
"to eliminate train-eval distribution mismatch. Set >0 only if you ALSO "
|
| 47 |
"show hints during eval.")
|
|
|
|
|
|
|
|
|
|
| 48 |
parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_meta_trained")
|
|
|
|
|
|
|
| 49 |
parser.add_argument("--report_to", type=str, default="none")
|
| 50 |
args = parser.parse_args()
|
| 51 |
|
|
|
|
| 63 |
num_episodes=args.num_episodes,
|
| 64 |
strategy="mixed",
|
| 65 |
max_samples=args.max_samples,
|
|
|
|
| 66 |
hint_fraction=args.hint_fraction,
|
| 67 |
)
|
| 68 |
|
|
|
|
| 117 |
print("Step 3: Setting up reward functions")
|
| 118 |
print("=" * 60)
|
| 119 |
|
| 120 |
+
from reward_functions import format_valid, action_legal, env_reward, belief_accuracy
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
reward_funcs = [format_valid, action_legal, env_reward, belief_accuracy]
|
| 123 |
+
print("Using: format_valid + action_legal + env_reward + belief_accuracy")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
# ---------------------------------------------------------------
|
| 126 |
# 4. GRPO trainer config
|
|
|
|
| 134 |
max_prompt_length = 600 # history + hint room
|
| 135 |
max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
|
| 136 |
|
| 137 |
+
# reward_weights: suppress the format/action_legal layers (small, low-variance
|
| 138 |
+
# signals β too constant across a GRPO group to contribute meaningful advantage)
|
| 139 |
+
# and amplify the variable signals env_reward and belief_accuracy. belief_accuracy
|
| 140 |
+
# at 3.0 is the dominant learning signal.
|
| 141 |
+
# Order MUST match reward_funcs above: format_valid, action_legal, env_reward, belief_accuracy
|
| 142 |
reward_weights = [0.05, 0.05, 1.5, 3.0]
|
| 143 |
|
| 144 |
training_args_kwargs = dict(
|