InosLihka Claude Opus 4.7 (1M context) commited on
Commit
e21a960
·
1 Parent(s): dc0186f

iter2: fix mode collapse + 3 deeper bugs from code review

Browse files

ITER 1 FAILURE: mode collapse to 'EXERCISE 5 5 5' at step 100. Root cause:
format_valid (+1.0) and action_legal (+0.5) saturated for any valid output,
contributing zero to GRPO advantage. Once 4/4 completions per prompt were
identical, reward_std=0 and the policy froze.

7 fixes applied (4 from initial diagnosis + 3 from deep review subagent):

[1] sampling temperature 1.0 -> 1.5
Forces diverse rollouts per prompt, breaks mode collapse mechanism

[2] reward_weights [0.3, 0.3, 1.0, 1.0] -> [0.05, 0.05, 1.5, 3.0]
Aggressively suppress saturated layers, amplify variable signals.
Belief weight bumped to 3.0 (CRITICAL-1) so emitting belief is clearly
rewarded vs the no-belief penalty.

[3] action_legal: drop +0.5 for valid, return 0 instead
Removes another constant-reward source. Layer is now pure penalty for
malformed outputs.

[4] explicit repetition penalty in env_reward (-0.3 if action seen 3+ in row)
Direct training signal against mode collapse. Env already does effect
dampening but doesn't add explicit negative reward for repetition.

[5] _grade_episode late_quality normalization fix (CRITICAL-2)
Per-step rewards are CLAMPED to [-3, +3] in step(), not [-1, +1] as the
old normalization assumed. Old grader saturated late_quality=1.0 for any
mean_late >= +1, blind to good vs excellent late-half. Fixed to use the
actual reward range.

[6] hint_fraction default 0.15 -> 0.0 (MAJOR-3)
Eliminates train-eval distribution mismatch. Eval never shows hints, so
training with hints creates a fraction of training examples whose lessons
don't transfer. Set to 0 by default; can be re-enabled if eval also adds
hint visibility.

[7] env_reward seed fallback hardening (MAJOR-1)
Replace 'i % 50' fallback with '(i * 17) ^ 0xBEEF' to break deterministic
seed clusters. Avoids worst-case where 4 completions in same position-class
get identical env_reward (zero GRPO advantage).

[8] FAST_MODE preset: MAX_STEPS 200 -> 400, samples 800 -> 1200
Iter 1 collapsed at step 100 leaving 100 stuck steps. Give iter 2 more
room to recover from any local optima.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

scripts/train_on_hf.py CHANGED
@@ -57,8 +57,8 @@ PLOTS_DIR = "/tmp/rhythm_env/plots"
57
  FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
58
 
59
  if FAST_MODE:
60
- # Smoke-train preset: enough signal in 200 steps to decide go/no-go
61
- DEFAULTS = dict(MAX_STEPS=200, NUM_EPISODES=80, MAX_SAMPLES=800,
62
  NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.04,
63
  LEARNING_RATE=5e-5, EVAL_EPISODES=2)
64
  else:
 
57
  FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
58
 
59
  if FAST_MODE:
60
+ # Smoke-train preset: 400 steps gives the policy room to escape iter-1 mode collapse
61
+ DEFAULTS = dict(MAX_STEPS=400, NUM_EPISODES=120, MAX_SAMPLES=1200,
62
  NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.04,
63
  LEARNING_RATE=5e-5, EVAL_EPISODES=2)
64
  else:
server/rhythm_environment.py CHANGED
@@ -739,11 +739,13 @@ class RhythmEnvironment(Environment):
739
  if early and late:
740
  mean_early = sum(early) / len(early)
741
  mean_late = sum(late) / len(late)
742
- # late_quality: rewards typically in [-1, 1] per step, normalize
743
- late_quality = max(0.0, min(1.0, (mean_late + 1.0) / 2.0))
 
 
744
  gain = mean_late - mean_early
745
- # gain typically in [-1, 1]; clip to [0, 1] (only positive counts)
746
- gain_norm = max(0.0, min(1.0, gain))
747
  adaptation_score = gain_norm * late_quality
748
  else:
749
  adaptation_score = 0.0
 
739
  if early and late:
740
  mean_early = sum(early) / len(early)
741
  mean_late = sum(late) / len(late)
742
+ # Iter 2 fix: per-step rewards are CLAMPED to [-3, +3] in step(), not [-1, +1].
743
+ # Old normalization saturated late_quality at 1.0 for any mean_late >= +1,
744
+ # making the grader unable to distinguish good from excellent late-half.
745
+ late_quality = max(0.0, min(1.0, (mean_late + 3.0) / 6.0))
746
  gain = mean_late - mean_early
747
+ # gain in [-6, +6]; normalize to [0, 1] (only positive gain counts)
748
+ gain_norm = max(0.0, min(1.0, gain / 3.0))
749
  adaptation_score = gain_norm * late_quality
750
  else:
751
  adaptation_score = 0.0
training/reward_functions.py CHANGED
@@ -142,13 +142,15 @@ def action_legal(completions, **kwargs) -> list[float]:
142
  Layer 2: Is the parsed action one of the 10 valid actions?
143
 
144
  All 10 actions are always legal in this env (no state-dependent validity).
145
- +0.5 if legal, -1.0 if not parseable.
 
 
146
  """
147
  scores = []
148
  for completion in completions:
149
  response = completion[0]["content"] if isinstance(completion, list) else completion
150
  action = extract_action(response)
151
- scores.append(0.5 if action is not None else -1.0)
152
  return scores
153
 
154
 
@@ -201,14 +203,24 @@ def env_reward(
201
  ep_history = prompt_data.get("action_history", [])
202
  ep_mode = prompt_data.get("profile_mode", "continuous")
203
  else:
204
- ep_seed = i % 50
 
 
205
  ep_history = []
206
  ep_mode = "continuous"
207
 
208
  try:
209
  env = _replay_env(ep_seed, ep_history, ep_mode)
210
  obs = env.step(RhythmAction(action_type=action_type))
211
- scores.append(obs.reward)
 
 
 
 
 
 
 
 
212
  except Exception:
213
  scores.append(-3.0)
214
 
 
142
  Layer 2: Is the parsed action one of the 10 valid actions?
143
 
144
  All 10 actions are always legal in this env (no state-dependent validity).
145
+ Iter 2 fix: returns 0 for valid (was +0.5) so this layer becomes pure penalty
146
+ for malformed outputs. The +0.5 was a constant reward that contributed zero
147
+ to GRPO advantage and helped trigger mode collapse in iter 1.
148
  """
149
  scores = []
150
  for completion in completions:
151
  response = completion[0]["content"] if isinstance(completion, list) else completion
152
  action = extract_action(response)
153
+ scores.append(0.0 if action is not None else -1.0)
154
  return scores
155
 
156
 
 
203
  ep_history = prompt_data.get("action_history", [])
204
  ep_mode = prompt_data.get("profile_mode", "continuous")
205
  else:
206
+ # Iter 2 fix: mix index with prime to break deterministic seed clusters
207
+ # (avoids all completions in a position-class getting identical env_reward)
208
+ ep_seed = (i * 17) ^ 0xBEEF
209
  ep_history = []
210
  ep_mode = "continuous"
211
 
212
  try:
213
  env = _replay_env(ep_seed, ep_history, ep_mode)
214
  obs = env.step(RhythmAction(action_type=action_type))
215
+ reward = obs.reward
216
+ # Iter 2 fix: explicit repetition penalty as TRAINING signal (env already
217
+ # dampens repeats but doesn't add explicit negative reward). Discourages
218
+ # the lazy "spam one action" mode collapse pattern.
219
+ if ep_history and len(ep_history) >= 2:
220
+ recent = ep_history[-3:] # last 3 actions
221
+ if recent.count(action_type.value) >= 2: # this would make 3+ in a row
222
+ reward -= 0.3
223
+ scores.append(reward)
224
  except Exception:
225
  scores.append(-3.0)
226
 
training/train.py CHANGED
@@ -41,8 +41,10 @@ def main():
41
  help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)")
42
  parser.add_argument("--lora_rank", type=int, default=8,
43
  help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
44
- parser.add_argument("--hint_fraction", type=float, default=0.15,
45
- help="Fraction of dataset with profile hint visible (curriculum warmup)")
 
 
46
  parser.add_argument("--profile_mode", type=str, default="continuous",
47
  choices=["continuous", "discrete"],
48
  help="continuous = sampled per-episode (meta-RL); discrete = 3 hardcoded profiles")
@@ -144,13 +146,15 @@ def main():
144
  max_prompt_length = 600 # history + hint room
145
  max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
146
 
147
- # reward_weights: scale per-layer to prevent format/legal (saturated near +1) from
148
- # drowning out env_reward and belief_accuracy (the actual learning signals).
 
 
149
  # Order MUST match reward_funcs in main(): format_valid, action_legal, env_reward, belief_accuracy
150
- reward_weights = [0.3, 0.3, 1.0, 1.0]
151
 
152
  training_args_kwargs = dict(
153
- temperature=1.0,
154
  learning_rate=args.learning_rate,
155
  beta=args.beta,
156
  max_grad_norm=0.5,
 
41
  help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)")
42
  parser.add_argument("--lora_rank", type=int, default=8,
43
  help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
44
+ parser.add_argument("--hint_fraction", type=float, default=0.0,
45
+ help="Fraction of dataset with profile hint visible. Default 0.0 (no hints) "
46
+ "to eliminate train-eval distribution mismatch. Set >0 only if you ALSO "
47
+ "show hints during eval.")
48
  parser.add_argument("--profile_mode", type=str, default="continuous",
49
  choices=["continuous", "discrete"],
50
  help="continuous = sampled per-episode (meta-RL); discrete = 3 hardcoded profiles")
 
146
  max_prompt_length = 600 # history + hint room
147
  max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
148
 
149
+ # reward_weights: aggressively suppress saturated format/legal layers, amplify
150
+ # the variable signals (env + belief). Iter 1 collapsed because format/legal at
151
+ # +1.0/+0.5 contributed zero to GRPO advantage; iter 2 makes belief the dominant
152
+ # learning signal.
153
  # Order MUST match reward_funcs in main(): format_valid, action_legal, env_reward, belief_accuracy
154
+ reward_weights = [0.05, 0.05, 1.5, 3.0]
155
 
156
  training_args_kwargs = dict(
157
+ temperature=1.5, # bumped from 1.0 to force diverse rollouts and break mode collapse
158
  learning_rate=args.learning_rate,
159
  beta=args.beta,
160
  max_grad_norm=0.5,