Spaces:
Sleeping
Sleeping
fix: correct field names in build_prompt to match Pydantic models
Browse files- inference.py +95 -137
inference.py
CHANGED
|
@@ -11,6 +11,7 @@ Usage:
|
|
| 11 |
export HF_TOKEN="your-token"
|
| 12 |
python inference.py
|
| 13 |
"""
|
|
|
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
import json
|
|
@@ -20,152 +21,115 @@ import sys
|
|
| 20 |
|
| 21 |
from openai import OpenAI
|
| 22 |
|
| 23 |
-
from env import InvoiceExceptionEnv,
|
| 24 |
|
| 25 |
# ---------------------------------------------------------------------------
|
| 26 |
-
# Configuration from environment variables
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
|
| 29 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 30 |
-
MODEL_NAME
|
| 31 |
-
HF_TOKEN
|
| 32 |
-
|
| 33 |
|
| 34 |
# ---------------------------------------------------------------------------
|
| 35 |
-
# System prompt
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
|
| 38 |
SYSTEM_PROMPT = """You are an expert Accounts Payable (AP) analyst handling flagged invoice exceptions.
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
You have access to a document packet: Purchase Order (PO), Invoice, Goods Receipt Note
|
| 46 |
-
(GRN), Supplier Master, and an Exception Flag explaining why this invoice was flagged.
|
| 47 |
-
The actual document values are provided in each prompt — use them to reason.
|
| 48 |
-
|
| 49 |
-
You must investigate the root cause, apply business rules, make a decision, and close the case.
|
| 50 |
-
|
| 51 |
-
**Your action space** (respond with exactly ONE JSON action per turn):
|
| 52 |
-
|
| 53 |
-
1. inspect_field: {"type": "inspect_field", "params": {"document": "invoice|po|grn|supplier_master", "field": "field_name"}}
|
| 54 |
-
2. cross_check: {"type": "cross_check", "params": {"field": "field_name", "doc_a": "doc1", "doc_b": "doc2"}}
|
| 55 |
-
3. run_check: {"type": "run_check", "params": {"check_name": "check_name"}}
|
| 56 |
-
4. query_supplier: {"type": "query_supplier", "params": {"question": "your question", "channel": "phone|email"}}
|
| 57 |
-
5. query_internal: {"type": "query_internal", "params": {"department": "dept_name", "question": "your question"}}
|
| 58 |
-
6. apply_rule: {"type": "apply_rule", "params": {"rule_id": "rule_id"}}
|
| 59 |
-
7. make_decision: {"type": "make_decision", "params": {"decision": "approve|reject|hold|partial_approve", "reason": "explanation"}}
|
| 60 |
-
8. route_to: {"type": "route_to", "params": {"team": "team_name", "notes": "routing notes"}}
|
| 61 |
-
9. close_case: {"type": "close_case", "params": {"summary": "audit trail summary"}}
|
| 62 |
-
|
| 63 |
-
**Rules:**
|
| 64 |
-
- Always investigate before making a decision
|
| 65 |
-
- Never approve without running checks first
|
| 66 |
-
- Compare document values carefully — look for mismatches between PO, Invoice, GRN, and Supplier Master
|
| 67 |
-
- If bank account or email domain looks suspicious, use phone channel for supplier queries
|
| 68 |
-
- Respond with ONLY a JSON object, no extra text
|
| 69 |
-
"""
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# ---------------------------------------------------------------------------
|
| 73 |
-
# Prompt builder
|
| 74 |
# ---------------------------------------------------------------------------
|
| 75 |
|
| 76 |
def build_prompt(obs, step: int, max_steps: int, history: list) -> str:
|
| 77 |
-
"""Build the user prompt from the current observation state
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
grn_pending = sum(item.get("quantity_pending", 0) for item in grn_items)
|
| 83 |
-
grn_details = "; ".join(
|
| 84 |
-
f"{item.get('description', 'item')}: {item.get('quantity_received', '?')} received, {item.get('quantity_pending', 0)} pending"
|
| 85 |
-
for item in grn_items
|
| 86 |
-
)
|
| 87 |
|
| 88 |
lines = [
|
| 89 |
f"Step {step} of {max_steps}.",
|
| 90 |
-
|
| 91 |
f"EXCEPTION FLAG: {obs.exception_flag.flag_code}",
|
| 92 |
f"{obs.exception_flag.flag_description}",
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
f"PO #{
|
| 96 |
-
f"PO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
]
|
| 98 |
-
|
| 99 |
-
lines.append(f" - {item.description}: qty={item.quantity}, unit_price=INR {item.unit_price:,.2f}, total=INR {item.total:,.2f}")
|
| 100 |
-
|
| 101 |
-
lines.extend([
|
| 102 |
-
f"",
|
| 103 |
-
f"Invoice #{obs.invoice.invoice_number} | Date: {obs.invoice.invoice_date} | Total: INR {obs.invoice.total_amount:,.2f}",
|
| 104 |
-
f"Invoice Subtotal: INR {obs.invoice.subtotal:,.2f} | Tax ({obs.invoice.tax_rate}%): INR {obs.invoice.tax_amount:,.2f}",
|
| 105 |
-
f"Invoice Bank Account: {obs.invoice.bank_account} ({obs.invoice.bank_name})",
|
| 106 |
-
f"Invoice GSTIN: {obs.invoice.supplier_gstin}",
|
| 107 |
-
f"Invoice Email: {obs.invoice.supplier_email}",
|
| 108 |
-
f"Invoice Line Items:",
|
| 109 |
-
])
|
| 110 |
-
for item in obs.invoice.line_items:
|
| 111 |
-
lines.append(f" - {item.description}: qty={item.quantity}, unit_price=INR {item.unit_price:,.2f}, total=INR {item.total:,.2f}")
|
| 112 |
-
|
| 113 |
-
lines.extend([
|
| 114 |
-
f"",
|
| 115 |
-
f"GRN #{obs.grn.grn_number} | Date: {obs.grn.receipt_date}",
|
| 116 |
-
f"GRN Items: {grn_details}",
|
| 117 |
-
f"GRN Total received: {grn_received}, pending: {grn_pending}",
|
| 118 |
-
f"",
|
| 119 |
-
f"Supplier Master: {obs.supplier_master.supplier_name} ({obs.supplier_master.supplier_id})",
|
| 120 |
-
f"Supplier Bank Account: {obs.supplier_master.bank_account} ({obs.supplier_master.bank_name})",
|
| 121 |
-
f"Supplier GSTIN: {obs.supplier_master.gstin}",
|
| 122 |
-
f"Supplier Email Domain: {obs.supplier_master.registered_domain}",
|
| 123 |
-
f"Supplier Phone: {obs.supplier_master.contact_phone}",
|
| 124 |
-
f"",
|
| 125 |
-
f"=== AVAILABLE ACTIONS ===",
|
| 126 |
-
f"Available checks: {', '.join(obs.available_checks)}",
|
| 127 |
-
f"Available rules: {', '.join(obs.available_rules)}",
|
| 128 |
-
f"",
|
| 129 |
-
f"Knowledge base:",
|
| 130 |
-
])
|
| 131 |
for entry in obs.knowledge_base:
|
| 132 |
lines.append(f" - {entry}")
|
| 133 |
|
| 134 |
lines.append("")
|
| 135 |
-
lines.append(f"Cumulative reward
|
| 136 |
-
lines.append(f"Case status: {obs.case_status}")
|
| 137 |
|
| 138 |
if obs.checks_run:
|
| 139 |
-
lines.append(f"Checks already run:")
|
| 140 |
-
for c in obs.checks_run:
|
| 141 |
-
lines.append(f" - {c.check_name}: {'PASSED' if c.passed else 'FAILED'} — {c.detail[:100]}")
|
| 142 |
if obs.queries:
|
| 143 |
-
lines.append(f"Queries made:")
|
| 144 |
-
for q in obs.queries:
|
| 145 |
-
lines.append(f" - {q.target} (via {q.channel}): {q.response[:100]}...")
|
| 146 |
if obs.inspections:
|
| 147 |
-
lines.append(f"Fields inspected:")
|
| 148 |
-
for i in obs.inspections:
|
| 149 |
-
lines.append(f" - {i.document}.{i.field}: {str(i.value)[:100]}")
|
| 150 |
if obs.rules_applied:
|
| 151 |
-
lines.append(f"Rules applied: {', '.join(obs.rules_applied)}")
|
| 152 |
if obs.decision:
|
| 153 |
-
lines.append(f"Decision made: {obs.decision}")
|
| 154 |
if obs.routed_to:
|
| 155 |
-
lines.append(f"
|
| 156 |
|
| 157 |
if history:
|
| 158 |
lines.append("")
|
| 159 |
-
lines.append("Recent
|
| 160 |
for h in history[-5:]:
|
| 161 |
lines.append(f" {h}")
|
| 162 |
|
| 163 |
lines.append("")
|
| 164 |
-
lines.append("What is your next action? Respond with a single JSON object.")
|
| 165 |
|
| 166 |
return "\n".join(lines)
|
| 167 |
|
| 168 |
-
|
| 169 |
# ---------------------------------------------------------------------------
|
| 170 |
# LLM caller
|
| 171 |
# ---------------------------------------------------------------------------
|
|
@@ -177,7 +141,7 @@ def call_llm(client: OpenAI, user_prompt: str) -> str:
|
|
| 177 |
model=MODEL_NAME,
|
| 178 |
messages=[
|
| 179 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 180 |
-
{"role": "user",
|
| 181 |
],
|
| 182 |
temperature=0.1,
|
| 183 |
max_tokens=256,
|
|
@@ -187,30 +151,28 @@ def call_llm(client: OpenAI, user_prompt: str) -> str:
|
|
| 187 |
print(f"LLM call failed: {e}", file=sys.stderr)
|
| 188 |
return '{"type": "run_check", "params": {"check_name": "po_match"}}'
|
| 189 |
|
| 190 |
-
|
| 191 |
# ---------------------------------------------------------------------------
|
| 192 |
# Action parser
|
| 193 |
# ---------------------------------------------------------------------------
|
| 194 |
|
| 195 |
def parse_action(raw_text: str) -> dict:
|
| 196 |
"""
|
| 197 |
-
Parse the model
|
| 198 |
-
|
| 199 |
-
Falls back to run_check(po_match) if parsing fails.
|
| 200 |
"""
|
| 201 |
text = raw_text.strip()
|
| 202 |
|
| 203 |
-
#
|
| 204 |
if text.startswith("```"):
|
| 205 |
-
|
| 206 |
-
text = "\n".join(
|
| 207 |
|
| 208 |
try:
|
| 209 |
return json.loads(text.strip())
|
| 210 |
except json.JSONDecodeError:
|
| 211 |
pass
|
| 212 |
|
| 213 |
-
# Try to find JSON
|
| 214 |
match = re.search(r'\{.*\}', text, re.DOTALL)
|
| 215 |
if match:
|
| 216 |
try:
|
|
@@ -218,42 +180,38 @@ def parse_action(raw_text: str) -> dict:
|
|
| 218 |
except json.JSONDecodeError:
|
| 219 |
pass
|
| 220 |
|
| 221 |
-
# Safe fallback
|
| 222 |
return {"type": "run_check", "params": {"check_name": "po_match"}}
|
| 223 |
|
| 224 |
-
|
| 225 |
# ---------------------------------------------------------------------------
|
| 226 |
-
# Task runner
|
| 227 |
# ---------------------------------------------------------------------------
|
| 228 |
|
| 229 |
def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple:
|
| 230 |
-
"""Run one task episode
|
| 231 |
-
rewards = []
|
| 232 |
|
| 233 |
print(f"[START] task={task_id} env=invoice-exception-handler model={MODEL_NAME}", flush=True)
|
| 234 |
|
| 235 |
obs = env.reset(task_id)
|
| 236 |
-
max_steps = env._task.max_steps
|
| 237 |
-
history = []
|
| 238 |
|
| 239 |
for step in range(1, max_steps + 1):
|
| 240 |
-
# Build prompt from observation
|
| 241 |
user_prompt = build_prompt(obs, step, max_steps, history)
|
| 242 |
|
| 243 |
-
|
| 244 |
-
raw = call_llm(client, user_prompt)
|
| 245 |
action_dict = parse_action(raw)
|
| 246 |
|
| 247 |
-
# Execute
|
| 248 |
try:
|
| 249 |
result = env.step(action_dict)
|
| 250 |
reward = result.reward
|
| 251 |
-
done
|
| 252 |
-
error
|
| 253 |
-
except Exception as
|
| 254 |
reward = 0.0
|
| 255 |
-
done
|
| 256 |
-
error
|
| 257 |
result = None
|
| 258 |
|
| 259 |
rewards.append(reward)
|
|
@@ -268,14 +226,14 @@ def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple:
|
|
| 268 |
|
| 269 |
history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}")
|
| 270 |
|
| 271 |
-
if result:
|
| 272 |
obs = result.observation
|
| 273 |
|
| 274 |
if done:
|
| 275 |
break
|
| 276 |
|
| 277 |
-
score
|
| 278 |
-
success
|
| 279 |
steps_taken = min(step, max_steps)
|
| 280 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 281 |
|
|
@@ -287,24 +245,24 @@ def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple:
|
|
| 287 |
|
| 288 |
return steps_taken, score, rewards
|
| 289 |
|
| 290 |
-
|
| 291 |
# ---------------------------------------------------------------------------
|
| 292 |
-
# Main
|
| 293 |
# ---------------------------------------------------------------------------
|
| 294 |
|
| 295 |
def main() -> None:
|
| 296 |
-
"""
|
| 297 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 298 |
-
env
|
|
|
|
|
|
|
| 299 |
|
| 300 |
-
all_scores = []
|
| 301 |
for task_id in ALL_TASKS:
|
| 302 |
_, score, _ = run_task(client, env, task_id)
|
| 303 |
all_scores.append(score)
|
| 304 |
|
| 305 |
avg = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
| 306 |
-
print(f"\nAverage score: {avg:.3f}", flush=True)
|
| 307 |
|
| 308 |
|
| 309 |
if __name__ == "__main__":
|
| 310 |
-
main()
|
|
|
|
| 11 |
export HF_TOKEN="your-token"
|
| 12 |
python inference.py
|
| 13 |
"""
|
| 14 |
+
|
| 15 |
from __future__ import annotations
|
| 16 |
|
| 17 |
import json
|
|
|
|
| 21 |
|
| 22 |
from openai import OpenAI
|
| 23 |
|
| 24 |
+
from env import InvoiceExceptionEnv, ALL_TASKS
|
| 25 |
|
| 26 |
# ---------------------------------------------------------------------------
|
| 27 |
+
# Configuration — read from environment variables exactly as the spec requires
|
| 28 |
# ---------------------------------------------------------------------------
|
| 29 |
|
| 30 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 31 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 32 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # no default — spec requirement
|
|
|
|
| 33 |
|
| 34 |
# ---------------------------------------------------------------------------
|
| 35 |
+
# System prompt
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
|
| 38 |
SYSTEM_PROMPT = """You are an expert Accounts Payable (AP) analyst handling flagged invoice exceptions.
|
| 39 |
|
| 40 |
+
You receive a full document packet: Purchase Order (PO), Invoice, Goods Receipt Note (GRN),
|
| 41 |
+
Supplier Master record, and an Exception Flag explaining why the invoice was flagged.
|
| 42 |
+
|
| 43 |
+
Your job: investigate the root cause, apply business rules, make a decision, and close the case.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
CRITICAL RULE: If there is ANY suspicion of bank account fraud or BEC attack, contact the
|
| 46 |
+
supplier via PHONE only — never via email. Emailing may reach the fraudster.
|
| 47 |
+
|
| 48 |
+
Your action space — respond with exactly ONE JSON object per turn:
|
| 49 |
+
|
| 50 |
+
1. {"type": "inspect_field", "params": {"document": "invoice|po|grn|supplier_master", "field": "field_name"}}
|
| 51 |
+
2. {"type": "cross_check", "params": {"field": "field_name", "doc_a": "doc1", "doc_b": "doc2"}}
|
| 52 |
+
3. {"type": "run_check", "params": {"check_name": "check_name"}}
|
| 53 |
+
4. {"type": "query_supplier", "params": {"question": "your question", "channel": "phone|email"}}
|
| 54 |
+
5. {"type": "query_internal", "params": {"department": "dept_name", "question": "your question"}}
|
| 55 |
+
6. {"type": "apply_rule", "params": {"rule_id": "rule_id"}}
|
| 56 |
+
7. {"type": "make_decision", "params": {"decision": "approve|reject|hold|partial_approve", "reason": "explanation"}}
|
| 57 |
+
8. {"type": "route_to", "params": {"team": "team_name", "notes": "routing notes"}}
|
| 58 |
+
9. {"type": "close_case", "params": {"summary": "audit trail summary"}}
|
| 59 |
+
|
| 60 |
+
Rules:
|
| 61 |
+
- Always run checks BEFORE making a decision
|
| 62 |
+
- Never approve without verifying the root cause
|
| 63 |
+
- Use phone (not email) if fraud is suspected
|
| 64 |
+
- Respond with ONLY a JSON object, no explanation, no markdown fences
|
| 65 |
+
"""
|
| 66 |
|
| 67 |
# ---------------------------------------------------------------------------
|
| 68 |
+
# Prompt builder — shows the LLM the actual document data
|
| 69 |
# ---------------------------------------------------------------------------
|
| 70 |
|
| 71 |
def build_prompt(obs, step: int, max_steps: int, history: list) -> str:
|
| 72 |
+
"""Build the user prompt from the current observation state."""
|
| 73 |
+
po = obs.purchase_order
|
| 74 |
+
inv = obs.invoice
|
| 75 |
+
grn = obs.grn
|
| 76 |
+
sm = obs.supplier_master
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
lines = [
|
| 79 |
f"Step {step} of {max_steps}.",
|
| 80 |
+
"",
|
| 81 |
f"EXCEPTION FLAG: {obs.exception_flag.flag_code}",
|
| 82 |
f"{obs.exception_flag.flag_description}",
|
| 83 |
+
"",
|
| 84 |
+
"=== DOCUMENT DATA ===",
|
| 85 |
+
f"PO #{po.po_number} | Supplier: {po.vendor_name} | Total: {po.total_amount} | Terms: {po.payment_terms}",
|
| 86 |
+
f"PO lines: {[(i.description[:30], 'qty='+str(i.quantity), 'unit='+str(i.unit_price)) for i in po.line_items]}",
|
| 87 |
+
"",
|
| 88 |
+
f"Invoice #{inv.invoice_number} | Date: {inv.invoice_date} | Subtotal: {inv.subtotal} | Tax: {inv.tax_amount} | Total: {inv.total_amount}",
|
| 89 |
+
f"Invoice GSTIN: {inv.supplier_gstin} | Bank: {inv.bank_account} {inv.ifsc_code}",
|
| 90 |
+
f"Invoice lines: {[(i.description[:30], 'qty='+str(i.quantity), 'unit='+str(i.unit_price)) for i in inv.line_items]}",
|
| 91 |
+
"",
|
| 92 |
+
f"GRN: received={sum(i.get('quantity_received', 0) for i in grn.items_received)} units | pending={sum(i.get('quantity_pending', 0) for i in grn.items_received)} units",
|
| 93 |
+
"",
|
| 94 |
+
f"Supplier Master: GSTIN={sm.gstin} | Bank={sm.bank_account} {sm.ifsc_code} | Domain={sm.registered_domain}",
|
| 95 |
+
"",
|
| 96 |
+
"=== AVAILABLE ACTIONS ===",
|
| 97 |
+
f"Checks you can run: {', '.join(obs.available_checks)}",
|
| 98 |
+
f"Rules you can apply: {', '.join(obs.available_rules)}",
|
| 99 |
+
"",
|
| 100 |
+
"Knowledge base (company policies):",
|
| 101 |
]
|
| 102 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
for entry in obs.knowledge_base:
|
| 104 |
lines.append(f" - {entry}")
|
| 105 |
|
| 106 |
lines.append("")
|
| 107 |
+
lines.append(f"Cumulative reward: {obs.cumulative_reward:.2f} | Status: {obs.case_status}")
|
|
|
|
| 108 |
|
| 109 |
if obs.checks_run:
|
| 110 |
+
lines.append(f"Checks already run: {', '.join(c.check_name for c in obs.checks_run)}")
|
|
|
|
|
|
|
| 111 |
if obs.queries:
|
| 112 |
+
lines.append(f"Queries already made: {', '.join(q.target for q in obs.queries)}")
|
|
|
|
|
|
|
| 113 |
if obs.inspections:
|
| 114 |
+
lines.append(f"Fields already inspected: {', '.join(f'{i.document}.{i.field}' for i in obs.inspections)}")
|
|
|
|
|
|
|
| 115 |
if obs.rules_applied:
|
| 116 |
+
lines.append(f"Rules already applied: {', '.join(obs.rules_applied)}")
|
| 117 |
if obs.decision:
|
| 118 |
+
lines.append(f"Decision already made: {obs.decision}")
|
| 119 |
if obs.routed_to:
|
| 120 |
+
lines.append(f"Already routed to: {', '.join(obs.routed_to)}")
|
| 121 |
|
| 122 |
if history:
|
| 123 |
lines.append("")
|
| 124 |
+
lines.append("Recent steps:")
|
| 125 |
for h in history[-5:]:
|
| 126 |
lines.append(f" {h}")
|
| 127 |
|
| 128 |
lines.append("")
|
| 129 |
+
lines.append("What is your next action? Respond with a single JSON object only.")
|
| 130 |
|
| 131 |
return "\n".join(lines)
|
| 132 |
|
|
|
|
| 133 |
# ---------------------------------------------------------------------------
|
| 134 |
# LLM caller
|
| 135 |
# ---------------------------------------------------------------------------
|
|
|
|
| 141 |
model=MODEL_NAME,
|
| 142 |
messages=[
|
| 143 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 144 |
+
{"role": "user", "content": user_prompt},
|
| 145 |
],
|
| 146 |
temperature=0.1,
|
| 147 |
max_tokens=256,
|
|
|
|
| 151 |
print(f"LLM call failed: {e}", file=sys.stderr)
|
| 152 |
return '{"type": "run_check", "params": {"check_name": "po_match"}}'
|
| 153 |
|
|
|
|
| 154 |
# ---------------------------------------------------------------------------
|
| 155 |
# Action parser
|
| 156 |
# ---------------------------------------------------------------------------
|
| 157 |
|
| 158 |
def parse_action(raw_text: str) -> dict:
|
| 159 |
"""
|
| 160 |
+
Parse the model response into an action dict.
|
| 161 |
+
Strips markdown fences, handles whitespace, falls back on parse failure.
|
|
|
|
| 162 |
"""
|
| 163 |
text = raw_text.strip()
|
| 164 |
|
| 165 |
+
# Strip ```json ... ``` or ``` ... ``` fences
|
| 166 |
if text.startswith("```"):
|
| 167 |
+
parts = text.split("\n")
|
| 168 |
+
text = "\n".join(parts[1:-1] if parts[-1].strip() == "```" else parts[1:])
|
| 169 |
|
| 170 |
try:
|
| 171 |
return json.loads(text.strip())
|
| 172 |
except json.JSONDecodeError:
|
| 173 |
pass
|
| 174 |
|
| 175 |
+
# Try to find JSON anywhere in the text
|
| 176 |
match = re.search(r'\{.*\}', text, re.DOTALL)
|
| 177 |
if match:
|
| 178 |
try:
|
|
|
|
| 180 |
except json.JSONDecodeError:
|
| 181 |
pass
|
| 182 |
|
| 183 |
+
# Safe fallback — never crash
|
| 184 |
return {"type": "run_check", "params": {"check_name": "po_match"}}
|
| 185 |
|
|
|
|
| 186 |
# ---------------------------------------------------------------------------
|
| 187 |
+
# Task runner — one full episode
|
| 188 |
# ---------------------------------------------------------------------------
|
| 189 |
|
| 190 |
def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple:
|
| 191 |
+
"""Run one task episode. Returns (steps_taken, score, rewards)."""
|
| 192 |
+
rewards: list[float] = []
|
| 193 |
|
| 194 |
print(f"[START] task={task_id} env=invoice-exception-handler model={MODEL_NAME}", flush=True)
|
| 195 |
|
| 196 |
obs = env.reset(task_id)
|
| 197 |
+
max_steps = env._task.max_steps # reads the correct limit per task: 18 / 20 / 25
|
| 198 |
+
history: list[str] = []
|
| 199 |
|
| 200 |
for step in range(1, max_steps + 1):
|
|
|
|
| 201 |
user_prompt = build_prompt(obs, step, max_steps, history)
|
| 202 |
|
| 203 |
+
raw = call_llm(client, user_prompt)
|
|
|
|
| 204 |
action_dict = parse_action(raw)
|
| 205 |
|
|
|
|
| 206 |
try:
|
| 207 |
result = env.step(action_dict)
|
| 208 |
reward = result.reward
|
| 209 |
+
done = result.done
|
| 210 |
+
error = None
|
| 211 |
+
except Exception as exc:
|
| 212 |
reward = 0.0
|
| 213 |
+
done = False
|
| 214 |
+
error = str(exc)
|
| 215 |
result = None
|
| 216 |
|
| 217 |
rewards.append(reward)
|
|
|
|
| 226 |
|
| 227 |
history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}")
|
| 228 |
|
| 229 |
+
if result is not None:
|
| 230 |
obs = result.observation
|
| 231 |
|
| 232 |
if done:
|
| 233 |
break
|
| 234 |
|
| 235 |
+
score = env.grade()["score"]
|
| 236 |
+
success = score >= 0.5
|
| 237 |
steps_taken = min(step, max_steps)
|
| 238 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 239 |
|
|
|
|
| 245 |
|
| 246 |
return steps_taken, score, rewards
|
| 247 |
|
|
|
|
| 248 |
# ---------------------------------------------------------------------------
|
| 249 |
+
# Main — run all three tasks in sequence
|
| 250 |
# ---------------------------------------------------------------------------
|
| 251 |
|
| 252 |
def main() -> None:
|
| 253 |
+
"""Entry point — runs inference on all tasks and prints average score."""
|
| 254 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 255 |
+
env = InvoiceExceptionEnv(seed=42)
|
| 256 |
+
|
| 257 |
+
all_scores: list[float] = []
|
| 258 |
|
|
|
|
| 259 |
for task_id in ALL_TASKS:
|
| 260 |
_, score, _ = run_task(client, env, task_id)
|
| 261 |
all_scores.append(score)
|
| 262 |
|
| 263 |
avg = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
| 264 |
+
print(f"\nAverage score across all tasks: {avg:.3f}", flush=True)
|
| 265 |
|
| 266 |
|
| 267 |
if __name__ == "__main__":
|
| 268 |
+
main()
|