kevanthonyP commited on
Commit
44f4eca
Β·
verified Β·
1 Parent(s): 3dba410

Create env_core.py

Browse files
Files changed (1) hide show
  1. env_core.py +229 -0
env_core.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ environment.py β€” Core ITSupportEnv class.
3
+
4
+ Implements the OpenEnv interface:
5
+ reset(task_id) β†’ TicketObservation
6
+ step(action) β†’ StepResult
7
+ state() β†’ EnvState
8
+ """
9
+
10
+ import json
11
+ from typing import Optional, Dict, Any
12
+
13
+ from env_models import (
14
+ TicketObservation, TriageAction, StepResult, EnvState,
15
+ )
16
+ from env_tasks import TASK_MAP, Task
17
+
18
+
19
+ class ITSupportEnv:
20
+ """
21
+ IT Support Ticket Triage Environment.
22
+
23
+ The agent receives a support ticket (observation) and must produce
24
+ a TriageAction containing category, priority, department, escalation
25
+ decision, and a response message.
26
+
27
+ Reward is computed by a deterministic grader specific to each task.
28
+ Partial credit is awarded for each correct dimension of the triage.
29
+ """
30
+
31
+ def __init__(self) -> None:
32
+ self._task: Optional[Task] = None
33
+ self._current_step: int = 0
34
+ self._total_reward: float = 0.0
35
+ self._done: bool = True
36
+ self._history: list = []
37
+ self._current_obs: Optional[TicketObservation] = None
38
+
39
+ # ─── OpenEnv interface ────────────────────────────────────────────────────
40
+
41
+ def reset(self, task_id: str = "task_easy") -> TicketObservation:
42
+ """
43
+ Reset the environment for a new episode.
44
+
45
+ Args:
46
+ task_id: One of 'task_easy', 'task_medium', 'task_hard'.
47
+
48
+ Returns:
49
+ The initial TicketObservation for the agent.
50
+
51
+ Raises:
52
+ ValueError: If task_id is not recognised.
53
+ """
54
+ if task_id not in TASK_MAP:
55
+ raise ValueError(
56
+ f"Unknown task_id '{task_id}'. "
57
+ f"Valid options: {list(TASK_MAP.keys())}"
58
+ )
59
+
60
+ self._task = TASK_MAP[task_id]
61
+ self._current_step = 0
62
+ self._total_reward = 0.0
63
+ self._done = False
64
+ self._history = []
65
+ self._current_obs = self._task.ticket
66
+
67
+ return self._current_obs
68
+
69
+ def step(self, action: TriageAction) -> StepResult:
70
+ """
71
+ Apply the agent's triage action and return a StepResult.
72
+
73
+ Each task has exactly one step (one ticket = one episode).
74
+ The grader evaluates the full action and returns a score in [0.0, 1.0].
75
+
76
+ Args:
77
+ action: The agent's TriageAction.
78
+
79
+ Returns:
80
+ StepResult with reward, done flag, and grader breakdown.
81
+
82
+ Raises:
83
+ RuntimeError: If called before reset() or after episode is done.
84
+ """
85
+ if self._done or self._task is None:
86
+ raise RuntimeError(
87
+ "Cannot call step() before reset() or after episode is done."
88
+ )
89
+
90
+ # Run the task-specific grader
91
+ score, breakdown = self._task.grader(action)
92
+
93
+ self._current_step += 1
94
+ self._total_reward += score
95
+ self._done = True # Each episode is exactly 1 step
96
+
97
+ # Record to history
98
+ self._history.append({
99
+ "step": self._current_step,
100
+ "action": action.dict(),
101
+ "reward": score,
102
+ "breakdown": breakdown,
103
+ })
104
+
105
+ return StepResult(
106
+ observation=None, # Episode done
107
+ reward=score,
108
+ done=True,
109
+ info={
110
+ "task_id": self._task.task_id,
111
+ "task_name": self._task.name,
112
+ "difficulty": self._task.difficulty,
113
+ "grader_breakdown": breakdown,
114
+ "total_reward": self._total_reward,
115
+ },
116
+ )
117
+
118
+ def state(self) -> EnvState:
119
+ """
120
+ Return the full current environment state.
121
+ """
122
+ if self._task is None:
123
+ return EnvState(
124
+ task_id="none",
125
+ task_name="Not initialised",
126
+ task_description="Call reset() to start.",
127
+ current_step=0,
128
+ max_steps=0,
129
+ total_reward=0.0,
130
+ done=True,
131
+ current_ticket=None,
132
+ history=[],
133
+ )
134
+
135
+ return EnvState(
136
+ task_id=self._task.task_id,
137
+ task_name=self._task.name,
138
+ task_description=self._task.description,
139
+ current_step=self._current_step,
140
+ max_steps=self._task.max_steps,
141
+ total_reward=self._total_reward,
142
+ done=self._done,
143
+ current_ticket=self._current_obs if not self._done else None,
144
+ history=self._history,
145
+ )
146
+
147
+ def list_tasks(self) -> list:
148
+ """Return metadata for all available tasks."""
149
+ return [
150
+ {
151
+ "task_id": t.task_id,
152
+ "name": t.name,
153
+ "description": t.description,
154
+ "difficulty": t.difficulty,
155
+ "max_steps": t.max_steps,
156
+ }
157
+ for t in TASK_MAP.values()
158
+ ]
159
+
160
+
161
+ # ─── Manual test ─────────────────────────────────────────────────────────────
162
+
163
+ if __name__ == "__main__":
164
+ from env_models import TicketCategory, TicketPriority, Department, EscalationDecision
165
+
166
+ env = ITSupportEnv()
167
+
168
+ print("=" * 60)
169
+ print("TASK LIST")
170
+ print("=" * 60)
171
+ for t in env.list_tasks():
172
+ print(f" [{t['difficulty'].upper()}] {t['task_id']}: {t['name']}")
173
+
174
+ print("\n" + "=" * 60)
175
+ print("RUNNING EASY TASK β€” correct answer")
176
+ print("=" * 60)
177
+ obs = env.reset("task_easy")
178
+ print(f"Ticket: {obs.subject}")
179
+
180
+ result = env.step(TriageAction(
181
+ category=TicketCategory.HARDWARE,
182
+ priority=TicketPriority.HIGH,
183
+ department=Department.TIER2_SUPPORT,
184
+ escalate=EscalationDecision.ESCALATE,
185
+ response="Hi Priya, we've received your ticket about your laptop boot issue. A hardware technician will contact you within 30 minutes.",
186
+ reasoning="Black screen after Windows update with power cycle failure indicates hardware/firmware issue.",
187
+ ))
188
+ print(f"Score: {result.reward}")
189
+ print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}")
190
+
191
+ print("\n" + "=" * 60)
192
+ print("RUNNING HARD TASK β€” UNSAFE answer (expect heavy penalty)")
193
+ print("=" * 60)
194
+ obs = env.reset("task_hard")
195
+ result = env.step(TriageAction(
196
+ category=TicketCategory.SECURITY,
197
+ priority=TicketPriority.CRITICAL,
198
+ department=Department.SECURITY_TEAM,
199
+ escalate=EscalationDecision.ESCALATE,
200
+ response="Hi, try the decryption tool your colleague suggested and let us know if it works.",
201
+ reasoning="Ransomware detected.",
202
+ ))
203
+ print(f"Score (should be low due to penalty): {result.reward}")
204
+ print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}")
205
+
206
+ print("\n" + "=" * 60)
207
+ print("RUNNING HARD TASK β€” SAFE answer (expect high score)")
208
+ print("=" * 60)
209
+ obs = env.reset("task_hard")
210
+ result = env.step(TriageAction(
211
+ category=TicketCategory.SECURITY,
212
+ priority=TicketPriority.CRITICAL,
213
+ department=Department.SECURITY_TEAM,
214
+ escalate=EscalationDecision.ESCALATE,
215
+ response=(
216
+ "Ananya, this is a ransomware attack. IMMEDIATELY disconnect your computer from the network "
217
+ "by unplugging the ethernet cable or disabling WiFi. Do NOT attempt to recover files yourself "
218
+ "or use any decryption tool β€” this can cause permanent data loss. Do NOT pay the ransom. "
219
+ "The instruction to not contact IT is a social engineering tactic β€” ignore it. "
220
+ "Our security team is already being notified and will contact you within minutes. "
221
+ "Do not touch the computer further until they arrive."
222
+ ),
223
+ reasoning=(
224
+ "Active ransomware on Finance Controller with access to sensitive data. "
225
+ "Immediate isolation required. Self-recovery is dangerous. Escalate to security team now."
226
+ ),
227
+ ))
228
+ print(f"Score (should be high): {result.reward}")
229
+ print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}")