Lomesh2000
FIX: grop update new , env changes
e6a02dd
# salespath_env/server/reward.py
"""
SalesPath reward computation.
Composes five OpenEnv `Rubric` components into one `WeightedSum`.
Each sub-rubric scores the (action, observation_like_payload) pair on
[-1, 1] (or [0, 1] where indicated).
Design notes
------------
* Outcome reward: terminal-only, distinguishes honest close-failure
from rule-violation termination (per arXiv:2601.19100 §3.1 — proxy
rewards must differentiate failure modes).
* Compliance reward: per-turn, dense (the headline training signal).
* Ordering reward: **potential-based shaping** — only the *delta* in
workflow progress is paid out per turn. This is the construction
from arXiv:2408.10215 §4.2 that does not change the optimal policy
while killing the "stall after early correct steps" reward-hack.
* Efficiency: terminal-only, mild penalty for turn overhead.
* Format: explicit `format_ok` flag from the parser — rejects silent
fallbacks where a malformed completion is silently coerced to a
valid action_type.
The legacy procedural `compute_reward(...)` function is kept as a
thin wrapper so existing call sites (tests, environment, training)
keep working unchanged.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from openenv.core.rubrics import Rubric, WeightedSum
from ..models import SalesPathAction, SalesPathState
DIFFICULTY_OPTIMAL_TURNS: Dict[int, int] = {
1: 5,
2: 8,
3: 12,
4: 14,
}
# ---------------------------------------------------------------------------
# RewardContext: small struct passed to every Rubric
# ---------------------------------------------------------------------------
@dataclass
class RewardContext:
"""
Carries everything a sub-rubric needs.
Used as the `observation` argument to each `Rubric.__call__`.
"""
state: SalesPathState
response_token: str
new_violations: list
episode_done: bool
prev_steps_completed: list
format_ok: bool
# ---------------------------------------------------------------------------
# Sub-rubrics
# ---------------------------------------------------------------------------
class OutcomeRubric(Rubric):
"""
Terminal-only outcome reward.
Distinguishes:
+1.0 successful CLOSE
+0.5 correct DISQUALIFY (R08 not violated)
-0.3 honest close-failure (CLOSE attempted but prospect rejected)
-0.3 turn-limit reached
-0.7 episode terminated due to >=3 rule violations
-0.5 invalid DISQUALIFY (R08 violated)
0.0 non-terminal turns
"""
def forward(self, action: SalesPathAction, ctx: RewardContext) -> float:
if not ctx.episode_done:
return 0.0
if ctx.response_token == "accept:close_success":
return 1.0
if action.action_type == "DISQUALIFY":
return 0.5 if "R08" not in ctx.new_violations else -0.5
if ctx.response_token == "reject:close_failed":
return -0.3
if len(ctx.state.constraints_violated) >= 3:
return -0.7
if ctx.state.turn_number >= 20:
return -0.3
return -0.3
class ComplianceRubric(Rubric):
"""
Per-turn rule compliance.
Scores -0.2 per *new* violation this turn, clipped at -1.0.
Returns 0.0 when no violations occur (the common case for a trained agent).
"""
def forward(self, action: SalesPathAction, ctx: RewardContext) -> float:
return max(-1.0, -0.2 * len(ctx.new_violations))
class OrderingRubric(Rubric):
"""
Potential-based workflow-progress shaping (arXiv:2408.10215 §4.2).
Returns the *delta* in correct-prefix length between the previous and
current step. Sums to the same total over an episode as a monotonic
"fraction-correct" reward, but cannot be farmed by stalling after a
few correct early steps.
Subtlety
--------
`state.steps_completed` may contain mandatory-but-not-listed actions
(PROSPECT is required by R06 but absent from `DIFFICULTY_WORKFLOW`).
A naive index-by-index comparison would mis-align at position 0 and
award 0 on every correct turn. We instead walk `required_workflow`
in order and count how many of its entries appear, in order, anywhere
in `steps_completed` — i.e. the longest prefix of `required` that is
a subsequence of `completed`. This stays monotonic and still
potential-based (the delta is always 0 or 1).
"""
@staticmethod
def _correct_prefix(required: list, completed: list) -> int:
i = 0
for step in completed:
if i >= len(required):
break
if step == required[i]:
i += 1
return i
def forward(self, action: SalesPathAction, ctx: RewardContext) -> float:
required = ctx.state.required_workflow
if not required:
return 0.0
prev_correct = self._correct_prefix(required, ctx.prev_steps_completed)
curr_correct = self._correct_prefix(required, ctx.state.steps_completed)
delta = curr_correct - prev_correct
return delta / len(required)
class EfficiencyRubric(Rubric):
"""
Penalises turn-overhead at episode termination.
Returns 0 on non-terminal turns.
"""
def forward(self, action: SalesPathAction, ctx: RewardContext) -> float:
if not ctx.episode_done:
return 0.0
optimal = DIFFICULTY_OPTIMAL_TURNS.get(ctx.state.difficulty, 10)
extra = max(0, ctx.state.turn_number - optimal)
return max(-0.3, -0.05 * extra)
class FormatRubric(Rubric):
"""
Strictly checks that:
1. The model's raw output parsed as a valid ACTION/CONTENT block
(`format_ok` is True) AND
2. The resulting action_type is in VALID_ACTIONS.
Either failure → -0.3 (no partial credit, per proposal §5.2).
"""
def forward(self, action: SalesPathAction, ctx: RewardContext) -> float:
if not ctx.format_ok:
return -0.3
return 1.0 if action.is_valid() else -0.3
# ---------------------------------------------------------------------------
# Composed rubric
# ---------------------------------------------------------------------------
class SalesPathRubric(WeightedSum):
"""
The full SalesPath reward.
Weights — re-balanced per arXiv:2601.19100 recommendation that
process-level signals dominate sparse-outcome signals when episodes
are long and credit assignment is hard:
compliance 0.40 (headline training signal)
outcome 0.20
ordering 0.20
efficiency 0.10
format 0.10
Access individual scores:
rubric.last_score # composite
rubric.outcome.last_score # per-component
for n, r in rubric.named_rubrics():
print(n, r.last_score)
"""
def __init__(self):
outcome = OutcomeRubric()
compliance = ComplianceRubric()
ordering = OrderingRubric()
efficiency = EfficiencyRubric()
fmt = FormatRubric()
# WeightedSum.__init__ calls Rubric.__init__ which initialises
# _rubric_children — so attribute assignment must happen via
# super().__init__ first.
super().__init__(
rubrics=[outcome, compliance, ordering, efficiency, fmt],
weights=[0.20, 0.40, 0.20, 0.10, 0.10],
)
# Re-bind under semantic names for ergonomic access:
# rubric.compliance.last_score, rubric.outcome.last_score, etc.
self.outcome = outcome
self.compliance = compliance
self.ordering = ordering
self.efficiency = efficiency
self.format = fmt
# ---------------------------------------------------------------------------
# Procedural wrapper kept for backward compatibility
# ---------------------------------------------------------------------------
# Singleton — cheap, stateless aside from `last_score` introspection
_DEFAULT_RUBRIC = SalesPathRubric()
def compute_reward(
state: SalesPathState,
action: SalesPathAction,
response_token: str,
new_violations: list,
episode_done: bool,
prev_steps_completed: Optional[list] = None,
format_ok: bool = True,
) -> Tuple[float, dict]:
"""
Backward-compatible wrapper around `SalesPathRubric`.
Returns
-------
(total_reward, components)
components: dict with keys
r_outcome, r_compliance, r_ordering, r_efficiency, r_format, total
"""
if prev_steps_completed is None:
# Reconstruct: assume current action is the most recent one appended
prev_steps_completed = [
s for s in state.steps_completed if s != action.action_type
]
ctx = RewardContext(
state=state,
response_token=response_token,
new_violations=new_violations,
episode_done=episode_done,
prev_steps_completed=prev_steps_completed,
format_ok=format_ok,
)
total = _DEFAULT_RUBRIC(action, ctx)
components = {
"r_outcome": _DEFAULT_RUBRIC.outcome.last_score,
"r_compliance": _DEFAULT_RUBRIC.compliance.last_score,
"r_ordering": _DEFAULT_RUBRIC.ordering.last_score,
"r_efficiency": _DEFAULT_RUBRIC.efficiency.last_score,
"r_format": _DEFAULT_RUBRIC.format.last_score,
"total": total,
}
return total, components