from __future__ import annotations import random import uuid from typing import Any from openenv.core.env_server import Environment from tool_use_env.grader import grade_task from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState from tool_use_env.tasks import TASKS, TASK_SEQUENCE class ToolUseEnvironment(Environment): SUPPORTS_CONCURRENT_SESSIONS = True MAX_STEPS = 6 def __init__(self) -> None: super().__init__() self._state = ToolUseState() self._active_task: dict[str, Any] | None = None self._task_cursor = 0 def _select_task(self, seed: int | None = None, task_id: str | None = None) -> dict[str, Any]: if task_id: if task_id not in TASKS: raise ValueError(f"Unknown task_id '{task_id}'") return TASKS[task_id] if seed is not None: rng = random.Random(seed) return TASKS[TASK_SEQUENCE[rng.randrange(len(TASK_SEQUENCE))]] selected = TASKS[TASK_SEQUENCE[self._task_cursor % len(TASK_SEQUENCE)]] self._task_cursor += 1 return selected def reset( self, seed: int | None = None, episode_id: str | None = None, **kwargs: Any, ) -> ToolUseObservation: task = self._select_task(seed=seed, task_id=kwargs.get("task_id")) self._active_task = task self._state = ToolUseState( episode_id=episode_id or str(uuid.uuid4()), step_count=0, task_id=task["task_id"], task_name=task["task_name"], difficulty=task["difficulty"], objective=task["objective"], cumulative_reward=0.0, final_score=0.0, drafted_reply=None, resolution_code=None, expected_resolution_code=task["expected_resolution_code"], required_evidence=list(task["required_evidence"]), collected_evidence=["ticket"], action_history=[], repeat_action_count=0, last_action_error=None, known_artifacts={}, known_policies={}, ) return self._build_observation( reward=0.0, done=False, last_tool_result=( "Ticket loaded. Start by reviewing the ticket, then inspect the most relevant " "artifacts and policy before submitting a resolution." ), ) def _normalize_artifact_id(self, artifact_id: str | None) -> str | None: if not artifact_id: return None normalized = artifact_id.strip().lower().replace(" ", "_") aliases = { "payments": "payment", "billing": "payment", "risk": "risk_log", "risklog": "risk_log", "profile": "account", } return aliases.get(normalized, normalized) def _resolve_policy_key(self, query: str | None) -> str | None: if not query or not self._active_task: return None normalized = query.strip().lower().replace(" ", "_") policies = self._active_task["policies"] if normalized in policies: return normalized alias_map = { "damaged": "damaged_items", "damage": "damaged_items", "replacement": "damaged_items", "duplicate": "duplicate_charge", "duplicate_charge": "duplicate_charge", "billing": "duplicate_charge", "fraud": "account_takeover", "takeover": "account_takeover", "account_takeover": "account_takeover", "security": "account_takeover", } mapped = alias_map.get(normalized) if mapped in policies: return mapped for key in policies: if normalized in key: return key return None def _record_repeat_if_needed(self, evidence_key: str) -> bool: if evidence_key in self._state.collected_evidence: self._state.repeat_action_count += 1 return True return False def _partial_score(self) -> float: if not self._active_task: return 0.0 return grade_task( self._active_task, self._state.collected_evidence, self._state.drafted_reply, self._state.resolution_code, self._state.step_count, self._state.repeat_action_count, )["final_score"] def _append_history(self, action: ToolUseAction) -> None: parts = [action.action_type] if action.artifact_id: parts.append(f"artifact={action.artifact_id}") if action.query: parts.append(f"query={action.query}") if action.resolution_code: parts.append(f"resolution={action.resolution_code}") self._state.action_history.append(" | ".join(parts)) def _build_observation( self, reward: float, done: bool, last_tool_result: str | None, last_action_error: str | None = None, ) -> ToolUseObservation: task = self._active_task if not task: raise RuntimeError("Environment has no active task.") score = self._state.final_score if done else self._partial_score() remaining_steps = max(0, self.MAX_STEPS - self._state.step_count) known_items = self._state.collected_evidence or ["ticket"] draft_status = "present" if self._state.drafted_reply else "missing" resolution_status = self._state.resolution_code or "not submitted" summary = ( f"Known evidence: {', '.join(known_items)}. " f"Draft reply: {draft_status}. " f"Resolution: {resolution_status}. " f"Submit the best supported resolution before steps run out." ) return ToolUseObservation( done=done, reward=round(min(max(reward, 0.0), 1.0), 3), task_id=task["task_id"], difficulty=task["difficulty"], objective=task["objective"], customer_message=task["customer_message"], workspace_summary=summary, available_actions=[ "review_ticket", "inspect_artifact", "search_policy", "draft_reply", "submit_resolution", ], available_resolution_codes=list(task["available_resolution_codes"]), collected_evidence=list(self._state.collected_evidence), last_tool_result=last_tool_result, last_action_error=last_action_error, remaining_steps=remaining_steps, current_score=round(score, 3), metadata={ "task_name": task["task_name"], "action_history": list(self._state.action_history), }, ) def _finish_episode(self, resolution_code: str | None, feedback: str) -> ToolUseObservation: if not self._active_task: raise RuntimeError("Environment has no active task.") self._state.resolution_code = resolution_code breakdown = grade_task( self._active_task, self._state.collected_evidence, self._state.drafted_reply, self._state.resolution_code, self._state.step_count, self._state.repeat_action_count, ) self._state.final_score = breakdown["final_score"] self._state.last_action_error = None result_text = ( f"{feedback} | final_score={breakdown['final_score']:.3f} | " f"resolution_score={breakdown['resolution_score']:.3f} | " f"evidence_score={breakdown['evidence_score']:.3f} | " f"reply_score={breakdown['reply_score']:.3f} | " f"efficiency_score={breakdown['efficiency_score']:.3f}" ) return self._build_observation( reward=breakdown["final_score"], done=True, last_tool_result=result_text, ) def step( self, action: ToolUseAction, timeout_s: float | None = None, **kwargs: Any, ) -> ToolUseObservation: if not self._active_task: raise RuntimeError("Call reset() before step().") if self._state.final_score > 0 and self._state.resolution_code: return self._build_observation( reward=0.0, done=True, last_tool_result="Episode already finished.", last_action_error="episode_already_done", ) self._state.step_count += 1 self._append_history(action) reward = 0.0 last_tool_result = None error = None if action.action_type == "review_ticket": repeated = self._record_repeat_if_needed("ticket") reward = 0.02 if repeated else 0.10 last_tool_result = self._active_task["customer_message"] elif action.action_type == "inspect_artifact": artifact_id = self._normalize_artifact_id(action.artifact_id) artifacts = self._active_task["artifacts"] if not artifact_id or artifact_id not in artifacts: error = "invalid_artifact_id" last_tool_result = ( "Unknown artifact. Valid artifacts: " + ", ".join(sorted(artifacts.keys())) ) else: evidence_key = f"artifact:{artifact_id}" repeated = self._record_repeat_if_needed(evidence_key) if not repeated: self._state.collected_evidence.append(evidence_key) self._state.known_artifacts[artifact_id] = artifacts[artifact_id] reward = 0.14 if evidence_key in self._state.required_evidence else 0.04 else: reward = 0.01 last_tool_result = artifacts[artifact_id] elif action.action_type == "search_policy": policy_key = self._resolve_policy_key(action.query) policies = self._active_task["policies"] if not policy_key: error = "policy_not_found" last_tool_result = ( "No matching policy found. Available policies: " + ", ".join(sorted(policies.keys())) ) else: evidence_key = f"policy:{policy_key}" repeated = self._record_repeat_if_needed(evidence_key) if not repeated: self._state.collected_evidence.append(evidence_key) self._state.known_policies[policy_key] = policies[policy_key] reward = 0.14 if evidence_key in self._state.required_evidence else 0.04 else: reward = 0.01 last_tool_result = policies[policy_key] elif action.action_type == "draft_reply": if not action.message or not action.message.strip(): error = "empty_reply" last_tool_result = "Draft reply cannot be empty." else: self._state.drafted_reply = action.message.strip() keywords = self._active_task["reply_keywords"] hits = sum( 1 for keyword in keywords if keyword.lower() in self._state.drafted_reply.lower() ) reward = round(0.05 + (0.15 * (hits / len(keywords))), 3) last_tool_result = ( f"Draft saved. Included {hits}/{len(keywords)} required reply cues." ) elif action.action_type == "submit_resolution": if not action.resolution_code: error = "missing_resolution_code" last_tool_result = "submit_resolution requires a resolution_code." elif action.resolution_code not in self._active_task["available_resolution_codes"]: error = "invalid_resolution_code" last_tool_result = ( "Unsupported resolution code. Valid codes: " + ", ".join(self._active_task["available_resolution_codes"]) ) else: return self._finish_episode( resolution_code=action.resolution_code, feedback=f"Resolution submitted: {action.resolution_code}", ) else: error = "invalid_action_type" last_tool_result = "Unsupported action_type." self._state.last_action_error = error if self._state.step_count >= self.MAX_STEPS: return self._finish_episode( resolution_code=self._state.resolution_code, feedback="Episode ended because the step limit was reached.", ) self._state.cumulative_reward = round(self._state.cumulative_reward + reward, 3) return self._build_observation( reward=reward, done=False, last_tool_result=last_tool_result, last_action_error=error, ) @property def state(self) -> ToolUseState: return self._state