""" Competition inference script for the Invoice Exception Handler environment. Uses the OpenAI client to call an LLM that acts as an AP analyst. Reads API_BASE_URL, MODEL_NAME, HF_TOKEN from environment variables. Emits [START], [STEP], [END] lines to stdout as required by the spec. Usage: export API_BASE_URL="https://router.huggingface.co/v1" export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct" export HF_TOKEN="your-token" python inference.py """ from __future__ import annotations import json import os import re import sys from openai import OpenAI from env import InvoiceExceptionEnv, ALL_TASKS # --------------------------------------------------------------------------- # Configuration — read from environment variables exactly as the spec requires # --------------------------------------------------------------------------- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") HF_TOKEN = os.getenv("HF_TOKEN") # no default — spec requirement # --------------------------------------------------------------------------- # System prompt # --------------------------------------------------------------------------- SYSTEM_PROMPT = """You are an expert Accounts Payable (AP) analyst handling flagged invoice exceptions. You receive a full document packet: Purchase Order (PO), Invoice, Goods Receipt Note (GRN), Supplier Master record, and an Exception Flag explaining why the invoice was flagged. Your job: investigate the root cause, apply business rules, make a decision, and close the case. CRITICAL RULE: If there is ANY suspicion of bank account fraud or BEC attack, contact the supplier via PHONE only — never via email. Emailing may reach the fraudster. Your action space — respond with exactly ONE JSON object per turn: 1. {"type": "inspect_field", "params": {"document": "invoice|po|grn|supplier_master", "field": "field_name"}} 2. {"type": "cross_check", "params": {"field": "field_name", "doc_a": "doc1", "doc_b": "doc2"}} 3. {"type": "run_check", "params": {"check_name": "check_name"}} 4. {"type": "query_supplier", "params": {"question": "your question", "channel": "phone|email"}} 5. {"type": "query_internal", "params": {"department": "dept_name", "question": "your question"}} 6. {"type": "apply_rule", "params": {"rule_id": "rule_id"}} 7. {"type": "make_decision", "params": {"decision": "approve|reject|hold|partial_approve", "reason": "explanation"}} 8. {"type": "route_to", "params": {"team": "team_name", "notes": "routing notes"}} 9. {"type": "close_case", "params": {"summary": "audit trail summary"}} Rules: - Always run checks BEFORE making a decision - Never approve without verifying the root cause - Use phone (not email) if fraud is suspected - Respond with ONLY a JSON object, no explanation, no markdown fences """ # --------------------------------------------------------------------------- # Prompt builder — shows the LLM the actual document data # --------------------------------------------------------------------------- def build_prompt(obs, step: int, max_steps: int, history: list) -> str: """Build the user prompt from the current observation state.""" po = obs.purchase_order inv = obs.invoice grn = obs.grn sm = obs.supplier_master lines = [ f"Step {step} of {max_steps}.", "", f"EXCEPTION FLAG: {obs.exception_flag.flag_code}", f"{obs.exception_flag.flag_description}", "", "=== DOCUMENT DATA ===", f"PO #{po.po_number} | Supplier: {po.vendor_name} | Total: {po.total_amount} | Terms: {po.payment_terms}", f"PO lines: {[(i.description[:30], 'qty='+str(i.quantity), 'unit='+str(i.unit_price)) for i in po.line_items]}", "", f"Invoice #{inv.invoice_number} | Date: {inv.invoice_date} | Subtotal: {inv.subtotal} | Tax: {inv.tax_amount} | Total: {inv.total_amount}", f"Invoice GSTIN: {inv.supplier_gstin} | Bank: {inv.bank_account} {inv.ifsc_code}", f"Invoice lines: {[(i.description[:30], 'qty='+str(i.quantity), 'unit='+str(i.unit_price)) for i in inv.line_items]}", "", f"GRN: received={sum(i.get('quantity_received', 0) for i in grn.items_received)} units | pending={sum(i.get('quantity_pending', 0) for i in grn.items_received)} units", "", f"Supplier Master: GSTIN={sm.gstin} | Bank={sm.bank_account} {sm.ifsc_code} | Domain={sm.registered_domain}", "", "=== AVAILABLE ACTIONS ===", f"Checks you can run: {', '.join(obs.available_checks)}", f"Rules you can apply: {', '.join(obs.available_rules)}", "", "Knowledge base (company policies):", ] for entry in obs.knowledge_base: lines.append(f" - {entry}") lines.append("") lines.append(f"Cumulative reward: {obs.cumulative_reward:.2f} | Status: {obs.case_status}") if obs.checks_run: lines.append(f"Checks already run: {', '.join(c.check_name for c in obs.checks_run)}") if obs.queries: lines.append(f"Queries already made: {', '.join(q.target for q in obs.queries)}") if obs.inspections: lines.append(f"Fields already inspected: {', '.join(f'{i.document}.{i.field}' for i in obs.inspections)}") if obs.rules_applied: lines.append(f"Rules already applied: {', '.join(obs.rules_applied)}") if obs.decision: lines.append(f"Decision already made: {obs.decision}") if obs.routed_to: lines.append(f"Already routed to: {', '.join(obs.routed_to)}") if history: lines.append("") lines.append("Recent steps:") for h in history[-5:]: lines.append(f" {h}") lines.append("") lines.append("What is your next action? Respond with a single JSON object only.") return "\n".join(lines) # --------------------------------------------------------------------------- # LLM caller # --------------------------------------------------------------------------- def call_llm(client: OpenAI, user_prompt: str) -> str: """Call the LLM and return its raw text response.""" try: response = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0.1, max_tokens=256, ) return response.choices[0].message.content or "" except Exception as e: print(f"LLM call failed: {e}", file=sys.stderr) return '{"type": "run_check", "params": {"check_name": "po_match"}}' # --------------------------------------------------------------------------- # Action parser # --------------------------------------------------------------------------- def parse_action(raw_text: str) -> dict: """ Parse the model response into an action dict. Strips markdown fences, handles whitespace, falls back on parse failure. """ text = raw_text.strip() # Strip ```json ... ``` or ``` ... ``` fences if text.startswith("```"): parts = text.split("\n") text = "\n".join(parts[1:-1] if parts[-1].strip() == "```" else parts[1:]) try: return json.loads(text.strip()) except json.JSONDecodeError: pass # Try to find JSON anywhere in the text match = re.search(r'\{.*\}', text, re.DOTALL) if match: try: return json.loads(match.group()) except json.JSONDecodeError: pass # Safe fallback — never crash return {"type": "run_check", "params": {"check_name": "po_match"}} # --------------------------------------------------------------------------- # Task runner — one full episode # --------------------------------------------------------------------------- def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple: """Run one task episode. Returns (steps_taken, score, rewards).""" rewards: list[float] = [] print(f"[START] task={task_id} env=invoice-exception-handler model={MODEL_NAME}", flush=True) obs = env.reset(task_id) max_steps = env._task.max_steps # reads the correct limit per task: 18 / 20 / 25 history: list[str] = [] for step in range(1, max_steps + 1): user_prompt = build_prompt(obs, step, max_steps, history) raw = call_llm(client, user_prompt) action_dict = parse_action(raw) try: result = env.step(action_dict) reward = result.reward done = result.done error = None except Exception as exc: reward = 0.0 done = False error = str(exc) result = None rewards.append(reward) action_str = json.dumps(action_dict) print( f"[STEP] step={step} action={action_str} " f"reward={reward:.2f} done={str(done).lower()} " f"error={error or 'null'}", flush=True, ) history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}") if result is not None: obs = result.observation if done: break score = env.grade()["score"] success = score >= 0.5 steps_taken = min(step, max_steps) rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps_taken} " f"score={score:.3f} rewards={rewards_str}", flush=True, ) return steps_taken, score, rewards # --------------------------------------------------------------------------- # Main — run all three tasks in sequence # --------------------------------------------------------------------------- def main() -> None: """Entry point — runs inference on all tasks and prints average score.""" client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) env = InvoiceExceptionEnv(seed=42) all_scores: list[float] = [] for task_id in ALL_TASKS: _, score, _ = run_task(client, env, task_id) all_scores.append(score) avg = sum(all_scores) / len(all_scores) if all_scores else 0.0 print(f"\nAverage score across all tasks: {avg:.3f}", flush=True) if __name__ == "__main__": main()