Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import random | |
| import uuid | |
| from typing import Any | |
| from openenv.core.env_server import Environment | |
| from tool_use_env.grader import grade_task | |
| from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState | |
| from tool_use_env.tasks import TASKS, TASK_SEQUENCE | |
| class ToolUseEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| MAX_STEPS = 6 | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._state = ToolUseState() | |
| self._active_task: dict[str, Any] | None = None | |
| self._task_cursor = 0 | |
| def _select_task(self, seed: int | None = None, task_id: str | None = None) -> dict[str, Any]: | |
| if task_id: | |
| if task_id not in TASKS: | |
| raise ValueError(f"Unknown task_id '{task_id}'") | |
| return TASKS[task_id] | |
| if seed is not None: | |
| rng = random.Random(seed) | |
| return TASKS[TASK_SEQUENCE[rng.randrange(len(TASK_SEQUENCE))]] | |
| selected = TASKS[TASK_SEQUENCE[self._task_cursor % len(TASK_SEQUENCE)]] | |
| self._task_cursor += 1 | |
| return selected | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| **kwargs: Any, | |
| ) -> ToolUseObservation: | |
| task = self._select_task(seed=seed, task_id=kwargs.get("task_id")) | |
| self._active_task = task | |
| self._state = ToolUseState( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| step_count=0, | |
| task_id=task["task_id"], | |
| task_name=task["task_name"], | |
| difficulty=task["difficulty"], | |
| objective=task["objective"], | |
| cumulative_reward=0.0, | |
| final_score=0.0, | |
| drafted_reply=None, | |
| resolution_code=None, | |
| expected_resolution_code=task["expected_resolution_code"], | |
| required_evidence=list(task["required_evidence"]), | |
| collected_evidence=["ticket"], | |
| action_history=[], | |
| repeat_action_count=0, | |
| last_action_error=None, | |
| known_artifacts={}, | |
| known_policies={}, | |
| ) | |
| return self._build_observation( | |
| reward=0.0, | |
| done=False, | |
| last_tool_result=( | |
| "Ticket loaded. Start by reviewing the ticket, then inspect the most relevant " | |
| "artifacts and policy before submitting a resolution." | |
| ), | |
| ) | |
| def _normalize_artifact_id(self, artifact_id: str | None) -> str | None: | |
| if not artifact_id: | |
| return None | |
| normalized = artifact_id.strip().lower().replace(" ", "_") | |
| aliases = { | |
| "payments": "payment", | |
| "billing": "payment", | |
| "risk": "risk_log", | |
| "risklog": "risk_log", | |
| "profile": "account", | |
| } | |
| return aliases.get(normalized, normalized) | |
| def _resolve_policy_key(self, query: str | None) -> str | None: | |
| if not query or not self._active_task: | |
| return None | |
| normalized = query.strip().lower().replace(" ", "_") | |
| policies = self._active_task["policies"] | |
| if normalized in policies: | |
| return normalized | |
| alias_map = { | |
| "damaged": "damaged_items", | |
| "damage": "damaged_items", | |
| "replacement": "damaged_items", | |
| "duplicate": "duplicate_charge", | |
| "duplicate_charge": "duplicate_charge", | |
| "billing": "duplicate_charge", | |
| "fraud": "account_takeover", | |
| "takeover": "account_takeover", | |
| "account_takeover": "account_takeover", | |
| "security": "account_takeover", | |
| } | |
| mapped = alias_map.get(normalized) | |
| if mapped in policies: | |
| return mapped | |
| for key in policies: | |
| if normalized in key: | |
| return key | |
| return None | |
| def _record_repeat_if_needed(self, evidence_key: str) -> bool: | |
| if evidence_key in self._state.collected_evidence: | |
| self._state.repeat_action_count += 1 | |
| return True | |
| return False | |
| def _partial_score(self) -> float: | |
| if not self._active_task: | |
| return 0.0 | |
| return grade_task( | |
| self._active_task, | |
| self._state.collected_evidence, | |
| self._state.drafted_reply, | |
| self._state.resolution_code, | |
| self._state.step_count, | |
| self._state.repeat_action_count, | |
| )["final_score"] | |
| def _append_history(self, action: ToolUseAction) -> None: | |
| parts = [action.action_type] | |
| if action.artifact_id: | |
| parts.append(f"artifact={action.artifact_id}") | |
| if action.query: | |
| parts.append(f"query={action.query}") | |
| if action.resolution_code: | |
| parts.append(f"resolution={action.resolution_code}") | |
| self._state.action_history.append(" | ".join(parts)) | |
| def _build_observation( | |
| self, | |
| reward: float, | |
| done: bool, | |
| last_tool_result: str | None, | |
| last_action_error: str | None = None, | |
| ) -> ToolUseObservation: | |
| task = self._active_task | |
| if not task: | |
| raise RuntimeError("Environment has no active task.") | |
| score = self._state.final_score if done else self._partial_score() | |
| remaining_steps = max(0, self.MAX_STEPS - self._state.step_count) | |
| known_items = self._state.collected_evidence or ["ticket"] | |
| draft_status = "present" if self._state.drafted_reply else "missing" | |
| resolution_status = self._state.resolution_code or "not submitted" | |
| summary = ( | |
| f"Known evidence: {', '.join(known_items)}. " | |
| f"Draft reply: {draft_status}. " | |
| f"Resolution: {resolution_status}. " | |
| f"Submit the best supported resolution before steps run out." | |
| ) | |
| return ToolUseObservation( | |
| done=done, | |
| reward=round(min(max(reward, 0.0), 1.0), 3), | |
| task_id=task["task_id"], | |
| difficulty=task["difficulty"], | |
| objective=task["objective"], | |
| customer_message=task["customer_message"], | |
| workspace_summary=summary, | |
| available_actions=[ | |
| "review_ticket", | |
| "inspect_artifact", | |
| "search_policy", | |
| "draft_reply", | |
| "submit_resolution", | |
| ], | |
| available_resolution_codes=list(task["available_resolution_codes"]), | |
| collected_evidence=list(self._state.collected_evidence), | |
| last_tool_result=last_tool_result, | |
| last_action_error=last_action_error, | |
| remaining_steps=remaining_steps, | |
| current_score=round(score, 3), | |
| metadata={ | |
| "task_name": task["task_name"], | |
| "action_history": list(self._state.action_history), | |
| }, | |
| ) | |
| def _finish_episode(self, resolution_code: str | None, feedback: str) -> ToolUseObservation: | |
| if not self._active_task: | |
| raise RuntimeError("Environment has no active task.") | |
| self._state.resolution_code = resolution_code | |
| breakdown = grade_task( | |
| self._active_task, | |
| self._state.collected_evidence, | |
| self._state.drafted_reply, | |
| self._state.resolution_code, | |
| self._state.step_count, | |
| self._state.repeat_action_count, | |
| ) | |
| self._state.final_score = breakdown["final_score"] | |
| self._state.last_action_error = None | |
| result_text = ( | |
| f"{feedback} | final_score={breakdown['final_score']:.3f} | " | |
| f"resolution_score={breakdown['resolution_score']:.3f} | " | |
| f"evidence_score={breakdown['evidence_score']:.3f} | " | |
| f"reply_score={breakdown['reply_score']:.3f} | " | |
| f"efficiency_score={breakdown['efficiency_score']:.3f}" | |
| ) | |
| return self._build_observation( | |
| reward=breakdown["final_score"], | |
| done=True, | |
| last_tool_result=result_text, | |
| ) | |
| def step( | |
| self, | |
| action: ToolUseAction, | |
| timeout_s: float | None = None, | |
| **kwargs: Any, | |
| ) -> ToolUseObservation: | |
| if not self._active_task: | |
| raise RuntimeError("Call reset() before step().") | |
| if self._state.final_score > 0 and self._state.resolution_code: | |
| return self._build_observation( | |
| reward=0.0, | |
| done=True, | |
| last_tool_result="Episode already finished.", | |
| last_action_error="episode_already_done", | |
| ) | |
| self._state.step_count += 1 | |
| self._append_history(action) | |
| reward = 0.0 | |
| last_tool_result = None | |
| error = None | |
| if action.action_type == "review_ticket": | |
| repeated = self._record_repeat_if_needed("ticket") | |
| reward = 0.02 if repeated else 0.10 | |
| last_tool_result = self._active_task["customer_message"] | |
| elif action.action_type == "inspect_artifact": | |
| artifact_id = self._normalize_artifact_id(action.artifact_id) | |
| artifacts = self._active_task["artifacts"] | |
| if not artifact_id or artifact_id not in artifacts: | |
| error = "invalid_artifact_id" | |
| last_tool_result = ( | |
| "Unknown artifact. Valid artifacts: " | |
| + ", ".join(sorted(artifacts.keys())) | |
| ) | |
| else: | |
| evidence_key = f"artifact:{artifact_id}" | |
| repeated = self._record_repeat_if_needed(evidence_key) | |
| if not repeated: | |
| self._state.collected_evidence.append(evidence_key) | |
| self._state.known_artifacts[artifact_id] = artifacts[artifact_id] | |
| reward = 0.14 if evidence_key in self._state.required_evidence else 0.04 | |
| else: | |
| reward = 0.01 | |
| last_tool_result = artifacts[artifact_id] | |
| elif action.action_type == "search_policy": | |
| policy_key = self._resolve_policy_key(action.query) | |
| policies = self._active_task["policies"] | |
| if not policy_key: | |
| error = "policy_not_found" | |
| last_tool_result = ( | |
| "No matching policy found. Available policies: " | |
| + ", ".join(sorted(policies.keys())) | |
| ) | |
| else: | |
| evidence_key = f"policy:{policy_key}" | |
| repeated = self._record_repeat_if_needed(evidence_key) | |
| if not repeated: | |
| self._state.collected_evidence.append(evidence_key) | |
| self._state.known_policies[policy_key] = policies[policy_key] | |
| reward = 0.14 if evidence_key in self._state.required_evidence else 0.04 | |
| else: | |
| reward = 0.01 | |
| last_tool_result = policies[policy_key] | |
| elif action.action_type == "draft_reply": | |
| if not action.message or not action.message.strip(): | |
| error = "empty_reply" | |
| last_tool_result = "Draft reply cannot be empty." | |
| else: | |
| self._state.drafted_reply = action.message.strip() | |
| keywords = self._active_task["reply_keywords"] | |
| hits = sum( | |
| 1 for keyword in keywords if keyword.lower() in self._state.drafted_reply.lower() | |
| ) | |
| reward = round(0.05 + (0.15 * (hits / len(keywords))), 3) | |
| last_tool_result = ( | |
| f"Draft saved. Included {hits}/{len(keywords)} required reply cues." | |
| ) | |
| elif action.action_type == "submit_resolution": | |
| if not action.resolution_code: | |
| error = "missing_resolution_code" | |
| last_tool_result = "submit_resolution requires a resolution_code." | |
| elif action.resolution_code not in self._active_task["available_resolution_codes"]: | |
| error = "invalid_resolution_code" | |
| last_tool_result = ( | |
| "Unsupported resolution code. Valid codes: " | |
| + ", ".join(self._active_task["available_resolution_codes"]) | |
| ) | |
| else: | |
| return self._finish_episode( | |
| resolution_code=action.resolution_code, | |
| feedback=f"Resolution submitted: {action.resolution_code}", | |
| ) | |
| else: | |
| error = "invalid_action_type" | |
| last_tool_result = "Unsupported action_type." | |
| self._state.last_action_error = error | |
| if self._state.step_count >= self.MAX_STEPS: | |
| return self._finish_episode( | |
| resolution_code=self._state.resolution_code, | |
| feedback="Episode ended because the step limit was reached.", | |
| ) | |
| self._state.cumulative_reward = round(self._state.cumulative_reward + reward, 3) | |
| return self._build_observation( | |
| reward=reward, | |
| done=False, | |
| last_tool_result=last_tool_result, | |
| last_action_error=error, | |
| ) | |
| def state(self) -> ToolUseState: | |
| return self._state | |