invoice-exception-handler / env /environment.py
YUS200619's picture
feat: complete invoice exception handler v1.0.0
562f58d
"""
InvoiceExceptionEnv — the main environment class.
This is the only class external code needs to import. It wraps the task
registry, dispatches actions, manages episode state, and provides the
OpenEnv-compatible API: reset(), step(), state(), grade().
"""
from __future__ import annotations
import random
from typing import Any, Dict, List, Optional, Union
from .models import (
Action, ActionType, CaseStatus, EnvironmentState, StepResult,
)
from .tasks import ALL_TASKS, BaseTask, EpisodeData, make_task
class InvoiceExceptionEnv:
"""
OpenEnv-compatible Invoice Exception Handler environment.
Usage:
env = InvoiceExceptionEnv(seed=42)
obs = env.reset("task1_price_variance")
result = env.step(Action.run_check("tolerance_rule"))
scores = env.grade()
"""
def __init__(self, seed: Optional[int] = None) -> None:
"""Initialise with an optional seed for reproducibility."""
self._rng = random.Random(seed)
self._task: Optional[BaseTask] = None
self._ep: Optional[EpisodeData] = None
self._state_cache: Optional[EnvironmentState] = None
self._done: bool = False
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def reset(self, task_id: Optional[str] = None) -> EnvironmentState:
"""
Start a new episode. If task_id is None, picks one at random.
Returns the initial EnvironmentState showing all documents and
available actions.
"""
if task_id is None:
task_id = self._rng.choice(ALL_TASKS)
self._task = make_task(task_id)
self._ep = EpisodeData()
self._done = False
self._state_cache = self._build_state()
return self._state_cache
def step(self, action: Union[Action, Dict[str, Any]]) -> StepResult:
"""
Execute one action. Returns observation, reward, done flag, and
info dict. Raises RuntimeError if called before reset() or after
the episode is done.
"""
if self._task is None or self._ep is None:
raise RuntimeError("Call reset() before step().")
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new one.")
# Convert dict to Action if needed
if isinstance(action, dict):
action = Action(
type=ActionType(action.get("type", action.get("action_type", ""))),
params=action.get("params", {}),
)
# Dispatch the action
reward, info = self._dispatch(action)
# Update episode
self._ep.step_count += 1
self._ep.cumulative_reward += reward
# Check SLA breach
sla_penalty = 0.0
if self._ep.step_count >= self._task.max_steps:
sla_penalty = -0.10
self._done = True
info["sla_breach"] = True
# Check done conditions
if self._ep.closed:
self._done = True
total_reward = reward + sla_penalty
self._ep.cumulative_reward += sla_penalty # add SLA penalty separately
# Rebuild state
self._state_cache = self._build_state()
return StepResult(
observation=self._state_cache,
reward=round(total_reward, 4),
done=self._done,
info=info,
)
def state(self) -> EnvironmentState:
"""Return the current state without advancing the episode."""
if self._state_cache is None:
raise RuntimeError("Call reset() before state().")
return self._state_cache
def grade(self) -> Dict[str, float]:
"""Run the task grader on the current episode and return scores."""
if self._task is None or self._ep is None:
raise RuntimeError("Call reset() before grade().")
return self._task.grade(self._ep)
def action_space_sample(self) -> Action:
"""Return a random valid action for baseline/testing purposes."""
if self._task is None:
raise RuntimeError("Call reset() before action_space_sample().")
action_type = self._rng.choice(list(ActionType))
if action_type == ActionType.INSPECT_FIELD:
doc = self._rng.choice(["invoice", "po", "grn", "supplier_master"])
field = self._rng.choice(["line_items", "total_amount", "bank_account",
"supplier_gstin", "items_received"])
return Action.inspect_field(doc, field)
elif action_type == ActionType.CROSS_CHECK:
field = self._rng.choice(["unit_price", "total_amount", "bank_account",
"gstin", "quantity"])
doc_a = self._rng.choice(["invoice", "po"])
doc_b = self._rng.choice(["po", "grn", "supplier_master"])
return Action.cross_check(field, doc_a, doc_b)
elif action_type == ActionType.RUN_CHECK:
check = self._rng.choice(self._task.available_checks)
return Action.run_check(check)
elif action_type == ActionType.QUERY_SUPPLIER:
channel = self._rng.choice(["email", "phone"])
return Action.query_supplier("What is the status?", channel)
elif action_type == ActionType.QUERY_INTERNAL:
dept = self._rng.choice(["procurement", "finance", "legal", "security"])
return Action.query_internal(dept, "Can you provide information?")
elif action_type == ActionType.APPLY_RULE:
rule = self._rng.choice(self._task.available_rules)
return Action.apply_rule(rule)
elif action_type == ActionType.MAKE_DECISION:
decision = self._rng.choice(["approve", "reject", "hold", "partial_approve"])
return Action.make_decision(decision, "Random baseline decision.")
elif action_type == ActionType.ROUTE_TO:
team = self._rng.choice(["procurement", "finance", "legal", "security"])
return Action.route_to(team, "Random baseline routing.")
elif action_type == ActionType.CLOSE_CASE:
return Action.close_case("Random baseline closure.")
# Fallback
return Action.run_check(self._task.available_checks[0])
# ------------------------------------------------------------------
# Internal methods
# ------------------------------------------------------------------
def _dispatch(self, action: Action) -> tuple:
"""
Route an action to the appropriate task simulator.
Returns (reward, info dict). Handles repeat-action penalties.
"""
params = action.params
info: Dict[str, Any] = {"action_type": action.type.value}
if action.type == ActionType.INSPECT_FIELD:
doc = params.get("document", "")
field = params.get("field", "")
# Repeat penalty
if self._ep.has_inspected(doc, field):
info["repeat"] = True
return -0.02, info
result, reward = self._task.simulate_inspect(doc, field)
self._ep.inspections.append(result)
info["result"] = result.model_dump()
return reward, info
elif action.type == ActionType.CROSS_CHECK:
field = params.get("field", "")
doc_a = params.get("doc_a", "")
doc_b = params.get("doc_b", "")
check_key = f"cross_{field}_{doc_a}_{doc_b}"
if self._ep.has_checked(check_key):
info["repeat"] = True
return -0.03, info
result, reward = self._task.simulate_cross_check(field, doc_a, doc_b)
self._ep.checks.append(result)
info["result"] = result.model_dump()
return reward, info
elif action.type == ActionType.RUN_CHECK:
check_name = params.get("check_name", "")
if self._ep.has_checked(check_name):
info["repeat"] = True
return -0.03, info
result, reward = self._task.simulate_run_check(check_name)
self._ep.checks.append(result)
info["result"] = result.model_dump()
return reward, info
elif action.type == ActionType.QUERY_SUPPLIER:
question = params.get("question", "")
channel = params.get("channel", "email")
if self._ep.has_queried("supplier"):
info["repeat"] = True
return -0.05, info
result, reward = self._task.simulate_query_supplier(question, channel)
self._ep.queries.append(result)
info["result"] = result.model_dump()
return reward, info
elif action.type == ActionType.QUERY_INTERNAL:
department = params.get("department", "")
question = params.get("question", "")
if self._ep.has_queried(department.lower()):
info["repeat"] = True
return -0.03, info
result, reward = self._task.simulate_query_internal(department, question)
self._ep.queries.append(result)
info["result"] = result.model_dump()
return reward, info
elif action.type == ActionType.APPLY_RULE:
rule_id = params.get("rule_id", "")
if rule_id in self._ep.rules_applied:
info["repeat"] = True
return -0.03, info
detail, reward = self._task.simulate_apply_rule(rule_id)
self._ep.rules_applied.append(rule_id)
info["detail"] = detail
return reward, info
elif action.type == ActionType.MAKE_DECISION:
decision = params.get("decision", "")
reason = params.get("reason", "")
if self._ep.decision is not None:
info["repeat"] = True
return -0.05, info
reward = self._task.simulate_make_decision(decision, reason, self._ep)
self._ep.decision = decision
self._ep.decision_reason = reason
info["decision"] = decision
return reward, info
elif action.type == ActionType.ROUTE_TO:
team = params.get("team", "")
notes = params.get("notes", "")
if team.lower() in self._ep.routed_to:
info["repeat"] = True
return -0.02, info
reward = self._task.simulate_route_to(team, notes, self._ep)
self._ep.routed_to.append(team.lower())
info["routed_to"] = team
return reward, info
elif action.type == ActionType.CLOSE_CASE:
summary = params.get("summary", "")
if self._ep.closed:
info["repeat"] = True
return -0.05, info
reward = self._task.simulate_close(summary, self._ep)
self._ep.closed = True
self._ep.close_summary = summary
info["closed"] = True
return reward, info
# Unknown action type
return 0.0, {"error": f"Unknown action type: {action.type}"}
def _build_state(self) -> EnvironmentState:
"""Construct an EnvironmentState from current task and episode data."""
# Determine case status
if self._ep.closed:
status = CaseStatus.CLOSED
elif self._ep.routed_to:
status = CaseStatus.ROUTED
elif self._ep.decision is not None:
status = CaseStatus.DECIDED
elif self._ep.step_count > 0:
status = CaseStatus.IN_REVIEW
else:
status = CaseStatus.OPEN
return EnvironmentState(
task_id=self._task.task_id,
step_number=self._ep.step_count,
case_status=status,
purchase_order=self._task.get_purchase_order(),
invoice=self._task.get_invoice(),
grn=self._task.get_grn(),
supplier_master=self._task.get_supplier_master(),
exception_flag=self._task.get_exception_flag(),
inspections=list(self._ep.inspections),
checks_run=list(self._ep.checks),
queries=list(self._ep.queries),
rules_applied=list(self._ep.rules_applied),
decision=self._ep.decision,
decision_reason=self._ep.decision_reason,
routed_to=list(self._ep.routed_to),
case_closed=self._ep.closed,
close_summary=self._ep.close_summary,
available_actions=[at.value for at in ActionType],
available_checks=self._task.available_checks,
available_rules=self._task.available_rules,
knowledge_base=self._task.knowledge_base,
cumulative_reward=round(self._ep.cumulative_reward, 4),
)