Spaces:
Running
Running
Commit ·
ca62faa
1
Parent(s): 37edd09
Reward overhaul: add compute_reasoning_quality (anchoring+arithmetic+specificity+consistency), parity-grounds penalty, reduce outcome 40%->30%, add 10% reasoning quality signal
Browse files- server/reward.py +140 -8
- server/undertrial_environment.py +4 -1
- training/train_grpo.py +16 -3
server/reward.py
CHANGED
|
@@ -288,6 +288,7 @@ LOW_SES_TERMS = ["theft", "robbery", "dacoity", "kidnapping", "narcotics", "poc
|
|
| 288 |
def compute_bias_penalty(
|
| 289 |
recommended_outcome: str,
|
| 290 |
episode: Dict[str, Any],
|
|
|
|
| 291 |
) -> float:
|
| 292 |
"""
|
| 293 |
Computes a bias penalty for the agent's recommendation.
|
|
@@ -330,9 +331,125 @@ def compute_bias_penalty(
|
|
| 330 |
if bias_flag and gt_granted and not agent_granted:
|
| 331 |
penalty += 0.15
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
return max(0.0, min(1.0, penalty))
|
| 334 |
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
# ---------------------------------------------------------------------------
|
| 337 |
# Master reward function
|
| 338 |
# ---------------------------------------------------------------------------
|
|
@@ -347,30 +464,45 @@ def compute_reward(
|
|
| 347 |
step_count: int = 0,
|
| 348 |
max_steps: int = 10,
|
| 349 |
statutory_tool_used: bool = False,
|
|
|
|
|
|
|
|
|
|
| 350 |
) -> Dict[str, float]:
|
| 351 |
"""
|
| 352 |
Computes the full reward for a submitted bail assessment memo.
|
| 353 |
|
| 354 |
Formula:
|
| 355 |
-
R = 0.
|
| 356 |
+ 0.2*flight_risk_accuracy
|
| 357 |
+ 0.2*statutory_accuracy
|
| 358 |
+ 0.2*condition_appropriateness
|
| 359 |
-
+ 0.1*
|
|
|
|
|
|
|
| 360 |
- 0.3*bias_penalty
|
| 361 |
|
| 362 |
Returns a dict with all component scores + total_reward.
|
| 363 |
-
Range: [-0.
|
| 364 |
"""
|
| 365 |
gt = episode["ground_truth"]
|
| 366 |
|
|
|
|
|
|
|
| 367 |
om = compute_outcome_match(agent_outcome, gt)
|
| 368 |
fr = compute_flight_risk_accuracy(agent_flight_risk, gt)
|
| 369 |
sa = compute_statutory_accuracy(agent_eligible, agent_computation, episode)
|
| 370 |
ca = compute_condition_score(agent_outcome, agent_conditions, gt)
|
| 371 |
-
bias = compute_bias_penalty(agent_outcome, episode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
-
#
|
| 374 |
# Only fires on directionally-correct outcomes (om >= 0.8) to prevent
|
| 375 |
# rewarding efficient-but-wrong agents.
|
| 376 |
efficiency = 0.0
|
|
@@ -378,18 +510,18 @@ def compute_reward(
|
|
| 378 |
efficiency = round((1.0 - (step_count - 1) / (max_steps - 1)), 4)
|
| 379 |
efficiency = max(0.0, min(1.0, efficiency))
|
| 380 |
|
| 381 |
-
#
|
| 382 |
-
# Incentivises explicit BNSS 479 computation before issuing the order.
|
| 383 |
process_bonus = 0.05 if statutory_tool_used else 0.0
|
| 384 |
|
| 385 |
lam = 0.3
|
| 386 |
-
total = 0.
|
| 387 |
|
| 388 |
return {
|
| 389 |
"outcome_match": round(om, 4),
|
| 390 |
"flight_risk_accuracy": round(fr, 4),
|
| 391 |
"statutory_accuracy": round(sa, 4),
|
| 392 |
"condition_appropriateness": round(ca, 4),
|
|
|
|
| 393 |
"efficiency_bonus": round(efficiency, 4),
|
| 394 |
"process_bonus": round(process_bonus,4),
|
| 395 |
"bias_penalty": round(bias, 4),
|
|
|
|
| 288 |
def compute_bias_penalty(
|
| 289 |
recommended_outcome: str,
|
| 290 |
episode: Dict[str, Any],
|
| 291 |
+
agent_grounds: Optional[List[str]] = None,
|
| 292 |
) -> float:
|
| 293 |
"""
|
| 294 |
Computes a bias penalty for the agent's recommendation.
|
|
|
|
| 331 |
if bias_flag and gt_granted and not agent_granted:
|
| 332 |
penalty += 0.15
|
| 333 |
|
| 334 |
+
# ── Penalty 4: Parity case — agent diverges AND never mentions parity ─────
|
| 335 |
+
# HC relied on co-accused parity; agent disagrees AND didn't engage with it.
|
| 336 |
+
if parity_used and (agent_granted != gt_granted) and agent_grounds is not None:
|
| 337 |
+
grounds_lower = " ".join(agent_grounds).lower()
|
| 338 |
+
if not any(w in grounds_lower for w in PARITY_WORDS):
|
| 339 |
+
penalty += 0.10 # Extra for ignoring parity without acknowledging it
|
| 340 |
+
|
| 341 |
return max(0.0, min(1.0, penalty))
|
| 342 |
|
| 343 |
|
| 344 |
+
# ---------------------------------------------------------------------------
|
| 345 |
+
# 6. Reasoning Quality (10% — replaces 10% from outcome weight)
|
| 346 |
+
# ---------------------------------------------------------------------------
|
| 347 |
+
|
| 348 |
+
PARITY_WORDS = ["parity", "co-accused", "co accused", "similarly placed",
|
| 349 |
+
"bail granted to", "co-prisoner", "coaccused"]
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def compute_reasoning_quality(
|
| 353 |
+
flight_risk_justification: str,
|
| 354 |
+
agent_risk_label: str,
|
| 355 |
+
statutory_computation: str,
|
| 356 |
+
grounds_for: List[str],
|
| 357 |
+
grounds_against: List[str],
|
| 358 |
+
episode: Dict[str, Any],
|
| 359 |
+
) -> float:
|
| 360 |
+
"""
|
| 361 |
+
Scores the quality of the agent's reasoning without an LLM judge.
|
| 362 |
+
|
| 363 |
+
Three sub-scores (averaged):
|
| 364 |
+
1. Justification anchoring — does flight risk justification cite
|
| 365 |
+
case-specific facts (crime type, IPC section, custody duration)?
|
| 366 |
+
2. Arithmetic verification — do the actual episode numbers appear
|
| 367 |
+
in the statutory computation (not just any number)?
|
| 368 |
+
3. Grounds specificity — do bail grounds reference crime-specific
|
| 369 |
+
facts rather than boilerplate?
|
| 370 |
+
|
| 371 |
+
Plus a consistency deduction:
|
| 372 |
+
- Label says Low but text contains High-risk keywords → -0.10
|
| 373 |
+
- Label says High but text contains Low-risk keywords → -0.10
|
| 374 |
+
"""
|
| 375 |
+
just = flight_risk_justification.lower()
|
| 376 |
+
comp = statutory_computation.lower()
|
| 377 |
+
grounds_text = " ".join(grounds_for + grounds_against).lower()
|
| 378 |
+
|
| 379 |
+
sections = episode.get("ipc_sections", [])
|
| 380 |
+
custody_mo = episode.get("custody_months") or 0.0
|
| 381 |
+
max_sent = episode.get("max_sentence_years", 5.0)
|
| 382 |
+
crime_type = episode.get("crime_type", "").lower()
|
| 383 |
+
|
| 384 |
+
# ── Sub-score 1: Justification anchoring ──────────────────────────────
|
| 385 |
+
anchor_hits, anchor_max = 0, 0
|
| 386 |
+
if crime_type:
|
| 387 |
+
# At least one meaningful word from crime type in justification
|
| 388 |
+
if any(w in just for w in crime_type.split() if len(w) > 3):
|
| 389 |
+
anchor_hits += 1
|
| 390 |
+
anchor_max += 1
|
| 391 |
+
if sections:
|
| 392 |
+
if any(sec.strip() in just for sec in sections):
|
| 393 |
+
anchor_hits += 1
|
| 394 |
+
anchor_max += 1
|
| 395 |
+
if custody_mo > 0:
|
| 396 |
+
# Exact custody months mentioned
|
| 397 |
+
if str(int(custody_mo)) in just or f"{custody_mo:.1f}" in just:
|
| 398 |
+
anchor_hits += 1
|
| 399 |
+
anchor_max += 1
|
| 400 |
+
|
| 401 |
+
just_words = len(just.split())
|
| 402 |
+
raw_anchor = anchor_hits / max(1, anchor_max)
|
| 403 |
+
# Cap anchoring score at 0.5 if justification is suspiciously short
|
| 404 |
+
anchor_score = raw_anchor if just_words >= 15 else min(0.5, raw_anchor)
|
| 405 |
+
|
| 406 |
+
# ── Sub-score 2: Arithmetic verification ──────────────────────────────
|
| 407 |
+
if custody_mo > 0:
|
| 408 |
+
threshold_mo = (max_sent * 12) / 2
|
| 409 |
+
comp_numbers = [float(n) for n in re.findall(r'\d+\.?\d*', comp)]
|
| 410 |
+
has_custody = any(abs(n - custody_mo) <= 1.5 for n in comp_numbers)
|
| 411 |
+
has_threshold = any(abs(n - threshold_mo) <= 2.0 or
|
| 412 |
+
abs(n - (max_sent * 12)) <= 2.0
|
| 413 |
+
for n in comp_numbers)
|
| 414 |
+
comp_words = len(comp.split())
|
| 415 |
+
if comp_words < 10:
|
| 416 |
+
arith_score = 0.3 if (has_custody or has_threshold) else 0.0
|
| 417 |
+
else:
|
| 418 |
+
arith_score = 0.5 * has_custody + 0.5 * has_threshold
|
| 419 |
+
else:
|
| 420 |
+
arith_score = 0.5 # No custody data — neutral, can't verify
|
| 421 |
+
|
| 422 |
+
# ── Sub-score 3: Grounds specificity ─────────────────────────────────
|
| 423 |
+
g_hits, g_max = 0, 0
|
| 424 |
+
if crime_type:
|
| 425 |
+
if any(w in grounds_text for w in crime_type.split() if len(w) > 3):
|
| 426 |
+
g_hits += 1
|
| 427 |
+
g_max += 1
|
| 428 |
+
if sections:
|
| 429 |
+
if any(sec.strip() in grounds_text for sec in sections):
|
| 430 |
+
g_hits += 1
|
| 431 |
+
g_max += 1
|
| 432 |
+
grounds_words = len(grounds_text.split())
|
| 433 |
+
raw_grounds = g_hits / max(1, g_max)
|
| 434 |
+
grounds_score = raw_grounds if grounds_words >= 10 else min(0.4, raw_grounds)
|
| 435 |
+
|
| 436 |
+
base = (anchor_score + arith_score + grounds_score) / 3
|
| 437 |
+
|
| 438 |
+
# ── Consistency deduction: label contradicts justification text ────────
|
| 439 |
+
label = agent_risk_label.strip().lower()
|
| 440 |
+
consistency_deduction = 0.0
|
| 441 |
+
if "low" in label:
|
| 442 |
+
high_hits = sum(1 for kw in FLIGHT_RISK_KEYWORDS["High"] if kw in just)
|
| 443 |
+
if high_hits >= 2:
|
| 444 |
+
consistency_deduction = 0.10
|
| 445 |
+
elif "high" in label:
|
| 446 |
+
low_hits = sum(1 for kw in FLIGHT_RISK_KEYWORDS["Low"] if kw in just)
|
| 447 |
+
if low_hits >= 2:
|
| 448 |
+
consistency_deduction = 0.10
|
| 449 |
+
|
| 450 |
+
return round(max(0.0, min(1.0, base - consistency_deduction)), 4)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
# ---------------------------------------------------------------------------
|
| 454 |
# Master reward function
|
| 455 |
# ---------------------------------------------------------------------------
|
|
|
|
| 464 |
step_count: int = 0,
|
| 465 |
max_steps: int = 10,
|
| 466 |
statutory_tool_used: bool = False,
|
| 467 |
+
agent_flight_risk_justification: str = "",
|
| 468 |
+
agent_grounds_for: Optional[List[str]] = None,
|
| 469 |
+
agent_grounds_against: Optional[List[str]] = None,
|
| 470 |
) -> Dict[str, float]:
|
| 471 |
"""
|
| 472 |
Computes the full reward for a submitted bail assessment memo.
|
| 473 |
|
| 474 |
Formula:
|
| 475 |
+
R = 0.3*outcome_match (was 0.4 — reduced to reward reasoning)
|
| 476 |
+ 0.2*flight_risk_accuracy
|
| 477 |
+ 0.2*statutory_accuracy
|
| 478 |
+ 0.2*condition_appropriateness
|
| 479 |
+
+ 0.1*reasoning_quality (NEW — anchoring + arithmetic + specificity)
|
| 480 |
+
+ 0.1*efficiency_bonus (only when outcome is correct)
|
| 481 |
+
+ 0.05*process_bonus
|
| 482 |
- 0.3*bias_penalty
|
| 483 |
|
| 484 |
Returns a dict with all component scores + total_reward.
|
| 485 |
+
Range: approx [-0.4, 1.1].
|
| 486 |
"""
|
| 487 |
gt = episode["ground_truth"]
|
| 488 |
|
| 489 |
+
grounds_all = (agent_grounds_for or []) + (agent_grounds_against or [])
|
| 490 |
+
|
| 491 |
om = compute_outcome_match(agent_outcome, gt)
|
| 492 |
fr = compute_flight_risk_accuracy(agent_flight_risk, gt)
|
| 493 |
sa = compute_statutory_accuracy(agent_eligible, agent_computation, episode)
|
| 494 |
ca = compute_condition_score(agent_outcome, agent_conditions, gt)
|
| 495 |
+
bias = compute_bias_penalty(agent_outcome, episode, agent_grounds=grounds_all)
|
| 496 |
+
rq = compute_reasoning_quality(
|
| 497 |
+
flight_risk_justification = agent_flight_risk_justification,
|
| 498 |
+
agent_risk_label = agent_flight_risk,
|
| 499 |
+
statutory_computation = agent_computation,
|
| 500 |
+
grounds_for = agent_grounds_for or [],
|
| 501 |
+
grounds_against = agent_grounds_against or [],
|
| 502 |
+
episode = episode,
|
| 503 |
+
)
|
| 504 |
|
| 505 |
+
# Efficiency bonus: reward finishing faster when the answer is correct.
|
| 506 |
# Only fires on directionally-correct outcomes (om >= 0.8) to prevent
|
| 507 |
# rewarding efficient-but-wrong agents.
|
| 508 |
efficiency = 0.0
|
|
|
|
| 510 |
efficiency = round((1.0 - (step_count - 1) / (max_steps - 1)), 4)
|
| 511 |
efficiency = max(0.0, min(1.0, efficiency))
|
| 512 |
|
| 513 |
+
# Process reward: +0.05 if agent actually used the statutory tool.
|
|
|
|
| 514 |
process_bonus = 0.05 if statutory_tool_used else 0.0
|
| 515 |
|
| 516 |
lam = 0.3
|
| 517 |
+
total = 0.3*om + 0.2*fr + 0.2*sa + 0.2*ca + 0.1*rq + 0.1*efficiency + process_bonus - lam*bias
|
| 518 |
|
| 519 |
return {
|
| 520 |
"outcome_match": round(om, 4),
|
| 521 |
"flight_risk_accuracy": round(fr, 4),
|
| 522 |
"statutory_accuracy": round(sa, 4),
|
| 523 |
"condition_appropriateness": round(ca, 4),
|
| 524 |
+
"reasoning_quality": round(rq, 4),
|
| 525 |
"efficiency_bonus": round(efficiency, 4),
|
| 526 |
"process_bonus": round(process_bonus,4),
|
| 527 |
"bias_penalty": round(bias, 4),
|
server/undertrial_environment.py
CHANGED
|
@@ -127,7 +127,10 @@ class UndertriAIEnvironment(Environment):
|
|
| 127 |
episode = self._episode,
|
| 128 |
step_count = self._step_count,
|
| 129 |
max_steps = self.MAX_STEPS,
|
| 130 |
-
statutory_tool_used
|
|
|
|
|
|
|
|
|
|
| 131 |
)
|
| 132 |
# Apply skip penalty (can push total legitimately negative)
|
| 133 |
reward_dict["total_reward"] = round(reward_dict["total_reward"] - no_tool_penalty, 4)
|
|
|
|
| 127 |
episode = self._episode,
|
| 128 |
step_count = self._step_count,
|
| 129 |
max_steps = self.MAX_STEPS,
|
| 130 |
+
statutory_tool_used = self._statutory_tool_called,
|
| 131 |
+
agent_flight_risk_justification = action.flight_risk_justification,
|
| 132 |
+
agent_grounds_for = action.grounds_for_bail,
|
| 133 |
+
agent_grounds_against = action.grounds_against_bail,
|
| 134 |
)
|
| 135 |
# Apply skip penalty (can push total legitimately negative)
|
| 136 |
reward_dict["total_reward"] = round(reward_dict["total_reward"] - no_tool_penalty, 4)
|
training/train_grpo.py
CHANGED
|
@@ -51,6 +51,7 @@ try:
|
|
| 51 |
compute_statutory_accuracy,
|
| 52 |
compute_condition_score,
|
| 53 |
compute_bias_penalty as _server_bias,
|
|
|
|
| 54 |
)
|
| 55 |
_USE_SERVER_REWARDS = True
|
| 56 |
print("[reward] Using authoritative server/reward.py functions.")
|
|
@@ -337,14 +338,26 @@ def combined_reward(
|
|
| 337 |
parsed.get("conditions", []),
|
| 338 |
gt,
|
| 339 |
)
|
| 340 |
-
b = _server_bias(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
else:
|
| 342 |
# Local fallback
|
| 343 |
o = reward_outcome_match([comp], [ep])[0]
|
| 344 |
fr = reward_flight_risk([comp], [ep])[0]
|
| 345 |
s = reward_statutory([comp], [ep])[0]
|
| 346 |
-
ca = reward_conditions([comp], [ep])[0]
|
| 347 |
b = reward_no_bias([comp], [ep])[0]
|
|
|
|
| 348 |
|
| 349 |
# R4 efficiency bonus: reward fewer steps when outcome is correct
|
| 350 |
eff = 0.0
|
|
@@ -354,7 +367,7 @@ def combined_reward(
|
|
| 354 |
if sc is not None:
|
| 355 |
eff = max(0.0, 1.0 - (sc - 1) / 9)
|
| 356 |
|
| 357 |
-
total = 0.
|
| 358 |
rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
|
| 359 |
return rewards
|
| 360 |
|
|
|
|
| 51 |
compute_statutory_accuracy,
|
| 52 |
compute_condition_score,
|
| 53 |
compute_bias_penalty as _server_bias,
|
| 54 |
+
compute_reasoning_quality,
|
| 55 |
)
|
| 56 |
_USE_SERVER_REWARDS = True
|
| 57 |
print("[reward] Using authoritative server/reward.py functions.")
|
|
|
|
| 338 |
parsed.get("conditions", []),
|
| 339 |
gt,
|
| 340 |
)
|
| 341 |
+
b = _server_bias(
|
| 342 |
+
parsed["recommended_outcome"], ep,
|
| 343 |
+
agent_grounds=parsed.get("grounds_for", []) + parsed.get("grounds_against", []),
|
| 344 |
+
)
|
| 345 |
+
rq = compute_reasoning_quality(
|
| 346 |
+
flight_risk_justification = parsed.get("flight_risk_just", ""),
|
| 347 |
+
agent_risk_label = parsed.get("flight_risk", ""),
|
| 348 |
+
statutory_computation = parsed.get("statutory_computation", ""),
|
| 349 |
+
grounds_for = parsed.get("grounds_for", []),
|
| 350 |
+
grounds_against = parsed.get("grounds_against", []),
|
| 351 |
+
episode = ep,
|
| 352 |
+
)
|
| 353 |
else:
|
| 354 |
# Local fallback
|
| 355 |
o = reward_outcome_match([comp], [ep])[0]
|
| 356 |
fr = reward_flight_risk([comp], [ep])[0]
|
| 357 |
s = reward_statutory([comp], [ep])[0]
|
| 358 |
+
ca = reward_conditions([comp], [ep])[0]
|
| 359 |
b = reward_no_bias([comp], [ep])[0]
|
| 360 |
+
rq = 0.5 # Neutral when server functions unavailable
|
| 361 |
|
| 362 |
# R4 efficiency bonus: reward fewer steps when outcome is correct
|
| 363 |
eff = 0.0
|
|
|
|
| 367 |
if sc is not None:
|
| 368 |
eff = max(0.0, 1.0 - (sc - 1) / 9)
|
| 369 |
|
| 370 |
+
total = 0.3*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*rq + 0.1*eff - 0.3*b
|
| 371 |
rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
|
| 372 |
return rewards
|
| 373 |
|