Spaces:
Sleeping
iter2: fix mode collapse + 3 deeper bugs from code review
Browse filesITER 1 FAILURE: mode collapse to 'EXERCISE 5 5 5' at step 100. Root cause:
format_valid (+1.0) and action_legal (+0.5) saturated for any valid output,
contributing zero to GRPO advantage. Once 4/4 completions per prompt were
identical, reward_std=0 and the policy froze.
7 fixes applied (4 from initial diagnosis + 3 from deep review subagent):
[1] sampling temperature 1.0 -> 1.5
Forces diverse rollouts per prompt, breaks mode collapse mechanism
[2] reward_weights [0.3, 0.3, 1.0, 1.0] -> [0.05, 0.05, 1.5, 3.0]
Aggressively suppress saturated layers, amplify variable signals.
Belief weight bumped to 3.0 (CRITICAL-1) so emitting belief is clearly
rewarded vs the no-belief penalty.
[3] action_legal: drop +0.5 for valid, return 0 instead
Removes another constant-reward source. Layer is now pure penalty for
malformed outputs.
[4] explicit repetition penalty in env_reward (-0.3 if action seen 3+ in row)
Direct training signal against mode collapse. Env already does effect
dampening but doesn't add explicit negative reward for repetition.
[5] _grade_episode late_quality normalization fix (CRITICAL-2)
Per-step rewards are CLAMPED to [-3, +3] in step(), not [-1, +1] as the
old normalization assumed. Old grader saturated late_quality=1.0 for any
mean_late >= +1, blind to good vs excellent late-half. Fixed to use the
actual reward range.
[6] hint_fraction default 0.15 -> 0.0 (MAJOR-3)
Eliminates train-eval distribution mismatch. Eval never shows hints, so
training with hints creates a fraction of training examples whose lessons
don't transfer. Set to 0 by default; can be re-enabled if eval also adds
hint visibility.
[7] env_reward seed fallback hardening (MAJOR-1)
Replace 'i % 50' fallback with '(i * 17) ^ 0xBEEF' to break deterministic
seed clusters. Avoids worst-case where 4 completions in same position-class
get identical env_reward (zero GRPO advantage).
[8] FAST_MODE preset: MAX_STEPS 200 -> 400, samples 800 -> 1200
Iter 1 collapsed at step 100 leaving 100 stuck steps. Give iter 2 more
room to recover from any local optima.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- scripts/train_on_hf.py +2 -2
- server/rhythm_environment.py +6 -4
- training/reward_functions.py +16 -4
- training/train.py +10 -6
|
@@ -57,8 +57,8 @@ PLOTS_DIR = "/tmp/rhythm_env/plots"
|
|
| 57 |
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
|
| 58 |
|
| 59 |
if FAST_MODE:
|
| 60 |
-
# Smoke-train preset:
|
| 61 |
-
DEFAULTS = dict(MAX_STEPS=
|
| 62 |
NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.04,
|
| 63 |
LEARNING_RATE=5e-5, EVAL_EPISODES=2)
|
| 64 |
else:
|
|
|
|
| 57 |
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
|
| 58 |
|
| 59 |
if FAST_MODE:
|
| 60 |
+
# Smoke-train preset: 400 steps gives the policy room to escape iter-1 mode collapse
|
| 61 |
+
DEFAULTS = dict(MAX_STEPS=400, NUM_EPISODES=120, MAX_SAMPLES=1200,
|
| 62 |
NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.04,
|
| 63 |
LEARNING_RATE=5e-5, EVAL_EPISODES=2)
|
| 64 |
else:
|
|
@@ -739,11 +739,13 @@ class RhythmEnvironment(Environment):
|
|
| 739 |
if early and late:
|
| 740 |
mean_early = sum(early) / len(early)
|
| 741 |
mean_late = sum(late) / len(late)
|
| 742 |
-
#
|
| 743 |
-
|
|
|
|
|
|
|
| 744 |
gain = mean_late - mean_early
|
| 745 |
-
# gain
|
| 746 |
-
gain_norm = max(0.0, min(1.0, gain))
|
| 747 |
adaptation_score = gain_norm * late_quality
|
| 748 |
else:
|
| 749 |
adaptation_score = 0.0
|
|
|
|
| 739 |
if early and late:
|
| 740 |
mean_early = sum(early) / len(early)
|
| 741 |
mean_late = sum(late) / len(late)
|
| 742 |
+
# Iter 2 fix: per-step rewards are CLAMPED to [-3, +3] in step(), not [-1, +1].
|
| 743 |
+
# Old normalization saturated late_quality at 1.0 for any mean_late >= +1,
|
| 744 |
+
# making the grader unable to distinguish good from excellent late-half.
|
| 745 |
+
late_quality = max(0.0, min(1.0, (mean_late + 3.0) / 6.0))
|
| 746 |
gain = mean_late - mean_early
|
| 747 |
+
# gain in [-6, +6]; normalize to [0, 1] (only positive gain counts)
|
| 748 |
+
gain_norm = max(0.0, min(1.0, gain / 3.0))
|
| 749 |
adaptation_score = gain_norm * late_quality
|
| 750 |
else:
|
| 751 |
adaptation_score = 0.0
|
|
@@ -142,13 +142,15 @@ def action_legal(completions, **kwargs) -> list[float]:
|
|
| 142 |
Layer 2: Is the parsed action one of the 10 valid actions?
|
| 143 |
|
| 144 |
All 10 actions are always legal in this env (no state-dependent validity).
|
| 145 |
-
+0.5
|
|
|
|
|
|
|
| 146 |
"""
|
| 147 |
scores = []
|
| 148 |
for completion in completions:
|
| 149 |
response = completion[0]["content"] if isinstance(completion, list) else completion
|
| 150 |
action = extract_action(response)
|
| 151 |
-
scores.append(0.
|
| 152 |
return scores
|
| 153 |
|
| 154 |
|
|
@@ -201,14 +203,24 @@ def env_reward(
|
|
| 201 |
ep_history = prompt_data.get("action_history", [])
|
| 202 |
ep_mode = prompt_data.get("profile_mode", "continuous")
|
| 203 |
else:
|
| 204 |
-
|
|
|
|
|
|
|
| 205 |
ep_history = []
|
| 206 |
ep_mode = "continuous"
|
| 207 |
|
| 208 |
try:
|
| 209 |
env = _replay_env(ep_seed, ep_history, ep_mode)
|
| 210 |
obs = env.step(RhythmAction(action_type=action_type))
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
except Exception:
|
| 213 |
scores.append(-3.0)
|
| 214 |
|
|
|
|
| 142 |
Layer 2: Is the parsed action one of the 10 valid actions?
|
| 143 |
|
| 144 |
All 10 actions are always legal in this env (no state-dependent validity).
|
| 145 |
+
Iter 2 fix: returns 0 for valid (was +0.5) so this layer becomes pure penalty
|
| 146 |
+
for malformed outputs. The +0.5 was a constant reward that contributed zero
|
| 147 |
+
to GRPO advantage and helped trigger mode collapse in iter 1.
|
| 148 |
"""
|
| 149 |
scores = []
|
| 150 |
for completion in completions:
|
| 151 |
response = completion[0]["content"] if isinstance(completion, list) else completion
|
| 152 |
action = extract_action(response)
|
| 153 |
+
scores.append(0.0 if action is not None else -1.0)
|
| 154 |
return scores
|
| 155 |
|
| 156 |
|
|
|
|
| 203 |
ep_history = prompt_data.get("action_history", [])
|
| 204 |
ep_mode = prompt_data.get("profile_mode", "continuous")
|
| 205 |
else:
|
| 206 |
+
# Iter 2 fix: mix index with prime to break deterministic seed clusters
|
| 207 |
+
# (avoids all completions in a position-class getting identical env_reward)
|
| 208 |
+
ep_seed = (i * 17) ^ 0xBEEF
|
| 209 |
ep_history = []
|
| 210 |
ep_mode = "continuous"
|
| 211 |
|
| 212 |
try:
|
| 213 |
env = _replay_env(ep_seed, ep_history, ep_mode)
|
| 214 |
obs = env.step(RhythmAction(action_type=action_type))
|
| 215 |
+
reward = obs.reward
|
| 216 |
+
# Iter 2 fix: explicit repetition penalty as TRAINING signal (env already
|
| 217 |
+
# dampens repeats but doesn't add explicit negative reward). Discourages
|
| 218 |
+
# the lazy "spam one action" mode collapse pattern.
|
| 219 |
+
if ep_history and len(ep_history) >= 2:
|
| 220 |
+
recent = ep_history[-3:] # last 3 actions
|
| 221 |
+
if recent.count(action_type.value) >= 2: # this would make 3+ in a row
|
| 222 |
+
reward -= 0.3
|
| 223 |
+
scores.append(reward)
|
| 224 |
except Exception:
|
| 225 |
scores.append(-3.0)
|
| 226 |
|
|
@@ -41,8 +41,10 @@ def main():
|
|
| 41 |
help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)")
|
| 42 |
parser.add_argument("--lora_rank", type=int, default=8,
|
| 43 |
help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
|
| 44 |
-
parser.add_argument("--hint_fraction", type=float, default=0.
|
| 45 |
-
help="Fraction of dataset with profile hint visible (
|
|
|
|
|
|
|
| 46 |
parser.add_argument("--profile_mode", type=str, default="continuous",
|
| 47 |
choices=["continuous", "discrete"],
|
| 48 |
help="continuous = sampled per-episode (meta-RL); discrete = 3 hardcoded profiles")
|
|
@@ -144,13 +146,15 @@ def main():
|
|
| 144 |
max_prompt_length = 600 # history + hint room
|
| 145 |
max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
|
| 146 |
|
| 147 |
-
# reward_weights:
|
| 148 |
-
#
|
|
|
|
|
|
|
| 149 |
# Order MUST match reward_funcs in main(): format_valid, action_legal, env_reward, belief_accuracy
|
| 150 |
-
reward_weights = [0.
|
| 151 |
|
| 152 |
training_args_kwargs = dict(
|
| 153 |
-
temperature=1.
|
| 154 |
learning_rate=args.learning_rate,
|
| 155 |
beta=args.beta,
|
| 156 |
max_grad_norm=0.5,
|
|
|
|
| 41 |
help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)")
|
| 42 |
parser.add_argument("--lora_rank", type=int, default=8,
|
| 43 |
help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
|
| 44 |
+
parser.add_argument("--hint_fraction", type=float, default=0.0,
|
| 45 |
+
help="Fraction of dataset with profile hint visible. Default 0.0 (no hints) "
|
| 46 |
+
"to eliminate train-eval distribution mismatch. Set >0 only if you ALSO "
|
| 47 |
+
"show hints during eval.")
|
| 48 |
parser.add_argument("--profile_mode", type=str, default="continuous",
|
| 49 |
choices=["continuous", "discrete"],
|
| 50 |
help="continuous = sampled per-episode (meta-RL); discrete = 3 hardcoded profiles")
|
|
|
|
| 146 |
max_prompt_length = 600 # history + hint room
|
| 147 |
max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
|
| 148 |
|
| 149 |
+
# reward_weights: aggressively suppress saturated format/legal layers, amplify
|
| 150 |
+
# the variable signals (env + belief). Iter 1 collapsed because format/legal at
|
| 151 |
+
# +1.0/+0.5 contributed zero to GRPO advantage; iter 2 makes belief the dominant
|
| 152 |
+
# learning signal.
|
| 153 |
# Order MUST match reward_funcs in main(): format_valid, action_legal, env_reward, belief_accuracy
|
| 154 |
+
reward_weights = [0.05, 0.05, 1.5, 3.0]
|
| 155 |
|
| 156 |
training_args_kwargs = dict(
|
| 157 |
+
temperature=1.5, # bumped from 1.0 to force diverse rollouts and break mode collapse
|
| 158 |
learning_rate=args.learning_rate,
|
| 159 |
beta=args.beta,
|
| 160 |
max_grad_norm=0.5,
|