ToolUseEnv / tool_use_env /server /tool_use_env_environment.py
Clove25's picture
Upload 53 files
18feac5 verified
from __future__ import annotations
import random
import uuid
from typing import Any
from openenv.core.env_server import Environment
from tool_use_env.grader import grade_task
from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState
from tool_use_env.tasks import TASKS, TASK_SEQUENCE
class ToolUseEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS = True
MAX_STEPS = 6
def __init__(self) -> None:
super().__init__()
self._state = ToolUseState()
self._active_task: dict[str, Any] | None = None
self._task_cursor = 0
def _select_task(self, seed: int | None = None, task_id: str | None = None) -> dict[str, Any]:
if task_id:
if task_id not in TASKS:
raise ValueError(f"Unknown task_id '{task_id}'")
return TASKS[task_id]
if seed is not None:
rng = random.Random(seed)
return TASKS[TASK_SEQUENCE[rng.randrange(len(TASK_SEQUENCE))]]
selected = TASKS[TASK_SEQUENCE[self._task_cursor % len(TASK_SEQUENCE)]]
self._task_cursor += 1
return selected
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: Any,
) -> ToolUseObservation:
task = self._select_task(seed=seed, task_id=kwargs.get("task_id"))
self._active_task = task
self._state = ToolUseState(
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
task_id=task["task_id"],
task_name=task["task_name"],
difficulty=task["difficulty"],
objective=task["objective"],
cumulative_reward=0.0,
final_score=0.0,
drafted_reply=None,
resolution_code=None,
expected_resolution_code=task["expected_resolution_code"],
required_evidence=list(task["required_evidence"]),
collected_evidence=["ticket"],
action_history=[],
repeat_action_count=0,
last_action_error=None,
known_artifacts={},
known_policies={},
)
return self._build_observation(
reward=0.0,
done=False,
last_tool_result=(
"Ticket loaded. Start by reviewing the ticket, then inspect the most relevant "
"artifacts and policy before submitting a resolution."
),
)
def _normalize_artifact_id(self, artifact_id: str | None) -> str | None:
if not artifact_id:
return None
normalized = artifact_id.strip().lower().replace(" ", "_")
aliases = {
"payments": "payment",
"billing": "payment",
"risk": "risk_log",
"risklog": "risk_log",
"profile": "account",
}
return aliases.get(normalized, normalized)
def _resolve_policy_key(self, query: str | None) -> str | None:
if not query or not self._active_task:
return None
normalized = query.strip().lower().replace(" ", "_")
policies = self._active_task["policies"]
if normalized in policies:
return normalized
alias_map = {
"damaged": "damaged_items",
"damage": "damaged_items",
"replacement": "damaged_items",
"duplicate": "duplicate_charge",
"duplicate_charge": "duplicate_charge",
"billing": "duplicate_charge",
"fraud": "account_takeover",
"takeover": "account_takeover",
"account_takeover": "account_takeover",
"security": "account_takeover",
}
mapped = alias_map.get(normalized)
if mapped in policies:
return mapped
for key in policies:
if normalized in key:
return key
return None
def _record_repeat_if_needed(self, evidence_key: str) -> bool:
if evidence_key in self._state.collected_evidence:
self._state.repeat_action_count += 1
return True
return False
def _partial_score(self) -> float:
if not self._active_task:
return 0.0
return grade_task(
self._active_task,
self._state.collected_evidence,
self._state.drafted_reply,
self._state.resolution_code,
self._state.step_count,
self._state.repeat_action_count,
)["final_score"]
def _append_history(self, action: ToolUseAction) -> None:
parts = [action.action_type]
if action.artifact_id:
parts.append(f"artifact={action.artifact_id}")
if action.query:
parts.append(f"query={action.query}")
if action.resolution_code:
parts.append(f"resolution={action.resolution_code}")
self._state.action_history.append(" | ".join(parts))
def _build_observation(
self,
reward: float,
done: bool,
last_tool_result: str | None,
last_action_error: str | None = None,
) -> ToolUseObservation:
task = self._active_task
if not task:
raise RuntimeError("Environment has no active task.")
score = self._state.final_score if done else self._partial_score()
remaining_steps = max(0, self.MAX_STEPS - self._state.step_count)
known_items = self._state.collected_evidence or ["ticket"]
draft_status = "present" if self._state.drafted_reply else "missing"
resolution_status = self._state.resolution_code or "not submitted"
summary = (
f"Known evidence: {', '.join(known_items)}. "
f"Draft reply: {draft_status}. "
f"Resolution: {resolution_status}. "
f"Submit the best supported resolution before steps run out."
)
return ToolUseObservation(
done=done,
reward=round(min(max(reward, 0.0), 1.0), 3),
task_id=task["task_id"],
difficulty=task["difficulty"],
objective=task["objective"],
customer_message=task["customer_message"],
workspace_summary=summary,
available_actions=[
"review_ticket",
"inspect_artifact",
"search_policy",
"draft_reply",
"submit_resolution",
],
available_resolution_codes=list(task["available_resolution_codes"]),
collected_evidence=list(self._state.collected_evidence),
last_tool_result=last_tool_result,
last_action_error=last_action_error,
remaining_steps=remaining_steps,
current_score=round(score, 3),
metadata={
"task_name": task["task_name"],
"action_history": list(self._state.action_history),
},
)
def _finish_episode(self, resolution_code: str | None, feedback: str) -> ToolUseObservation:
if not self._active_task:
raise RuntimeError("Environment has no active task.")
self._state.resolution_code = resolution_code
breakdown = grade_task(
self._active_task,
self._state.collected_evidence,
self._state.drafted_reply,
self._state.resolution_code,
self._state.step_count,
self._state.repeat_action_count,
)
self._state.final_score = breakdown["final_score"]
self._state.last_action_error = None
result_text = (
f"{feedback} | final_score={breakdown['final_score']:.3f} | "
f"resolution_score={breakdown['resolution_score']:.3f} | "
f"evidence_score={breakdown['evidence_score']:.3f} | "
f"reply_score={breakdown['reply_score']:.3f} | "
f"efficiency_score={breakdown['efficiency_score']:.3f}"
)
return self._build_observation(
reward=breakdown["final_score"],
done=True,
last_tool_result=result_text,
)
def step(
self,
action: ToolUseAction,
timeout_s: float | None = None,
**kwargs: Any,
) -> ToolUseObservation:
if not self._active_task:
raise RuntimeError("Call reset() before step().")
if self._state.final_score > 0 and self._state.resolution_code:
return self._build_observation(
reward=0.0,
done=True,
last_tool_result="Episode already finished.",
last_action_error="episode_already_done",
)
self._state.step_count += 1
self._append_history(action)
reward = 0.0
last_tool_result = None
error = None
if action.action_type == "review_ticket":
repeated = self._record_repeat_if_needed("ticket")
reward = 0.02 if repeated else 0.10
last_tool_result = self._active_task["customer_message"]
elif action.action_type == "inspect_artifact":
artifact_id = self._normalize_artifact_id(action.artifact_id)
artifacts = self._active_task["artifacts"]
if not artifact_id or artifact_id not in artifacts:
error = "invalid_artifact_id"
last_tool_result = (
"Unknown artifact. Valid artifacts: "
+ ", ".join(sorted(artifacts.keys()))
)
else:
evidence_key = f"artifact:{artifact_id}"
repeated = self._record_repeat_if_needed(evidence_key)
if not repeated:
self._state.collected_evidence.append(evidence_key)
self._state.known_artifacts[artifact_id] = artifacts[artifact_id]
reward = 0.14 if evidence_key in self._state.required_evidence else 0.04
else:
reward = 0.01
last_tool_result = artifacts[artifact_id]
elif action.action_type == "search_policy":
policy_key = self._resolve_policy_key(action.query)
policies = self._active_task["policies"]
if not policy_key:
error = "policy_not_found"
last_tool_result = (
"No matching policy found. Available policies: "
+ ", ".join(sorted(policies.keys()))
)
else:
evidence_key = f"policy:{policy_key}"
repeated = self._record_repeat_if_needed(evidence_key)
if not repeated:
self._state.collected_evidence.append(evidence_key)
self._state.known_policies[policy_key] = policies[policy_key]
reward = 0.14 if evidence_key in self._state.required_evidence else 0.04
else:
reward = 0.01
last_tool_result = policies[policy_key]
elif action.action_type == "draft_reply":
if not action.message or not action.message.strip():
error = "empty_reply"
last_tool_result = "Draft reply cannot be empty."
else:
self._state.drafted_reply = action.message.strip()
keywords = self._active_task["reply_keywords"]
hits = sum(
1 for keyword in keywords if keyword.lower() in self._state.drafted_reply.lower()
)
reward = round(0.05 + (0.15 * (hits / len(keywords))), 3)
last_tool_result = (
f"Draft saved. Included {hits}/{len(keywords)} required reply cues."
)
elif action.action_type == "submit_resolution":
if not action.resolution_code:
error = "missing_resolution_code"
last_tool_result = "submit_resolution requires a resolution_code."
elif action.resolution_code not in self._active_task["available_resolution_codes"]:
error = "invalid_resolution_code"
last_tool_result = (
"Unsupported resolution code. Valid codes: "
+ ", ".join(self._active_task["available_resolution_codes"])
)
else:
return self._finish_episode(
resolution_code=action.resolution_code,
feedback=f"Resolution submitted: {action.resolution_code}",
)
else:
error = "invalid_action_type"
last_tool_result = "Unsupported action_type."
self._state.last_action_error = error
if self._state.step_count >= self.MAX_STEPS:
return self._finish_episode(
resolution_code=self._state.resolution_code,
feedback="Episode ended because the step limit was reached.",
)
self._state.cumulative_reward = round(self._state.cumulative_reward + reward, 3)
return self._build_observation(
reward=reward,
done=False,
last_tool_result=last_tool_result,
last_action_error=error,
)
@property
def state(self) -> ToolUseState:
return self._state