Spaces:
Sleeping
Sleeping
Commit ·
898bc18
1
Parent(s): ca62faa
Fix 8 compliance gaps: repeat-action dedup+cache, min-steps hard block, criminal history tool (12th action), efficiency removed from training formula, circular import cleaned, yaml formula synced
Browse files- models.py +7 -0
- openenv.yaml +5 -2
- server/__init__.py +5 -6
- server/app.py +2 -1
- server/undertrial_environment.py +61 -4
- training/train_grpo.py +4 -7
models.py
CHANGED
|
@@ -135,6 +135,13 @@ class ApplyProportionalityAction(Action):
|
|
| 135 |
)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class SubmitMemoAction(Action):
|
| 139 |
"""
|
| 140 |
TERMINAL ACTION — Submit the structured bail assessment memo.
|
|
|
|
| 135 |
)
|
| 136 |
|
| 137 |
|
| 138 |
+
class PullCriminalHistoryAction(Action):
|
| 139 |
+
"""Pull the accused's prior criminal record, bail history, and conviction status."""
|
| 140 |
+
tool_name: Literal["pull_criminal_history"] = "pull_criminal_history"
|
| 141 |
+
include_bail_history: bool = Field(
|
| 142 |
+
default=True, description="Whether to include prior bail applications and outcomes"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
class SubmitMemoAction(Action):
|
| 146 |
"""
|
| 147 |
TERMINAL ACTION — Submit the structured bail assessment memo.
|
openenv.yaml
CHANGED
|
@@ -46,20 +46,23 @@ actions:
|
|
| 46 |
description: Examine specific case factors (parity, evidence tampering, victim vulnerability)
|
| 47 |
- name: apply_proportionality
|
| 48 |
description: Apply BNSS 479 proportionality — custody vs. max sentence vs. trial timeline
|
|
|
|
|
|
|
| 49 |
- name: submit_memo
|
| 50 |
description: "TERMINAL — Submit structured bail assessment memo"
|
| 51 |
|
| 52 |
reward:
|
| 53 |
-
formula: "0.
|
| 54 |
range: [-0.7, 1.15]
|
| 55 |
terminal_action: submit_memo
|
| 56 |
deterministic: true
|
| 57 |
llm_as_judge: false
|
| 58 |
components:
|
| 59 |
-
- outcome_match: "Agreement with real High Court decision (
|
| 60 |
- flight_risk_accuracy: "Flight risk classification accuracy (20%)"
|
| 61 |
- statutory_accuracy: "IPC/BNSS threshold computation (20%)"
|
| 62 |
- condition_appropriateness: "Bail condition quality (20%)"
|
|
|
|
| 63 |
- bias_penalty: "Penalty for ignoring parity in bias cases (-30%)"
|
| 64 |
|
| 65 |
curriculum:
|
|
|
|
| 46 |
description: Examine specific case factors (parity, evidence tampering, victim vulnerability)
|
| 47 |
- name: apply_proportionality
|
| 48 |
description: Apply BNSS 479 proportionality — custody vs. max sentence vs. trial timeline
|
| 49 |
+
- name: pull_criminal_history
|
| 50 |
+
description: Pull the accused's prior criminal record, bail history, and conviction status
|
| 51 |
- name: submit_memo
|
| 52 |
description: "TERMINAL — Submit structured bail assessment memo"
|
| 53 |
|
| 54 |
reward:
|
| 55 |
+
formula: "0.3*outcome + 0.2*flight_risk + 0.2*statutory + 0.2*conditions + 0.1*reasoning_quality + 0.1*efficiency + 0.05*process_bonus - 0.3*bias"
|
| 56 |
range: [-0.7, 1.15]
|
| 57 |
terminal_action: submit_memo
|
| 58 |
deterministic: true
|
| 59 |
llm_as_judge: false
|
| 60 |
components:
|
| 61 |
+
- outcome_match: "Agreement with real High Court decision (30%)"
|
| 62 |
- flight_risk_accuracy: "Flight risk classification accuracy (20%)"
|
| 63 |
- statutory_accuracy: "IPC/BNSS threshold computation (20%)"
|
| 64 |
- condition_appropriateness: "Bail condition quality (20%)"
|
| 65 |
+
- reasoning_quality: "Justification anchoring + arithmetic verification + grounds specificity (10%)"
|
| 66 |
- bias_penalty: "Penalty for ignoring parity in bias cases (-30%)"
|
| 67 |
|
| 68 |
curriculum:
|
server/__init__.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
-
"""UndertriAI server package.
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
pass # Standalone import (e.g., from train_grpo.py) — skip re-exports
|
|
|
|
| 1 |
+
"""UndertriAI server package.
|
| 2 |
+
|
| 3 |
+
Server-side code only. Do not import from client.py here —
|
| 4 |
+
user-facing exports live in the root undertrial_ai/__init__.py.
|
| 5 |
+
"""
|
|
|
server/app.py
CHANGED
|
@@ -101,7 +101,7 @@ def step(payload: dict):
|
|
| 101 |
AssessSuretyAction, ClassifyBailTypeAction,
|
| 102 |
ReadSubmissionsAction, AssessFlightRiskAction,
|
| 103 |
CheckCaseFactorsAction, ApplyProportionalityAction,
|
| 104 |
-
SubmitMemoAction,
|
| 105 |
)
|
| 106 |
ACTION_MAP = {
|
| 107 |
"request_document": RequestDocumentAction,
|
|
@@ -114,6 +114,7 @@ def step(payload: dict):
|
|
| 114 |
"assess_flight_risk": AssessFlightRiskAction,
|
| 115 |
"check_case_factors": CheckCaseFactorsAction,
|
| 116 |
"apply_proportionality": ApplyProportionalityAction,
|
|
|
|
| 117 |
"submit_memo": SubmitMemoAction,
|
| 118 |
}
|
| 119 |
action_cls = ACTION_MAP.get(tool_name)
|
|
|
|
| 101 |
AssessSuretyAction, ClassifyBailTypeAction,
|
| 102 |
ReadSubmissionsAction, AssessFlightRiskAction,
|
| 103 |
CheckCaseFactorsAction, ApplyProportionalityAction,
|
| 104 |
+
PullCriminalHistoryAction, SubmitMemoAction,
|
| 105 |
)
|
| 106 |
ACTION_MAP = {
|
| 107 |
"request_document": RequestDocumentAction,
|
|
|
|
| 114 |
"assess_flight_risk": AssessFlightRiskAction,
|
| 115 |
"check_case_factors": CheckCaseFactorsAction,
|
| 116 |
"apply_proportionality": ApplyProportionalityAction,
|
| 117 |
+
"pull_criminal_history": PullCriminalHistoryAction,
|
| 118 |
"submit_memo": SubmitMemoAction,
|
| 119 |
}
|
| 120 |
action_cls = ACTION_MAP.get(tool_name)
|
server/undertrial_environment.py
CHANGED
|
@@ -37,6 +37,7 @@ from ..models import (
|
|
| 37 |
ComputeStatutoryEligibilityAction, AssessSuretyAction, ClassifyBailTypeAction,
|
| 38 |
ReadSubmissionsAction, AssessFlightRiskAction, CheckCaseFactorsAction,
|
| 39 |
ApplyProportionalityAction,
|
|
|
|
| 40 |
SubmitMemoAction,
|
| 41 |
)
|
| 42 |
from .precedent_db import PrecedentDB
|
|
@@ -97,7 +98,8 @@ class UndertriAIEnvironment(Environment):
|
|
| 97 |
self._flags = []
|
| 98 |
self._retrieved_precedents = []
|
| 99 |
self._action_history: List[str] = []
|
| 100 |
-
self._statutory_tool_called: bool = False #
|
|
|
|
| 101 |
return self._make_observation(action_result=None)
|
| 102 |
|
| 103 |
def step(
|
|
@@ -114,8 +116,26 @@ class UndertriAIEnvironment(Environment):
|
|
| 114 |
|
| 115 |
# ---- Terminal action: submit memo ----
|
| 116 |
if isinstance(action, SubmitMemoAction):
|
| 117 |
-
#
|
| 118 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
no_tool_penalty = 0.40 if self._step_count == 1 else 0.0
|
| 120 |
|
| 121 |
reward_dict = compute_reward(
|
|
@@ -147,9 +167,26 @@ class UndertriAIEnvironment(Environment):
|
|
| 147 |
info=reward_dict,
|
| 148 |
)
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# ---- Tool actions with optional timeout enforcement ----
|
| 151 |
if isinstance(action, ComputeStatutoryEligibilityAction):
|
| 152 |
-
self._statutory_tool_called = True #
|
| 153 |
|
| 154 |
if timeout_s is not None:
|
| 155 |
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
@@ -342,6 +379,26 @@ class UndertriAIEnvironment(Environment):
|
|
| 342 |
lines.append(" ⚠️ Projected total custody exceeds maximum sentence — strong proportionality argument for bail")
|
| 343 |
return "\n".join(lines)
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
return f"Unknown action type: {type(action).__name__}"
|
| 346 |
|
| 347 |
# ------------------------------------------------------------------
|
|
|
|
| 37 |
ComputeStatutoryEligibilityAction, AssessSuretyAction, ClassifyBailTypeAction,
|
| 38 |
ReadSubmissionsAction, AssessFlightRiskAction, CheckCaseFactorsAction,
|
| 39 |
ApplyProportionalityAction,
|
| 40 |
+
PullCriminalHistoryAction,
|
| 41 |
SubmitMemoAction,
|
| 42 |
)
|
| 43 |
from .precedent_db import PrecedentDB
|
|
|
|
| 98 |
self._flags = []
|
| 99 |
self._retrieved_precedents = []
|
| 100 |
self._action_history: List[str] = []
|
| 101 |
+
self._statutory_tool_called: bool = False # process reward tracking
|
| 102 |
+
self._tools_called: set = set() # 5B.2: track unique tool types for repeat detection
|
| 103 |
return self._make_observation(action_result=None)
|
| 104 |
|
| 105 |
def step(
|
|
|
|
| 116 |
|
| 117 |
# ---- Terminal action: submit memo ----
|
| 118 |
if isinstance(action, SubmitMemoAction):
|
| 119 |
+
# 4.5 Hard minimum: agent must have called at least 1 distinct tool before submitting.
|
| 120 |
+
# This is a structural gate — even a skip-penalty can't compensate for zero information.
|
| 121 |
+
if len(self._tools_called) == 0:
|
| 122 |
+
obs = self._make_observation(
|
| 123 |
+
action_result=(
|
| 124 |
+
"[BLOCKED] You must call at least one legal tool before submitting a memo. "
|
| 125 |
+
"Use tools such as compute_statutory_eligibility, assess_flight_risk, "
|
| 126 |
+
"read_submissions, or check_case_factors first."
|
| 127 |
+
),
|
| 128 |
+
memo_submitted=False,
|
| 129 |
+
)
|
| 130 |
+
return StepResult(
|
| 131 |
+
observation=obs,
|
| 132 |
+
reward=-0.15, # Stronger signal than just a penalty post-submission
|
| 133 |
+
done=False,
|
| 134 |
+
info={"blocked": "minimum_tools_not_met", "tools_called": 0},
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Skip penalty only if submitted on step 1 despite having called a tool
|
| 138 |
+
# (edge case where first action is somehow both a tool and submit)
|
| 139 |
no_tool_penalty = 0.40 if self._step_count == 1 else 0.0
|
| 140 |
|
| 141 |
reward_dict = compute_reward(
|
|
|
|
| 167 |
info=reward_dict,
|
| 168 |
)
|
| 169 |
|
| 170 |
+
# ---- Repeat-action deduplication (5B.2) ----
|
| 171 |
+
tool_key = type(action).__name__
|
| 172 |
+
if tool_key in self._tools_called:
|
| 173 |
+
# Return cached note — no re-execution, no reward gaming
|
| 174 |
+
obs = self._make_observation(
|
| 175 |
+
action_result=(
|
| 176 |
+
f"[CACHED] {tool_key} was already called this episode. "
|
| 177 |
+
"The result is already in your action history above. "
|
| 178 |
+
"Use a different tool or submit your memo."
|
| 179 |
+
),
|
| 180 |
+
memo_submitted=False,
|
| 181 |
+
)
|
| 182 |
+
return StepResult(observation=obs, reward=-0.05, done=False,
|
| 183 |
+
info={"cached": True, "tool": tool_key})
|
| 184 |
+
|
| 185 |
+
self._tools_called.add(tool_key)
|
| 186 |
+
|
| 187 |
# ---- Tool actions with optional timeout enforcement ----
|
| 188 |
if isinstance(action, ComputeStatutoryEligibilityAction):
|
| 189 |
+
self._statutory_tool_called = True # track for process reward
|
| 190 |
|
| 191 |
if timeout_s is not None:
|
| 192 |
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
|
|
| 379 |
lines.append(" ⚠️ Projected total custody exceeds maximum sentence — strong proportionality argument for bail")
|
| 380 |
return "\n".join(lines)
|
| 381 |
|
| 382 |
+
elif isinstance(action, PullCriminalHistoryAction):
|
| 383 |
+
ep = self._episode
|
| 384 |
+
profile = ep.get("accused_profile", {})
|
| 385 |
+
prior = profile.get("prior_cases", "No prior criminal record on file")
|
| 386 |
+
bail_type = profile.get("bail_type", "Unknown")
|
| 387 |
+
lines = [
|
| 388 |
+
"Criminal History Report:",
|
| 389 |
+
f" Prior cases: {prior}",
|
| 390 |
+
f" Bail type context: {bail_type}",
|
| 391 |
+
]
|
| 392 |
+
if action.include_bail_history:
|
| 393 |
+
# Infer from parity flag and stage whether HC has dealt with bail before
|
| 394 |
+
parity = ep.get("ground_truth", {}).get("parity_argument_used", False)
|
| 395 |
+
lines.append(
|
| 396 |
+
f" Prior bail history: {'Co-accused parity argument on record — HC previously granted bail to similarly placed accused' if parity else 'No co-accused parity argument on record'}"
|
| 397 |
+
)
|
| 398 |
+
first_time = prior in ("None", "nil", "no prior", "No prior criminal record on file", None, "")
|
| 399 |
+
lines.append(f" → Classification: {'FIRST-TIME OFFENDER ✓' if first_time else 'HAS PRIOR RECORD — review above'}")
|
| 400 |
+
return "\n".join(lines)
|
| 401 |
+
|
| 402 |
return f"Unknown action type: {type(action).__name__}"
|
| 403 |
|
| 404 |
# ------------------------------------------------------------------
|
training/train_grpo.py
CHANGED
|
@@ -359,15 +359,12 @@ def combined_reward(
|
|
| 359 |
b = reward_no_bias([comp], [ep])[0]
|
| 360 |
rq = 0.5 # Neutral when server functions unavailable
|
| 361 |
|
| 362 |
-
#
|
|
|
|
|
|
|
| 363 |
eff = 0.0
|
| 364 |
-
if o >= 0.8:
|
| 365 |
-
steps_taken = kwargs.get("step_counts", [None] * len(completions))
|
| 366 |
-
sc = steps_taken[completions.index(comp)] if comp in completions else None
|
| 367 |
-
if sc is not None:
|
| 368 |
-
eff = max(0.0, 1.0 - (sc - 1) / 9)
|
| 369 |
|
| 370 |
-
total = 0.3*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*rq
|
| 371 |
rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
|
| 372 |
return rewards
|
| 373 |
|
|
|
|
| 359 |
b = reward_no_bias([comp], [ep])[0]
|
| 360 |
rq = 0.5 # Neutral when server functions unavailable
|
| 361 |
|
| 362 |
+
# NOTE: Efficiency is NOT computed in GRPO training because step_count=1
|
| 363 |
+
# always (single-shot generation), making eff=1.0 a constant non-signal.
|
| 364 |
+
# Efficiency is preserved in the environment's compute_reward for live inference.
|
| 365 |
eff = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
+
total = 0.3*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*rq - 0.3*b
|
| 368 |
rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
|
| 369 |
return rewards
|
| 370 |
|