""" InvoiceExceptionEnv — the main environment class. This is the only class external code needs to import. It wraps the task registry, dispatches actions, manages episode state, and provides the OpenEnv-compatible API: reset(), step(), state(), grade(). """ from __future__ import annotations import random from typing import Any, Dict, List, Optional, Union from .models import ( Action, ActionType, CaseStatus, EnvironmentState, StepResult, ) from .tasks import ALL_TASKS, BaseTask, EpisodeData, make_task class InvoiceExceptionEnv: """ OpenEnv-compatible Invoice Exception Handler environment. Usage: env = InvoiceExceptionEnv(seed=42) obs = env.reset("task1_price_variance") result = env.step(Action.run_check("tolerance_rule")) scores = env.grade() """ def __init__(self, seed: Optional[int] = None) -> None: """Initialise with an optional seed for reproducibility.""" self._rng = random.Random(seed) self._task: Optional[BaseTask] = None self._ep: Optional[EpisodeData] = None self._state_cache: Optional[EnvironmentState] = None self._done: bool = False # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def reset(self, task_id: Optional[str] = None) -> EnvironmentState: """ Start a new episode. If task_id is None, picks one at random. Returns the initial EnvironmentState showing all documents and available actions. """ if task_id is None: task_id = self._rng.choice(ALL_TASKS) self._task = make_task(task_id) self._ep = EpisodeData() self._done = False self._state_cache = self._build_state() return self._state_cache def step(self, action: Union[Action, Dict[str, Any]]) -> StepResult: """ Execute one action. Returns observation, reward, done flag, and info dict. Raises RuntimeError if called before reset() or after the episode is done. """ if self._task is None or self._ep is None: raise RuntimeError("Call reset() before step().") if self._done: raise RuntimeError("Episode is done. Call reset() to start a new one.") # Convert dict to Action if needed if isinstance(action, dict): action = Action( type=ActionType(action.get("type", action.get("action_type", ""))), params=action.get("params", {}), ) # Dispatch the action reward, info = self._dispatch(action) # Update episode self._ep.step_count += 1 self._ep.cumulative_reward += reward # Check SLA breach sla_penalty = 0.0 if self._ep.step_count >= self._task.max_steps: sla_penalty = -0.10 self._done = True info["sla_breach"] = True # Check done conditions if self._ep.closed: self._done = True total_reward = reward + sla_penalty self._ep.cumulative_reward += sla_penalty # add SLA penalty separately # Rebuild state self._state_cache = self._build_state() return StepResult( observation=self._state_cache, reward=round(total_reward, 4), done=self._done, info=info, ) def state(self) -> EnvironmentState: """Return the current state without advancing the episode.""" if self._state_cache is None: raise RuntimeError("Call reset() before state().") return self._state_cache def grade(self) -> Dict[str, float]: """Run the task grader on the current episode and return scores.""" if self._task is None or self._ep is None: raise RuntimeError("Call reset() before grade().") return self._task.grade(self._ep) def action_space_sample(self) -> Action: """Return a random valid action for baseline/testing purposes.""" if self._task is None: raise RuntimeError("Call reset() before action_space_sample().") action_type = self._rng.choice(list(ActionType)) if action_type == ActionType.INSPECT_FIELD: doc = self._rng.choice(["invoice", "po", "grn", "supplier_master"]) field = self._rng.choice(["line_items", "total_amount", "bank_account", "supplier_gstin", "items_received"]) return Action.inspect_field(doc, field) elif action_type == ActionType.CROSS_CHECK: field = self._rng.choice(["unit_price", "total_amount", "bank_account", "gstin", "quantity"]) doc_a = self._rng.choice(["invoice", "po"]) doc_b = self._rng.choice(["po", "grn", "supplier_master"]) return Action.cross_check(field, doc_a, doc_b) elif action_type == ActionType.RUN_CHECK: check = self._rng.choice(self._task.available_checks) return Action.run_check(check) elif action_type == ActionType.QUERY_SUPPLIER: channel = self._rng.choice(["email", "phone"]) return Action.query_supplier("What is the status?", channel) elif action_type == ActionType.QUERY_INTERNAL: dept = self._rng.choice(["procurement", "finance", "legal", "security"]) return Action.query_internal(dept, "Can you provide information?") elif action_type == ActionType.APPLY_RULE: rule = self._rng.choice(self._task.available_rules) return Action.apply_rule(rule) elif action_type == ActionType.MAKE_DECISION: decision = self._rng.choice(["approve", "reject", "hold", "partial_approve"]) return Action.make_decision(decision, "Random baseline decision.") elif action_type == ActionType.ROUTE_TO: team = self._rng.choice(["procurement", "finance", "legal", "security"]) return Action.route_to(team, "Random baseline routing.") elif action_type == ActionType.CLOSE_CASE: return Action.close_case("Random baseline closure.") # Fallback return Action.run_check(self._task.available_checks[0]) # ------------------------------------------------------------------ # Internal methods # ------------------------------------------------------------------ def _dispatch(self, action: Action) -> tuple: """ Route an action to the appropriate task simulator. Returns (reward, info dict). Handles repeat-action penalties. """ params = action.params info: Dict[str, Any] = {"action_type": action.type.value} if action.type == ActionType.INSPECT_FIELD: doc = params.get("document", "") field = params.get("field", "") # Repeat penalty if self._ep.has_inspected(doc, field): info["repeat"] = True return -0.02, info result, reward = self._task.simulate_inspect(doc, field) self._ep.inspections.append(result) info["result"] = result.model_dump() return reward, info elif action.type == ActionType.CROSS_CHECK: field = params.get("field", "") doc_a = params.get("doc_a", "") doc_b = params.get("doc_b", "") check_key = f"cross_{field}_{doc_a}_{doc_b}" if self._ep.has_checked(check_key): info["repeat"] = True return -0.03, info result, reward = self._task.simulate_cross_check(field, doc_a, doc_b) self._ep.checks.append(result) info["result"] = result.model_dump() return reward, info elif action.type == ActionType.RUN_CHECK: check_name = params.get("check_name", "") if self._ep.has_checked(check_name): info["repeat"] = True return -0.03, info result, reward = self._task.simulate_run_check(check_name) self._ep.checks.append(result) info["result"] = result.model_dump() return reward, info elif action.type == ActionType.QUERY_SUPPLIER: question = params.get("question", "") channel = params.get("channel", "email") if self._ep.has_queried("supplier"): info["repeat"] = True return -0.05, info result, reward = self._task.simulate_query_supplier(question, channel) self._ep.queries.append(result) info["result"] = result.model_dump() return reward, info elif action.type == ActionType.QUERY_INTERNAL: department = params.get("department", "") question = params.get("question", "") if self._ep.has_queried(department.lower()): info["repeat"] = True return -0.03, info result, reward = self._task.simulate_query_internal(department, question) self._ep.queries.append(result) info["result"] = result.model_dump() return reward, info elif action.type == ActionType.APPLY_RULE: rule_id = params.get("rule_id", "") if rule_id in self._ep.rules_applied: info["repeat"] = True return -0.03, info detail, reward = self._task.simulate_apply_rule(rule_id) self._ep.rules_applied.append(rule_id) info["detail"] = detail return reward, info elif action.type == ActionType.MAKE_DECISION: decision = params.get("decision", "") reason = params.get("reason", "") if self._ep.decision is not None: info["repeat"] = True return -0.05, info reward = self._task.simulate_make_decision(decision, reason, self._ep) self._ep.decision = decision self._ep.decision_reason = reason info["decision"] = decision return reward, info elif action.type == ActionType.ROUTE_TO: team = params.get("team", "") notes = params.get("notes", "") if team.lower() in self._ep.routed_to: info["repeat"] = True return -0.02, info reward = self._task.simulate_route_to(team, notes, self._ep) self._ep.routed_to.append(team.lower()) info["routed_to"] = team return reward, info elif action.type == ActionType.CLOSE_CASE: summary = params.get("summary", "") if self._ep.closed: info["repeat"] = True return -0.05, info reward = self._task.simulate_close(summary, self._ep) self._ep.closed = True self._ep.close_summary = summary info["closed"] = True return reward, info # Unknown action type return 0.0, {"error": f"Unknown action type: {action.type}"} def _build_state(self) -> EnvironmentState: """Construct an EnvironmentState from current task and episode data.""" # Determine case status if self._ep.closed: status = CaseStatus.CLOSED elif self._ep.routed_to: status = CaseStatus.ROUTED elif self._ep.decision is not None: status = CaseStatus.DECIDED elif self._ep.step_count > 0: status = CaseStatus.IN_REVIEW else: status = CaseStatus.OPEN return EnvironmentState( task_id=self._task.task_id, step_number=self._ep.step_count, case_status=status, purchase_order=self._task.get_purchase_order(), invoice=self._task.get_invoice(), grn=self._task.get_grn(), supplier_master=self._task.get_supplier_master(), exception_flag=self._task.get_exception_flag(), inspections=list(self._ep.inspections), checks_run=list(self._ep.checks), queries=list(self._ep.queries), rules_applied=list(self._ep.rules_applied), decision=self._ep.decision, decision_reason=self._ep.decision_reason, routed_to=list(self._ep.routed_to), case_closed=self._ep.closed, close_summary=self._ep.close_summary, available_actions=[at.value for at in ActionType], available_checks=self._task.available_checks, available_rules=self._task.available_rules, knowledge_base=self._task.knowledge_base, cumulative_reward=round(self._ep.cumulative_reward, 4), )