import json import random from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from .graders.category_grader import grade_classification, grade_information_collection from .graders.faq_grader import ( grade_escalation, grade_faq_retrieval, grade_operation_choice, ) from .graders.resolution_grader import grade_case_closure, grade_resolution from .graders.score_utils import ensure_open_unit_interval from .models import Action, Observation, Reward, TicketState from .user_simulator import UserSimulator def _data_dir() -> Path: return Path(__file__).resolve().parent / "data" class HelpdeskEnv: def __init__(self): data_dir = _data_dir() tickets_dir = data_dir / "tickets" with open(data_dir / "knowledge_base.json", "r", encoding="utf-8") as f: self.kb: List[Dict[str, str]] = json.load(f) with open(tickets_dir / "easy.json", "r", encoding="utf-8") as f: self.easy_tickets: List[Dict[str, Any]] = json.load(f) with open(tickets_dir / "medium.json", "r", encoding="utf-8") as f: self.medium_tickets: List[Dict[str, Any]] = json.load(f) with open(tickets_dir / "hard.json", "r", encoding="utf-8") as f: self.hard_tickets: List[Dict[str, Any]] = json.load(f) self.current_ticket: Optional[Dict[str, Any]] = None self.ticket_state: Optional[TicketState] = None self.user_sim: Optional[UserSimulator] = None self.task_id: str = "easy" self.turn_number: int = 0 self.conversation_history: List[Dict[str, str]] = [] self.action_history: List[str] = [] def reset(self, task_id: str = "easy") -> Observation: pool_map = { "easy": self.easy_tickets, "medium": self.medium_tickets, "hard": self.hard_tickets, } if task_id not in pool_map: raise ValueError("task_id must be one of: easy, medium, hard") self.task_id = task_id self.current_ticket = random.choice(pool_map[task_id]) self.ticket_state = TicketState( ticket_id=self.current_ticket["id"], track=self._infer_track(self.current_ticket), required_slots=self._required_slots(self.current_ticket, task_id), ) self.user_sim = UserSimulator(self.current_ticket) if task_id == "hard" else None self.turn_number = 0 self.conversation_history = [] self.action_history = [] return self.state() def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: if self.current_ticket is None or self.ticket_state is None: raise RuntimeError("Environment not initialized. Call reset() first.") canonical_action = self._canonicalize_action(action) self.turn_number += 1 self.ticket_state.turns_used += 1 self.action_history.append(canonical_action.action_type) self._track_collected_slots(canonical_action) action_content = ( canonical_action.message or canonical_action.operation or canonical_action.target or canonical_action.action_type ) self.conversation_history.append({"role": "agent", "content": action_content}) done = False metrics: Dict[str, float] = { "correctness": 0.0, "safety": 1.0, "resolution": 0.0, "efficiency": 0.0, "penalties": 0.0, } info: Dict[str, Any] = { "action_type": canonical_action.action_type, "operation": canonical_action.operation, "target": canonical_action.target, } if canonical_action.action_type == "ask_for_details": metrics["correctness"] = self._grade_detail_request(canonical_action) if self.task_id == "hard" and self.user_sim is not None: user_response = self.user_sim.respond(canonical_action.message or "") self.conversation_history.append({"role": "user", "content": user_response}) self.ticket_state.clarification_received = self.user_sim.clarification_given info["user_response"] = user_response elif canonical_action.action_type == "take_action": correctness, resolved = self._grade_take_action(canonical_action) metrics["correctness"] = correctness self.ticket_state.issue_resolved = resolved if resolved: metrics["resolution"] = grade_resolution(self.ticket_state) done = True elif canonical_action.action_type == "respond_to_user": metrics["correctness"] = self._grade_response(canonical_action) if self.task_id == "hard" and self.user_sim is not None: user_response = self.user_sim.respond(canonical_action.message or "") self.conversation_history.append({"role": "user", "content": user_response}) self.ticket_state.issue_resolved = self.user_sim.confirm_resolved() info["user_response"] = user_response elif canonical_action.action_type == "escalate_case": metrics["correctness"] = grade_escalation( True, bool(self.current_ticket.get("should_escalate", False)), ) self.ticket_state.escalated = True metrics["resolution"] = metrics["correctness"] info["escalation_accuracy"] = metrics["correctness"] done = True elif canonical_action.action_type == "close_case": if self.task_id == "hard" and self.user_sim is not None: self.ticket_state.issue_resolved = self.user_sim.confirm_resolved() metrics["resolution"] = grade_case_closure(self.ticket_state) if metrics["resolution"] == 0.0 and not self.ticket_state.escalated: metrics["penalties"] -= 0.20 done = True metrics["safety"] = self._grade_safety(canonical_action, metrics) metrics["efficiency"] = self._grade_efficiency(done) reward = self._calculate_reward(metrics, done=done) info.update( { "ticket_id": self.ticket_state.ticket_id, "task_id": self.task_id, "track": self.ticket_state.track, "turn_number": self.turn_number, } ) return self.state(), reward, done, info def _canonicalize_action(self, action: Action) -> Action: if action.action_type in { "ask_for_details", "take_action", "respond_to_user", "escalate_case", "close_case", }: return action if action.action_type == "classify": return Action( action_type="take_action", operation="classify_issue", category=action.category, message=action.message, ) if action.action_type == "lookup_faq": return Action( action_type="take_action", operation="lookup_faq", faq_id=action.faq_id, message=action.message, ) if action.action_type == "ask_clarification": return Action( action_type="ask_for_details", fields_requested=["issue_details"], message=action.message, ) if action.action_type == "reply": return Action( action_type="respond_to_user", message=action.message, ) if action.action_type == "escalate": return Action( action_type="escalate_case", target="human_agent", message=action.message, ) if action.action_type == "resolve_ticket": return Action( action_type="close_case", operation="resolve_with_guidance", message=action.message, ) raise ValueError(f"Unsupported action type: {action.action_type}") def _infer_track(self, ticket: Dict[str, Any]) -> str: category = ( ticket.get("issue_category") or ticket.get("gold_category") or ticket.get("difficulty") or self.task_id ) return str(category).strip().lower().replace(" ", "_") def _required_slots(self, ticket: Dict[str, Any], task_id: str) -> List[str]: if task_id == "easy": return ["issue_category"] if task_id == "medium": return ["faq_or_escalation_decision"] return ["issue_details", "resolution_confirmation"] def _track_collected_slots(self, action: Action) -> None: if self.ticket_state is None: return for field_name in action.fields_requested: self.ticket_state.collected_slots[field_name] = "requested" if action.operation: self.ticket_state.collected_slots["last_operation"] = action.operation if action.target: self.ticket_state.collected_slots["escalation_target"] = action.target def _grade_detail_request(self, action: Action) -> float: if self.ticket_state is None: return ensure_open_unit_interval(0.0) if not action.fields_requested and not action.message: return ensure_open_unit_interval(0.0) if not self.ticket_state.required_slots: return ensure_open_unit_interval(0.5) info_score = grade_information_collection( action.fields_requested, self.ticket_state.required_slots, ) if self.task_id != "hard" and info_score <= 0.001: return ensure_open_unit_interval(0.5) return ensure_open_unit_interval(info_score) def _grade_take_action(self, action: Action) -> Tuple[float, bool]: if self.current_ticket is None: return ensure_open_unit_interval(0.0), False operation = (action.operation or "").strip().lower() if operation == "classify_issue": gold_category = self.current_ticket.get("gold_category", "") score = grade_classification(action.category or "", gold_category) resolved = (action.category or "").strip().lower() == str(gold_category).strip().lower() return score, resolved if operation == "lookup_faq": gold_faq_id = self.current_ticket.get("gold_faq_id", "") score = grade_faq_retrieval(action.faq_id or "", gold_faq_id) if self.ticket_state is not None and (action.faq_id or "").strip() == str(gold_faq_id).strip(): self.ticket_state.correct_faq_retrieved = True return score, False if operation == "resolve_with_guidance": resolved = bool( self.ticket_state and self.ticket_state.correct_faq_retrieved and (self.task_id != "hard" or self.ticket_state.clarification_received) ) return ensure_open_unit_interval(1.0 if resolved else 0.0), resolved if operation == "check_status": return ensure_open_unit_interval(0.5), False banking_operations = { "check_payment", "check_refund", "check_kyc", "secure_account", "troubleshoot_upi", } op_score = grade_operation_choice(operation, banking_operations) return op_score, False def _grade_response(self, action: Action) -> float: if not action.message: return ensure_open_unit_interval(0.0) if self.task_id == "hard" and self.ticket_state and self.ticket_state.correct_faq_retrieved: return ensure_open_unit_interval(1.0) return ensure_open_unit_interval(0.5) def _grade_safety(self, action: Action, metrics: Dict[str, float]) -> float: text = (action.message or "").lower() sensitive_markers = ["otp", "pin", "cvv", "password"] if any(marker in text for marker in sensitive_markers): metrics["penalties"] -= 0.50 return ensure_open_unit_interval(0.0) if action.action_type == "close_case" and metrics["resolution"] <= 0.001: return ensure_open_unit_interval(0.25) if action.action_type == "escalate_case": expected = bool(self.current_ticket.get("should_escalate", False)) return ensure_open_unit_interval(1.0 if expected else 0.6) return ensure_open_unit_interval(1.0) def _grade_efficiency(self, done: bool) -> float: max_turns = 1 if self.task_id == "easy" else 2 if self.task_id == "medium" else 6 if not done: remaining_ratio = max(0.0, 1.0 - (self.turn_number / max_turns)) return ensure_open_unit_interval(round(0.5 * remaining_ratio, 3)) return ensure_open_unit_interval(1.0 - (0.1 * max(0, self.turn_number - 1))) def _calculate_reward(self, metrics: Dict[str, float], done: bool) -> Reward: correctness = ensure_open_unit_interval(metrics.get("correctness", 0.0)) safety = ensure_open_unit_interval(metrics.get("safety", 0.0)) resolution = ensure_open_unit_interval(metrics.get("resolution", 0.0)) efficiency = ensure_open_unit_interval(metrics.get("efficiency", 0.0)) penalties = metrics.get("penalties", 0.0) weighted = ( (0.35 * correctness) + (0.30 * safety) + (0.20 * resolution) + (0.15 * efficiency) ) recent_actions = self.action_history[-3:] if len(recent_actions) >= 2 and len(set(recent_actions)) < len(recent_actions): penalties -= 0.05 final_value = ensure_open_unit_interval(weighted + penalties) return Reward( value=final_value, correctness=correctness, safety=safety, resolution=resolution, efficiency=efficiency, penalties=penalties, done=done, info={ "turn_number": self.turn_number, "task_id": self.task_id, "escalation_accuracy": ensure_open_unit_interval( metrics.get("escalation_accuracy", correctness) ), }, ) def _build_known_facts(self) -> Dict[str, Any]: if self.current_ticket is None or self.ticket_state is None: return {} facts = { "difficulty": self.current_ticket.get("difficulty", self.task_id), "knowledge_base": self.kb, "available_categories": [ "payment_failure", "refund_delay", "fraud_complaint", "kyc_account_restriction", "upi_pin_or_bank_linking", ], "clarification_received": self.ticket_state.clarification_received, "faq_retrieved": self.ticket_state.correct_faq_retrieved, "issue_resolved": self.ticket_state.issue_resolved, "collected_slots": self.ticket_state.collected_slots, } return facts def state(self) -> Observation: if self.current_ticket is None or self.ticket_state is None: raise RuntimeError("Environment not initialized. Call reset() first.") customer_message = self.current_ticket.get("text") or self.current_ticket.get( "initial_text", "" ) return Observation( case_id=self.current_ticket["id"], track=self.task_id, customer_message=customer_message, conversation_history=self.conversation_history, known_facts=self._build_known_facts(), required_slots=self.ticket_state.required_slots, available_actions=[ "ask_for_details", "take_action", "respond_to_user", "escalate_case", "close_case", ], turn_number=self.turn_number, )