Spaces:
Sleeping
Sleeping
| """ | |
| InvoiceExceptionEnv — the main environment class. | |
| This is the only class external code needs to import. It wraps the task | |
| registry, dispatches actions, manages episode state, and provides the | |
| OpenEnv-compatible API: reset(), step(), state(), grade(). | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from typing import Any, Dict, List, Optional, Union | |
| from .models import ( | |
| Action, ActionType, CaseStatus, EnvironmentState, StepResult, | |
| ) | |
| from .tasks import ALL_TASKS, BaseTask, EpisodeData, make_task | |
| class InvoiceExceptionEnv: | |
| """ | |
| OpenEnv-compatible Invoice Exception Handler environment. | |
| Usage: | |
| env = InvoiceExceptionEnv(seed=42) | |
| obs = env.reset("task1_price_variance") | |
| result = env.step(Action.run_check("tolerance_rule")) | |
| scores = env.grade() | |
| """ | |
| def __init__(self, seed: Optional[int] = None) -> None: | |
| """Initialise with an optional seed for reproducibility.""" | |
| self._rng = random.Random(seed) | |
| self._task: Optional[BaseTask] = None | |
| self._ep: Optional[EpisodeData] = None | |
| self._state_cache: Optional[EnvironmentState] = None | |
| self._done: bool = False | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def reset(self, task_id: Optional[str] = None) -> EnvironmentState: | |
| """ | |
| Start a new episode. If task_id is None, picks one at random. | |
| Returns the initial EnvironmentState showing all documents and | |
| available actions. | |
| """ | |
| if task_id is None: | |
| task_id = self._rng.choice(ALL_TASKS) | |
| self._task = make_task(task_id) | |
| self._ep = EpisodeData() | |
| self._done = False | |
| self._state_cache = self._build_state() | |
| return self._state_cache | |
| def step(self, action: Union[Action, Dict[str, Any]]) -> StepResult: | |
| """ | |
| Execute one action. Returns observation, reward, done flag, and | |
| info dict. Raises RuntimeError if called before reset() or after | |
| the episode is done. | |
| """ | |
| if self._task is None or self._ep is None: | |
| raise RuntimeError("Call reset() before step().") | |
| if self._done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new one.") | |
| # Convert dict to Action if needed | |
| if isinstance(action, dict): | |
| action = Action( | |
| type=ActionType(action.get("type", action.get("action_type", ""))), | |
| params=action.get("params", {}), | |
| ) | |
| # Dispatch the action | |
| reward, info = self._dispatch(action) | |
| # Update episode | |
| self._ep.step_count += 1 | |
| self._ep.cumulative_reward += reward | |
| # Check SLA breach | |
| sla_penalty = 0.0 | |
| if self._ep.step_count >= self._task.max_steps: | |
| sla_penalty = -0.10 | |
| self._done = True | |
| info["sla_breach"] = True | |
| # Check done conditions | |
| if self._ep.closed: | |
| self._done = True | |
| total_reward = reward + sla_penalty | |
| self._ep.cumulative_reward += sla_penalty # add SLA penalty separately | |
| # Rebuild state | |
| self._state_cache = self._build_state() | |
| return StepResult( | |
| observation=self._state_cache, | |
| reward=round(total_reward, 4), | |
| done=self._done, | |
| info=info, | |
| ) | |
| def state(self) -> EnvironmentState: | |
| """Return the current state without advancing the episode.""" | |
| if self._state_cache is None: | |
| raise RuntimeError("Call reset() before state().") | |
| return self._state_cache | |
| def grade(self) -> Dict[str, float]: | |
| """Run the task grader on the current episode and return scores.""" | |
| if self._task is None or self._ep is None: | |
| raise RuntimeError("Call reset() before grade().") | |
| return self._task.grade(self._ep) | |
| def action_space_sample(self) -> Action: | |
| """Return a random valid action for baseline/testing purposes.""" | |
| if self._task is None: | |
| raise RuntimeError("Call reset() before action_space_sample().") | |
| action_type = self._rng.choice(list(ActionType)) | |
| if action_type == ActionType.INSPECT_FIELD: | |
| doc = self._rng.choice(["invoice", "po", "grn", "supplier_master"]) | |
| field = self._rng.choice(["line_items", "total_amount", "bank_account", | |
| "supplier_gstin", "items_received"]) | |
| return Action.inspect_field(doc, field) | |
| elif action_type == ActionType.CROSS_CHECK: | |
| field = self._rng.choice(["unit_price", "total_amount", "bank_account", | |
| "gstin", "quantity"]) | |
| doc_a = self._rng.choice(["invoice", "po"]) | |
| doc_b = self._rng.choice(["po", "grn", "supplier_master"]) | |
| return Action.cross_check(field, doc_a, doc_b) | |
| elif action_type == ActionType.RUN_CHECK: | |
| check = self._rng.choice(self._task.available_checks) | |
| return Action.run_check(check) | |
| elif action_type == ActionType.QUERY_SUPPLIER: | |
| channel = self._rng.choice(["email", "phone"]) | |
| return Action.query_supplier("What is the status?", channel) | |
| elif action_type == ActionType.QUERY_INTERNAL: | |
| dept = self._rng.choice(["procurement", "finance", "legal", "security"]) | |
| return Action.query_internal(dept, "Can you provide information?") | |
| elif action_type == ActionType.APPLY_RULE: | |
| rule = self._rng.choice(self._task.available_rules) | |
| return Action.apply_rule(rule) | |
| elif action_type == ActionType.MAKE_DECISION: | |
| decision = self._rng.choice(["approve", "reject", "hold", "partial_approve"]) | |
| return Action.make_decision(decision, "Random baseline decision.") | |
| elif action_type == ActionType.ROUTE_TO: | |
| team = self._rng.choice(["procurement", "finance", "legal", "security"]) | |
| return Action.route_to(team, "Random baseline routing.") | |
| elif action_type == ActionType.CLOSE_CASE: | |
| return Action.close_case("Random baseline closure.") | |
| # Fallback | |
| return Action.run_check(self._task.available_checks[0]) | |
| # ------------------------------------------------------------------ | |
| # Internal methods | |
| # ------------------------------------------------------------------ | |
| def _dispatch(self, action: Action) -> tuple: | |
| """ | |
| Route an action to the appropriate task simulator. | |
| Returns (reward, info dict). Handles repeat-action penalties. | |
| """ | |
| params = action.params | |
| info: Dict[str, Any] = {"action_type": action.type.value} | |
| if action.type == ActionType.INSPECT_FIELD: | |
| doc = params.get("document", "") | |
| field = params.get("field", "") | |
| # Repeat penalty | |
| if self._ep.has_inspected(doc, field): | |
| info["repeat"] = True | |
| return -0.02, info | |
| result, reward = self._task.simulate_inspect(doc, field) | |
| self._ep.inspections.append(result) | |
| info["result"] = result.model_dump() | |
| return reward, info | |
| elif action.type == ActionType.CROSS_CHECK: | |
| field = params.get("field", "") | |
| doc_a = params.get("doc_a", "") | |
| doc_b = params.get("doc_b", "") | |
| check_key = f"cross_{field}_{doc_a}_{doc_b}" | |
| if self._ep.has_checked(check_key): | |
| info["repeat"] = True | |
| return -0.03, info | |
| result, reward = self._task.simulate_cross_check(field, doc_a, doc_b) | |
| self._ep.checks.append(result) | |
| info["result"] = result.model_dump() | |
| return reward, info | |
| elif action.type == ActionType.RUN_CHECK: | |
| check_name = params.get("check_name", "") | |
| if self._ep.has_checked(check_name): | |
| info["repeat"] = True | |
| return -0.03, info | |
| result, reward = self._task.simulate_run_check(check_name) | |
| self._ep.checks.append(result) | |
| info["result"] = result.model_dump() | |
| return reward, info | |
| elif action.type == ActionType.QUERY_SUPPLIER: | |
| question = params.get("question", "") | |
| channel = params.get("channel", "email") | |
| if self._ep.has_queried("supplier"): | |
| info["repeat"] = True | |
| return -0.05, info | |
| result, reward = self._task.simulate_query_supplier(question, channel) | |
| self._ep.queries.append(result) | |
| info["result"] = result.model_dump() | |
| return reward, info | |
| elif action.type == ActionType.QUERY_INTERNAL: | |
| department = params.get("department", "") | |
| question = params.get("question", "") | |
| if self._ep.has_queried(department.lower()): | |
| info["repeat"] = True | |
| return -0.03, info | |
| result, reward = self._task.simulate_query_internal(department, question) | |
| self._ep.queries.append(result) | |
| info["result"] = result.model_dump() | |
| return reward, info | |
| elif action.type == ActionType.APPLY_RULE: | |
| rule_id = params.get("rule_id", "") | |
| if rule_id in self._ep.rules_applied: | |
| info["repeat"] = True | |
| return -0.03, info | |
| detail, reward = self._task.simulate_apply_rule(rule_id) | |
| self._ep.rules_applied.append(rule_id) | |
| info["detail"] = detail | |
| return reward, info | |
| elif action.type == ActionType.MAKE_DECISION: | |
| decision = params.get("decision", "") | |
| reason = params.get("reason", "") | |
| if self._ep.decision is not None: | |
| info["repeat"] = True | |
| return -0.05, info | |
| reward = self._task.simulate_make_decision(decision, reason, self._ep) | |
| self._ep.decision = decision | |
| self._ep.decision_reason = reason | |
| info["decision"] = decision | |
| return reward, info | |
| elif action.type == ActionType.ROUTE_TO: | |
| team = params.get("team", "") | |
| notes = params.get("notes", "") | |
| if team.lower() in self._ep.routed_to: | |
| info["repeat"] = True | |
| return -0.02, info | |
| reward = self._task.simulate_route_to(team, notes, self._ep) | |
| self._ep.routed_to.append(team.lower()) | |
| info["routed_to"] = team | |
| return reward, info | |
| elif action.type == ActionType.CLOSE_CASE: | |
| summary = params.get("summary", "") | |
| if self._ep.closed: | |
| info["repeat"] = True | |
| return -0.05, info | |
| reward = self._task.simulate_close(summary, self._ep) | |
| self._ep.closed = True | |
| self._ep.close_summary = summary | |
| info["closed"] = True | |
| return reward, info | |
| # Unknown action type | |
| return 0.0, {"error": f"Unknown action type: {action.type}"} | |
| def _build_state(self) -> EnvironmentState: | |
| """Construct an EnvironmentState from current task and episode data.""" | |
| # Determine case status | |
| if self._ep.closed: | |
| status = CaseStatus.CLOSED | |
| elif self._ep.routed_to: | |
| status = CaseStatus.ROUTED | |
| elif self._ep.decision is not None: | |
| status = CaseStatus.DECIDED | |
| elif self._ep.step_count > 0: | |
| status = CaseStatus.IN_REVIEW | |
| else: | |
| status = CaseStatus.OPEN | |
| return EnvironmentState( | |
| task_id=self._task.task_id, | |
| step_number=self._ep.step_count, | |
| case_status=status, | |
| purchase_order=self._task.get_purchase_order(), | |
| invoice=self._task.get_invoice(), | |
| grn=self._task.get_grn(), | |
| supplier_master=self._task.get_supplier_master(), | |
| exception_flag=self._task.get_exception_flag(), | |
| inspections=list(self._ep.inspections), | |
| checks_run=list(self._ep.checks), | |
| queries=list(self._ep.queries), | |
| rules_applied=list(self._ep.rules_applied), | |
| decision=self._ep.decision, | |
| decision_reason=self._ep.decision_reason, | |
| routed_to=list(self._ep.routed_to), | |
| case_closed=self._ep.closed, | |
| close_summary=self._ep.close_summary, | |
| available_actions=[at.value for at in ActionType], | |
| available_checks=self._task.available_checks, | |
| available_rules=self._task.available_rules, | |
| knowledge_base=self._task.knowledge_base, | |
| cumulative_reward=round(self._ep.cumulative_reward, 4), | |
| ) | |