Spaces:
Sleeping
Sleeping
File size: 8,218 Bytes
44f4eca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | """
environment.py β Core ITSupportEnv class.
Implements the OpenEnv interface:
reset(task_id) β TicketObservation
step(action) β StepResult
state() β EnvState
"""
import json
from typing import Optional, Dict, Any
from env_models import (
TicketObservation, TriageAction, StepResult, EnvState,
)
from env_tasks import TASK_MAP, Task
class ITSupportEnv:
"""
IT Support Ticket Triage Environment.
The agent receives a support ticket (observation) and must produce
a TriageAction containing category, priority, department, escalation
decision, and a response message.
Reward is computed by a deterministic grader specific to each task.
Partial credit is awarded for each correct dimension of the triage.
"""
def __init__(self) -> None:
self._task: Optional[Task] = None
self._current_step: int = 0
self._total_reward: float = 0.0
self._done: bool = True
self._history: list = []
self._current_obs: Optional[TicketObservation] = None
# βββ OpenEnv interface ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def reset(self, task_id: str = "task_easy") -> TicketObservation:
"""
Reset the environment for a new episode.
Args:
task_id: One of 'task_easy', 'task_medium', 'task_hard'.
Returns:
The initial TicketObservation for the agent.
Raises:
ValueError: If task_id is not recognised.
"""
if task_id not in TASK_MAP:
raise ValueError(
f"Unknown task_id '{task_id}'. "
f"Valid options: {list(TASK_MAP.keys())}"
)
self._task = TASK_MAP[task_id]
self._current_step = 0
self._total_reward = 0.0
self._done = False
self._history = []
self._current_obs = self._task.ticket
return self._current_obs
def step(self, action: TriageAction) -> StepResult:
"""
Apply the agent's triage action and return a StepResult.
Each task has exactly one step (one ticket = one episode).
The grader evaluates the full action and returns a score in [0.0, 1.0].
Args:
action: The agent's TriageAction.
Returns:
StepResult with reward, done flag, and grader breakdown.
Raises:
RuntimeError: If called before reset() or after episode is done.
"""
if self._done or self._task is None:
raise RuntimeError(
"Cannot call step() before reset() or after episode is done."
)
# Run the task-specific grader
score, breakdown = self._task.grader(action)
self._current_step += 1
self._total_reward += score
self._done = True # Each episode is exactly 1 step
# Record to history
self._history.append({
"step": self._current_step,
"action": action.dict(),
"reward": score,
"breakdown": breakdown,
})
return StepResult(
observation=None, # Episode done
reward=score,
done=True,
info={
"task_id": self._task.task_id,
"task_name": self._task.name,
"difficulty": self._task.difficulty,
"grader_breakdown": breakdown,
"total_reward": self._total_reward,
},
)
def state(self) -> EnvState:
"""
Return the full current environment state.
"""
if self._task is None:
return EnvState(
task_id="none",
task_name="Not initialised",
task_description="Call reset() to start.",
current_step=0,
max_steps=0,
total_reward=0.0,
done=True,
current_ticket=None,
history=[],
)
return EnvState(
task_id=self._task.task_id,
task_name=self._task.name,
task_description=self._task.description,
current_step=self._current_step,
max_steps=self._task.max_steps,
total_reward=self._total_reward,
done=self._done,
current_ticket=self._current_obs if not self._done else None,
history=self._history,
)
def list_tasks(self) -> list:
"""Return metadata for all available tasks."""
return [
{
"task_id": t.task_id,
"name": t.name,
"description": t.description,
"difficulty": t.difficulty,
"max_steps": t.max_steps,
}
for t in TASK_MAP.values()
]
# βββ Manual test βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
from env_models import TicketCategory, TicketPriority, Department, EscalationDecision
env = ITSupportEnv()
print("=" * 60)
print("TASK LIST")
print("=" * 60)
for t in env.list_tasks():
print(f" [{t['difficulty'].upper()}] {t['task_id']}: {t['name']}")
print("\n" + "=" * 60)
print("RUNNING EASY TASK β correct answer")
print("=" * 60)
obs = env.reset("task_easy")
print(f"Ticket: {obs.subject}")
result = env.step(TriageAction(
category=TicketCategory.HARDWARE,
priority=TicketPriority.HIGH,
department=Department.TIER2_SUPPORT,
escalate=EscalationDecision.ESCALATE,
response="Hi Priya, we've received your ticket about your laptop boot issue. A hardware technician will contact you within 30 minutes.",
reasoning="Black screen after Windows update with power cycle failure indicates hardware/firmware issue.",
))
print(f"Score: {result.reward}")
print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}")
print("\n" + "=" * 60)
print("RUNNING HARD TASK β UNSAFE answer (expect heavy penalty)")
print("=" * 60)
obs = env.reset("task_hard")
result = env.step(TriageAction(
category=TicketCategory.SECURITY,
priority=TicketPriority.CRITICAL,
department=Department.SECURITY_TEAM,
escalate=EscalationDecision.ESCALATE,
response="Hi, try the decryption tool your colleague suggested and let us know if it works.",
reasoning="Ransomware detected.",
))
print(f"Score (should be low due to penalty): {result.reward}")
print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}")
print("\n" + "=" * 60)
print("RUNNING HARD TASK β SAFE answer (expect high score)")
print("=" * 60)
obs = env.reset("task_hard")
result = env.step(TriageAction(
category=TicketCategory.SECURITY,
priority=TicketPriority.CRITICAL,
department=Department.SECURITY_TEAM,
escalate=EscalationDecision.ESCALATE,
response=(
"Ananya, this is a ransomware attack. IMMEDIATELY disconnect your computer from the network "
"by unplugging the ethernet cable or disabling WiFi. Do NOT attempt to recover files yourself "
"or use any decryption tool β this can cause permanent data loss. Do NOT pay the ransom. "
"The instruction to not contact IT is a social engineering tactic β ignore it. "
"Our security team is already being notified and will contact you within minutes. "
"Do not touch the computer further until they arrive."
),
reasoning=(
"Active ransomware on Finance Controller with access to sensitive data. "
"Immediate isolation required. Self-recovery is dangerous. Escalate to security team now."
),
))
print(f"Score (should be high): {result.reward}")
print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}") |