Spaces:
Sleeping
Sleeping
| """ | |
| environment.py β Core ITSupportEnv class. | |
| Implements the OpenEnv interface: | |
| reset(task_id) β TicketObservation | |
| step(action) β StepResult | |
| state() β EnvState | |
| """ | |
| import json | |
| from typing import Optional, Dict, Any | |
| from env_models import ( | |
| TicketObservation, TriageAction, StepResult, EnvState, | |
| ) | |
| from env_tasks import TASK_MAP, Task | |
| class ITSupportEnv: | |
| """ | |
| IT Support Ticket Triage Environment. | |
| The agent receives a support ticket (observation) and must produce | |
| a TriageAction containing category, priority, department, escalation | |
| decision, and a response message. | |
| Reward is computed by a deterministic grader specific to each task. | |
| Partial credit is awarded for each correct dimension of the triage. | |
| """ | |
| def __init__(self) -> None: | |
| self._task: Optional[Task] = None | |
| self._current_step: int = 0 | |
| self._total_reward: float = 0.0 | |
| self._done: bool = True | |
| self._history: list = [] | |
| self._current_obs: Optional[TicketObservation] = None | |
| # βββ OpenEnv interface ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_id: str = "task_easy") -> TicketObservation: | |
| """ | |
| Reset the environment for a new episode. | |
| Args: | |
| task_id: One of 'task_easy', 'task_medium', 'task_hard'. | |
| Returns: | |
| The initial TicketObservation for the agent. | |
| Raises: | |
| ValueError: If task_id is not recognised. | |
| """ | |
| if task_id not in TASK_MAP: | |
| raise ValueError( | |
| f"Unknown task_id '{task_id}'. " | |
| f"Valid options: {list(TASK_MAP.keys())}" | |
| ) | |
| self._task = TASK_MAP[task_id] | |
| self._current_step = 0 | |
| self._total_reward = 0.0 | |
| self._done = False | |
| self._history = [] | |
| self._current_obs = self._task.ticket | |
| return self._current_obs | |
| def step(self, action: TriageAction) -> StepResult: | |
| """ | |
| Apply the agent's triage action and return a StepResult. | |
| Each task has exactly one step (one ticket = one episode). | |
| The grader evaluates the full action and returns a score in [0.0, 1.0]. | |
| Args: | |
| action: The agent's TriageAction. | |
| Returns: | |
| StepResult with reward, done flag, and grader breakdown. | |
| Raises: | |
| RuntimeError: If called before reset() or after episode is done. | |
| """ | |
| if self._done or self._task is None: | |
| raise RuntimeError( | |
| "Cannot call step() before reset() or after episode is done." | |
| ) | |
| # Run the task-specific grader | |
| score, breakdown = self._task.grader(action) | |
| self._current_step += 1 | |
| self._total_reward += score | |
| self._done = True # Each episode is exactly 1 step | |
| # Record to history | |
| self._history.append({ | |
| "step": self._current_step, | |
| "action": action.dict(), | |
| "reward": score, | |
| "breakdown": breakdown, | |
| }) | |
| return StepResult( | |
| observation=None, # Episode done | |
| reward=score, | |
| done=True, | |
| info={ | |
| "task_id": self._task.task_id, | |
| "task_name": self._task.name, | |
| "difficulty": self._task.difficulty, | |
| "grader_breakdown": breakdown, | |
| "total_reward": self._total_reward, | |
| }, | |
| ) | |
| def state(self) -> EnvState: | |
| """ | |
| Return the full current environment state. | |
| """ | |
| if self._task is None: | |
| return EnvState( | |
| task_id="none", | |
| task_name="Not initialised", | |
| task_description="Call reset() to start.", | |
| current_step=0, | |
| max_steps=0, | |
| total_reward=0.0, | |
| done=True, | |
| current_ticket=None, | |
| history=[], | |
| ) | |
| return EnvState( | |
| task_id=self._task.task_id, | |
| task_name=self._task.name, | |
| task_description=self._task.description, | |
| current_step=self._current_step, | |
| max_steps=self._task.max_steps, | |
| total_reward=self._total_reward, | |
| done=self._done, | |
| current_ticket=self._current_obs if not self._done else None, | |
| history=self._history, | |
| ) | |
| def list_tasks(self) -> list: | |
| """Return metadata for all available tasks.""" | |
| return [ | |
| { | |
| "task_id": t.task_id, | |
| "name": t.name, | |
| "description": t.description, | |
| "difficulty": t.difficulty, | |
| "max_steps": t.max_steps, | |
| } | |
| for t in TASK_MAP.values() | |
| ] | |
| # βββ Manual test βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| from env_models import TicketCategory, TicketPriority, Department, EscalationDecision | |
| env = ITSupportEnv() | |
| print("=" * 60) | |
| print("TASK LIST") | |
| print("=" * 60) | |
| for t in env.list_tasks(): | |
| print(f" [{t['difficulty'].upper()}] {t['task_id']}: {t['name']}") | |
| print("\n" + "=" * 60) | |
| print("RUNNING EASY TASK β correct answer") | |
| print("=" * 60) | |
| obs = env.reset("task_easy") | |
| print(f"Ticket: {obs.subject}") | |
| result = env.step(TriageAction( | |
| category=TicketCategory.HARDWARE, | |
| priority=TicketPriority.HIGH, | |
| department=Department.TIER2_SUPPORT, | |
| escalate=EscalationDecision.ESCALATE, | |
| response="Hi Priya, we've received your ticket about your laptop boot issue. A hardware technician will contact you within 30 minutes.", | |
| reasoning="Black screen after Windows update with power cycle failure indicates hardware/firmware issue.", | |
| )) | |
| print(f"Score: {result.reward}") | |
| print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}") | |
| print("\n" + "=" * 60) | |
| print("RUNNING HARD TASK β UNSAFE answer (expect heavy penalty)") | |
| print("=" * 60) | |
| obs = env.reset("task_hard") | |
| result = env.step(TriageAction( | |
| category=TicketCategory.SECURITY, | |
| priority=TicketPriority.CRITICAL, | |
| department=Department.SECURITY_TEAM, | |
| escalate=EscalationDecision.ESCALATE, | |
| response="Hi, try the decryption tool your colleague suggested and let us know if it works.", | |
| reasoning="Ransomware detected.", | |
| )) | |
| print(f"Score (should be low due to penalty): {result.reward}") | |
| print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}") | |
| print("\n" + "=" * 60) | |
| print("RUNNING HARD TASK β SAFE answer (expect high score)") | |
| print("=" * 60) | |
| obs = env.reset("task_hard") | |
| result = env.step(TriageAction( | |
| category=TicketCategory.SECURITY, | |
| priority=TicketPriority.CRITICAL, | |
| department=Department.SECURITY_TEAM, | |
| escalate=EscalationDecision.ESCALATE, | |
| response=( | |
| "Ananya, this is a ransomware attack. IMMEDIATELY disconnect your computer from the network " | |
| "by unplugging the ethernet cable or disabling WiFi. Do NOT attempt to recover files yourself " | |
| "or use any decryption tool β this can cause permanent data loss. Do NOT pay the ransom. " | |
| "The instruction to not contact IT is a social engineering tactic β ignore it. " | |
| "Our security team is already being notified and will contact you within minutes. " | |
| "Do not touch the computer further until they arrive." | |
| ), | |
| reasoning=( | |
| "Active ransomware on Finance Controller with access to sensitive data. " | |
| "Immediate isolation required. Self-recovery is dangerous. Escalate to security team now." | |
| ), | |
| )) | |
| print(f"Score (should be high): {result.reward}") | |
| print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}") |