Spaces:
Sleeping
iter3: align reward with grader + belief-first format + exploration shaping
Browse filesITER 2 FAILURE: agent escaped single-action collapse only to find a 2-cycle
loop (MEDITATE+EXERCISE alternation). final_score 0.22, worse than random.
Belief learning (+0.36) works but doesn't influence action choice β belief
was emitted as a string AFTER the action, so action couldn't condition on it.
Five runtime-level fixes for iter 3:
[1] Per-step reward grader-alignment (server/rhythm_environment.py:_compute_reward)
Add profile-INDEPENDENT bias: +0.5*progress_delta + 0.4*connection_delta.
Profile-weighted reward still drives belief inference (varies by profile),
but agent now ALWAYS gets penalized for ignoring progress and connection.
Closes the proxy/goal misalignment that let M+E alternation win.
[2] Belief-first output format (training/reward_functions.py, dataset.py, inference.py)
Was: 'ACTION_NAME S M W' (belief was post-hoc afterthought, didn't condition action)
Now: 'S M W ACTION_NAME' (causal LM generates belief tokens FIRST, then
action conditions on those belief tokens via attention).
The parser accepts both for backward compat. System prompt updated to
explain the belief-action coupling so the model learns to use the format.
[3] N-cycle penalty (training/reward_functions.py:env_reward)
Iter 2 had 3-in-a-row penalty but the agent found a 2-cycle loophole.
New rule: if last 6 actions have <=2 unique values, -0.4. Closes the M-E
alternation pattern specifically and any longer N-cycle.
[4] New-action exploration bonus (training/reward_functions.py:env_reward)
+0.2 reward for taking an action that hasn't appeared in the current
episode yet (until 6 unique actions tried). Pushes the agent to PROBE
early in episodes β the canonical meta-RL exploration signal.
[5] Sparse terminal reward (server/rhythm_environment.py:step)
At step 28, add (final_score - 0.5) * 5 to per-step reward. Direct
supervision on the actual grader. Range [-2.5, +2.5] gives strong
end-of-episode signal that overwhelms any local reward-hack.
Plus training config bumps:
- MAX_STEPS 400 -> 800 (iter 2 was still rising at step 400)
- NUM_GENERATIONS 4 -> 8 (lower advantage variance)
- LORA_RANK 8 -> 16 (more capacity for policy + belief integration)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- inference.py +20 -13
- scripts/train_on_hf.py +5 -5
- server/rhythm_environment.py +20 -3
- training/dataset.py +21 -15
- training/reward_functions.py +80 -36
|
@@ -80,24 +80,31 @@ people differently β you must INFER who you're helping from the rewards and
|
|
| 80 |
meter changes you observe.
|
| 81 |
|
| 82 |
Each step, output ONE LINE in this exact format:
|
| 83 |
-
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,
|
| 87 |
FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH
|
| 88 |
|
| 89 |
-
|
| 90 |
-
S = social preference (0=hates being social, 9=loves being social)
|
| 91 |
-
M = morning preference (0=night owl, 9=morning person)
|
| 92 |
-
W = work preference (0=avoids work, 9=workaholic)
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
-
|
| 98 |
-
-
|
| 99 |
-
-
|
| 100 |
-
- Watch for crashes: any meter under 0.1
|
| 101 |
- Connection decays passively β actively maintain it.
|
| 102 |
Respond with ONLY the format line, no other text.""")
|
| 103 |
|
|
@@ -212,7 +219,7 @@ Meters:
|
|
| 212 |
Serenity: {obs.serenity:.2f}
|
| 213 |
Connection: {obs.connection:.2f}{event_str}{history_str}
|
| 214 |
|
| 215 |
-
|
| 216 |
|
| 217 |
try:
|
| 218 |
completion = llm_client.chat.completions.create(
|
|
|
|
| 80 |
meter changes you observe.
|
| 81 |
|
| 82 |
Each step, output ONE LINE in this exact format:
|
| 83 |
+
S M W ACTION_NAME
|
| 84 |
|
| 85 |
+
First write your BELIEF as 3 digits 0-9, then the ACTION that fits:
|
| 86 |
+
S = social preference (0=hates social, 9=loves social)
|
| 87 |
+
M = morning preference (0=night owl, 9=morning person)
|
| 88 |
+
W = work preference (0=avoids work, 9=workaholic)
|
| 89 |
+
|
| 90 |
+
ACTION choices:
|
| 91 |
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,
|
| 92 |
FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH
|
| 93 |
|
| 94 |
+
Example: 3 8 7 DEEP_WORK
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
Belief-action coupling guide:
|
| 97 |
+
- High S: SOCIALIZE, FAMILY_TIME (extrovert boosts)
|
| 98 |
+
- High M: DEEP_WORK in morning slots (morning-person bonus)
|
| 99 |
+
- High W: DEEP_WORK, LEARN (workaholic energy)
|
| 100 |
+
- Low S: MEDITATE, ME_TIME (introvert recharge)
|
| 101 |
+
- Low M: DEEP_WORK in evening/night (night-owl bonus)
|
| 102 |
|
| 103 |
+
Tactics:
|
| 104 |
+
- Early week: PROBE varied actions to gather information.
|
| 105 |
+
- Late week: EXPLOIT β pick actions matching your sharpened belief.
|
| 106 |
+
- Don't repeat the same action; you'll get a repetition penalty.
|
| 107 |
+
- Watch for crashes: any meter under 0.1 = big penalty.
|
| 108 |
- Connection decays passively β actively maintain it.
|
| 109 |
Respond with ONLY the format line, no other text.""")
|
| 110 |
|
|
|
|
| 219 |
Serenity: {obs.serenity:.2f}
|
| 220 |
Connection: {obs.connection:.2f}{event_str}{history_str}
|
| 221 |
|
| 222 |
+
Output belief then action (format: S M W ACTION_NAME):""")
|
| 223 |
|
| 224 |
try:
|
| 225 |
completion = llm_client.chat.completions.create(
|
|
@@ -57,13 +57,13 @@ PLOTS_DIR = "/tmp/rhythm_env/plots"
|
|
| 57 |
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
|
| 58 |
|
| 59 |
if FAST_MODE:
|
| 60 |
-
#
|
| 61 |
-
DEFAULTS = dict(MAX_STEPS=
|
| 62 |
-
NUM_GENERATIONS=
|
| 63 |
LEARNING_RATE=5e-5, EVAL_EPISODES=2)
|
| 64 |
else:
|
| 65 |
-
DEFAULTS = dict(MAX_STEPS=
|
| 66 |
-
NUM_GENERATIONS=8, LORA_RANK=
|
| 67 |
LEARNING_RATE=5e-5, EVAL_EPISODES=5)
|
| 68 |
|
| 69 |
MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
|
|
|
|
| 57 |
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
|
| 58 |
|
| 59 |
if FAST_MODE:
|
| 60 |
+
# Iter 3 preset: 800 steps + 8 generations + LoRA 16 to escape mode collapse for real
|
| 61 |
+
DEFAULTS = dict(MAX_STEPS=800, NUM_EPISODES=200, MAX_SAMPLES=2000,
|
| 62 |
+
NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04,
|
| 63 |
LEARNING_RATE=5e-5, EVAL_EPISODES=2)
|
| 64 |
else:
|
| 65 |
+
DEFAULTS = dict(MAX_STEPS=2000, NUM_EPISODES=400, MAX_SAMPLES=4000,
|
| 66 |
+
NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04,
|
| 67 |
LEARNING_RATE=5e-5, EVAL_EPISODES=5)
|
| 68 |
|
| 69 |
MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
|
|
@@ -507,6 +507,13 @@ class RhythmEnvironment(Environment):
|
|
| 507 |
if done:
|
| 508 |
final_score = self._grade_episode()
|
| 509 |
reward_breakdown["final_score"] = round(final_score, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
# --- 14. Update state ---
|
| 512 |
self._state.step_count = self._timestep
|
|
@@ -695,10 +702,20 @@ class RhythmEnvironment(Environment):
|
|
| 695 |
self._vitality = max(0.0, self._vitality - vd)
|
| 696 |
|
| 697 |
def _compute_reward(self, deltas: Dict[str, float]) -> float:
|
| 698 |
-
"""Compute reward as hidden-weighted sum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
weights = self._profile["reward_weights"]
|
| 700 |
-
|
| 701 |
-
|
|
|
|
|
|
|
| 702 |
|
| 703 |
def _grade_episode(self) -> float:
|
| 704 |
"""
|
|
|
|
| 507 |
if done:
|
| 508 |
final_score = self._grade_episode()
|
| 509 |
reward_breakdown["final_score"] = round(final_score, 4)
|
| 510 |
+
# Iter 3 fix: sparse terminal reward β direct supervision on grader
|
| 511 |
+
# final_score is in [0, 1]; baseline-relative bonus gives strong signal
|
| 512 |
+
# for ending the episode well. Range: -2.5 (terrible) to +2.5 (perfect).
|
| 513 |
+
terminal_bonus = (final_score - 0.5) * 5.0
|
| 514 |
+
reward = max(-3.0, min(3.0, reward + terminal_bonus))
|
| 515 |
+
self._total_reward += terminal_bonus # update tracking too
|
| 516 |
+
reward_breakdown["terminal_bonus"] = round(terminal_bonus, 4)
|
| 517 |
|
| 518 |
# --- 14. Update state ---
|
| 519 |
self._state.step_count = self._timestep
|
|
|
|
| 702 |
self._vitality = max(0.0, self._vitality - vd)
|
| 703 |
|
| 704 |
def _compute_reward(self, deltas: Dict[str, float]) -> float:
|
| 705 |
+
"""Compute reward as hidden-weighted sum + grader-aligned bias.
|
| 706 |
+
|
| 707 |
+
Iter 3 fix: Add a profile-INDEPENDENT bias term for progress and
|
| 708 |
+
connection. The original profile-weighted reward drives belief inference
|
| 709 |
+
(varies by profile), but allowed agents to game it by spamming recovery
|
| 710 |
+
actions if the sampled profile didn't weight progress/connection. The
|
| 711 |
+
bias makes the per-step reward correlate with the FINAL grader (which
|
| 712 |
+
weights progress 0.25 and connection 0.15).
|
| 713 |
+
"""
|
| 714 |
weights = self._profile["reward_weights"]
|
| 715 |
+
profile_reward = sum(deltas[m] * weights[m] for m in METERS) * REWARD_SCALE
|
| 716 |
+
# Grader-aligned bias: scaled so max bonus is ~0.1/step (manageable vs profile_reward)
|
| 717 |
+
grader_bias = 0.5 * deltas["progress"] + 0.4 * deltas["connection"]
|
| 718 |
+
return profile_reward + grader_bias
|
| 719 |
|
| 720 |
def _grade_episode(self) -> float:
|
| 721 |
"""
|
|
@@ -32,23 +32,29 @@ SYSTEM_PROMPT = (
|
|
| 32 |
"different people differently β you must INFER who you're helping from the\n"
|
| 33 |
"rewards and meter changes you observe.\n\n"
|
| 34 |
"Each step, output ONE LINE in this exact format:\n"
|
| 35 |
-
"
|
| 36 |
-
"
|
| 37 |
-
" DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,\n"
|
| 38 |
-
" FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH\n\n"
|
| 39 |
-
"and S, M, W are single digits (0-9) representing your current belief:\n"
|
| 40 |
" S = social preference (0=hates being social, 9=loves being social)\n"
|
| 41 |
" M = morning preference (0=night owl, 9=morning person)\n"
|
| 42 |
" W = work preference (0=avoids work, 9=workaholic)\n\n"
|
| 43 |
-
"
|
| 44 |
-
"
|
| 45 |
-
"
|
| 46 |
-
"
|
| 47 |
-
"-
|
| 48 |
-
"-
|
| 49 |
-
"
|
| 50 |
-
"-
|
| 51 |
-
"-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"Respond with ONLY the format line, no other text."
|
| 53 |
)
|
| 54 |
|
|
@@ -96,7 +102,7 @@ def format_observation_prompt(obs, profile_hint: dict | None = None) -> str:
|
|
| 96 |
f"{event_str}"
|
| 97 |
f"{history_str}"
|
| 98 |
f"{hint_str}\n\n"
|
| 99 |
-
f"
|
| 100 |
)
|
| 101 |
|
| 102 |
|
|
|
|
| 32 |
"different people differently β you must INFER who you're helping from the\n"
|
| 33 |
"rewards and meter changes you observe.\n\n"
|
| 34 |
"Each step, output ONE LINE in this exact format:\n"
|
| 35 |
+
" S M W ACTION_NAME\n\n"
|
| 36 |
+
"First, write your belief about the person as 3 digits (0-9):\n"
|
|
|
|
|
|
|
|
|
|
| 37 |
" S = social preference (0=hates being social, 9=loves being social)\n"
|
| 38 |
" M = morning preference (0=night owl, 9=morning person)\n"
|
| 39 |
" W = work preference (0=avoids work, 9=workaholic)\n\n"
|
| 40 |
+
"Then choose the action that BEST FITS that belief, from:\n"
|
| 41 |
+
" DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,\n"
|
| 42 |
+
" FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH\n\n"
|
| 43 |
+
"Example: 3 8 7 DEEP_WORK (this person is moderately introverted, strongly\n"
|
| 44 |
+
"morning-oriented, fairly work-driven β so deep work in the morning fits)\n\n"
|
| 45 |
+
"Belief-action coupling guide:\n"
|
| 46 |
+
"- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply\n"
|
| 47 |
+
"- High M (morning person): DEEP_WORK in early slots gets bonus cognition\n"
|
| 48 |
+
"- High W (workaholic): DEEP_WORK, LEARN drive progress AND may energize\n"
|
| 49 |
+
"- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE\n"
|
| 50 |
+
"- Low M (night owl): DEEP_WORK in evening/night slots\n\n"
|
| 51 |
+
"Tactics:\n"
|
| 52 |
+
"- Early in the week, PROBE varied actions to gather information.\n"
|
| 53 |
+
"- Update your belief from the rewards you see.\n"
|
| 54 |
+
"- Late in the week, EXPLOIT β pick actions matching your sharpened belief.\n"
|
| 55 |
+
"- Don't repeat the same action excessively; you'll get a repetition penalty.\n"
|
| 56 |
+
"- Watch for crashes: any meter under 0.1 = big penalty.\n"
|
| 57 |
+
"- Connection decays passively β actively maintain it via SOCIALIZE/FAMILY_TIME.\n"
|
| 58 |
"Respond with ONLY the format line, no other text."
|
| 59 |
)
|
| 60 |
|
|
|
|
| 102 |
f"{event_str}"
|
| 103 |
f"{history_str}"
|
| 104 |
f"{hint_str}\n\n"
|
| 105 |
+
f"Output your belief, then your action (format: S M W ACTION_NAME):"
|
| 106 |
)
|
| 107 |
|
| 108 |
|
|
@@ -2,20 +2,29 @@
|
|
| 2 |
Reward functions for RhythmEnv GRPO training (meta-RL version).
|
| 3 |
|
| 4 |
Four-layer reward stack:
|
| 5 |
-
1. format_valid β does the LLM output have a parseable
|
| 6 |
2. action_legal β is the action one of the 10 valid actions?
|
| 7 |
3. env_reward β actual environment reward (seed-replay) for the chosen action
|
|
|
|
| 8 |
4. belief_accuracy β how close is the belief vector to the hidden profile's true vector?
|
| 9 |
|
| 10 |
-
|
| 11 |
-
- ACTION_NAME: one of 10 valid actions
|
| 12 |
- S, M, W: single digits 0-9 representing the agent's belief about the user
|
| 13 |
S = social preference (0=hates social, 9=loves social)
|
| 14 |
M = morning preference (0=night owl, 9=morning person)
|
| 15 |
W = work preference (0=avoids work, 9=workaholic)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
β action=DEEP_WORK, belief=[0.33, 0.89, 0.78]
|
| 19 |
|
| 20 |
Each function returns a list of floats (one per completion).
|
| 21 |
"""
|
|
@@ -58,7 +67,7 @@ def extract_action_and_belief(text: str) -> tuple[ActionType | None, list[float]
|
|
| 58 |
if not parts:
|
| 59 |
return None, list(DEFAULT_BELIEF), False
|
| 60 |
|
| 61 |
-
# Find action and its index in parts
|
| 62 |
action: ActionType | None = None
|
| 63 |
action_idx = -1
|
| 64 |
for idx, p in enumerate(parts):
|
|
@@ -76,35 +85,53 @@ def extract_action_and_belief(text: str) -> tuple[ActionType | None, list[float]
|
|
| 76 |
if action is not None:
|
| 77 |
break
|
| 78 |
|
| 79 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
belief: list[float] = []
|
| 81 |
belief_provided = False
|
|
|
|
| 82 |
if action_idx >= 0:
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
belief.append(0.5)
|
| 104 |
-
else:
|
| 105 |
-
belief = list(DEFAULT_BELIEF)
|
| 106 |
-
|
| 107 |
-
if not belief:
|
| 108 |
belief = list(DEFAULT_BELIEF)
|
| 109 |
|
| 110 |
return action, belief, belief_provided
|
|
@@ -213,13 +240,30 @@ def env_reward(
|
|
| 213 |
env = _replay_env(ep_seed, ep_history, ep_mode)
|
| 214 |
obs = env.step(RhythmAction(action_type=action_type))
|
| 215 |
reward = obs.reward
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
#
|
| 219 |
if ep_history and len(ep_history) >= 2:
|
| 220 |
-
|
| 221 |
-
if
|
| 222 |
reward -= 0.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
scores.append(reward)
|
| 224 |
except Exception:
|
| 225 |
scores.append(-3.0)
|
|
|
|
| 2 |
Reward functions for RhythmEnv GRPO training (meta-RL version).
|
| 3 |
|
| 4 |
Four-layer reward stack:
|
| 5 |
+
1. format_valid β does the LLM output have a parseable belief + action format?
|
| 6 |
2. action_legal β is the action one of the 10 valid actions?
|
| 7 |
3. env_reward β actual environment reward (seed-replay) for the chosen action
|
| 8 |
+
plus diversity/exploration shaping
|
| 9 |
4. belief_accuracy β how close is the belief vector to the hidden profile's true vector?
|
| 10 |
|
| 11 |
+
ITER 3 FORMAT (belief-first): "S M W ACTION_NAME"
|
|
|
|
| 12 |
- S, M, W: single digits 0-9 representing the agent's belief about the user
|
| 13 |
S = social preference (0=hates social, 9=loves social)
|
| 14 |
M = morning preference (0=night owl, 9=morning person)
|
| 15 |
W = work preference (0=avoids work, 9=workaholic)
|
| 16 |
+
- ACTION_NAME: one of 10 valid actions
|
| 17 |
+
|
| 18 |
+
Example: "3 8 7 DEEP_WORK"
|
| 19 |
+
β belief=[0.33, 0.89, 0.78], action=DEEP_WORK
|
| 20 |
+
|
| 21 |
+
Why belief-first: in causal LM generation, tokens generated EARLIER condition
|
| 22 |
+
LATER tokens. With "S M W ACTION", the model commits to a belief first, and
|
| 23 |
+
the action is then conditioned on that belief β making the belief functionally
|
| 24 |
+
useful for action selection. The previous "ACTION S M W" order made belief a
|
| 25 |
+
post-hoc afterthought that didn't influence behavior.
|
| 26 |
|
| 27 |
+
The parser ALSO accepts the old "ACTION S M W" format for backward compatibility.
|
|
|
|
| 28 |
|
| 29 |
Each function returns a list of floats (one per completion).
|
| 30 |
"""
|
|
|
|
| 67 |
if not parts:
|
| 68 |
return None, list(DEFAULT_BELIEF), False
|
| 69 |
|
| 70 |
+
# Find action and its index in parts
|
| 71 |
action: ActionType | None = None
|
| 72 |
action_idx = -1
|
| 73 |
for idx, p in enumerate(parts):
|
|
|
|
| 85 |
if action is not None:
|
| 86 |
break
|
| 87 |
|
| 88 |
+
# Iter 3: parse belief from BEFORE the action (belief-first format).
|
| 89 |
+
# Falls back to AFTER the action (legacy format) if no digits found before.
|
| 90 |
+
def _parse_digit(token: str) -> float | None:
|
| 91 |
+
token = token.strip().rstrip(".")
|
| 92 |
+
if not token:
|
| 93 |
+
return None
|
| 94 |
+
try:
|
| 95 |
+
if len(token) == 1 and token.isdigit():
|
| 96 |
+
return int(token) / 9.0
|
| 97 |
+
val = float(token)
|
| 98 |
+
if val > 1.0:
|
| 99 |
+
val = val / 9.0
|
| 100 |
+
return max(0.0, min(1.0, val))
|
| 101 |
+
except (ValueError, IndexError):
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
belief: list[float] = []
|
| 105 |
belief_provided = False
|
| 106 |
+
|
| 107 |
if action_idx >= 0:
|
| 108 |
+
# Try belief-first: 3 digits BEFORE the action
|
| 109 |
+
if action_idx >= 3:
|
| 110 |
+
cand = [_parse_digit(parts[action_idx - 3 + i]) for i in range(3)]
|
| 111 |
+
if all(c is not None for c in cand):
|
| 112 |
+
belief = cand # type: ignore[assignment]
|
| 113 |
+
belief_provided = True
|
| 114 |
+
|
| 115 |
+
# If belief-first didn't work, try legacy after-action format
|
| 116 |
+
if not belief_provided:
|
| 117 |
+
after_belief: list[float] = []
|
| 118 |
+
after_provided = False
|
| 119 |
+
for i in range(3):
|
| 120 |
+
j = action_idx + 1 + i
|
| 121 |
+
if j < len(parts):
|
| 122 |
+
d = _parse_digit(parts[j])
|
| 123 |
+
if d is not None:
|
| 124 |
+
after_belief.append(d)
|
| 125 |
+
after_provided = True
|
| 126 |
else:
|
| 127 |
+
after_belief.append(0.5)
|
| 128 |
+
else:
|
| 129 |
+
after_belief.append(0.5)
|
| 130 |
+
if after_provided:
|
| 131 |
+
belief = after_belief
|
| 132 |
+
belief_provided = True
|
| 133 |
+
|
| 134 |
+
if not belief or len(belief) != 3:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
belief = list(DEFAULT_BELIEF)
|
| 136 |
|
| 137 |
return action, belief, belief_provided
|
|
|
|
| 240 |
env = _replay_env(ep_seed, ep_history, ep_mode)
|
| 241 |
obs = env.step(RhythmAction(action_type=action_type))
|
| 242 |
reward = obs.reward
|
| 243 |
+
chosen = action_type.value
|
| 244 |
+
|
| 245 |
+
# Iter 2 fix: explicit 3-in-a-row repetition penalty
|
| 246 |
if ep_history and len(ep_history) >= 2:
|
| 247 |
+
recent3 = ep_history[-3:]
|
| 248 |
+
if recent3.count(chosen) >= 2: # this action would make 3+ in a row
|
| 249 |
reward -= 0.3
|
| 250 |
+
|
| 251 |
+
# Iter 3 fix: N-CYCLE penalty (catches the M-E-M-E-... loop iter 2 fell into)
|
| 252 |
+
# If last 6 actions (including this one) have <=2 unique values, apply penalty
|
| 253 |
+
if ep_history and len(ep_history) >= 5:
|
| 254 |
+
last6 = ep_history[-5:] + [chosen]
|
| 255 |
+
if len(set(last6)) <= 2:
|
| 256 |
+
reward -= 0.4
|
| 257 |
+
|
| 258 |
+
# Iter 3 fix: NEW-ACTION exploration bonus
|
| 259 |
+
# If this action hasn't appeared yet in the current episode, +0.2.
|
| 260 |
+
# Strong incentive in early steps to TRY varied actions, fading as
|
| 261 |
+
# the action set grows. Stops once 6+ different actions tried.
|
| 262 |
+
if ep_history is not None:
|
| 263 |
+
seen = set(ep_history)
|
| 264 |
+
if chosen not in seen and len(seen) < 6:
|
| 265 |
+
reward += 0.2
|
| 266 |
+
|
| 267 |
scores.append(reward)
|
| 268 |
except Exception:
|
| 269 |
scores.append(-3.0)
|