""" Typed models for the Invoice Exception Handler OpenEnv environment. Every object the agent sees or produces is defined here as a Pydantic model. This is the single source of truth for the data contract between the environment simulation and the agent. """ from __future__ import annotations import time from enum import Enum from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Enumerations # --------------------------------------------------------------------------- class ActionType(str, Enum): """The nine action types an agent can take during an episode.""" INSPECT_FIELD = "inspect_field" CROSS_CHECK = "cross_check" RUN_CHECK = "run_check" QUERY_SUPPLIER = "query_supplier" QUERY_INTERNAL = "query_internal" APPLY_RULE = "apply_rule" MAKE_DECISION = "make_decision" ROUTE_TO = "route_to" CLOSE_CASE = "close_case" class DecisionType(str, Enum): """Possible decisions the agent can make on a flagged invoice.""" APPROVE = "approve" REJECT = "reject" HOLD = "hold" PARTIAL_APPROVE = "partial_approve" class CaseStatus(str, Enum): """Lifecycle status of an invoice exception case.""" OPEN = "open" IN_REVIEW = "in_review" DECIDED = "decided" ROUTED = "routed" CLOSED = "closed" # --------------------------------------------------------------------------- # Document models — read-only context given to the agent # --------------------------------------------------------------------------- class LineItem(BaseModel): """One line on an invoice or purchase order.""" description: str = Field(..., description="Item description") quantity: int = Field(..., description="Number of units") unit_price: float = Field(..., description="Price per unit in INR") total: float = Field(..., description="Line total in INR (quantity × unit_price)") tax_rate: Optional[float] = Field(None, description="Tax rate as a percentage, if applicable") class PurchaseOrder(BaseModel): """What was agreed to be purchased.""" po_number: str = Field(..., description="Unique PO identifier") vendor_name: str = Field(..., description="Supplier name on the PO") po_date: str = Field(..., description="Date the PO was raised (YYYY-MM-DD)") line_items: List[LineItem] = Field(default_factory=list, description="Items on the PO") total_amount: float = Field(..., description="Total PO value in INR") payment_terms: str = Field("Net-30", description="Payment terms") currency: str = Field("INR", description="Currency code") class Invoice(BaseModel): """What the supplier is claiming — the document under exception review.""" invoice_number: str = Field(..., description="Unique invoice identifier") supplier_name: str = Field(..., description="Supplier name on the invoice") invoice_date: str = Field(..., description="Date of the invoice (YYYY-MM-DD)") due_date: str = Field(..., description="Payment due date (YYYY-MM-DD)") po_reference: str = Field(..., description="PO number referenced by this invoice") line_items: List[LineItem] = Field(default_factory=list, description="Items invoiced") subtotal: float = Field(..., description="Pre-tax total in INR") tax_amount: float = Field(..., description="Total tax amount in INR") tax_rate: float = Field(..., description="Applied tax rate as a percentage") total_amount: float = Field(..., description="Grand total including tax in INR") bank_account: str = Field(..., description="Supplier bank account on the invoice") bank_name: str = Field("", description="Bank name") ifsc_code: str = Field("", description="IFSC / routing code") supplier_gstin: str = Field("", description="GST Identification Number on the invoice") supplier_email: str = Field("", description="Email address on the invoice") currency: str = Field("INR", description="Currency code") class GoodsReceiptNote(BaseModel): """What actually arrived at the warehouse (or service confirmation).""" grn_number: str = Field(..., description="Unique GRN identifier") po_reference: str = Field(..., description="PO number this receipt is against") receipt_date: str = Field(..., description="Date goods/services were received (YYYY-MM-DD)") items_received: List[Dict[str, Any]] = Field( default_factory=list, description="List of received item dicts with description, quantity_received, quantity_pending, quantity_rejected" ) receiving_officer: str = Field("", description="Person who signed the receipt") notes: str = Field("", description="Any delivery notes or discrepancies observed") class SupplierMaster(BaseModel): """The verified, registered supplier record in the company's ERP system.""" supplier_id: str = Field(..., description="Internal supplier code") supplier_name: str = Field(..., description="Registered legal name") registered_address: str = Field("", description="Registered business address") gstin: str = Field(..., description="Verified GST Identification Number") bank_account: str = Field(..., description="Verified bank account number") bank_name: str = Field("", description="Bank name") ifsc_code: str = Field("", description="Verified IFSC / routing code") contact_email: str = Field("", description="Registered email address") contact_phone: str = Field("", description="Registered phone number") registered_domain: str = Field("", description="Verified email domain for the supplier") pan_number: str = Field("", description="PAN (tax ID)") status: str = Field("active", description="Supplier status: active, suspended, blacklisted") class ExceptionFlag(BaseModel): """Why the AP system flagged this invoice for manual review.""" flag_code: str = Field(..., description="Machine-readable code, e.g. PRICE_MISMATCH") flag_description: str = Field(..., description="Human-readable explanation of the flag") auto_hold: bool = Field(False, description="Whether the system placed an automatic payment hold") flagged_date: str = Field("", description="Date the flag was raised (YYYY-MM-DD)") severity: str = Field("medium", description="low / medium / high / critical") # --------------------------------------------------------------------------- # Action model # --------------------------------------------------------------------------- class Action(BaseModel): """ An action the agent wants to take. Use the classmethod constructors for convenience: Action.run_check("tolerance_rule") Action.make_decision("approve", "reason here") """ type: ActionType = Field(..., description="Which action type to execute") params: Dict[str, Any] = Field(default_factory=dict, description="Parameters for the action") # --- Classmethod constructors for each action type --- @classmethod def inspect_field(cls, document: str, field: str) -> Action: """Look at a specific field in a document.""" return cls(type=ActionType.INSPECT_FIELD, params={"document": document, "field": field}) @classmethod def cross_check(cls, field: str, doc_a: str, doc_b: str) -> Action: """Compare a field between two documents.""" return cls(type=ActionType.CROSS_CHECK, params={"field": field, "doc_a": doc_a, "doc_b": doc_b}) @classmethod def run_check(cls, check_name: str) -> Action: """Run a named validation check.""" return cls(type=ActionType.RUN_CHECK, params={"check_name": check_name}) @classmethod def query_supplier(cls, question: str, channel: str = "email") -> Action: """Ask the supplier a question via a specific channel.""" return cls(type=ActionType.QUERY_SUPPLIER, params={"question": question, "channel": channel}) @classmethod def query_internal(cls, department: str, question: str) -> Action: """Ask an internal department a question.""" return cls(type=ActionType.QUERY_INTERNAL, params={"department": department, "question": question}) @classmethod def apply_rule(cls, rule_id: str) -> Action: """Apply a named business policy rule.""" return cls(type=ActionType.APPLY_RULE, params={"rule_id": rule_id}) @classmethod def make_decision(cls, decision: str, reason: str) -> Action: """Make a case decision with a documented reason.""" return cls(type=ActionType.MAKE_DECISION, params={"decision": decision, "reason": reason}) @classmethod def route_to(cls, team: str, notes: str = "") -> Action: """Escalate the case to a specific team.""" return cls(type=ActionType.ROUTE_TO, params={"team": team, "notes": notes}) @classmethod def close_case(cls, summary: str) -> Action: """Close the case with an audit trail summary.""" return cls(type=ActionType.CLOSE_CASE, params={"summary": summary}) # --------------------------------------------------------------------------- # Result models — returned by simulators # --------------------------------------------------------------------------- class InspectionResult(BaseModel): """What came back from inspecting a specific field in a document.""" document: str = Field(..., description="Which document was inspected") field: str = Field(..., description="Which field was inspected") value: Any = Field(..., description="The value found in that field") note: str = Field("", description="Any contextual note about the value") timestamp: float = Field(default_factory=time.time, description="When the inspection happened") class CheckResult(BaseModel): """What came back from running a validation check or cross-check.""" check_name: str = Field(..., description="Name of the check that was run") passed: bool = Field(..., description="Whether the check passed (True) or failed (False)") detail: str = Field("", description="Human-readable detail of what was found") timestamp: float = Field(default_factory=time.time, description="When the check was run") class QueryResult(BaseModel): """What came back from querying a supplier or internal department.""" target: str = Field(..., description="Who was queried (supplier, procurement, finance, etc.)") question: str = Field("", description="The question that was asked") response: str = Field(..., description="The response received") channel: str = Field("email", description="Communication channel used (email, phone, etc.)") timestamp: float = Field(default_factory=time.time, description="When the query was made") # --------------------------------------------------------------------------- # State models # --------------------------------------------------------------------------- class EnvironmentState(BaseModel): """ The full observable state returned by reset() and step(). This is what the agent sees at every turn — all documents, all history, and all available actions/checks/rules for the current task. """ task_id: str = Field(..., description="Which task is currently running") step_number: int = Field(0, description="Current step number in the episode") case_status: CaseStatus = Field(CaseStatus.OPEN, description="Current lifecycle status") # The five documents purchase_order: PurchaseOrder = Field(..., description="The purchase order") invoice: Invoice = Field(..., description="The invoice under review") grn: GoodsReceiptNote = Field(..., description="The goods receipt note") supplier_master: SupplierMaster = Field(..., description="The verified supplier record") exception_flag: ExceptionFlag = Field(..., description="Why this invoice was flagged") # Agent history — what has been done so far inspections: List[InspectionResult] = Field(default_factory=list, description="Fields inspected") checks_run: List[CheckResult] = Field(default_factory=list, description="Checks completed") queries: List[QueryResult] = Field(default_factory=list, description="Queries made") rules_applied: List[str] = Field(default_factory=list, description="Rules applied") # Decision state decision: Optional[str] = Field(None, description="Current decision if one has been made") decision_reason: Optional[str] = Field(None, description="Reason for the decision") routed_to: List[str] = Field(default_factory=list, description="Teams case has been routed to") case_closed: bool = Field(False, description="Whether the case has been closed") close_summary: Optional[str] = Field(None, description="Closure summary if case is closed") # Action hints — what the agent can do available_actions: List[str] = Field(default_factory=list, description="All valid action types") available_checks: List[str] = Field(default_factory=list, description="Check names for this task") available_rules: List[str] = Field(default_factory=list, description="Rule IDs for this task") knowledge_base: List[str] = Field(default_factory=list, description="Policy entries for this task") # Running totals cumulative_reward: float = Field(0.0, description="Sum of all rewards received so far") class StepResult(BaseModel): """What step() returns — the observation, reward, done flag, and info dict.""" observation: EnvironmentState = Field(..., description="Updated environment state after the action") reward: float = Field(..., description="Reward for this specific action") done: bool = Field(False, description="Whether the episode is over") info: Dict[str, Any] = Field(default_factory=dict, description="Extra info about the step")