it-support-triage / env_core.py
kevanthonyP's picture
Create env_core.py
44f4eca verified
"""
environment.py β€” Core ITSupportEnv class.
Implements the OpenEnv interface:
reset(task_id) β†’ TicketObservation
step(action) β†’ StepResult
state() β†’ EnvState
"""
import json
from typing import Optional, Dict, Any
from env_models import (
TicketObservation, TriageAction, StepResult, EnvState,
)
from env_tasks import TASK_MAP, Task
class ITSupportEnv:
"""
IT Support Ticket Triage Environment.
The agent receives a support ticket (observation) and must produce
a TriageAction containing category, priority, department, escalation
decision, and a response message.
Reward is computed by a deterministic grader specific to each task.
Partial credit is awarded for each correct dimension of the triage.
"""
def __init__(self) -> None:
self._task: Optional[Task] = None
self._current_step: int = 0
self._total_reward: float = 0.0
self._done: bool = True
self._history: list = []
self._current_obs: Optional[TicketObservation] = None
# ─── OpenEnv interface ────────────────────────────────────────────────────
def reset(self, task_id: str = "task_easy") -> TicketObservation:
"""
Reset the environment for a new episode.
Args:
task_id: One of 'task_easy', 'task_medium', 'task_hard'.
Returns:
The initial TicketObservation for the agent.
Raises:
ValueError: If task_id is not recognised.
"""
if task_id not in TASK_MAP:
raise ValueError(
f"Unknown task_id '{task_id}'. "
f"Valid options: {list(TASK_MAP.keys())}"
)
self._task = TASK_MAP[task_id]
self._current_step = 0
self._total_reward = 0.0
self._done = False
self._history = []
self._current_obs = self._task.ticket
return self._current_obs
def step(self, action: TriageAction) -> StepResult:
"""
Apply the agent's triage action and return a StepResult.
Each task has exactly one step (one ticket = one episode).
The grader evaluates the full action and returns a score in [0.0, 1.0].
Args:
action: The agent's TriageAction.
Returns:
StepResult with reward, done flag, and grader breakdown.
Raises:
RuntimeError: If called before reset() or after episode is done.
"""
if self._done or self._task is None:
raise RuntimeError(
"Cannot call step() before reset() or after episode is done."
)
# Run the task-specific grader
score, breakdown = self._task.grader(action)
self._current_step += 1
self._total_reward += score
self._done = True # Each episode is exactly 1 step
# Record to history
self._history.append({
"step": self._current_step,
"action": action.dict(),
"reward": score,
"breakdown": breakdown,
})
return StepResult(
observation=None, # Episode done
reward=score,
done=True,
info={
"task_id": self._task.task_id,
"task_name": self._task.name,
"difficulty": self._task.difficulty,
"grader_breakdown": breakdown,
"total_reward": self._total_reward,
},
)
def state(self) -> EnvState:
"""
Return the full current environment state.
"""
if self._task is None:
return EnvState(
task_id="none",
task_name="Not initialised",
task_description="Call reset() to start.",
current_step=0,
max_steps=0,
total_reward=0.0,
done=True,
current_ticket=None,
history=[],
)
return EnvState(
task_id=self._task.task_id,
task_name=self._task.name,
task_description=self._task.description,
current_step=self._current_step,
max_steps=self._task.max_steps,
total_reward=self._total_reward,
done=self._done,
current_ticket=self._current_obs if not self._done else None,
history=self._history,
)
def list_tasks(self) -> list:
"""Return metadata for all available tasks."""
return [
{
"task_id": t.task_id,
"name": t.name,
"description": t.description,
"difficulty": t.difficulty,
"max_steps": t.max_steps,
}
for t in TASK_MAP.values()
]
# ─── Manual test ─────────────────────────────────────────────────────────────
if __name__ == "__main__":
from env_models import TicketCategory, TicketPriority, Department, EscalationDecision
env = ITSupportEnv()
print("=" * 60)
print("TASK LIST")
print("=" * 60)
for t in env.list_tasks():
print(f" [{t['difficulty'].upper()}] {t['task_id']}: {t['name']}")
print("\n" + "=" * 60)
print("RUNNING EASY TASK β€” correct answer")
print("=" * 60)
obs = env.reset("task_easy")
print(f"Ticket: {obs.subject}")
result = env.step(TriageAction(
category=TicketCategory.HARDWARE,
priority=TicketPriority.HIGH,
department=Department.TIER2_SUPPORT,
escalate=EscalationDecision.ESCALATE,
response="Hi Priya, we've received your ticket about your laptop boot issue. A hardware technician will contact you within 30 minutes.",
reasoning="Black screen after Windows update with power cycle failure indicates hardware/firmware issue.",
))
print(f"Score: {result.reward}")
print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}")
print("\n" + "=" * 60)
print("RUNNING HARD TASK β€” UNSAFE answer (expect heavy penalty)")
print("=" * 60)
obs = env.reset("task_hard")
result = env.step(TriageAction(
category=TicketCategory.SECURITY,
priority=TicketPriority.CRITICAL,
department=Department.SECURITY_TEAM,
escalate=EscalationDecision.ESCALATE,
response="Hi, try the decryption tool your colleague suggested and let us know if it works.",
reasoning="Ransomware detected.",
))
print(f"Score (should be low due to penalty): {result.reward}")
print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}")
print("\n" + "=" * 60)
print("RUNNING HARD TASK β€” SAFE answer (expect high score)")
print("=" * 60)
obs = env.reset("task_hard")
result = env.step(TriageAction(
category=TicketCategory.SECURITY,
priority=TicketPriority.CRITICAL,
department=Department.SECURITY_TEAM,
escalate=EscalationDecision.ESCALATE,
response=(
"Ananya, this is a ransomware attack. IMMEDIATELY disconnect your computer from the network "
"by unplugging the ethernet cable or disabling WiFi. Do NOT attempt to recover files yourself "
"or use any decryption tool β€” this can cause permanent data loss. Do NOT pay the ransom. "
"The instruction to not contact IT is a social engineering tactic β€” ignore it. "
"Our security team is already being notified and will contact you within minutes. "
"Do not touch the computer further until they arrive."
),
reasoning=(
"Active ransomware on Finance Controller with access to sensitive data. "
"Immediate isolation required. Self-recovery is dangerous. Escalate to security team now."
),
))
print(f"Score (should be high): {result.reward}")
print(f"Breakdown: {json.dumps(result.info['grader_breakdown'], indent=2)}")