import asyncio import json import os import textwrap from typing import Any, List, Optional from openai import OpenAI from tool_use_env.client import ToolUseEnv from tool_use_env.models import ToolUseAction from tool_use_env.tasks import TASK_SEQUENCE LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://127.0.0.1:8000") BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "support_ops_env") MAX_STEPS = 6 TEMPERATURE = 0.0 MAX_TOKENS = 220 SYSTEM_PROMPT = textwrap.dedent( """ You are operating a customer-support workflow environment. Your job is to gather the minimum necessary evidence, draft a short customer reply, and submit the correct final resolution code. Reply with JSON only using this schema: { "action_type": "review_ticket|inspect_artifact|search_policy|draft_reply|submit_resolution", "artifact_id": "optional string", "query": "optional string", "message": "optional string", "resolution_code": "optional string" } Use concise messages. Prefer exact artifact ids and exact resolution codes shown in the observation. """ ).strip() def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: error_val = error if error else "null" print( f"[STEP] step={step} action={action} reward={reward:.2f} " f"done={str(done).lower()} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) print( f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True, ) def _serialize_action(action: ToolUseAction) -> str: payload = {"action_type": action.action_type} if action.artifact_id: payload["artifact_id"] = action.artifact_id if action.query: payload["query"] = action.query if action.message: payload["message"] = action.message.replace("\n", " ").strip() if action.resolution_code: payload["resolution_code"] = action.resolution_code return json.dumps(payload, ensure_ascii=True, separators=(",", ":")) def _fallback_action(observation: Any) -> ToolUseAction: evidence = set(observation.collected_evidence) task_id = observation.task_id if "ticket" not in evidence: return ToolUseAction(action_type="review_ticket") task_plans = { "damaged-mug-replacement": [ ToolUseAction(action_type="inspect_artifact", artifact_id="order"), ToolUseAction(action_type="search_policy", query="damaged_items"), ToolUseAction( action_type="draft_reply", message=( "We are sending a replacement within 48 hours. " "There is no need to return the broken mug." ), ), ToolUseAction(action_type="submit_resolution", resolution_code="send_replacement"), ], "duplicate-charge-refund": [ ToolUseAction(action_type="inspect_artifact", artifact_id="order"), ToolUseAction(action_type="inspect_artifact", artifact_id="payment"), ToolUseAction(action_type="search_policy", query="duplicate_charge"), ToolUseAction( action_type="draft_reply", message=( "We confirmed the duplicate charge and issued a refund. " "You should see the refund in 3-5 business days." ), ), ToolUseAction( action_type="submit_resolution", resolution_code="refund_duplicate_charge", ), ], "account-takeover-fraud": [ ToolUseAction(action_type="inspect_artifact", artifact_id="account"), ToolUseAction(action_type="inspect_artifact", artifact_id="risk_log"), ToolUseAction(action_type="search_policy", query="account_takeover"), ToolUseAction( action_type="draft_reply", message=( "We locked your account immediately and escalated this to our fraud team. " "You will receive an update within 24 hours." ), ), ToolUseAction( action_type="submit_resolution", resolution_code="lock_account_and_escalate_fraud", ), ], } plan = task_plans[task_id] for candidate in plan: if candidate.action_type == "inspect_artifact": if f"artifact:{candidate.artifact_id}" not in evidence: return candidate elif candidate.action_type == "search_policy": if f"policy:{candidate.query}" not in evidence: return candidate elif candidate.action_type == "draft_reply" and not observation.last_tool_result.startswith("Draft saved"): return candidate elif candidate.action_type == "submit_resolution": return candidate return ToolUseAction(action_type="submit_resolution", resolution_code=observation.available_resolution_codes[0]) def _prompt_for_observation(step: int, observation: Any) -> str: return textwrap.dedent( f""" Step: {step} Task ID: {observation.task_id} Difficulty: {observation.difficulty} Objective: {observation.objective} Customer message: {observation.customer_message} Workspace summary: {observation.workspace_summary} Collected evidence: {observation.collected_evidence} Available resolution codes: {observation.available_resolution_codes} Last tool result: {observation.last_tool_result} Last action error: {observation.last_action_error} Remaining steps: {observation.remaining_steps} Return the single best next action as JSON. """ ).strip() def _model_action(client: OpenAI, step: int, observation: Any) -> ToolUseAction: fallback = _fallback_action(observation) if not API_KEY: return fallback try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": _prompt_for_observation(step, observation)}, ], temperature=TEMPERATURE, max_tokens=MAX_TOKENS, response_format={"type": "json_object"}, ) raw = (completion.choices[0].message.content or "").strip() data = json.loads(raw) return ToolUseAction( action_type=data.get("action_type", fallback.action_type), artifact_id=data.get("artifact_id"), query=data.get("query"), message=data.get("message"), resolution_code=data.get("resolution_code"), ) except Exception: return fallback async def _connect_env() -> ToolUseEnv: if LOCAL_IMAGE_NAME: return await ToolUseEnv.from_docker_image(LOCAL_IMAGE_NAME) env = ToolUseEnv(base_url=ENV_BASE_URL) await env.connect() return env async def run_task(client: OpenAI, env: ToolUseEnv, task_id: str) -> float: rewards: List[float] = [] steps_taken = 0 score = 0.0 success = False log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) try: result = await env.reset(task_id=task_id, seed=7) observation = result.observation for step in range(1, MAX_STEPS + 1): if result.done: break action = _model_action(client, step, observation) action_str = _serialize_action(action) result = await env.step(action) observation = result.observation reward = float(result.reward or 0.0) done = bool(result.done) error = observation.last_action_error rewards.append(reward) steps_taken = step log_step(step=step, action=action_str, reward=reward, done=done, error=error) if done: break state = await env.state() score = float(state.final_score) success = score >= 0.8 finally: log_end(success=success, steps=steps_taken, score=score, rewards=rewards) return score async def main() -> None: client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY or "missing") env = await _connect_env() try: scores = [] for task_id in TASK_SEQUENCE: score = await run_task(client, env, task_id) scores.append(score) finally: await env.close() if __name__ == "__main__": asyncio.run(main())