File size: 6,446 Bytes
b4cf41e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | """ReAct reasoning loop for StableToolBench evaluation.
Implements the iterative Thought -> Action -> Observation loop,
making tool calls via the virtual API server and managing conversation state.
"""
import json, time, requests
from typing import Dict, List, Any, Optional, Tuple
from copy import deepcopy
from llm_client import LLMClient
from config import MAX_STEPS, MAX_OBSERVATION_LENGTH, API_SERVER_URL, API_SERVER_PORT
class ReActRunner:
"""Runs a single ReAct episode for one query."""
def __init__(self, llm, functions, tool_descriptions, api_name_reflect, tool_names, cate_names, service_url=None, max_steps=MAX_STEPS, max_observation_length=MAX_OBSERVATION_LENGTH):
self.llm, self.functions = llm, functions
self.api_name_reflect, self.tool_names, self.cate_names = api_name_reflect, tool_names, cate_names
self.service_url = service_url or f"{API_SERVER_URL}:{API_SERVER_PORT}/virtual"
self.max_steps, self.max_observation_length = max_steps, max_observation_length
self.success, self.final_answer, self.trajectory = False, "", []
self.total_tokens, self.query_count = 0, 0
def run(self, initial_messages):
messages, give_up = list(initial_messages), False
for step in range(self.max_steps):
self.llm.change_messages(messages)
response, error_code, tokens = self.llm.parse(tools=self.functions, process_id=0)
self.total_tokens += tokens; self.query_count += 1
if error_code != 0:
self.trajectory.append(("error", "LLM generation failed")); break
content = response.get("content", "")
if content: self.trajectory.append(("thought", content))
tool_calls = response.get("tool_calls", [])
if not tool_calls:
messages.append(response)
if step > 0: break
continue
for i, tc in enumerate(tool_calls):
func_name, func_args = tc["function"]["name"], tc["function"]["arguments"]
self.trajectory.append(("action", f"{func_name}({func_args})"))
observation, status = self._execute_tool(func_name, func_args)
if len(observation) > self.max_observation_length:
observation = observation[:self.max_observation_length] + "..."
self.trajectory.append(("observation", observation))
if func_name == "Finish":
try: args = json.loads(func_args) if isinstance(func_args, str) else func_args
except: args = {}
if args.get("return_type") == "give_answer":
self.success, self.final_answer = True, args.get("final_answer", "")
elif args.get("return_type") == "give_up_and_restart": give_up = True
break
if status == 1: tc["function"]["name"] = "invalid_hallucination_function_name"
if tool_calls:
messages.append(response)
for i, tc in enumerate(tool_calls):
obs_idx = len(self.trajectory) - (len(tool_calls) - i) * 2 + 1
obs = self.trajectory[obs_idx][1] if 0 <= obs_idx < len(self.trajectory) else ""
messages.append({"role": "tool", "name": tc["function"]["name"], "content": obs, "tool_call_id": tc["id"]})
else: messages.append(response)
if self.success or give_up: break
return {"success": self.success, "final_answer": self.final_answer, "trajectory": self.trajectory, "give_up": give_up, "total_tokens": self.total_tokens, "query_count": self.query_count, "steps": step + 1, "messages": messages}
def _execute_tool(self, action_name, action_input):
if action_name == "Finish":
try: json_data = json.loads(action_input) if isinstance(action_input, str) else action_input
except:
json_data = {}
if '"return_type": "give_answer"' in str(action_input): json_data["return_type"] = "give_answer"
elif '"return_type": "give_up_and_restart"' in str(action_input): json_data["return_type"] = "give_up_and_restart"
if '"final_answer": "' in str(action_input):
start = str(action_input).find('"final_answer": "') + len('"final_answer": "')
json_data["final_answer"] = str(action_input)[start:].rstrip('"} ')
if "return_type" not in json_data: return '{"error":"must have return_type"}', 2
if json_data["return_type"] == "give_up_and_restart": return '{"response":"chose to give up and restart"}', 4
elif json_data["return_type"] == "give_answer":
if "final_answer" not in json_data: return '{"error":"must have final_answer"}', 2
return '{"response":"successfully giving the final answer."}', 3
else: return '{"error":"return_type is not a valid choice"}', 2
for k, func_dict in enumerate(self.functions):
func = func_dict["function"]
if func["name"].endswith(action_name) or func["name"] == action_name:
pure_api_name = self.api_name_reflect.get(func["name"], action_name)
payload = {"category": self.cate_names[k] if k < len(self.cate_names) else "", "tool_name": self.tool_names[k] if k < len(self.tool_names) else "", "api_name": pure_api_name, "tool_input": action_input, "strip": "", "toolbench_key": ""}
try:
resp = requests.post(self.service_url, json=payload, timeout=30)
if resp.status_code != 200: return json.dumps({"error": f"Server error: {resp.status_code}", "response": ""}), 12
response = resp.json()
error = response.get("error", "")
status_map = {"API not working error...": 6, "Unauthorized error...": 7, "Unsubscribed error...": 8, "Too many requests error...": 9, "Message error...": 11}
return json.dumps(response), status_map.get(error, 0)
except requests.exceptions.Timeout: return json.dumps({"error": "Timeout error...", "response": ""}), 5
except Exception as e: return json.dumps({"error": f"Request error: {str(e)}", "response": ""}), 12
return json.dumps({"error": f"No such function name: {action_name}", "response": ""}), 1
|