"""Train a simple tabular Q-learning agent against the local SupportDesk env. This is an extra playground script for local experimentation. It is not part of the hackathon submission baseline and intentionally uses a compact, hand-built discrete action library so that plain Python Q-learning can train quickly. """ from __future__ import annotations import argparse import random import sys from dataclasses import dataclass from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from supportdesk_env import ( SupportDeskAction, get_task, grade_case, list_task_ids, ) from supportdesk_env.policies import default_note, default_reply from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment @dataclass class EvalResult: """Compact report for a greedy evaluation episode.""" task_id: str score: float reward: float steps: int actions: list[str] def build_action_library(task_id: str) -> list[SupportDeskAction]: """Return a small discrete action set for a task.""" task = get_task(task_id) wrong_queue = next(queue for queue in ("general_support", "billing_ops", "trust_and_safety", "platform_engineering") if queue != task.gold_queue) wrong_priority = next(priority for priority in ("low", "normal", "high", "urgent") if priority != task.gold_priority) wrong_issue = next(issue for issue in ("general_question", "duplicate_charge", "account_compromise", "production_incident") if issue != task.gold_issue_type) partial_fields = list(task.required_requested_fields[:1]) if not partial_fields: partial_fields = ["billing_email"] if task.required_requested_fields: good_request = SupportDeskAction( operation="request_info", requested_fields=list(task.required_requested_fields), status=task.gold_status, reply=default_reply(task_id), ) else: good_request = SupportDeskAction( operation="request_info", requested_fields=["billing_email"], status="waiting_on_customer", reply="Please confirm the billing email on the account so we can continue.", ) partial_request = SupportDeskAction( operation="request_info", requested_fields=partial_fields, status="waiting_on_customer", reply="Please share more details so we can investigate.", ) return [ SupportDeskAction( operation="classify", queue=task.gold_queue, priority=task.gold_priority, issue_type=task.gold_issue_type, ), SupportDeskAction( operation="classify", queue=wrong_queue, priority=wrong_priority, issue_type=wrong_issue, ), good_request, partial_request, SupportDeskAction(operation="draft_reply", reply=default_reply(task_id)), SupportDeskAction(operation="draft_reply", reply="Thanks for reaching out. We are checking this now."), SupportDeskAction(operation="add_internal_note", internal_note=default_note(task_id)), SupportDeskAction(operation="add_internal_note", internal_note="Customer contacted support with a problem."), SupportDeskAction( operation="submit", status=task.gold_status, resolution_code=task.gold_resolution_code, ), SupportDeskAction( operation="submit", status="resolved", resolution_code="closed_generic", ), ] def state_key(task_id: str, observation) -> tuple: """Compress the observation into a tabular Q-learning state.""" case = observation.case return ( task_id, case.queue or "_", case.priority or "_", case.issue_type or "_", case.status, case.resolution_code or "_", tuple(case.requested_fields), bool(case.reply), bool(case.internal_note), observation.remaining_steps, ) def action_label(action: SupportDeskAction) -> str: """Human-readable action label for debug output.""" parts = [action.operation] if action.queue: parts.append(action.queue) if action.priority: parts.append(action.priority) if action.issue_type: parts.append(action.issue_type) if action.status: parts.append(action.status) if action.resolution_code: parts.append(action.resolution_code) if action.requested_fields: parts.append(",".join(action.requested_fields)) if action.reply: parts.append("reply") if action.internal_note: parts.append("note") return " | ".join(parts) def choose_action(q_values: dict[tuple, list[float]], state: tuple, num_actions: int, epsilon: float) -> int: """Epsilon-greedy action selection.""" if state not in q_values: q_values[state] = [0.0] * num_actions if random.random() < epsilon: return random.randrange(num_actions) best_value = max(q_values[state]) best_indices = [index for index, value in enumerate(q_values[state]) if value == best_value] return random.choice(best_indices) def train_q_agent( episodes_per_task: int, alpha: float, gamma: float, epsilon: float, epsilon_decay: float, min_epsilon: float, seed: int, ) -> tuple[dict[tuple, list[float]], dict[str, list[SupportDeskAction]]]: """Train a small tabular Q-learning agent over all tasks.""" random.seed(seed) q_values: dict[tuple, list[float]] = {} action_libraries = {task_id: build_action_library(task_id) for task_id in list_task_ids()} for _ in range(episodes_per_task): for task_id in list_task_ids(): env = SupportDeskEnvironment(task_id=task_id) observation = env.reset() actions = action_libraries[task_id] try: while not observation.done: state = state_key(task_id, observation) action_index = choose_action(q_values, state, len(actions), epsilon) next_observation = env.step(actions[action_index]) next_state = state_key(task_id, next_observation) if next_state not in q_values: q_values[next_state] = [0.0] * len(actions) td_target = next_observation.reward + gamma * (0.0 if next_observation.done else max(q_values[next_state])) td_error = td_target - q_values[state][action_index] q_values[state][action_index] += alpha * td_error observation = next_observation finally: env.close() epsilon = max(min_epsilon, epsilon * epsilon_decay) return q_values, action_libraries def evaluate_policy( q_values: dict[tuple, list[float]], action_libraries: dict[str, list[SupportDeskAction]], ) -> list[EvalResult]: """Run a greedy evaluation episode for each task.""" results: list[EvalResult] = [] for task_id in list_task_ids(): env = SupportDeskEnvironment(task_id=task_id) observation = env.reset() actions = action_libraries[task_id] chosen_actions: list[str] = [] try: while not observation.done: state = state_key(task_id, observation) q_values.setdefault(state, [0.0] * len(actions)) action_index = max(range(len(actions)), key=lambda idx: q_values[state][idx]) action = actions[action_index] chosen_actions.append(action_label(action)) observation = env.step(action) results.append( EvalResult( task_id=task_id, score=grade_case(get_task(task_id), env.state.case).total_score, reward=env.state.reward, steps=env.state.step_count, actions=chosen_actions, ) ) finally: env.close() return results def main() -> None: parser = argparse.ArgumentParser(description="Train a simple tabular Q-learning agent on SupportDesk.") parser.add_argument("--episodes-per-task", type=int, default=250) parser.add_argument("--alpha", type=float, default=0.45) parser.add_argument("--gamma", type=float, default=0.92) parser.add_argument("--epsilon", type=float, default=0.35) parser.add_argument("--epsilon-decay", type=float, default=0.99) parser.add_argument("--min-epsilon", type=float, default=0.03) parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() q_values, action_libraries = train_q_agent( episodes_per_task=args.episodes_per_task, alpha=args.alpha, gamma=args.gamma, epsilon=args.epsilon, epsilon_decay=args.epsilon_decay, min_epsilon=args.min_epsilon, seed=args.seed, ) results = evaluate_policy(q_values, action_libraries) average_score = sum(result.score for result in results) / len(results) print("Tabular Q-learning evaluation") print("============================") for result in results: print( f"{result.task_id}: score={result.score:.2f} reward={result.reward:.2f} " f"steps={result.steps}" ) print(f" actions: {' -> '.join(result.actions)}") print(f"average_score={average_score:.3f}") if __name__ == "__main__": main()