YUS200619 commited on
Commit
6ed2433
·
1 Parent(s): d8c9b01

fix: correct field names in build_prompt to match Pydantic models

Browse files
Files changed (1) hide show
  1. inference.py +95 -137
inference.py CHANGED
@@ -11,6 +11,7 @@ Usage:
11
  export HF_TOKEN="your-token"
12
  python inference.py
13
  """
 
14
  from __future__ import annotations
15
 
16
  import json
@@ -20,152 +21,115 @@ import sys
20
 
21
  from openai import OpenAI
22
 
23
- from env import InvoiceExceptionEnv, Action, ALL_TASKS
24
 
25
  # ---------------------------------------------------------------------------
26
- # Configuration from environment variables
27
  # ---------------------------------------------------------------------------
28
 
29
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
30
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
31
- HF_TOKEN = os.getenv("HF_TOKEN")
32
-
33
 
34
  # ---------------------------------------------------------------------------
35
- # System prompt — tells the LLM how to act
36
  # ---------------------------------------------------------------------------
37
 
38
  SYSTEM_PROMPT = """You are an expert Accounts Payable (AP) analyst handling flagged invoice exceptions.
39
 
40
- ⚠️ CRITICAL RULE: If there is ANY suspicion of bank account fraud, BEC attack, or
41
- supplier impersonation, you MUST contact the supplier via PHONE (channel="phone"),
42
- NEVER via email. Emailing a potentially compromised account will contact the fraudster
43
- and incur a severe penalty.
44
-
45
- You have access to a document packet: Purchase Order (PO), Invoice, Goods Receipt Note
46
- (GRN), Supplier Master, and an Exception Flag explaining why this invoice was flagged.
47
- The actual document values are provided in each prompt — use them to reason.
48
-
49
- You must investigate the root cause, apply business rules, make a decision, and close the case.
50
-
51
- **Your action space** (respond with exactly ONE JSON action per turn):
52
-
53
- 1. inspect_field: {"type": "inspect_field", "params": {"document": "invoice|po|grn|supplier_master", "field": "field_name"}}
54
- 2. cross_check: {"type": "cross_check", "params": {"field": "field_name", "doc_a": "doc1", "doc_b": "doc2"}}
55
- 3. run_check: {"type": "run_check", "params": {"check_name": "check_name"}}
56
- 4. query_supplier: {"type": "query_supplier", "params": {"question": "your question", "channel": "phone|email"}}
57
- 5. query_internal: {"type": "query_internal", "params": {"department": "dept_name", "question": "your question"}}
58
- 6. apply_rule: {"type": "apply_rule", "params": {"rule_id": "rule_id"}}
59
- 7. make_decision: {"type": "make_decision", "params": {"decision": "approve|reject|hold|partial_approve", "reason": "explanation"}}
60
- 8. route_to: {"type": "route_to", "params": {"team": "team_name", "notes": "routing notes"}}
61
- 9. close_case: {"type": "close_case", "params": {"summary": "audit trail summary"}}
62
-
63
- **Rules:**
64
- - Always investigate before making a decision
65
- - Never approve without running checks first
66
- - Compare document values carefully — look for mismatches between PO, Invoice, GRN, and Supplier Master
67
- - If bank account or email domain looks suspicious, use phone channel for supplier queries
68
- - Respond with ONLY a JSON object, no extra text
69
- """
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # ---------------------------------------------------------------------------
73
- # Prompt builder
74
  # ---------------------------------------------------------------------------
75
 
76
  def build_prompt(obs, step: int, max_steps: int, history: list) -> str:
77
- """Build the user prompt from the current observation state, including document data."""
78
-
79
- # Build GRN summary safely from the dict-based items_received
80
- grn_items = obs.grn.items_received
81
- grn_received = sum(item.get("quantity_received", 0) for item in grn_items)
82
- grn_pending = sum(item.get("quantity_pending", 0) for item in grn_items)
83
- grn_details = "; ".join(
84
- f"{item.get('description', 'item')}: {item.get('quantity_received', '?')} received, {item.get('quantity_pending', 0)} pending"
85
- for item in grn_items
86
- )
87
 
88
  lines = [
89
  f"Step {step} of {max_steps}.",
90
- f"",
91
  f"EXCEPTION FLAG: {obs.exception_flag.flag_code}",
92
  f"{obs.exception_flag.flag_description}",
93
- f"",
94
- f"=== DOCUMENT SUMMARY ===",
95
- f"PO #{obs.purchase_order.po_number} | Total: INR {obs.purchase_order.total_amount:,.2f} | Terms: {obs.purchase_order.payment_terms}",
96
- f"PO Line Items:",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  ]
98
- for item in obs.purchase_order.line_items:
99
- lines.append(f" - {item.description}: qty={item.quantity}, unit_price=INR {item.unit_price:,.2f}, total=INR {item.total:,.2f}")
100
-
101
- lines.extend([
102
- f"",
103
- f"Invoice #{obs.invoice.invoice_number} | Date: {obs.invoice.invoice_date} | Total: INR {obs.invoice.total_amount:,.2f}",
104
- f"Invoice Subtotal: INR {obs.invoice.subtotal:,.2f} | Tax ({obs.invoice.tax_rate}%): INR {obs.invoice.tax_amount:,.2f}",
105
- f"Invoice Bank Account: {obs.invoice.bank_account} ({obs.invoice.bank_name})",
106
- f"Invoice GSTIN: {obs.invoice.supplier_gstin}",
107
- f"Invoice Email: {obs.invoice.supplier_email}",
108
- f"Invoice Line Items:",
109
- ])
110
- for item in obs.invoice.line_items:
111
- lines.append(f" - {item.description}: qty={item.quantity}, unit_price=INR {item.unit_price:,.2f}, total=INR {item.total:,.2f}")
112
-
113
- lines.extend([
114
- f"",
115
- f"GRN #{obs.grn.grn_number} | Date: {obs.grn.receipt_date}",
116
- f"GRN Items: {grn_details}",
117
- f"GRN Total received: {grn_received}, pending: {grn_pending}",
118
- f"",
119
- f"Supplier Master: {obs.supplier_master.supplier_name} ({obs.supplier_master.supplier_id})",
120
- f"Supplier Bank Account: {obs.supplier_master.bank_account} ({obs.supplier_master.bank_name})",
121
- f"Supplier GSTIN: {obs.supplier_master.gstin}",
122
- f"Supplier Email Domain: {obs.supplier_master.registered_domain}",
123
- f"Supplier Phone: {obs.supplier_master.contact_phone}",
124
- f"",
125
- f"=== AVAILABLE ACTIONS ===",
126
- f"Available checks: {', '.join(obs.available_checks)}",
127
- f"Available rules: {', '.join(obs.available_rules)}",
128
- f"",
129
- f"Knowledge base:",
130
- ])
131
  for entry in obs.knowledge_base:
132
  lines.append(f" - {entry}")
133
 
134
  lines.append("")
135
- lines.append(f"Cumulative reward so far: {obs.cumulative_reward:.2f}")
136
- lines.append(f"Case status: {obs.case_status}")
137
 
138
  if obs.checks_run:
139
- lines.append(f"Checks already run:")
140
- for c in obs.checks_run:
141
- lines.append(f" - {c.check_name}: {'PASSED' if c.passed else 'FAILED'} — {c.detail[:100]}")
142
  if obs.queries:
143
- lines.append(f"Queries made:")
144
- for q in obs.queries:
145
- lines.append(f" - {q.target} (via {q.channel}): {q.response[:100]}...")
146
  if obs.inspections:
147
- lines.append(f"Fields inspected:")
148
- for i in obs.inspections:
149
- lines.append(f" - {i.document}.{i.field}: {str(i.value)[:100]}")
150
  if obs.rules_applied:
151
- lines.append(f"Rules applied: {', '.join(obs.rules_applied)}")
152
  if obs.decision:
153
- lines.append(f"Decision made: {obs.decision}")
154
  if obs.routed_to:
155
- lines.append(f"Routed to: {', '.join(obs.routed_to)}")
156
 
157
  if history:
158
  lines.append("")
159
- lines.append("Recent history:")
160
  for h in history[-5:]:
161
  lines.append(f" {h}")
162
 
163
  lines.append("")
164
- lines.append("What is your next action? Respond with a single JSON object.")
165
 
166
  return "\n".join(lines)
167
 
168
-
169
  # ---------------------------------------------------------------------------
170
  # LLM caller
171
  # ---------------------------------------------------------------------------
@@ -177,7 +141,7 @@ def call_llm(client: OpenAI, user_prompt: str) -> str:
177
  model=MODEL_NAME,
178
  messages=[
179
  {"role": "system", "content": SYSTEM_PROMPT},
180
- {"role": "user", "content": user_prompt},
181
  ],
182
  temperature=0.1,
183
  max_tokens=256,
@@ -187,30 +151,28 @@ def call_llm(client: OpenAI, user_prompt: str) -> str:
187
  print(f"LLM call failed: {e}", file=sys.stderr)
188
  return '{"type": "run_check", "params": {"check_name": "po_match"}}'
189
 
190
-
191
  # ---------------------------------------------------------------------------
192
  # Action parser
193
  # ---------------------------------------------------------------------------
194
 
195
  def parse_action(raw_text: str) -> dict:
196
  """
197
- Parse the model's response into an action dict.
198
- Handles markdown code fences, extra whitespace, and minor formatting errors.
199
- Falls back to run_check(po_match) if parsing fails.
200
  """
201
  text = raw_text.strip()
202
 
203
- # Remove ```json or ``` fences if present
204
  if text.startswith("```"):
205
- lines = text.split("\n")
206
- text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
207
 
208
  try:
209
  return json.loads(text.strip())
210
  except json.JSONDecodeError:
211
  pass
212
 
213
- # Try to find JSON within the text
214
  match = re.search(r'\{.*\}', text, re.DOTALL)
215
  if match:
216
  try:
@@ -218,42 +180,38 @@ def parse_action(raw_text: str) -> dict:
218
  except json.JSONDecodeError:
219
  pass
220
 
221
- # Safe fallback
222
  return {"type": "run_check", "params": {"check_name": "po_match"}}
223
 
224
-
225
  # ---------------------------------------------------------------------------
226
- # Task runner
227
  # ---------------------------------------------------------------------------
228
 
229
  def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple:
230
- """Run one task episode and return (steps_taken, score, rewards)."""
231
- rewards = []
232
 
233
  print(f"[START] task={task_id} env=invoice-exception-handler model={MODEL_NAME}", flush=True)
234
 
235
  obs = env.reset(task_id)
236
- max_steps = env._task.max_steps # read from the task itself
237
- history = []
238
 
239
  for step in range(1, max_steps + 1):
240
- # Build prompt from observation
241
  user_prompt = build_prompt(obs, step, max_steps, history)
242
 
243
- # Call LLM
244
- raw = call_llm(client, user_prompt)
245
  action_dict = parse_action(raw)
246
 
247
- # Execute
248
  try:
249
  result = env.step(action_dict)
250
  reward = result.reward
251
- done = result.done
252
- error = None
253
- except Exception as e:
254
  reward = 0.0
255
- done = False
256
- error = str(e)
257
  result = None
258
 
259
  rewards.append(reward)
@@ -268,14 +226,14 @@ def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple:
268
 
269
  history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}")
270
 
271
- if result:
272
  obs = result.observation
273
 
274
  if done:
275
  break
276
 
277
- score = env.grade()["score"]
278
- success = score >= 0.5
279
  steps_taken = min(step, max_steps)
280
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
281
 
@@ -287,24 +245,24 @@ def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple:
287
 
288
  return steps_taken, score, rewards
289
 
290
-
291
  # ---------------------------------------------------------------------------
292
- # Main
293
  # ---------------------------------------------------------------------------
294
 
295
  def main() -> None:
296
- """Run inference on all tasks."""
297
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
298
- env = InvoiceExceptionEnv(seed=42)
 
 
299
 
300
- all_scores = []
301
  for task_id in ALL_TASKS:
302
  _, score, _ = run_task(client, env, task_id)
303
  all_scores.append(score)
304
 
305
  avg = sum(all_scores) / len(all_scores) if all_scores else 0.0
306
- print(f"\nAverage score: {avg:.3f}", flush=True)
307
 
308
 
309
  if __name__ == "__main__":
310
- main()
 
11
  export HF_TOKEN="your-token"
12
  python inference.py
13
  """
14
+
15
  from __future__ import annotations
16
 
17
  import json
 
21
 
22
  from openai import OpenAI
23
 
24
+ from env import InvoiceExceptionEnv, ALL_TASKS
25
 
26
  # ---------------------------------------------------------------------------
27
+ # Configuration — read from environment variables exactly as the spec requires
28
  # ---------------------------------------------------------------------------
29
 
30
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
31
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
32
+ HF_TOKEN = os.getenv("HF_TOKEN") # no default — spec requirement
 
33
 
34
  # ---------------------------------------------------------------------------
35
+ # System prompt
36
  # ---------------------------------------------------------------------------
37
 
38
  SYSTEM_PROMPT = """You are an expert Accounts Payable (AP) analyst handling flagged invoice exceptions.
39
 
40
+ You receive a full document packet: Purchase Order (PO), Invoice, Goods Receipt Note (GRN),
41
+ Supplier Master record, and an Exception Flag explaining why the invoice was flagged.
42
+
43
+ Your job: investigate the root cause, apply business rules, make a decision, and close the case.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ CRITICAL RULE: If there is ANY suspicion of bank account fraud or BEC attack, contact the
46
+ supplier via PHONE only — never via email. Emailing may reach the fraudster.
47
+
48
+ Your action space — respond with exactly ONE JSON object per turn:
49
+
50
+ 1. {"type": "inspect_field", "params": {"document": "invoice|po|grn|supplier_master", "field": "field_name"}}
51
+ 2. {"type": "cross_check", "params": {"field": "field_name", "doc_a": "doc1", "doc_b": "doc2"}}
52
+ 3. {"type": "run_check", "params": {"check_name": "check_name"}}
53
+ 4. {"type": "query_supplier", "params": {"question": "your question", "channel": "phone|email"}}
54
+ 5. {"type": "query_internal", "params": {"department": "dept_name", "question": "your question"}}
55
+ 6. {"type": "apply_rule", "params": {"rule_id": "rule_id"}}
56
+ 7. {"type": "make_decision", "params": {"decision": "approve|reject|hold|partial_approve", "reason": "explanation"}}
57
+ 8. {"type": "route_to", "params": {"team": "team_name", "notes": "routing notes"}}
58
+ 9. {"type": "close_case", "params": {"summary": "audit trail summary"}}
59
+
60
+ Rules:
61
+ - Always run checks BEFORE making a decision
62
+ - Never approve without verifying the root cause
63
+ - Use phone (not email) if fraud is suspected
64
+ - Respond with ONLY a JSON object, no explanation, no markdown fences
65
+ """
66
 
67
  # ---------------------------------------------------------------------------
68
+ # Prompt builder — shows the LLM the actual document data
69
  # ---------------------------------------------------------------------------
70
 
71
  def build_prompt(obs, step: int, max_steps: int, history: list) -> str:
72
+ """Build the user prompt from the current observation state."""
73
+ po = obs.purchase_order
74
+ inv = obs.invoice
75
+ grn = obs.grn
76
+ sm = obs.supplier_master
 
 
 
 
 
77
 
78
  lines = [
79
  f"Step {step} of {max_steps}.",
80
+ "",
81
  f"EXCEPTION FLAG: {obs.exception_flag.flag_code}",
82
  f"{obs.exception_flag.flag_description}",
83
+ "",
84
+ "=== DOCUMENT DATA ===",
85
+ f"PO #{po.po_number} | Supplier: {po.vendor_name} | Total: {po.total_amount} | Terms: {po.payment_terms}",
86
+ f"PO lines: {[(i.description[:30], 'qty='+str(i.quantity), 'unit='+str(i.unit_price)) for i in po.line_items]}",
87
+ "",
88
+ f"Invoice #{inv.invoice_number} | Date: {inv.invoice_date} | Subtotal: {inv.subtotal} | Tax: {inv.tax_amount} | Total: {inv.total_amount}",
89
+ f"Invoice GSTIN: {inv.supplier_gstin} | Bank: {inv.bank_account} {inv.ifsc_code}",
90
+ f"Invoice lines: {[(i.description[:30], 'qty='+str(i.quantity), 'unit='+str(i.unit_price)) for i in inv.line_items]}",
91
+ "",
92
+ f"GRN: received={sum(i.get('quantity_received', 0) for i in grn.items_received)} units | pending={sum(i.get('quantity_pending', 0) for i in grn.items_received)} units",
93
+ "",
94
+ f"Supplier Master: GSTIN={sm.gstin} | Bank={sm.bank_account} {sm.ifsc_code} | Domain={sm.registered_domain}",
95
+ "",
96
+ "=== AVAILABLE ACTIONS ===",
97
+ f"Checks you can run: {', '.join(obs.available_checks)}",
98
+ f"Rules you can apply: {', '.join(obs.available_rules)}",
99
+ "",
100
+ "Knowledge base (company policies):",
101
  ]
102
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  for entry in obs.knowledge_base:
104
  lines.append(f" - {entry}")
105
 
106
  lines.append("")
107
+ lines.append(f"Cumulative reward: {obs.cumulative_reward:.2f} | Status: {obs.case_status}")
 
108
 
109
  if obs.checks_run:
110
+ lines.append(f"Checks already run: {', '.join(c.check_name for c in obs.checks_run)}")
 
 
111
  if obs.queries:
112
+ lines.append(f"Queries already made: {', '.join(q.target for q in obs.queries)}")
 
 
113
  if obs.inspections:
114
+ lines.append(f"Fields already inspected: {', '.join(f'{i.document}.{i.field}' for i in obs.inspections)}")
 
 
115
  if obs.rules_applied:
116
+ lines.append(f"Rules already applied: {', '.join(obs.rules_applied)}")
117
  if obs.decision:
118
+ lines.append(f"Decision already made: {obs.decision}")
119
  if obs.routed_to:
120
+ lines.append(f"Already routed to: {', '.join(obs.routed_to)}")
121
 
122
  if history:
123
  lines.append("")
124
+ lines.append("Recent steps:")
125
  for h in history[-5:]:
126
  lines.append(f" {h}")
127
 
128
  lines.append("")
129
+ lines.append("What is your next action? Respond with a single JSON object only.")
130
 
131
  return "\n".join(lines)
132
 
 
133
  # ---------------------------------------------------------------------------
134
  # LLM caller
135
  # ---------------------------------------------------------------------------
 
141
  model=MODEL_NAME,
142
  messages=[
143
  {"role": "system", "content": SYSTEM_PROMPT},
144
+ {"role": "user", "content": user_prompt},
145
  ],
146
  temperature=0.1,
147
  max_tokens=256,
 
151
  print(f"LLM call failed: {e}", file=sys.stderr)
152
  return '{"type": "run_check", "params": {"check_name": "po_match"}}'
153
 
 
154
  # ---------------------------------------------------------------------------
155
  # Action parser
156
  # ---------------------------------------------------------------------------
157
 
158
  def parse_action(raw_text: str) -> dict:
159
  """
160
+ Parse the model response into an action dict.
161
+ Strips markdown fences, handles whitespace, falls back on parse failure.
 
162
  """
163
  text = raw_text.strip()
164
 
165
+ # Strip ```json ... ``` or ``` ... ``` fences
166
  if text.startswith("```"):
167
+ parts = text.split("\n")
168
+ text = "\n".join(parts[1:-1] if parts[-1].strip() == "```" else parts[1:])
169
 
170
  try:
171
  return json.loads(text.strip())
172
  except json.JSONDecodeError:
173
  pass
174
 
175
+ # Try to find JSON anywhere in the text
176
  match = re.search(r'\{.*\}', text, re.DOTALL)
177
  if match:
178
  try:
 
180
  except json.JSONDecodeError:
181
  pass
182
 
183
+ # Safe fallback — never crash
184
  return {"type": "run_check", "params": {"check_name": "po_match"}}
185
 
 
186
  # ---------------------------------------------------------------------------
187
+ # Task runner — one full episode
188
  # ---------------------------------------------------------------------------
189
 
190
  def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple:
191
+ """Run one task episode. Returns (steps_taken, score, rewards)."""
192
+ rewards: list[float] = []
193
 
194
  print(f"[START] task={task_id} env=invoice-exception-handler model={MODEL_NAME}", flush=True)
195
 
196
  obs = env.reset(task_id)
197
+ max_steps = env._task.max_steps # reads the correct limit per task: 18 / 20 / 25
198
+ history: list[str] = []
199
 
200
  for step in range(1, max_steps + 1):
 
201
  user_prompt = build_prompt(obs, step, max_steps, history)
202
 
203
+ raw = call_llm(client, user_prompt)
 
204
  action_dict = parse_action(raw)
205
 
 
206
  try:
207
  result = env.step(action_dict)
208
  reward = result.reward
209
+ done = result.done
210
+ error = None
211
+ except Exception as exc:
212
  reward = 0.0
213
+ done = False
214
+ error = str(exc)
215
  result = None
216
 
217
  rewards.append(reward)
 
226
 
227
  history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}")
228
 
229
+ if result is not None:
230
  obs = result.observation
231
 
232
  if done:
233
  break
234
 
235
+ score = env.grade()["score"]
236
+ success = score >= 0.5
237
  steps_taken = min(step, max_steps)
238
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
239
 
 
245
 
246
  return steps_taken, score, rewards
247
 
 
248
  # ---------------------------------------------------------------------------
249
+ # Main — run all three tasks in sequence
250
  # ---------------------------------------------------------------------------
251
 
252
  def main() -> None:
253
+ """Entry point — runs inference on all tasks and prints average score."""
254
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
255
+ env = InvoiceExceptionEnv(seed=42)
256
+
257
+ all_scores: list[float] = []
258
 
 
259
  for task_id in ALL_TASKS:
260
  _, score, _ = run_task(client, env, task_id)
261
  all_scores.append(score)
262
 
263
  avg = sum(all_scores) / len(all_scores) if all_scores else 0.0
264
+ print(f"\nAverage score across all tasks: {avg:.3f}", flush=True)
265
 
266
 
267
  if __name__ == "__main__":
268
+ main()