Spaces:
Sleeping
Sleeping
| """ | |
| Competition inference script for the Invoice Exception Handler environment. | |
| Uses the OpenAI client to call an LLM that acts as an AP analyst. | |
| Reads API_BASE_URL, MODEL_NAME, HF_TOKEN from environment variables. | |
| Emits [START], [STEP], [END] lines to stdout as required by the spec. | |
| Usage: | |
| export API_BASE_URL="https://router.huggingface.co/v1" | |
| export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct" | |
| export HF_TOKEN="your-token" | |
| python inference.py | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import sys | |
| from openai import OpenAI | |
| from env import InvoiceExceptionEnv, ALL_TASKS | |
| # --------------------------------------------------------------------------- | |
| # Configuration — read from environment variables exactly as the spec requires | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") # no default — spec requirement | |
| # --------------------------------------------------------------------------- | |
| # System prompt | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """You are an expert Accounts Payable (AP) analyst handling flagged invoice exceptions. | |
| You receive a full document packet: Purchase Order (PO), Invoice, Goods Receipt Note (GRN), | |
| Supplier Master record, and an Exception Flag explaining why the invoice was flagged. | |
| Your job: investigate the root cause, apply business rules, make a decision, and close the case. | |
| CRITICAL RULE: If there is ANY suspicion of bank account fraud or BEC attack, contact the | |
| supplier via PHONE only — never via email. Emailing may reach the fraudster. | |
| Your action space — respond with exactly ONE JSON object per turn: | |
| 1. {"type": "inspect_field", "params": {"document": "invoice|po|grn|supplier_master", "field": "field_name"}} | |
| 2. {"type": "cross_check", "params": {"field": "field_name", "doc_a": "doc1", "doc_b": "doc2"}} | |
| 3. {"type": "run_check", "params": {"check_name": "check_name"}} | |
| 4. {"type": "query_supplier", "params": {"question": "your question", "channel": "phone|email"}} | |
| 5. {"type": "query_internal", "params": {"department": "dept_name", "question": "your question"}} | |
| 6. {"type": "apply_rule", "params": {"rule_id": "rule_id"}} | |
| 7. {"type": "make_decision", "params": {"decision": "approve|reject|hold|partial_approve", "reason": "explanation"}} | |
| 8. {"type": "route_to", "params": {"team": "team_name", "notes": "routing notes"}} | |
| 9. {"type": "close_case", "params": {"summary": "audit trail summary"}} | |
| Rules: | |
| - Always run checks BEFORE making a decision | |
| - Never approve without verifying the root cause | |
| - Use phone (not email) if fraud is suspected | |
| - Respond with ONLY a JSON object, no explanation, no markdown fences | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Prompt builder — shows the LLM the actual document data | |
| # --------------------------------------------------------------------------- | |
| def build_prompt(obs, step: int, max_steps: int, history: list) -> str: | |
| """Build the user prompt from the current observation state.""" | |
| po = obs.purchase_order | |
| inv = obs.invoice | |
| grn = obs.grn | |
| sm = obs.supplier_master | |
| lines = [ | |
| f"Step {step} of {max_steps}.", | |
| "", | |
| f"EXCEPTION FLAG: {obs.exception_flag.flag_code}", | |
| f"{obs.exception_flag.flag_description}", | |
| "", | |
| "=== DOCUMENT DATA ===", | |
| f"PO #{po.po_number} | Supplier: {po.vendor_name} | Total: {po.total_amount} | Terms: {po.payment_terms}", | |
| f"PO lines: {[(i.description[:30], 'qty='+str(i.quantity), 'unit='+str(i.unit_price)) for i in po.line_items]}", | |
| "", | |
| f"Invoice #{inv.invoice_number} | Date: {inv.invoice_date} | Subtotal: {inv.subtotal} | Tax: {inv.tax_amount} | Total: {inv.total_amount}", | |
| f"Invoice GSTIN: {inv.supplier_gstin} | Bank: {inv.bank_account} {inv.ifsc_code}", | |
| f"Invoice lines: {[(i.description[:30], 'qty='+str(i.quantity), 'unit='+str(i.unit_price)) for i in inv.line_items]}", | |
| "", | |
| f"GRN: received={sum(i.get('quantity_received', 0) for i in grn.items_received)} units | pending={sum(i.get('quantity_pending', 0) for i in grn.items_received)} units", | |
| "", | |
| f"Supplier Master: GSTIN={sm.gstin} | Bank={sm.bank_account} {sm.ifsc_code} | Domain={sm.registered_domain}", | |
| "", | |
| "=== AVAILABLE ACTIONS ===", | |
| f"Checks you can run: {', '.join(obs.available_checks)}", | |
| f"Rules you can apply: {', '.join(obs.available_rules)}", | |
| "", | |
| "Knowledge base (company policies):", | |
| ] | |
| for entry in obs.knowledge_base: | |
| lines.append(f" - {entry}") | |
| lines.append("") | |
| lines.append(f"Cumulative reward: {obs.cumulative_reward:.2f} | Status: {obs.case_status}") | |
| if obs.checks_run: | |
| lines.append(f"Checks already run: {', '.join(c.check_name for c in obs.checks_run)}") | |
| if obs.queries: | |
| lines.append(f"Queries already made: {', '.join(q.target for q in obs.queries)}") | |
| if obs.inspections: | |
| lines.append(f"Fields already inspected: {', '.join(f'{i.document}.{i.field}' for i in obs.inspections)}") | |
| if obs.rules_applied: | |
| lines.append(f"Rules already applied: {', '.join(obs.rules_applied)}") | |
| if obs.decision: | |
| lines.append(f"Decision already made: {obs.decision}") | |
| if obs.routed_to: | |
| lines.append(f"Already routed to: {', '.join(obs.routed_to)}") | |
| if history: | |
| lines.append("") | |
| lines.append("Recent steps:") | |
| for h in history[-5:]: | |
| lines.append(f" {h}") | |
| lines.append("") | |
| lines.append("What is your next action? Respond with a single JSON object only.") | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # LLM caller | |
| # --------------------------------------------------------------------------- | |
| def call_llm(client: OpenAI, user_prompt: str) -> str: | |
| """Call the LLM and return its raw text response.""" | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.1, | |
| max_tokens=256, | |
| ) | |
| return response.choices[0].message.content or "" | |
| except Exception as e: | |
| print(f"LLM call failed: {e}", file=sys.stderr) | |
| return '{"type": "run_check", "params": {"check_name": "po_match"}}' | |
| # --------------------------------------------------------------------------- | |
| # Action parser | |
| # --------------------------------------------------------------------------- | |
| def parse_action(raw_text: str) -> dict: | |
| """ | |
| Parse the model response into an action dict. | |
| Strips markdown fences, handles whitespace, falls back on parse failure. | |
| """ | |
| text = raw_text.strip() | |
| # Strip ```json ... ``` or ``` ... ``` fences | |
| if text.startswith("```"): | |
| parts = text.split("\n") | |
| text = "\n".join(parts[1:-1] if parts[-1].strip() == "```" else parts[1:]) | |
| try: | |
| return json.loads(text.strip()) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to find JSON anywhere in the text | |
| match = re.search(r'\{.*\}', text, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group()) | |
| except json.JSONDecodeError: | |
| pass | |
| # Safe fallback — never crash | |
| return {"type": "run_check", "params": {"check_name": "po_match"}} | |
| # --------------------------------------------------------------------------- | |
| # Task runner — one full episode | |
| # --------------------------------------------------------------------------- | |
| def run_task(client: OpenAI, env: InvoiceExceptionEnv, task_id: str) -> tuple: | |
| """Run one task episode. Returns (steps_taken, score, rewards).""" | |
| rewards: list[float] = [] | |
| print(f"[START] task={task_id} env=invoice-exception-handler model={MODEL_NAME}", flush=True) | |
| obs = env.reset(task_id) | |
| max_steps = env._task.max_steps # reads the correct limit per task: 18 / 20 / 25 | |
| history: list[str] = [] | |
| for step in range(1, max_steps + 1): | |
| user_prompt = build_prompt(obs, step, max_steps, history) | |
| raw = call_llm(client, user_prompt) | |
| action_dict = parse_action(raw) | |
| try: | |
| result = env.step(action_dict) | |
| reward = result.reward | |
| done = result.done | |
| error = None | |
| except Exception as exc: | |
| reward = 0.0 | |
| done = False | |
| error = str(exc) | |
| result = None | |
| rewards.append(reward) | |
| action_str = json.dumps(action_dict) | |
| print( | |
| f"[STEP] step={step} action={action_str} " | |
| f"reward={reward:.2f} done={str(done).lower()} " | |
| f"error={error or 'null'}", | |
| flush=True, | |
| ) | |
| history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}") | |
| if result is not None: | |
| obs = result.observation | |
| if done: | |
| break | |
| score = env.grade()["score"] | |
| success = score >= 0.5 | |
| steps_taken = min(step, max_steps) | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps_taken} " | |
| f"score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| return steps_taken, score, rewards | |
| # --------------------------------------------------------------------------- | |
| # Main — run all three tasks in sequence | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| """Entry point — runs inference on all tasks and prints average score.""" | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| env = InvoiceExceptionEnv(seed=42) | |
| all_scores: list[float] = [] | |
| for task_id in ALL_TASKS: | |
| _, score, _ = run_task(client, env, task_id) | |
| all_scores.append(score) | |
| avg = sum(all_scores) / len(all_scores) if all_scores else 0.0 | |
| print(f"\nAverage score across all tasks: {avg:.3f}", flush=True) | |
| if __name__ == "__main__": | |
| main() |