YUS200619's picture
feat: complete invoice exception handler v1.0.0
562f58d
"""
Typed models for the Invoice Exception Handler OpenEnv environment.
Every object the agent sees or produces is defined here as a Pydantic model.
This is the single source of truth for the data contract between the
environment simulation and the agent.
"""
from __future__ import annotations
import time
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Enumerations
# ---------------------------------------------------------------------------
class ActionType(str, Enum):
"""The nine action types an agent can take during an episode."""
INSPECT_FIELD = "inspect_field"
CROSS_CHECK = "cross_check"
RUN_CHECK = "run_check"
QUERY_SUPPLIER = "query_supplier"
QUERY_INTERNAL = "query_internal"
APPLY_RULE = "apply_rule"
MAKE_DECISION = "make_decision"
ROUTE_TO = "route_to"
CLOSE_CASE = "close_case"
class DecisionType(str, Enum):
"""Possible decisions the agent can make on a flagged invoice."""
APPROVE = "approve"
REJECT = "reject"
HOLD = "hold"
PARTIAL_APPROVE = "partial_approve"
class CaseStatus(str, Enum):
"""Lifecycle status of an invoice exception case."""
OPEN = "open"
IN_REVIEW = "in_review"
DECIDED = "decided"
ROUTED = "routed"
CLOSED = "closed"
# ---------------------------------------------------------------------------
# Document models — read-only context given to the agent
# ---------------------------------------------------------------------------
class LineItem(BaseModel):
"""One line on an invoice or purchase order."""
description: str = Field(..., description="Item description")
quantity: int = Field(..., description="Number of units")
unit_price: float = Field(..., description="Price per unit in INR")
total: float = Field(..., description="Line total in INR (quantity × unit_price)")
tax_rate: Optional[float] = Field(None, description="Tax rate as a percentage, if applicable")
class PurchaseOrder(BaseModel):
"""What was agreed to be purchased."""
po_number: str = Field(..., description="Unique PO identifier")
vendor_name: str = Field(..., description="Supplier name on the PO")
po_date: str = Field(..., description="Date the PO was raised (YYYY-MM-DD)")
line_items: List[LineItem] = Field(default_factory=list, description="Items on the PO")
total_amount: float = Field(..., description="Total PO value in INR")
payment_terms: str = Field("Net-30", description="Payment terms")
currency: str = Field("INR", description="Currency code")
class Invoice(BaseModel):
"""What the supplier is claiming — the document under exception review."""
invoice_number: str = Field(..., description="Unique invoice identifier")
supplier_name: str = Field(..., description="Supplier name on the invoice")
invoice_date: str = Field(..., description="Date of the invoice (YYYY-MM-DD)")
due_date: str = Field(..., description="Payment due date (YYYY-MM-DD)")
po_reference: str = Field(..., description="PO number referenced by this invoice")
line_items: List[LineItem] = Field(default_factory=list, description="Items invoiced")
subtotal: float = Field(..., description="Pre-tax total in INR")
tax_amount: float = Field(..., description="Total tax amount in INR")
tax_rate: float = Field(..., description="Applied tax rate as a percentage")
total_amount: float = Field(..., description="Grand total including tax in INR")
bank_account: str = Field(..., description="Supplier bank account on the invoice")
bank_name: str = Field("", description="Bank name")
ifsc_code: str = Field("", description="IFSC / routing code")
supplier_gstin: str = Field("", description="GST Identification Number on the invoice")
supplier_email: str = Field("", description="Email address on the invoice")
currency: str = Field("INR", description="Currency code")
class GoodsReceiptNote(BaseModel):
"""What actually arrived at the warehouse (or service confirmation)."""
grn_number: str = Field(..., description="Unique GRN identifier")
po_reference: str = Field(..., description="PO number this receipt is against")
receipt_date: str = Field(..., description="Date goods/services were received (YYYY-MM-DD)")
items_received: List[Dict[str, Any]] = Field(
default_factory=list,
description="List of received item dicts with description, quantity_received, quantity_pending, quantity_rejected"
)
receiving_officer: str = Field("", description="Person who signed the receipt")
notes: str = Field("", description="Any delivery notes or discrepancies observed")
class SupplierMaster(BaseModel):
"""The verified, registered supplier record in the company's ERP system."""
supplier_id: str = Field(..., description="Internal supplier code")
supplier_name: str = Field(..., description="Registered legal name")
registered_address: str = Field("", description="Registered business address")
gstin: str = Field(..., description="Verified GST Identification Number")
bank_account: str = Field(..., description="Verified bank account number")
bank_name: str = Field("", description="Bank name")
ifsc_code: str = Field("", description="Verified IFSC / routing code")
contact_email: str = Field("", description="Registered email address")
contact_phone: str = Field("", description="Registered phone number")
registered_domain: str = Field("", description="Verified email domain for the supplier")
pan_number: str = Field("", description="PAN (tax ID)")
status: str = Field("active", description="Supplier status: active, suspended, blacklisted")
class ExceptionFlag(BaseModel):
"""Why the AP system flagged this invoice for manual review."""
flag_code: str = Field(..., description="Machine-readable code, e.g. PRICE_MISMATCH")
flag_description: str = Field(..., description="Human-readable explanation of the flag")
auto_hold: bool = Field(False, description="Whether the system placed an automatic payment hold")
flagged_date: str = Field("", description="Date the flag was raised (YYYY-MM-DD)")
severity: str = Field("medium", description="low / medium / high / critical")
# ---------------------------------------------------------------------------
# Action model
# ---------------------------------------------------------------------------
class Action(BaseModel):
"""
An action the agent wants to take.
Use the classmethod constructors for convenience:
Action.run_check("tolerance_rule")
Action.make_decision("approve", "reason here")
"""
type: ActionType = Field(..., description="Which action type to execute")
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters for the action")
# --- Classmethod constructors for each action type ---
@classmethod
def inspect_field(cls, document: str, field: str) -> Action:
"""Look at a specific field in a document."""
return cls(type=ActionType.INSPECT_FIELD, params={"document": document, "field": field})
@classmethod
def cross_check(cls, field: str, doc_a: str, doc_b: str) -> Action:
"""Compare a field between two documents."""
return cls(type=ActionType.CROSS_CHECK, params={"field": field, "doc_a": doc_a, "doc_b": doc_b})
@classmethod
def run_check(cls, check_name: str) -> Action:
"""Run a named validation check."""
return cls(type=ActionType.RUN_CHECK, params={"check_name": check_name})
@classmethod
def query_supplier(cls, question: str, channel: str = "email") -> Action:
"""Ask the supplier a question via a specific channel."""
return cls(type=ActionType.QUERY_SUPPLIER, params={"question": question, "channel": channel})
@classmethod
def query_internal(cls, department: str, question: str) -> Action:
"""Ask an internal department a question."""
return cls(type=ActionType.QUERY_INTERNAL, params={"department": department, "question": question})
@classmethod
def apply_rule(cls, rule_id: str) -> Action:
"""Apply a named business policy rule."""
return cls(type=ActionType.APPLY_RULE, params={"rule_id": rule_id})
@classmethod
def make_decision(cls, decision: str, reason: str) -> Action:
"""Make a case decision with a documented reason."""
return cls(type=ActionType.MAKE_DECISION, params={"decision": decision, "reason": reason})
@classmethod
def route_to(cls, team: str, notes: str = "") -> Action:
"""Escalate the case to a specific team."""
return cls(type=ActionType.ROUTE_TO, params={"team": team, "notes": notes})
@classmethod
def close_case(cls, summary: str) -> Action:
"""Close the case with an audit trail summary."""
return cls(type=ActionType.CLOSE_CASE, params={"summary": summary})
# ---------------------------------------------------------------------------
# Result models — returned by simulators
# ---------------------------------------------------------------------------
class InspectionResult(BaseModel):
"""What came back from inspecting a specific field in a document."""
document: str = Field(..., description="Which document was inspected")
field: str = Field(..., description="Which field was inspected")
value: Any = Field(..., description="The value found in that field")
note: str = Field("", description="Any contextual note about the value")
timestamp: float = Field(default_factory=time.time, description="When the inspection happened")
class CheckResult(BaseModel):
"""What came back from running a validation check or cross-check."""
check_name: str = Field(..., description="Name of the check that was run")
passed: bool = Field(..., description="Whether the check passed (True) or failed (False)")
detail: str = Field("", description="Human-readable detail of what was found")
timestamp: float = Field(default_factory=time.time, description="When the check was run")
class QueryResult(BaseModel):
"""What came back from querying a supplier or internal department."""
target: str = Field(..., description="Who was queried (supplier, procurement, finance, etc.)")
question: str = Field("", description="The question that was asked")
response: str = Field(..., description="The response received")
channel: str = Field("email", description="Communication channel used (email, phone, etc.)")
timestamp: float = Field(default_factory=time.time, description="When the query was made")
# ---------------------------------------------------------------------------
# State models
# ---------------------------------------------------------------------------
class EnvironmentState(BaseModel):
"""
The full observable state returned by reset() and step().
This is what the agent sees at every turn — all documents, all history,
and all available actions/checks/rules for the current task.
"""
task_id: str = Field(..., description="Which task is currently running")
step_number: int = Field(0, description="Current step number in the episode")
case_status: CaseStatus = Field(CaseStatus.OPEN, description="Current lifecycle status")
# The five documents
purchase_order: PurchaseOrder = Field(..., description="The purchase order")
invoice: Invoice = Field(..., description="The invoice under review")
grn: GoodsReceiptNote = Field(..., description="The goods receipt note")
supplier_master: SupplierMaster = Field(..., description="The verified supplier record")
exception_flag: ExceptionFlag = Field(..., description="Why this invoice was flagged")
# Agent history — what has been done so far
inspections: List[InspectionResult] = Field(default_factory=list, description="Fields inspected")
checks_run: List[CheckResult] = Field(default_factory=list, description="Checks completed")
queries: List[QueryResult] = Field(default_factory=list, description="Queries made")
rules_applied: List[str] = Field(default_factory=list, description="Rules applied")
# Decision state
decision: Optional[str] = Field(None, description="Current decision if one has been made")
decision_reason: Optional[str] = Field(None, description="Reason for the decision")
routed_to: List[str] = Field(default_factory=list, description="Teams case has been routed to")
case_closed: bool = Field(False, description="Whether the case has been closed")
close_summary: Optional[str] = Field(None, description="Closure summary if case is closed")
# Action hints — what the agent can do
available_actions: List[str] = Field(default_factory=list, description="All valid action types")
available_checks: List[str] = Field(default_factory=list, description="Check names for this task")
available_rules: List[str] = Field(default_factory=list, description="Rule IDs for this task")
knowledge_base: List[str] = Field(default_factory=list, description="Policy entries for this task")
# Running totals
cumulative_reward: float = Field(0.0, description="Sum of all rewards received so far")
class StepResult(BaseModel):
"""What step() returns — the observation, reward, done flag, and info dict."""
observation: EnvironmentState = Field(..., description="Updated environment state after the action")
reward: float = Field(..., description="Reward for this specific action")
done: bool = Field(False, description="Whether the episode is over")
info: Dict[str, Any] = Field(default_factory=dict, description="Extra info about the step")