Draken1606 commited on
Commit
ca62faa
·
1 Parent(s): 37edd09

Reward overhaul: add compute_reasoning_quality (anchoring+arithmetic+specificity+consistency), parity-grounds penalty, reduce outcome 40%->30%, add 10% reasoning quality signal

Browse files
server/reward.py CHANGED
@@ -288,6 +288,7 @@ LOW_SES_TERMS = ["theft", "robbery", "dacoity", "kidnapping", "narcotics", "poc
288
  def compute_bias_penalty(
289
  recommended_outcome: str,
290
  episode: Dict[str, Any],
 
291
  ) -> float:
292
  """
293
  Computes a bias penalty for the agent's recommendation.
@@ -330,9 +331,125 @@ def compute_bias_penalty(
330
  if bias_flag and gt_granted and not agent_granted:
331
  penalty += 0.15
332
 
 
 
 
 
 
 
 
333
  return max(0.0, min(1.0, penalty))
334
 
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  # ---------------------------------------------------------------------------
337
  # Master reward function
338
  # ---------------------------------------------------------------------------
@@ -347,30 +464,45 @@ def compute_reward(
347
  step_count: int = 0,
348
  max_steps: int = 10,
349
  statutory_tool_used: bool = False,
 
 
 
350
  ) -> Dict[str, float]:
351
  """
352
  Computes the full reward for a submitted bail assessment memo.
353
 
354
  Formula:
355
- R = 0.4*outcome_match
356
  + 0.2*flight_risk_accuracy
357
  + 0.2*statutory_accuracy
358
  + 0.2*condition_appropriateness
359
- + 0.1*efficiency_bonus (only when outcome is correct)
 
 
360
  - 0.3*bias_penalty
361
 
362
  Returns a dict with all component scores + total_reward.
363
- Range: [-0.3, 1.1] (efficiency can push above 1.0 slightly on perfect runs).
364
  """
365
  gt = episode["ground_truth"]
366
 
 
 
367
  om = compute_outcome_match(agent_outcome, gt)
368
  fr = compute_flight_risk_accuracy(agent_flight_risk, gt)
369
  sa = compute_statutory_accuracy(agent_eligible, agent_computation, episode)
370
  ca = compute_condition_score(agent_outcome, agent_conditions, gt)
371
- bias = compute_bias_penalty(agent_outcome, episode)
 
 
 
 
 
 
 
 
372
 
373
- # R4 — Efficiency bonus: reward finishing faster when the answer is correct.
374
  # Only fires on directionally-correct outcomes (om >= 0.8) to prevent
375
  # rewarding efficient-but-wrong agents.
376
  efficiency = 0.0
@@ -378,18 +510,18 @@ def compute_reward(
378
  efficiency = round((1.0 - (step_count - 1) / (max_steps - 1)), 4)
379
  efficiency = max(0.0, min(1.0, efficiency))
380
 
381
- # M2 — Process reward: +0.05 if agent actually used the statutory tool.
382
- # Incentivises explicit BNSS 479 computation before issuing the order.
383
  process_bonus = 0.05 if statutory_tool_used else 0.0
384
 
385
  lam = 0.3
386
- total = 0.4*om + 0.2*fr + 0.2*sa + 0.2*ca + 0.1*efficiency + process_bonus - lam*bias
387
 
388
  return {
389
  "outcome_match": round(om, 4),
390
  "flight_risk_accuracy": round(fr, 4),
391
  "statutory_accuracy": round(sa, 4),
392
  "condition_appropriateness": round(ca, 4),
 
393
  "efficiency_bonus": round(efficiency, 4),
394
  "process_bonus": round(process_bonus,4),
395
  "bias_penalty": round(bias, 4),
 
288
  def compute_bias_penalty(
289
  recommended_outcome: str,
290
  episode: Dict[str, Any],
291
+ agent_grounds: Optional[List[str]] = None,
292
  ) -> float:
293
  """
294
  Computes a bias penalty for the agent's recommendation.
 
331
  if bias_flag and gt_granted and not agent_granted:
332
  penalty += 0.15
333
 
334
+ # ── Penalty 4: Parity case — agent diverges AND never mentions parity ─────
335
+ # HC relied on co-accused parity; agent disagrees AND didn't engage with it.
336
+ if parity_used and (agent_granted != gt_granted) and agent_grounds is not None:
337
+ grounds_lower = " ".join(agent_grounds).lower()
338
+ if not any(w in grounds_lower for w in PARITY_WORDS):
339
+ penalty += 0.10 # Extra for ignoring parity without acknowledging it
340
+
341
  return max(0.0, min(1.0, penalty))
342
 
343
 
344
+ # ---------------------------------------------------------------------------
345
+ # 6. Reasoning Quality (10% — replaces 10% from outcome weight)
346
+ # ---------------------------------------------------------------------------
347
+
348
+ PARITY_WORDS = ["parity", "co-accused", "co accused", "similarly placed",
349
+ "bail granted to", "co-prisoner", "coaccused"]
350
+
351
+
352
+ def compute_reasoning_quality(
353
+ flight_risk_justification: str,
354
+ agent_risk_label: str,
355
+ statutory_computation: str,
356
+ grounds_for: List[str],
357
+ grounds_against: List[str],
358
+ episode: Dict[str, Any],
359
+ ) -> float:
360
+ """
361
+ Scores the quality of the agent's reasoning without an LLM judge.
362
+
363
+ Three sub-scores (averaged):
364
+ 1. Justification anchoring — does flight risk justification cite
365
+ case-specific facts (crime type, IPC section, custody duration)?
366
+ 2. Arithmetic verification — do the actual episode numbers appear
367
+ in the statutory computation (not just any number)?
368
+ 3. Grounds specificity — do bail grounds reference crime-specific
369
+ facts rather than boilerplate?
370
+
371
+ Plus a consistency deduction:
372
+ - Label says Low but text contains High-risk keywords → -0.10
373
+ - Label says High but text contains Low-risk keywords → -0.10
374
+ """
375
+ just = flight_risk_justification.lower()
376
+ comp = statutory_computation.lower()
377
+ grounds_text = " ".join(grounds_for + grounds_against).lower()
378
+
379
+ sections = episode.get("ipc_sections", [])
380
+ custody_mo = episode.get("custody_months") or 0.0
381
+ max_sent = episode.get("max_sentence_years", 5.0)
382
+ crime_type = episode.get("crime_type", "").lower()
383
+
384
+ # ── Sub-score 1: Justification anchoring ──────────────────────────────
385
+ anchor_hits, anchor_max = 0, 0
386
+ if crime_type:
387
+ # At least one meaningful word from crime type in justification
388
+ if any(w in just for w in crime_type.split() if len(w) > 3):
389
+ anchor_hits += 1
390
+ anchor_max += 1
391
+ if sections:
392
+ if any(sec.strip() in just for sec in sections):
393
+ anchor_hits += 1
394
+ anchor_max += 1
395
+ if custody_mo > 0:
396
+ # Exact custody months mentioned
397
+ if str(int(custody_mo)) in just or f"{custody_mo:.1f}" in just:
398
+ anchor_hits += 1
399
+ anchor_max += 1
400
+
401
+ just_words = len(just.split())
402
+ raw_anchor = anchor_hits / max(1, anchor_max)
403
+ # Cap anchoring score at 0.5 if justification is suspiciously short
404
+ anchor_score = raw_anchor if just_words >= 15 else min(0.5, raw_anchor)
405
+
406
+ # ── Sub-score 2: Arithmetic verification ──────────────────────────────
407
+ if custody_mo > 0:
408
+ threshold_mo = (max_sent * 12) / 2
409
+ comp_numbers = [float(n) for n in re.findall(r'\d+\.?\d*', comp)]
410
+ has_custody = any(abs(n - custody_mo) <= 1.5 for n in comp_numbers)
411
+ has_threshold = any(abs(n - threshold_mo) <= 2.0 or
412
+ abs(n - (max_sent * 12)) <= 2.0
413
+ for n in comp_numbers)
414
+ comp_words = len(comp.split())
415
+ if comp_words < 10:
416
+ arith_score = 0.3 if (has_custody or has_threshold) else 0.0
417
+ else:
418
+ arith_score = 0.5 * has_custody + 0.5 * has_threshold
419
+ else:
420
+ arith_score = 0.5 # No custody data — neutral, can't verify
421
+
422
+ # ── Sub-score 3: Grounds specificity ─────────────────────────────────
423
+ g_hits, g_max = 0, 0
424
+ if crime_type:
425
+ if any(w in grounds_text for w in crime_type.split() if len(w) > 3):
426
+ g_hits += 1
427
+ g_max += 1
428
+ if sections:
429
+ if any(sec.strip() in grounds_text for sec in sections):
430
+ g_hits += 1
431
+ g_max += 1
432
+ grounds_words = len(grounds_text.split())
433
+ raw_grounds = g_hits / max(1, g_max)
434
+ grounds_score = raw_grounds if grounds_words >= 10 else min(0.4, raw_grounds)
435
+
436
+ base = (anchor_score + arith_score + grounds_score) / 3
437
+
438
+ # ── Consistency deduction: label contradicts justification text ────────
439
+ label = agent_risk_label.strip().lower()
440
+ consistency_deduction = 0.0
441
+ if "low" in label:
442
+ high_hits = sum(1 for kw in FLIGHT_RISK_KEYWORDS["High"] if kw in just)
443
+ if high_hits >= 2:
444
+ consistency_deduction = 0.10
445
+ elif "high" in label:
446
+ low_hits = sum(1 for kw in FLIGHT_RISK_KEYWORDS["Low"] if kw in just)
447
+ if low_hits >= 2:
448
+ consistency_deduction = 0.10
449
+
450
+ return round(max(0.0, min(1.0, base - consistency_deduction)), 4)
451
+
452
+
453
  # ---------------------------------------------------------------------------
454
  # Master reward function
455
  # ---------------------------------------------------------------------------
 
464
  step_count: int = 0,
465
  max_steps: int = 10,
466
  statutory_tool_used: bool = False,
467
+ agent_flight_risk_justification: str = "",
468
+ agent_grounds_for: Optional[List[str]] = None,
469
+ agent_grounds_against: Optional[List[str]] = None,
470
  ) -> Dict[str, float]:
471
  """
472
  Computes the full reward for a submitted bail assessment memo.
473
 
474
  Formula:
475
+ R = 0.3*outcome_match (was 0.4 — reduced to reward reasoning)
476
  + 0.2*flight_risk_accuracy
477
  + 0.2*statutory_accuracy
478
  + 0.2*condition_appropriateness
479
+ + 0.1*reasoning_quality (NEW anchoring + arithmetic + specificity)
480
+ + 0.1*efficiency_bonus (only when outcome is correct)
481
+ + 0.05*process_bonus
482
  - 0.3*bias_penalty
483
 
484
  Returns a dict with all component scores + total_reward.
485
+ Range: approx [-0.4, 1.1].
486
  """
487
  gt = episode["ground_truth"]
488
 
489
+ grounds_all = (agent_grounds_for or []) + (agent_grounds_against or [])
490
+
491
  om = compute_outcome_match(agent_outcome, gt)
492
  fr = compute_flight_risk_accuracy(agent_flight_risk, gt)
493
  sa = compute_statutory_accuracy(agent_eligible, agent_computation, episode)
494
  ca = compute_condition_score(agent_outcome, agent_conditions, gt)
495
+ bias = compute_bias_penalty(agent_outcome, episode, agent_grounds=grounds_all)
496
+ rq = compute_reasoning_quality(
497
+ flight_risk_justification = agent_flight_risk_justification,
498
+ agent_risk_label = agent_flight_risk,
499
+ statutory_computation = agent_computation,
500
+ grounds_for = agent_grounds_for or [],
501
+ grounds_against = agent_grounds_against or [],
502
+ episode = episode,
503
+ )
504
 
505
+ # Efficiency bonus: reward finishing faster when the answer is correct.
506
  # Only fires on directionally-correct outcomes (om >= 0.8) to prevent
507
  # rewarding efficient-but-wrong agents.
508
  efficiency = 0.0
 
510
  efficiency = round((1.0 - (step_count - 1) / (max_steps - 1)), 4)
511
  efficiency = max(0.0, min(1.0, efficiency))
512
 
513
+ # Process reward: +0.05 if agent actually used the statutory tool.
 
514
  process_bonus = 0.05 if statutory_tool_used else 0.0
515
 
516
  lam = 0.3
517
+ total = 0.3*om + 0.2*fr + 0.2*sa + 0.2*ca + 0.1*rq + 0.1*efficiency + process_bonus - lam*bias
518
 
519
  return {
520
  "outcome_match": round(om, 4),
521
  "flight_risk_accuracy": round(fr, 4),
522
  "statutory_accuracy": round(sa, 4),
523
  "condition_appropriateness": round(ca, 4),
524
+ "reasoning_quality": round(rq, 4),
525
  "efficiency_bonus": round(efficiency, 4),
526
  "process_bonus": round(process_bonus,4),
527
  "bias_penalty": round(bias, 4),
server/undertrial_environment.py CHANGED
@@ -127,7 +127,10 @@ class UndertriAIEnvironment(Environment):
127
  episode = self._episode,
128
  step_count = self._step_count,
129
  max_steps = self.MAX_STEPS,
130
- statutory_tool_used = self._statutory_tool_called, # M2
 
 
 
131
  )
132
  # Apply skip penalty (can push total legitimately negative)
133
  reward_dict["total_reward"] = round(reward_dict["total_reward"] - no_tool_penalty, 4)
 
127
  episode = self._episode,
128
  step_count = self._step_count,
129
  max_steps = self.MAX_STEPS,
130
+ statutory_tool_used = self._statutory_tool_called,
131
+ agent_flight_risk_justification = action.flight_risk_justification,
132
+ agent_grounds_for = action.grounds_for_bail,
133
+ agent_grounds_against = action.grounds_against_bail,
134
  )
135
  # Apply skip penalty (can push total legitimately negative)
136
  reward_dict["total_reward"] = round(reward_dict["total_reward"] - no_tool_penalty, 4)
training/train_grpo.py CHANGED
@@ -51,6 +51,7 @@ try:
51
  compute_statutory_accuracy,
52
  compute_condition_score,
53
  compute_bias_penalty as _server_bias,
 
54
  )
55
  _USE_SERVER_REWARDS = True
56
  print("[reward] Using authoritative server/reward.py functions.")
@@ -337,14 +338,26 @@ def combined_reward(
337
  parsed.get("conditions", []),
338
  gt,
339
  )
340
- b = _server_bias(parsed["recommended_outcome"], ep)
 
 
 
 
 
 
 
 
 
 
 
341
  else:
342
  # Local fallback
343
  o = reward_outcome_match([comp], [ep])[0]
344
  fr = reward_flight_risk([comp], [ep])[0]
345
  s = reward_statutory([comp], [ep])[0]
346
- ca = reward_conditions([comp], [ep])[0] # condition score, not format
347
  b = reward_no_bias([comp], [ep])[0]
 
348
 
349
  # R4 efficiency bonus: reward fewer steps when outcome is correct
350
  eff = 0.0
@@ -354,7 +367,7 @@ def combined_reward(
354
  if sc is not None:
355
  eff = max(0.0, 1.0 - (sc - 1) / 9)
356
 
357
- total = 0.4*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*eff - 0.3*b
358
  rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
359
  return rewards
360
 
 
51
  compute_statutory_accuracy,
52
  compute_condition_score,
53
  compute_bias_penalty as _server_bias,
54
+ compute_reasoning_quality,
55
  )
56
  _USE_SERVER_REWARDS = True
57
  print("[reward] Using authoritative server/reward.py functions.")
 
338
  parsed.get("conditions", []),
339
  gt,
340
  )
341
+ b = _server_bias(
342
+ parsed["recommended_outcome"], ep,
343
+ agent_grounds=parsed.get("grounds_for", []) + parsed.get("grounds_against", []),
344
+ )
345
+ rq = compute_reasoning_quality(
346
+ flight_risk_justification = parsed.get("flight_risk_just", ""),
347
+ agent_risk_label = parsed.get("flight_risk", ""),
348
+ statutory_computation = parsed.get("statutory_computation", ""),
349
+ grounds_for = parsed.get("grounds_for", []),
350
+ grounds_against = parsed.get("grounds_against", []),
351
+ episode = ep,
352
+ )
353
  else:
354
  # Local fallback
355
  o = reward_outcome_match([comp], [ep])[0]
356
  fr = reward_flight_risk([comp], [ep])[0]
357
  s = reward_statutory([comp], [ep])[0]
358
+ ca = reward_conditions([comp], [ep])[0]
359
  b = reward_no_bias([comp], [ep])[0]
360
+ rq = 0.5 # Neutral when server functions unavailable
361
 
362
  # R4 efficiency bonus: reward fewer steps when outcome is correct
363
  eff = 0.0
 
367
  if sc is not None:
368
  eff = max(0.0, 1.0 - (sc - 1) / 9)
369
 
370
+ total = 0.3*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*rq + 0.1*eff - 0.3*b
371
  rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
372
  return rewards
373