Spaces:
Sleeping
Sleeping
| """Train a simple tabular Q-learning agent against the local SupportDesk env. | |
| This is an extra playground script for local experimentation. It is not part of | |
| the hackathon submission baseline and intentionally uses a compact, hand-built | |
| discrete action library so that plain Python Q-learning can train quickly. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import random | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| REPO_ROOT = Path(__file__).resolve().parents[2] | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from supportdesk_env import ( | |
| SupportDeskAction, | |
| get_task, | |
| grade_case, | |
| list_task_ids, | |
| ) | |
| from supportdesk_env.policies import default_note, default_reply | |
| from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment | |
| class EvalResult: | |
| """Compact report for a greedy evaluation episode.""" | |
| task_id: str | |
| score: float | |
| reward: float | |
| steps: int | |
| actions: list[str] | |
| def build_action_library(task_id: str) -> list[SupportDeskAction]: | |
| """Return a small discrete action set for a task.""" | |
| task = get_task(task_id) | |
| wrong_queue = next(queue for queue in ("general_support", "billing_ops", "trust_and_safety", "platform_engineering") if queue != task.gold_queue) | |
| wrong_priority = next(priority for priority in ("low", "normal", "high", "urgent") if priority != task.gold_priority) | |
| wrong_issue = next(issue for issue in ("general_question", "duplicate_charge", "account_compromise", "production_incident") if issue != task.gold_issue_type) | |
| partial_fields = list(task.required_requested_fields[:1]) | |
| if not partial_fields: | |
| partial_fields = ["billing_email"] | |
| if task.required_requested_fields: | |
| good_request = SupportDeskAction( | |
| operation="request_info", | |
| requested_fields=list(task.required_requested_fields), | |
| status=task.gold_status, | |
| reply=default_reply(task_id), | |
| ) | |
| else: | |
| good_request = SupportDeskAction( | |
| operation="request_info", | |
| requested_fields=["billing_email"], | |
| status="waiting_on_customer", | |
| reply="Please confirm the billing email on the account so we can continue.", | |
| ) | |
| partial_request = SupportDeskAction( | |
| operation="request_info", | |
| requested_fields=partial_fields, | |
| status="waiting_on_customer", | |
| reply="Please share more details so we can investigate.", | |
| ) | |
| return [ | |
| SupportDeskAction( | |
| operation="classify", | |
| queue=task.gold_queue, | |
| priority=task.gold_priority, | |
| issue_type=task.gold_issue_type, | |
| ), | |
| SupportDeskAction( | |
| operation="classify", | |
| queue=wrong_queue, | |
| priority=wrong_priority, | |
| issue_type=wrong_issue, | |
| ), | |
| good_request, | |
| partial_request, | |
| SupportDeskAction(operation="draft_reply", reply=default_reply(task_id)), | |
| SupportDeskAction(operation="draft_reply", reply="Thanks for reaching out. We are checking this now."), | |
| SupportDeskAction(operation="add_internal_note", internal_note=default_note(task_id)), | |
| SupportDeskAction(operation="add_internal_note", internal_note="Customer contacted support with a problem."), | |
| SupportDeskAction( | |
| operation="submit", | |
| status=task.gold_status, | |
| resolution_code=task.gold_resolution_code, | |
| ), | |
| SupportDeskAction( | |
| operation="submit", | |
| status="resolved", | |
| resolution_code="closed_generic", | |
| ), | |
| ] | |
| def state_key(task_id: str, observation) -> tuple: | |
| """Compress the observation into a tabular Q-learning state.""" | |
| case = observation.case | |
| return ( | |
| task_id, | |
| case.queue or "_", | |
| case.priority or "_", | |
| case.issue_type or "_", | |
| case.status, | |
| case.resolution_code or "_", | |
| tuple(case.requested_fields), | |
| bool(case.reply), | |
| bool(case.internal_note), | |
| observation.remaining_steps, | |
| ) | |
| def action_label(action: SupportDeskAction) -> str: | |
| """Human-readable action label for debug output.""" | |
| parts = [action.operation] | |
| if action.queue: | |
| parts.append(action.queue) | |
| if action.priority: | |
| parts.append(action.priority) | |
| if action.issue_type: | |
| parts.append(action.issue_type) | |
| if action.status: | |
| parts.append(action.status) | |
| if action.resolution_code: | |
| parts.append(action.resolution_code) | |
| if action.requested_fields: | |
| parts.append(",".join(action.requested_fields)) | |
| if action.reply: | |
| parts.append("reply") | |
| if action.internal_note: | |
| parts.append("note") | |
| return " | ".join(parts) | |
| def choose_action(q_values: dict[tuple, list[float]], state: tuple, num_actions: int, epsilon: float) -> int: | |
| """Epsilon-greedy action selection.""" | |
| if state not in q_values: | |
| q_values[state] = [0.0] * num_actions | |
| if random.random() < epsilon: | |
| return random.randrange(num_actions) | |
| best_value = max(q_values[state]) | |
| best_indices = [index for index, value in enumerate(q_values[state]) if value == best_value] | |
| return random.choice(best_indices) | |
| def train_q_agent( | |
| episodes_per_task: int, | |
| alpha: float, | |
| gamma: float, | |
| epsilon: float, | |
| epsilon_decay: float, | |
| min_epsilon: float, | |
| seed: int, | |
| ) -> tuple[dict[tuple, list[float]], dict[str, list[SupportDeskAction]]]: | |
| """Train a small tabular Q-learning agent over all tasks.""" | |
| random.seed(seed) | |
| q_values: dict[tuple, list[float]] = {} | |
| action_libraries = {task_id: build_action_library(task_id) for task_id in list_task_ids()} | |
| for _ in range(episodes_per_task): | |
| for task_id in list_task_ids(): | |
| env = SupportDeskEnvironment(task_id=task_id) | |
| observation = env.reset() | |
| actions = action_libraries[task_id] | |
| try: | |
| while not observation.done: | |
| state = state_key(task_id, observation) | |
| action_index = choose_action(q_values, state, len(actions), epsilon) | |
| next_observation = env.step(actions[action_index]) | |
| next_state = state_key(task_id, next_observation) | |
| if next_state not in q_values: | |
| q_values[next_state] = [0.0] * len(actions) | |
| td_target = next_observation.reward + gamma * (0.0 if next_observation.done else max(q_values[next_state])) | |
| td_error = td_target - q_values[state][action_index] | |
| q_values[state][action_index] += alpha * td_error | |
| observation = next_observation | |
| finally: | |
| env.close() | |
| epsilon = max(min_epsilon, epsilon * epsilon_decay) | |
| return q_values, action_libraries | |
| def evaluate_policy( | |
| q_values: dict[tuple, list[float]], | |
| action_libraries: dict[str, list[SupportDeskAction]], | |
| ) -> list[EvalResult]: | |
| """Run a greedy evaluation episode for each task.""" | |
| results: list[EvalResult] = [] | |
| for task_id in list_task_ids(): | |
| env = SupportDeskEnvironment(task_id=task_id) | |
| observation = env.reset() | |
| actions = action_libraries[task_id] | |
| chosen_actions: list[str] = [] | |
| try: | |
| while not observation.done: | |
| state = state_key(task_id, observation) | |
| q_values.setdefault(state, [0.0] * len(actions)) | |
| action_index = max(range(len(actions)), key=lambda idx: q_values[state][idx]) | |
| action = actions[action_index] | |
| chosen_actions.append(action_label(action)) | |
| observation = env.step(action) | |
| results.append( | |
| EvalResult( | |
| task_id=task_id, | |
| score=grade_case(get_task(task_id), env.state.case).total_score, | |
| reward=env.state.reward, | |
| steps=env.state.step_count, | |
| actions=chosen_actions, | |
| ) | |
| ) | |
| finally: | |
| env.close() | |
| return results | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Train a simple tabular Q-learning agent on SupportDesk.") | |
| parser.add_argument("--episodes-per-task", type=int, default=250) | |
| parser.add_argument("--alpha", type=float, default=0.45) | |
| parser.add_argument("--gamma", type=float, default=0.92) | |
| parser.add_argument("--epsilon", type=float, default=0.35) | |
| parser.add_argument("--epsilon-decay", type=float, default=0.99) | |
| parser.add_argument("--min-epsilon", type=float, default=0.03) | |
| parser.add_argument("--seed", type=int, default=7) | |
| args = parser.parse_args() | |
| q_values, action_libraries = train_q_agent( | |
| episodes_per_task=args.episodes_per_task, | |
| alpha=args.alpha, | |
| gamma=args.gamma, | |
| epsilon=args.epsilon, | |
| epsilon_decay=args.epsilon_decay, | |
| min_epsilon=args.min_epsilon, | |
| seed=args.seed, | |
| ) | |
| results = evaluate_policy(q_values, action_libraries) | |
| average_score = sum(result.score for result in results) / len(results) | |
| print("Tabular Q-learning evaluation") | |
| print("============================") | |
| for result in results: | |
| print( | |
| f"{result.task_id}: score={result.score:.2f} reward={result.reward:.2f} " | |
| f"steps={result.steps}" | |
| ) | |
| print(f" actions: {' -> '.join(result.actions)}") | |
| print(f"average_score={average_score:.3f}") | |
| if __name__ == "__main__": | |
| main() | |