OSINT / tests /test_spawn_reward_hooks.py
siddeshwar-kagatikar
fix(rewards): never crash GRPO on malformed completions
d814291
from osint_env.env.spawn_reward_hooks import critical_steps, parl_style_spawn_reward
def test_critical_steps_matches_parallel_path_length():
total = critical_steps(main_steps=[1, 1, 1], parallel_subagent_steps=[[3, 2], [0], [4, 1, 2]])
assert total == 1 + 3 + 1 + 0 + 1 + 4
def test_parl_reward_prefers_finished_parallel_work():
base = parl_style_spawn_reward(
task_outcome_reward=0.2,
spawn_count=4,
finished_subtasks=1,
critical_steps=12,
lambda_parallel=0.2,
lambda_finish=0.25,
anneal=1.0,
breadth=2,
depth=3,
)
better = parl_style_spawn_reward(
task_outcome_reward=0.2,
spawn_count=4,
finished_subtasks=4,
critical_steps=8,
lambda_parallel=0.2,
lambda_finish=0.25,
anneal=1.0,
breadth=4,
depth=2,
)
assert better > base
def test_parl_auxiliary_can_be_annealed_out():
frozen = parl_style_spawn_reward(
task_outcome_reward=0.7,
spawn_count=8,
finished_subtasks=8,
critical_steps=5,
anneal=0.0,
)
assert frozen == 0.7