InosLihka Claude Opus 4.7 (1M context) commited on
Commit
64d24b3
Β·
1 Parent(s): e21a960

iter3: align reward with grader + belief-first format + exploration shaping

Browse files

ITER 2 FAILURE: agent escaped single-action collapse only to find a 2-cycle
loop (MEDITATE+EXERCISE alternation). final_score 0.22, worse than random.
Belief learning (+0.36) works but doesn't influence action choice β€” belief
was emitted as a string AFTER the action, so action couldn't condition on it.

Five runtime-level fixes for iter 3:

[1] Per-step reward grader-alignment (server/rhythm_environment.py:_compute_reward)
Add profile-INDEPENDENT bias: +0.5*progress_delta + 0.4*connection_delta.
Profile-weighted reward still drives belief inference (varies by profile),
but agent now ALWAYS gets penalized for ignoring progress and connection.
Closes the proxy/goal misalignment that let M+E alternation win.

[2] Belief-first output format (training/reward_functions.py, dataset.py, inference.py)
Was: 'ACTION_NAME S M W' (belief was post-hoc afterthought, didn't condition action)
Now: 'S M W ACTION_NAME' (causal LM generates belief tokens FIRST, then
action conditions on those belief tokens via attention).
The parser accepts both for backward compat. System prompt updated to
explain the belief-action coupling so the model learns to use the format.

[3] N-cycle penalty (training/reward_functions.py:env_reward)
Iter 2 had 3-in-a-row penalty but the agent found a 2-cycle loophole.
New rule: if last 6 actions have <=2 unique values, -0.4. Closes the M-E
alternation pattern specifically and any longer N-cycle.

[4] New-action exploration bonus (training/reward_functions.py:env_reward)
+0.2 reward for taking an action that hasn't appeared in the current
episode yet (until 6 unique actions tried). Pushes the agent to PROBE
early in episodes β€” the canonical meta-RL exploration signal.

[5] Sparse terminal reward (server/rhythm_environment.py:step)
At step 28, add (final_score - 0.5) * 5 to per-step reward. Direct
supervision on the actual grader. Range [-2.5, +2.5] gives strong
end-of-episode signal that overwhelms any local reward-hack.

Plus training config bumps:
- MAX_STEPS 400 -> 800 (iter 2 was still rising at step 400)
- NUM_GENERATIONS 4 -> 8 (lower advantage variance)
- LORA_RANK 8 -> 16 (more capacity for policy + belief integration)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

inference.py CHANGED
@@ -80,24 +80,31 @@ people differently β€” you must INFER who you're helping from the rewards and
80
  meter changes you observe.
81
 
82
  Each step, output ONE LINE in this exact format:
83
- ACTION_NAME S M W
84
 
85
- where ACTION_NAME is one of:
 
 
 
 
 
86
  DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,
87
  FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH
88
 
89
- and S, M, W are single digits (0-9) representing your current belief:
90
- S = social preference (0=hates being social, 9=loves being social)
91
- M = morning preference (0=night owl, 9=morning person)
92
- W = work preference (0=avoids work, 9=workaholic)
93
 
94
- Example: DEEP_WORK 3 8 7
 
 
 
 
 
95
 
96
- Tips:
97
- - Update your belief from rewards: SOCIALIZE giving big reward β†’ raise S.
98
- - Early in the week, PROBE different actions to learn the person.
99
- - Late in the week, EXPLOIT what you've learned.
100
- - Watch for crashes: any meter under 0.1 β†’ big penalty.
101
  - Connection decays passively β€” actively maintain it.
102
  Respond with ONLY the format line, no other text.""")
103
 
@@ -212,7 +219,7 @@ Meters:
212
  Serenity: {obs.serenity:.2f}
213
  Connection: {obs.connection:.2f}{event_str}{history_str}
214
 
215
- Choose your action (format: ACTION S M W):""")
216
 
217
  try:
218
  completion = llm_client.chat.completions.create(
 
80
  meter changes you observe.
81
 
82
  Each step, output ONE LINE in this exact format:
83
+ S M W ACTION_NAME
84
 
85
+ First write your BELIEF as 3 digits 0-9, then the ACTION that fits:
86
+ S = social preference (0=hates social, 9=loves social)
87
+ M = morning preference (0=night owl, 9=morning person)
88
+ W = work preference (0=avoids work, 9=workaholic)
89
+
90
+ ACTION choices:
91
  DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,
92
  FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH
93
 
94
+ Example: 3 8 7 DEEP_WORK
 
 
 
95
 
96
+ Belief-action coupling guide:
97
+ - High S: SOCIALIZE, FAMILY_TIME (extrovert boosts)
98
+ - High M: DEEP_WORK in morning slots (morning-person bonus)
99
+ - High W: DEEP_WORK, LEARN (workaholic energy)
100
+ - Low S: MEDITATE, ME_TIME (introvert recharge)
101
+ - Low M: DEEP_WORK in evening/night (night-owl bonus)
102
 
103
+ Tactics:
104
+ - Early week: PROBE varied actions to gather information.
105
+ - Late week: EXPLOIT β€” pick actions matching your sharpened belief.
106
+ - Don't repeat the same action; you'll get a repetition penalty.
107
+ - Watch for crashes: any meter under 0.1 = big penalty.
108
  - Connection decays passively β€” actively maintain it.
109
  Respond with ONLY the format line, no other text.""")
110
 
 
219
  Serenity: {obs.serenity:.2f}
220
  Connection: {obs.connection:.2f}{event_str}{history_str}
221
 
222
+ Output belief then action (format: S M W ACTION_NAME):""")
223
 
224
  try:
225
  completion = llm_client.chat.completions.create(
scripts/train_on_hf.py CHANGED
@@ -57,13 +57,13 @@ PLOTS_DIR = "/tmp/rhythm_env/plots"
57
  FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
58
 
59
  if FAST_MODE:
60
- # Smoke-train preset: 400 steps gives the policy room to escape iter-1 mode collapse
61
- DEFAULTS = dict(MAX_STEPS=400, NUM_EPISODES=120, MAX_SAMPLES=1200,
62
- NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.04,
63
  LEARNING_RATE=5e-5, EVAL_EPISODES=2)
64
  else:
65
- DEFAULTS = dict(MAX_STEPS=1500, NUM_EPISODES=300, MAX_SAMPLES=3000,
66
- NUM_GENERATIONS=8, LORA_RANK=8, BETA=0.04,
67
  LEARNING_RATE=5e-5, EVAL_EPISODES=5)
68
 
69
  MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
 
57
  FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
58
 
59
  if FAST_MODE:
60
+ # Iter 3 preset: 800 steps + 8 generations + LoRA 16 to escape mode collapse for real
61
+ DEFAULTS = dict(MAX_STEPS=800, NUM_EPISODES=200, MAX_SAMPLES=2000,
62
+ NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04,
63
  LEARNING_RATE=5e-5, EVAL_EPISODES=2)
64
  else:
65
+ DEFAULTS = dict(MAX_STEPS=2000, NUM_EPISODES=400, MAX_SAMPLES=4000,
66
+ NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04,
67
  LEARNING_RATE=5e-5, EVAL_EPISODES=5)
68
 
69
  MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
server/rhythm_environment.py CHANGED
@@ -507,6 +507,13 @@ class RhythmEnvironment(Environment):
507
  if done:
508
  final_score = self._grade_episode()
509
  reward_breakdown["final_score"] = round(final_score, 4)
 
 
 
 
 
 
 
510
 
511
  # --- 14. Update state ---
512
  self._state.step_count = self._timestep
@@ -695,10 +702,20 @@ class RhythmEnvironment(Environment):
695
  self._vitality = max(0.0, self._vitality - vd)
696
 
697
  def _compute_reward(self, deltas: Dict[str, float]) -> float:
698
- """Compute reward as hidden-weighted sum of meter deltas."""
 
 
 
 
 
 
 
 
699
  weights = self._profile["reward_weights"]
700
- reward = sum(deltas[m] * weights[m] for m in METERS)
701
- return reward * REWARD_SCALE
 
 
702
 
703
  def _grade_episode(self) -> float:
704
  """
 
507
  if done:
508
  final_score = self._grade_episode()
509
  reward_breakdown["final_score"] = round(final_score, 4)
510
+ # Iter 3 fix: sparse terminal reward β€” direct supervision on grader
511
+ # final_score is in [0, 1]; baseline-relative bonus gives strong signal
512
+ # for ending the episode well. Range: -2.5 (terrible) to +2.5 (perfect).
513
+ terminal_bonus = (final_score - 0.5) * 5.0
514
+ reward = max(-3.0, min(3.0, reward + terminal_bonus))
515
+ self._total_reward += terminal_bonus # update tracking too
516
+ reward_breakdown["terminal_bonus"] = round(terminal_bonus, 4)
517
 
518
  # --- 14. Update state ---
519
  self._state.step_count = self._timestep
 
702
  self._vitality = max(0.0, self._vitality - vd)
703
 
704
  def _compute_reward(self, deltas: Dict[str, float]) -> float:
705
+ """Compute reward as hidden-weighted sum + grader-aligned bias.
706
+
707
+ Iter 3 fix: Add a profile-INDEPENDENT bias term for progress and
708
+ connection. The original profile-weighted reward drives belief inference
709
+ (varies by profile), but allowed agents to game it by spamming recovery
710
+ actions if the sampled profile didn't weight progress/connection. The
711
+ bias makes the per-step reward correlate with the FINAL grader (which
712
+ weights progress 0.25 and connection 0.15).
713
+ """
714
  weights = self._profile["reward_weights"]
715
+ profile_reward = sum(deltas[m] * weights[m] for m in METERS) * REWARD_SCALE
716
+ # Grader-aligned bias: scaled so max bonus is ~0.1/step (manageable vs profile_reward)
717
+ grader_bias = 0.5 * deltas["progress"] + 0.4 * deltas["connection"]
718
+ return profile_reward + grader_bias
719
 
720
  def _grade_episode(self) -> float:
721
  """
training/dataset.py CHANGED
@@ -32,23 +32,29 @@ SYSTEM_PROMPT = (
32
  "different people differently β€” you must INFER who you're helping from the\n"
33
  "rewards and meter changes you observe.\n\n"
34
  "Each step, output ONE LINE in this exact format:\n"
35
- " ACTION_NAME S M W\n\n"
36
- "where ACTION_NAME is one of:\n"
37
- " DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,\n"
38
- " FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH\n\n"
39
- "and S, M, W are single digits (0-9) representing your current belief:\n"
40
  " S = social preference (0=hates being social, 9=loves being social)\n"
41
  " M = morning preference (0=night owl, 9=morning person)\n"
42
  " W = work preference (0=avoids work, 9=workaholic)\n\n"
43
- "Example: DEEP_WORK 3 8 7\n\n"
44
- "Tips:\n"
45
- "- Update your belief based on rewards: if SOCIALIZE gave a big positive reward,\n"
46
- " raise S; if it tanked vitality without a reward, lower S.\n"
47
- "- Early in the week, PROBE different actions to learn the person.\n"
48
- "- Late in the week, EXPLOIT what you've learned β€” pick actions matching\n"
49
- " the person's preferences (use your belief to guide).\n"
50
- "- Watch for crashes: any meter under 0.1 β†’ big penalty.\n"
51
- "- Connection decays passively β€” actively maintain it.\n"
 
 
 
 
 
 
 
 
 
52
  "Respond with ONLY the format line, no other text."
53
  )
54
 
@@ -96,7 +102,7 @@ def format_observation_prompt(obs, profile_hint: dict | None = None) -> str:
96
  f"{event_str}"
97
  f"{history_str}"
98
  f"{hint_str}\n\n"
99
- f"Choose your action (format: ACTION S M W):"
100
  )
101
 
102
 
 
32
  "different people differently β€” you must INFER who you're helping from the\n"
33
  "rewards and meter changes you observe.\n\n"
34
  "Each step, output ONE LINE in this exact format:\n"
35
+ " S M W ACTION_NAME\n\n"
36
+ "First, write your belief about the person as 3 digits (0-9):\n"
 
 
 
37
  " S = social preference (0=hates being social, 9=loves being social)\n"
38
  " M = morning preference (0=night owl, 9=morning person)\n"
39
  " W = work preference (0=avoids work, 9=workaholic)\n\n"
40
+ "Then choose the action that BEST FITS that belief, from:\n"
41
+ " DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,\n"
42
+ " FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH\n\n"
43
+ "Example: 3 8 7 DEEP_WORK (this person is moderately introverted, strongly\n"
44
+ "morning-oriented, fairly work-driven β€” so deep work in the morning fits)\n\n"
45
+ "Belief-action coupling guide:\n"
46
+ "- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply\n"
47
+ "- High M (morning person): DEEP_WORK in early slots gets bonus cognition\n"
48
+ "- High W (workaholic): DEEP_WORK, LEARN drive progress AND may energize\n"
49
+ "- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE\n"
50
+ "- Low M (night owl): DEEP_WORK in evening/night slots\n\n"
51
+ "Tactics:\n"
52
+ "- Early in the week, PROBE varied actions to gather information.\n"
53
+ "- Update your belief from the rewards you see.\n"
54
+ "- Late in the week, EXPLOIT β€” pick actions matching your sharpened belief.\n"
55
+ "- Don't repeat the same action excessively; you'll get a repetition penalty.\n"
56
+ "- Watch for crashes: any meter under 0.1 = big penalty.\n"
57
+ "- Connection decays passively β€” actively maintain it via SOCIALIZE/FAMILY_TIME.\n"
58
  "Respond with ONLY the format line, no other text."
59
  )
60
 
 
102
  f"{event_str}"
103
  f"{history_str}"
104
  f"{hint_str}\n\n"
105
+ f"Output your belief, then your action (format: S M W ACTION_NAME):"
106
  )
107
 
108
 
training/reward_functions.py CHANGED
@@ -2,20 +2,29 @@
2
  Reward functions for RhythmEnv GRPO training (meta-RL version).
3
 
4
  Four-layer reward stack:
5
- 1. format_valid β€” does the LLM output have a parseable action + belief format?
6
  2. action_legal β€” is the action one of the 10 valid actions?
7
  3. env_reward β€” actual environment reward (seed-replay) for the chosen action
 
8
  4. belief_accuracy β€” how close is the belief vector to the hidden profile's true vector?
9
 
10
- Action output format: "ACTION_NAME S M W"
11
- - ACTION_NAME: one of 10 valid actions
12
  - S, M, W: single digits 0-9 representing the agent's belief about the user
13
  S = social preference (0=hates social, 9=loves social)
14
  M = morning preference (0=night owl, 9=morning person)
15
  W = work preference (0=avoids work, 9=workaholic)
 
 
 
 
 
 
 
 
 
 
16
 
17
- Example: "DEEP_WORK 3 8 7"
18
- β†’ action=DEEP_WORK, belief=[0.33, 0.89, 0.78]
19
 
20
  Each function returns a list of floats (one per completion).
21
  """
@@ -58,7 +67,7 @@ def extract_action_and_belief(text: str) -> tuple[ActionType | None, list[float]
58
  if not parts:
59
  return None, list(DEFAULT_BELIEF), False
60
 
61
- # Find action and its index in parts (try first token, then any token)
62
  action: ActionType | None = None
63
  action_idx = -1
64
  for idx, p in enumerate(parts):
@@ -76,35 +85,53 @@ def extract_action_and_belief(text: str) -> tuple[ActionType | None, list[float]
76
  if action is not None:
77
  break
78
 
79
- # Parse next 3 tokens AFTER the action as belief digits/floats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  belief: list[float] = []
81
  belief_provided = False
 
82
  if action_idx >= 0:
83
- for i in range(3):
84
- j = action_idx + 1 + i
85
- if j < len(parts):
86
- p = parts[j].strip().rstrip(".")
87
- if not p:
88
- belief.append(0.5)
89
- continue
90
- try:
91
- if len(p) == 1 and p.isdigit():
92
- belief.append(int(p) / 9.0)
93
- belief_provided = True
 
 
 
 
 
 
 
94
  else:
95
- val = float(p)
96
- if val > 1.0:
97
- val = val / 9.0 # interpret as 0-9 scale
98
- belief.append(max(0.0, min(1.0, val)))
99
- belief_provided = True
100
- except (ValueError, IndexError):
101
- belief.append(0.5)
102
- else:
103
- belief.append(0.5)
104
- else:
105
- belief = list(DEFAULT_BELIEF)
106
-
107
- if not belief:
108
  belief = list(DEFAULT_BELIEF)
109
 
110
  return action, belief, belief_provided
@@ -213,13 +240,30 @@ def env_reward(
213
  env = _replay_env(ep_seed, ep_history, ep_mode)
214
  obs = env.step(RhythmAction(action_type=action_type))
215
  reward = obs.reward
216
- # Iter 2 fix: explicit repetition penalty as TRAINING signal (env already
217
- # dampens repeats but doesn't add explicit negative reward). Discourages
218
- # the lazy "spam one action" mode collapse pattern.
219
  if ep_history and len(ep_history) >= 2:
220
- recent = ep_history[-3:] # last 3 actions
221
- if recent.count(action_type.value) >= 2: # this would make 3+ in a row
222
  reward -= 0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  scores.append(reward)
224
  except Exception:
225
  scores.append(-3.0)
 
2
  Reward functions for RhythmEnv GRPO training (meta-RL version).
3
 
4
  Four-layer reward stack:
5
+ 1. format_valid β€” does the LLM output have a parseable belief + action format?
6
  2. action_legal β€” is the action one of the 10 valid actions?
7
  3. env_reward β€” actual environment reward (seed-replay) for the chosen action
8
+ plus diversity/exploration shaping
9
  4. belief_accuracy β€” how close is the belief vector to the hidden profile's true vector?
10
 
11
+ ITER 3 FORMAT (belief-first): "S M W ACTION_NAME"
 
12
  - S, M, W: single digits 0-9 representing the agent's belief about the user
13
  S = social preference (0=hates social, 9=loves social)
14
  M = morning preference (0=night owl, 9=morning person)
15
  W = work preference (0=avoids work, 9=workaholic)
16
+ - ACTION_NAME: one of 10 valid actions
17
+
18
+ Example: "3 8 7 DEEP_WORK"
19
+ β†’ belief=[0.33, 0.89, 0.78], action=DEEP_WORK
20
+
21
+ Why belief-first: in causal LM generation, tokens generated EARLIER condition
22
+ LATER tokens. With "S M W ACTION", the model commits to a belief first, and
23
+ the action is then conditioned on that belief β€” making the belief functionally
24
+ useful for action selection. The previous "ACTION S M W" order made belief a
25
+ post-hoc afterthought that didn't influence behavior.
26
 
27
+ The parser ALSO accepts the old "ACTION S M W" format for backward compatibility.
 
28
 
29
  Each function returns a list of floats (one per completion).
30
  """
 
67
  if not parts:
68
  return None, list(DEFAULT_BELIEF), False
69
 
70
+ # Find action and its index in parts
71
  action: ActionType | None = None
72
  action_idx = -1
73
  for idx, p in enumerate(parts):
 
85
  if action is not None:
86
  break
87
 
88
+ # Iter 3: parse belief from BEFORE the action (belief-first format).
89
+ # Falls back to AFTER the action (legacy format) if no digits found before.
90
+ def _parse_digit(token: str) -> float | None:
91
+ token = token.strip().rstrip(".")
92
+ if not token:
93
+ return None
94
+ try:
95
+ if len(token) == 1 and token.isdigit():
96
+ return int(token) / 9.0
97
+ val = float(token)
98
+ if val > 1.0:
99
+ val = val / 9.0
100
+ return max(0.0, min(1.0, val))
101
+ except (ValueError, IndexError):
102
+ return None
103
+
104
  belief: list[float] = []
105
  belief_provided = False
106
+
107
  if action_idx >= 0:
108
+ # Try belief-first: 3 digits BEFORE the action
109
+ if action_idx >= 3:
110
+ cand = [_parse_digit(parts[action_idx - 3 + i]) for i in range(3)]
111
+ if all(c is not None for c in cand):
112
+ belief = cand # type: ignore[assignment]
113
+ belief_provided = True
114
+
115
+ # If belief-first didn't work, try legacy after-action format
116
+ if not belief_provided:
117
+ after_belief: list[float] = []
118
+ after_provided = False
119
+ for i in range(3):
120
+ j = action_idx + 1 + i
121
+ if j < len(parts):
122
+ d = _parse_digit(parts[j])
123
+ if d is not None:
124
+ after_belief.append(d)
125
+ after_provided = True
126
  else:
127
+ after_belief.append(0.5)
128
+ else:
129
+ after_belief.append(0.5)
130
+ if after_provided:
131
+ belief = after_belief
132
+ belief_provided = True
133
+
134
+ if not belief or len(belief) != 3:
 
 
 
 
 
135
  belief = list(DEFAULT_BELIEF)
136
 
137
  return action, belief, belief_provided
 
240
  env = _replay_env(ep_seed, ep_history, ep_mode)
241
  obs = env.step(RhythmAction(action_type=action_type))
242
  reward = obs.reward
243
+ chosen = action_type.value
244
+
245
+ # Iter 2 fix: explicit 3-in-a-row repetition penalty
246
  if ep_history and len(ep_history) >= 2:
247
+ recent3 = ep_history[-3:]
248
+ if recent3.count(chosen) >= 2: # this action would make 3+ in a row
249
  reward -= 0.3
250
+
251
+ # Iter 3 fix: N-CYCLE penalty (catches the M-E-M-E-... loop iter 2 fell into)
252
+ # If last 6 actions (including this one) have <=2 unique values, apply penalty
253
+ if ep_history and len(ep_history) >= 5:
254
+ last6 = ep_history[-5:] + [chosen]
255
+ if len(set(last6)) <= 2:
256
+ reward -= 0.4
257
+
258
+ # Iter 3 fix: NEW-ACTION exploration bonus
259
+ # If this action hasn't appeared yet in the current episode, +0.2.
260
+ # Strong incentive in early steps to TRY varied actions, fading as
261
+ # the action set grows. Stops once 6+ different actions tried.
262
+ if ep_history is not None:
263
+ seen = set(ep_history)
264
+ if chosen not in seen and len(seen) < 6:
265
+ reward += 0.2
266
+
267
  scores.append(reward)
268
  except Exception:
269
  scores.append(-3.0)