Spaces:
Paused
Paused
File size: 5,012 Bytes
d814291 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | from __future__ import annotations
import math
from dataclasses import dataclass
@dataclass(slots=True)
class PARLRewardBreakdown:
total: float
auxiliary: float
parallel: float
finish: float
latency: float
breadth_bonus: float
depth_penalty: float
def critical_steps(main_steps: list[int], parallel_subagent_steps: list[list[int]]) -> int:
"""Compute critical-step latency proxy used in Kimi-style PARL shaping.
For each stage t, we add:
Smain(t) + max_i Ssub,i(t)
where Ssub,i(t) is the i-th sub-agent step count for that stage.
"""
if len(main_steps) != len(parallel_subagent_steps):
raise ValueError("main_steps and parallel_subagent_steps must have the same length")
total = 0
for stage_main, stage_sub in zip(main_steps, parallel_subagent_steps):
main = max(0, int(stage_main))
longest_sub = max((max(0, int(v)) for v in stage_sub), default=0)
total += main + longest_sub
return total
def parl_style_spawn_reward(
task_outcome_reward: float,
spawn_count: int,
finished_subtasks: int,
critical_steps: int,
lambda_parallel: float = 0.15,
lambda_finish: float = 0.20,
anneal: float = 1.0,
breadth: int | None = None,
depth: int | None = None,
max_parallel_hint: int | None = None,
) -> float:
"""Kimi K2.5 inspired PARL reward utility for future multi-agent branches.
This helper intentionally does not orchestrate agents. It only exposes the reward shape:
r_parl = r_perf + a * (lambda_parallel * r_parallel + lambda_finish * r_finish + r_latency)
where:
- r_parallel encourages non-zero agent spawning (avoids serial collapse)
- r_finish rewards meaningful completion, preventing spawn-only reward hacking
- r_latency favors lower critical-step execution paths
The optional breadth/depth controls are small shaping terms for future branches where
orchestration state includes tree shape telemetry.
"""
spawn_count = max(0, int(spawn_count))
finished_subtasks = max(0, int(finished_subtasks))
critical_steps = max(1, int(critical_steps))
anneal = max(0.0, min(1.0, anneal))
lambda_parallel = max(0.0, float(lambda_parallel))
lambda_finish = max(0.0, float(lambda_finish))
breadth = max(0, int(breadth or 0))
depth = max(0, int(depth or 0))
max_parallel_hint = max(0, int(max_parallel_hint or 0))
breakdown = parl_reward_breakdown(
task_outcome_reward=task_outcome_reward,
spawn_count=spawn_count,
finished_subtasks=finished_subtasks,
critical_steps=critical_steps,
lambda_parallel=lambda_parallel,
lambda_finish=lambda_finish,
anneal=anneal,
breadth=breadth,
depth=depth,
max_parallel_hint=max_parallel_hint,
)
return breakdown.total
def parl_reward_breakdown(
task_outcome_reward: float,
spawn_count: int,
finished_subtasks: int,
critical_steps: int,
lambda_parallel: float = 0.15,
lambda_finish: float = 0.20,
anneal: float = 1.0,
breadth: int | None = None,
depth: int | None = None,
max_parallel_hint: int | None = None,
) -> PARLRewardBreakdown:
spawn_count = max(0, int(spawn_count))
finished_subtasks = max(0, int(finished_subtasks))
critical_steps = max(1, int(critical_steps))
anneal = max(0.0, min(1.0, anneal))
lambda_parallel = max(0.0, float(lambda_parallel))
lambda_finish = max(0.0, float(lambda_finish))
breadth = max(0, int(breadth or 0))
depth = max(0, int(depth or 0))
max_parallel_hint = max(0, int(max_parallel_hint or 0))
if spawn_count == 0:
r_parallel = 0.0
r_finish = 0.0
else:
# Saturating incentive for parallelism so reward cannot grow unbounded with spawns.
r_parallel = math.tanh(spawn_count / 4.0)
if max_parallel_hint > 0:
utilization = min(1.0, spawn_count / max_parallel_hint)
r_parallel *= (0.7 + (0.3 * utilization))
r_finish = min(1.0, finished_subtasks / spawn_count)
if breadth > 0:
breadth_bonus = 0.04 * math.tanh(breadth / 6.0)
else:
breadth_bonus = 0.0
if depth > 0:
# Mild depth penalty discourages brittle over-decomposition chains.
depth_penalty = -0.03 * math.tanh(max(0, depth - 1) / 4.0)
else:
depth_penalty = 0.0
# Optional latency shaping hook using critical steps (higher is worse).
r_latency = 0.05 * (1.0 / critical_steps)
auxiliary = (
(lambda_parallel * r_parallel)
+ (lambda_finish * r_finish)
+ r_latency
+ breadth_bonus
+ depth_penalty
)
total = float(task_outcome_reward) + (anneal * auxiliary)
return PARLRewardBreakdown(
total=total,
auxiliary=anneal * auxiliary,
parallel=r_parallel,
finish=r_finish,
latency=r_latency,
breadth_bonus=breadth_bonus,
depth_penalty=depth_penalty,
)
|