| import json |
| import random |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from .graders.category_grader import grade_classification, grade_information_collection |
| from .graders.faq_grader import ( |
| grade_escalation, |
| grade_faq_retrieval, |
| grade_operation_choice, |
| ) |
| from .graders.resolution_grader import grade_case_closure, grade_resolution |
| from .graders.score_utils import ensure_open_unit_interval |
| from .models import Action, Observation, Reward, TicketState |
| from .user_simulator import UserSimulator |
|
|
|
|
| def _data_dir() -> Path: |
| return Path(__file__).resolve().parent / "data" |
|
|
|
|
| class HelpdeskEnv: |
| def __init__(self): |
| data_dir = _data_dir() |
| tickets_dir = data_dir / "tickets" |
|
|
| with open(data_dir / "knowledge_base.json", "r", encoding="utf-8") as f: |
| self.kb: List[Dict[str, str]] = json.load(f) |
| with open(tickets_dir / "easy.json", "r", encoding="utf-8") as f: |
| self.easy_tickets: List[Dict[str, Any]] = json.load(f) |
| with open(tickets_dir / "medium.json", "r", encoding="utf-8") as f: |
| self.medium_tickets: List[Dict[str, Any]] = json.load(f) |
| with open(tickets_dir / "hard.json", "r", encoding="utf-8") as f: |
| self.hard_tickets: List[Dict[str, Any]] = json.load(f) |
|
|
| self.current_ticket: Optional[Dict[str, Any]] = None |
| self.ticket_state: Optional[TicketState] = None |
| self.user_sim: Optional[UserSimulator] = None |
| self.task_id: str = "easy" |
| self.turn_number: int = 0 |
| self.conversation_history: List[Dict[str, str]] = [] |
| self.action_history: List[str] = [] |
|
|
| def reset(self, task_id: str = "easy") -> Observation: |
| pool_map = { |
| "easy": self.easy_tickets, |
| "medium": self.medium_tickets, |
| "hard": self.hard_tickets, |
| } |
| if task_id not in pool_map: |
| raise ValueError("task_id must be one of: easy, medium, hard") |
|
|
| self.task_id = task_id |
| self.current_ticket = random.choice(pool_map[task_id]) |
| self.ticket_state = TicketState( |
| ticket_id=self.current_ticket["id"], |
| track=self._infer_track(self.current_ticket), |
| required_slots=self._required_slots(self.current_ticket, task_id), |
| ) |
| self.user_sim = UserSimulator(self.current_ticket) if task_id == "hard" else None |
| self.turn_number = 0 |
| self.conversation_history = [] |
| self.action_history = [] |
|
|
| return self.state() |
|
|
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: |
| if self.current_ticket is None or self.ticket_state is None: |
| raise RuntimeError("Environment not initialized. Call reset() first.") |
|
|
| canonical_action = self._canonicalize_action(action) |
| self.turn_number += 1 |
| self.ticket_state.turns_used += 1 |
| self.action_history.append(canonical_action.action_type) |
| self._track_collected_slots(canonical_action) |
|
|
| action_content = ( |
| canonical_action.message |
| or canonical_action.operation |
| or canonical_action.target |
| or canonical_action.action_type |
| ) |
| self.conversation_history.append({"role": "agent", "content": action_content}) |
|
|
| done = False |
| metrics: Dict[str, float] = { |
| "correctness": 0.0, |
| "safety": 1.0, |
| "resolution": 0.0, |
| "efficiency": 0.0, |
| "penalties": 0.0, |
| } |
| info: Dict[str, Any] = { |
| "action_type": canonical_action.action_type, |
| "operation": canonical_action.operation, |
| "target": canonical_action.target, |
| } |
|
|
| if canonical_action.action_type == "ask_for_details": |
| metrics["correctness"] = self._grade_detail_request(canonical_action) |
| if self.task_id == "hard" and self.user_sim is not None: |
| user_response = self.user_sim.respond(canonical_action.message or "") |
| self.conversation_history.append({"role": "user", "content": user_response}) |
| self.ticket_state.clarification_received = self.user_sim.clarification_given |
| info["user_response"] = user_response |
|
|
| elif canonical_action.action_type == "take_action": |
| correctness, resolved = self._grade_take_action(canonical_action) |
| metrics["correctness"] = correctness |
| self.ticket_state.issue_resolved = resolved |
| if resolved: |
| metrics["resolution"] = grade_resolution(self.ticket_state) |
| done = True |
|
|
| elif canonical_action.action_type == "respond_to_user": |
| metrics["correctness"] = self._grade_response(canonical_action) |
| if self.task_id == "hard" and self.user_sim is not None: |
| user_response = self.user_sim.respond(canonical_action.message or "") |
| self.conversation_history.append({"role": "user", "content": user_response}) |
| self.ticket_state.issue_resolved = self.user_sim.confirm_resolved() |
| info["user_response"] = user_response |
|
|
| elif canonical_action.action_type == "escalate_case": |
| metrics["correctness"] = grade_escalation( |
| True, |
| bool(self.current_ticket.get("should_escalate", False)), |
| ) |
| self.ticket_state.escalated = True |
| metrics["resolution"] = metrics["correctness"] |
| info["escalation_accuracy"] = metrics["correctness"] |
| done = True |
|
|
| elif canonical_action.action_type == "close_case": |
| if self.task_id == "hard" and self.user_sim is not None: |
| self.ticket_state.issue_resolved = self.user_sim.confirm_resolved() |
| metrics["resolution"] = grade_case_closure(self.ticket_state) |
| if metrics["resolution"] == 0.0 and not self.ticket_state.escalated: |
| metrics["penalties"] -= 0.20 |
| done = True |
|
|
| metrics["safety"] = self._grade_safety(canonical_action, metrics) |
| metrics["efficiency"] = self._grade_efficiency(done) |
|
|
| reward = self._calculate_reward(metrics, done=done) |
| info.update( |
| { |
| "ticket_id": self.ticket_state.ticket_id, |
| "task_id": self.task_id, |
| "track": self.ticket_state.track, |
| "turn_number": self.turn_number, |
| } |
| ) |
| return self.state(), reward, done, info |
|
|
| def _canonicalize_action(self, action: Action) -> Action: |
| if action.action_type in { |
| "ask_for_details", |
| "take_action", |
| "respond_to_user", |
| "escalate_case", |
| "close_case", |
| }: |
| return action |
|
|
| if action.action_type == "classify": |
| return Action( |
| action_type="take_action", |
| operation="classify_issue", |
| category=action.category, |
| message=action.message, |
| ) |
|
|
| if action.action_type == "lookup_faq": |
| return Action( |
| action_type="take_action", |
| operation="lookup_faq", |
| faq_id=action.faq_id, |
| message=action.message, |
| ) |
|
|
| if action.action_type == "ask_clarification": |
| return Action( |
| action_type="ask_for_details", |
| fields_requested=["issue_details"], |
| message=action.message, |
| ) |
|
|
| if action.action_type == "reply": |
| return Action( |
| action_type="respond_to_user", |
| message=action.message, |
| ) |
|
|
| if action.action_type == "escalate": |
| return Action( |
| action_type="escalate_case", |
| target="human_agent", |
| message=action.message, |
| ) |
|
|
| if action.action_type == "resolve_ticket": |
| return Action( |
| action_type="close_case", |
| operation="resolve_with_guidance", |
| message=action.message, |
| ) |
|
|
| raise ValueError(f"Unsupported action type: {action.action_type}") |
|
|
| def _infer_track(self, ticket: Dict[str, Any]) -> str: |
| category = ( |
| ticket.get("issue_category") |
| or ticket.get("gold_category") |
| or ticket.get("difficulty") |
| or self.task_id |
| ) |
| return str(category).strip().lower().replace(" ", "_") |
|
|
| def _required_slots(self, ticket: Dict[str, Any], task_id: str) -> List[str]: |
| if task_id == "easy": |
| return ["issue_category"] |
| if task_id == "medium": |
| return ["faq_or_escalation_decision"] |
| return ["issue_details", "resolution_confirmation"] |
|
|
| def _track_collected_slots(self, action: Action) -> None: |
| if self.ticket_state is None: |
| return |
|
|
| for field_name in action.fields_requested: |
| self.ticket_state.collected_slots[field_name] = "requested" |
|
|
| if action.operation: |
| self.ticket_state.collected_slots["last_operation"] = action.operation |
| if action.target: |
| self.ticket_state.collected_slots["escalation_target"] = action.target |
|
|
| def _grade_detail_request(self, action: Action) -> float: |
| if self.ticket_state is None: |
| return ensure_open_unit_interval(0.0) |
| if not action.fields_requested and not action.message: |
| return ensure_open_unit_interval(0.0) |
| if not self.ticket_state.required_slots: |
| return ensure_open_unit_interval(0.5) |
| info_score = grade_information_collection( |
| action.fields_requested, |
| self.ticket_state.required_slots, |
| ) |
| if self.task_id != "hard" and info_score <= 0.001: |
| return ensure_open_unit_interval(0.5) |
| return ensure_open_unit_interval(info_score) |
|
|
| def _grade_take_action(self, action: Action) -> Tuple[float, bool]: |
| if self.current_ticket is None: |
| return ensure_open_unit_interval(0.0), False |
|
|
| operation = (action.operation or "").strip().lower() |
|
|
| if operation == "classify_issue": |
| gold_category = self.current_ticket.get("gold_category", "") |
| score = grade_classification(action.category or "", gold_category) |
| resolved = (action.category or "").strip().lower() == str(gold_category).strip().lower() |
| return score, resolved |
|
|
| if operation == "lookup_faq": |
| gold_faq_id = self.current_ticket.get("gold_faq_id", "") |
| score = grade_faq_retrieval(action.faq_id or "", gold_faq_id) |
| if self.ticket_state is not None and (action.faq_id or "").strip() == str(gold_faq_id).strip(): |
| self.ticket_state.correct_faq_retrieved = True |
| return score, False |
|
|
| if operation == "resolve_with_guidance": |
| resolved = bool( |
| self.ticket_state |
| and self.ticket_state.correct_faq_retrieved |
| and (self.task_id != "hard" or self.ticket_state.clarification_received) |
| ) |
| return ensure_open_unit_interval(1.0 if resolved else 0.0), resolved |
|
|
| if operation == "check_status": |
| return ensure_open_unit_interval(0.5), False |
|
|
| banking_operations = { |
| "check_payment", |
| "check_refund", |
| "check_kyc", |
| "secure_account", |
| "troubleshoot_upi", |
| } |
| op_score = grade_operation_choice(operation, banking_operations) |
| return op_score, False |
|
|
| def _grade_response(self, action: Action) -> float: |
| if not action.message: |
| return ensure_open_unit_interval(0.0) |
| if self.task_id == "hard" and self.ticket_state and self.ticket_state.correct_faq_retrieved: |
| return ensure_open_unit_interval(1.0) |
| return ensure_open_unit_interval(0.5) |
|
|
| def _grade_safety(self, action: Action, metrics: Dict[str, float]) -> float: |
| text = (action.message or "").lower() |
| sensitive_markers = ["otp", "pin", "cvv", "password"] |
| if any(marker in text for marker in sensitive_markers): |
| metrics["penalties"] -= 0.50 |
| return ensure_open_unit_interval(0.0) |
|
|
| if action.action_type == "close_case" and metrics["resolution"] <= 0.001: |
| return ensure_open_unit_interval(0.25) |
|
|
| if action.action_type == "escalate_case": |
| expected = bool(self.current_ticket.get("should_escalate", False)) |
| return ensure_open_unit_interval(1.0 if expected else 0.6) |
|
|
| return ensure_open_unit_interval(1.0) |
|
|
| def _grade_efficiency(self, done: bool) -> float: |
| max_turns = 1 if self.task_id == "easy" else 2 if self.task_id == "medium" else 6 |
| if not done: |
| remaining_ratio = max(0.0, 1.0 - (self.turn_number / max_turns)) |
| return ensure_open_unit_interval(round(0.5 * remaining_ratio, 3)) |
| return ensure_open_unit_interval(1.0 - (0.1 * max(0, self.turn_number - 1))) |
|
|
| def _calculate_reward(self, metrics: Dict[str, float], done: bool) -> Reward: |
| correctness = ensure_open_unit_interval(metrics.get("correctness", 0.0)) |
| safety = ensure_open_unit_interval(metrics.get("safety", 0.0)) |
| resolution = ensure_open_unit_interval(metrics.get("resolution", 0.0)) |
| efficiency = ensure_open_unit_interval(metrics.get("efficiency", 0.0)) |
| penalties = metrics.get("penalties", 0.0) |
|
|
| weighted = ( |
| (0.35 * correctness) |
| + (0.30 * safety) |
| + (0.20 * resolution) |
| + (0.15 * efficiency) |
| ) |
|
|
| recent_actions = self.action_history[-3:] |
| if len(recent_actions) >= 2 and len(set(recent_actions)) < len(recent_actions): |
| penalties -= 0.05 |
|
|
| final_value = ensure_open_unit_interval(weighted + penalties) |
| return Reward( |
| value=final_value, |
| correctness=correctness, |
| safety=safety, |
| resolution=resolution, |
| efficiency=efficiency, |
| penalties=penalties, |
| done=done, |
| info={ |
| "turn_number": self.turn_number, |
| "task_id": self.task_id, |
| "escalation_accuracy": ensure_open_unit_interval( |
| metrics.get("escalation_accuracy", correctness) |
| ), |
| }, |
| ) |
|
|
| def _build_known_facts(self) -> Dict[str, Any]: |
| if self.current_ticket is None or self.ticket_state is None: |
| return {} |
|
|
| facts = { |
| "difficulty": self.current_ticket.get("difficulty", self.task_id), |
| "knowledge_base": self.kb, |
| "available_categories": [ |
| "payment_failure", |
| "refund_delay", |
| "fraud_complaint", |
| "kyc_account_restriction", |
| "upi_pin_or_bank_linking", |
| ], |
| "clarification_received": self.ticket_state.clarification_received, |
| "faq_retrieved": self.ticket_state.correct_faq_retrieved, |
| "issue_resolved": self.ticket_state.issue_resolved, |
| "collected_slots": self.ticket_state.collected_slots, |
| } |
| return facts |
|
|
| def state(self) -> Observation: |
| if self.current_ticket is None or self.ticket_state is None: |
| raise RuntimeError("Environment not initialized. Call reset() first.") |
|
|
| customer_message = self.current_ticket.get("text") or self.current_ticket.get( |
| "initial_text", "" |
| ) |
| return Observation( |
| case_id=self.current_ticket["id"], |
| track=self.task_id, |
| customer_message=customer_message, |
| conversation_history=self.conversation_history, |
| known_facts=self._build_known_facts(), |
| required_slots=self.ticket_state.required_slots, |
| available_actions=[ |
| "ask_for_details", |
| "take_action", |
| "respond_to_user", |
| "escalate_case", |
| "close_case", |
| ], |
| turn_number=self.turn_number, |
| ) |
|
|