Spaces:
Sleeping
Sleeping
| """ | |
| Typed models for the Invoice Exception Handler OpenEnv environment. | |
| Every object the agent sees or produces is defined here as a Pydantic model. | |
| This is the single source of truth for the data contract between the | |
| environment simulation and the agent. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field | |
| # --------------------------------------------------------------------------- | |
| # Enumerations | |
| # --------------------------------------------------------------------------- | |
| class ActionType(str, Enum): | |
| """The nine action types an agent can take during an episode.""" | |
| INSPECT_FIELD = "inspect_field" | |
| CROSS_CHECK = "cross_check" | |
| RUN_CHECK = "run_check" | |
| QUERY_SUPPLIER = "query_supplier" | |
| QUERY_INTERNAL = "query_internal" | |
| APPLY_RULE = "apply_rule" | |
| MAKE_DECISION = "make_decision" | |
| ROUTE_TO = "route_to" | |
| CLOSE_CASE = "close_case" | |
| class DecisionType(str, Enum): | |
| """Possible decisions the agent can make on a flagged invoice.""" | |
| APPROVE = "approve" | |
| REJECT = "reject" | |
| HOLD = "hold" | |
| PARTIAL_APPROVE = "partial_approve" | |
| class CaseStatus(str, Enum): | |
| """Lifecycle status of an invoice exception case.""" | |
| OPEN = "open" | |
| IN_REVIEW = "in_review" | |
| DECIDED = "decided" | |
| ROUTED = "routed" | |
| CLOSED = "closed" | |
| # --------------------------------------------------------------------------- | |
| # Document models — read-only context given to the agent | |
| # --------------------------------------------------------------------------- | |
| class LineItem(BaseModel): | |
| """One line on an invoice or purchase order.""" | |
| description: str = Field(..., description="Item description") | |
| quantity: int = Field(..., description="Number of units") | |
| unit_price: float = Field(..., description="Price per unit in INR") | |
| total: float = Field(..., description="Line total in INR (quantity × unit_price)") | |
| tax_rate: Optional[float] = Field(None, description="Tax rate as a percentage, if applicable") | |
| class PurchaseOrder(BaseModel): | |
| """What was agreed to be purchased.""" | |
| po_number: str = Field(..., description="Unique PO identifier") | |
| vendor_name: str = Field(..., description="Supplier name on the PO") | |
| po_date: str = Field(..., description="Date the PO was raised (YYYY-MM-DD)") | |
| line_items: List[LineItem] = Field(default_factory=list, description="Items on the PO") | |
| total_amount: float = Field(..., description="Total PO value in INR") | |
| payment_terms: str = Field("Net-30", description="Payment terms") | |
| currency: str = Field("INR", description="Currency code") | |
| class Invoice(BaseModel): | |
| """What the supplier is claiming — the document under exception review.""" | |
| invoice_number: str = Field(..., description="Unique invoice identifier") | |
| supplier_name: str = Field(..., description="Supplier name on the invoice") | |
| invoice_date: str = Field(..., description="Date of the invoice (YYYY-MM-DD)") | |
| due_date: str = Field(..., description="Payment due date (YYYY-MM-DD)") | |
| po_reference: str = Field(..., description="PO number referenced by this invoice") | |
| line_items: List[LineItem] = Field(default_factory=list, description="Items invoiced") | |
| subtotal: float = Field(..., description="Pre-tax total in INR") | |
| tax_amount: float = Field(..., description="Total tax amount in INR") | |
| tax_rate: float = Field(..., description="Applied tax rate as a percentage") | |
| total_amount: float = Field(..., description="Grand total including tax in INR") | |
| bank_account: str = Field(..., description="Supplier bank account on the invoice") | |
| bank_name: str = Field("", description="Bank name") | |
| ifsc_code: str = Field("", description="IFSC / routing code") | |
| supplier_gstin: str = Field("", description="GST Identification Number on the invoice") | |
| supplier_email: str = Field("", description="Email address on the invoice") | |
| currency: str = Field("INR", description="Currency code") | |
| class GoodsReceiptNote(BaseModel): | |
| """What actually arrived at the warehouse (or service confirmation).""" | |
| grn_number: str = Field(..., description="Unique GRN identifier") | |
| po_reference: str = Field(..., description="PO number this receipt is against") | |
| receipt_date: str = Field(..., description="Date goods/services were received (YYYY-MM-DD)") | |
| items_received: List[Dict[str, Any]] = Field( | |
| default_factory=list, | |
| description="List of received item dicts with description, quantity_received, quantity_pending, quantity_rejected" | |
| ) | |
| receiving_officer: str = Field("", description="Person who signed the receipt") | |
| notes: str = Field("", description="Any delivery notes or discrepancies observed") | |
| class SupplierMaster(BaseModel): | |
| """The verified, registered supplier record in the company's ERP system.""" | |
| supplier_id: str = Field(..., description="Internal supplier code") | |
| supplier_name: str = Field(..., description="Registered legal name") | |
| registered_address: str = Field("", description="Registered business address") | |
| gstin: str = Field(..., description="Verified GST Identification Number") | |
| bank_account: str = Field(..., description="Verified bank account number") | |
| bank_name: str = Field("", description="Bank name") | |
| ifsc_code: str = Field("", description="Verified IFSC / routing code") | |
| contact_email: str = Field("", description="Registered email address") | |
| contact_phone: str = Field("", description="Registered phone number") | |
| registered_domain: str = Field("", description="Verified email domain for the supplier") | |
| pan_number: str = Field("", description="PAN (tax ID)") | |
| status: str = Field("active", description="Supplier status: active, suspended, blacklisted") | |
| class ExceptionFlag(BaseModel): | |
| """Why the AP system flagged this invoice for manual review.""" | |
| flag_code: str = Field(..., description="Machine-readable code, e.g. PRICE_MISMATCH") | |
| flag_description: str = Field(..., description="Human-readable explanation of the flag") | |
| auto_hold: bool = Field(False, description="Whether the system placed an automatic payment hold") | |
| flagged_date: str = Field("", description="Date the flag was raised (YYYY-MM-DD)") | |
| severity: str = Field("medium", description="low / medium / high / critical") | |
| # --------------------------------------------------------------------------- | |
| # Action model | |
| # --------------------------------------------------------------------------- | |
| class Action(BaseModel): | |
| """ | |
| An action the agent wants to take. | |
| Use the classmethod constructors for convenience: | |
| Action.run_check("tolerance_rule") | |
| Action.make_decision("approve", "reason here") | |
| """ | |
| type: ActionType = Field(..., description="Which action type to execute") | |
| params: Dict[str, Any] = Field(default_factory=dict, description="Parameters for the action") | |
| # --- Classmethod constructors for each action type --- | |
| def inspect_field(cls, document: str, field: str) -> Action: | |
| """Look at a specific field in a document.""" | |
| return cls(type=ActionType.INSPECT_FIELD, params={"document": document, "field": field}) | |
| def cross_check(cls, field: str, doc_a: str, doc_b: str) -> Action: | |
| """Compare a field between two documents.""" | |
| return cls(type=ActionType.CROSS_CHECK, params={"field": field, "doc_a": doc_a, "doc_b": doc_b}) | |
| def run_check(cls, check_name: str) -> Action: | |
| """Run a named validation check.""" | |
| return cls(type=ActionType.RUN_CHECK, params={"check_name": check_name}) | |
| def query_supplier(cls, question: str, channel: str = "email") -> Action: | |
| """Ask the supplier a question via a specific channel.""" | |
| return cls(type=ActionType.QUERY_SUPPLIER, params={"question": question, "channel": channel}) | |
| def query_internal(cls, department: str, question: str) -> Action: | |
| """Ask an internal department a question.""" | |
| return cls(type=ActionType.QUERY_INTERNAL, params={"department": department, "question": question}) | |
| def apply_rule(cls, rule_id: str) -> Action: | |
| """Apply a named business policy rule.""" | |
| return cls(type=ActionType.APPLY_RULE, params={"rule_id": rule_id}) | |
| def make_decision(cls, decision: str, reason: str) -> Action: | |
| """Make a case decision with a documented reason.""" | |
| return cls(type=ActionType.MAKE_DECISION, params={"decision": decision, "reason": reason}) | |
| def route_to(cls, team: str, notes: str = "") -> Action: | |
| """Escalate the case to a specific team.""" | |
| return cls(type=ActionType.ROUTE_TO, params={"team": team, "notes": notes}) | |
| def close_case(cls, summary: str) -> Action: | |
| """Close the case with an audit trail summary.""" | |
| return cls(type=ActionType.CLOSE_CASE, params={"summary": summary}) | |
| # --------------------------------------------------------------------------- | |
| # Result models — returned by simulators | |
| # --------------------------------------------------------------------------- | |
| class InspectionResult(BaseModel): | |
| """What came back from inspecting a specific field in a document.""" | |
| document: str = Field(..., description="Which document was inspected") | |
| field: str = Field(..., description="Which field was inspected") | |
| value: Any = Field(..., description="The value found in that field") | |
| note: str = Field("", description="Any contextual note about the value") | |
| timestamp: float = Field(default_factory=time.time, description="When the inspection happened") | |
| class CheckResult(BaseModel): | |
| """What came back from running a validation check or cross-check.""" | |
| check_name: str = Field(..., description="Name of the check that was run") | |
| passed: bool = Field(..., description="Whether the check passed (True) or failed (False)") | |
| detail: str = Field("", description="Human-readable detail of what was found") | |
| timestamp: float = Field(default_factory=time.time, description="When the check was run") | |
| class QueryResult(BaseModel): | |
| """What came back from querying a supplier or internal department.""" | |
| target: str = Field(..., description="Who was queried (supplier, procurement, finance, etc.)") | |
| question: str = Field("", description="The question that was asked") | |
| response: str = Field(..., description="The response received") | |
| channel: str = Field("email", description="Communication channel used (email, phone, etc.)") | |
| timestamp: float = Field(default_factory=time.time, description="When the query was made") | |
| # --------------------------------------------------------------------------- | |
| # State models | |
| # --------------------------------------------------------------------------- | |
| class EnvironmentState(BaseModel): | |
| """ | |
| The full observable state returned by reset() and step(). | |
| This is what the agent sees at every turn — all documents, all history, | |
| and all available actions/checks/rules for the current task. | |
| """ | |
| task_id: str = Field(..., description="Which task is currently running") | |
| step_number: int = Field(0, description="Current step number in the episode") | |
| case_status: CaseStatus = Field(CaseStatus.OPEN, description="Current lifecycle status") | |
| # The five documents | |
| purchase_order: PurchaseOrder = Field(..., description="The purchase order") | |
| invoice: Invoice = Field(..., description="The invoice under review") | |
| grn: GoodsReceiptNote = Field(..., description="The goods receipt note") | |
| supplier_master: SupplierMaster = Field(..., description="The verified supplier record") | |
| exception_flag: ExceptionFlag = Field(..., description="Why this invoice was flagged") | |
| # Agent history — what has been done so far | |
| inspections: List[InspectionResult] = Field(default_factory=list, description="Fields inspected") | |
| checks_run: List[CheckResult] = Field(default_factory=list, description="Checks completed") | |
| queries: List[QueryResult] = Field(default_factory=list, description="Queries made") | |
| rules_applied: List[str] = Field(default_factory=list, description="Rules applied") | |
| # Decision state | |
| decision: Optional[str] = Field(None, description="Current decision if one has been made") | |
| decision_reason: Optional[str] = Field(None, description="Reason for the decision") | |
| routed_to: List[str] = Field(default_factory=list, description="Teams case has been routed to") | |
| case_closed: bool = Field(False, description="Whether the case has been closed") | |
| close_summary: Optional[str] = Field(None, description="Closure summary if case is closed") | |
| # Action hints — what the agent can do | |
| available_actions: List[str] = Field(default_factory=list, description="All valid action types") | |
| available_checks: List[str] = Field(default_factory=list, description="Check names for this task") | |
| available_rules: List[str] = Field(default_factory=list, description="Rule IDs for this task") | |
| knowledge_base: List[str] = Field(default_factory=list, description="Policy entries for this task") | |
| # Running totals | |
| cumulative_reward: float = Field(0.0, description="Sum of all rewards received so far") | |
| class StepResult(BaseModel): | |
| """What step() returns — the observation, reward, done flag, and info dict.""" | |
| observation: EnvironmentState = Field(..., description="Updated environment state after the action") | |
| reward: float = Field(..., description="Reward for this specific action") | |
| done: bool = Field(False, description="Whether the episode is over") | |
| info: Dict[str, Any] = Field(default_factory=dict, description="Extra info about the step") | |