"""ReAct reasoning loop for StableToolBench evaluation. Implements the iterative Thought -> Action -> Observation loop, making tool calls via the virtual API server and managing conversation state. """ import json, time, requests from typing import Dict, List, Any, Optional, Tuple from copy import deepcopy from llm_client import LLMClient from config import MAX_STEPS, MAX_OBSERVATION_LENGTH, API_SERVER_URL, API_SERVER_PORT class ReActRunner: """Runs a single ReAct episode for one query.""" def __init__(self, llm, functions, tool_descriptions, api_name_reflect, tool_names, cate_names, service_url=None, max_steps=MAX_STEPS, max_observation_length=MAX_OBSERVATION_LENGTH): self.llm, self.functions = llm, functions self.api_name_reflect, self.tool_names, self.cate_names = api_name_reflect, tool_names, cate_names self.service_url = service_url or f"{API_SERVER_URL}:{API_SERVER_PORT}/virtual" self.max_steps, self.max_observation_length = max_steps, max_observation_length self.success, self.final_answer, self.trajectory = False, "", [] self.total_tokens, self.query_count = 0, 0 def run(self, initial_messages): messages, give_up = list(initial_messages), False for step in range(self.max_steps): self.llm.change_messages(messages) response, error_code, tokens = self.llm.parse(tools=self.functions, process_id=0) self.total_tokens += tokens; self.query_count += 1 if error_code != 0: self.trajectory.append(("error", "LLM generation failed")); break content = response.get("content", "") if content: self.trajectory.append(("thought", content)) tool_calls = response.get("tool_calls", []) if not tool_calls: messages.append(response) if step > 0: break continue for i, tc in enumerate(tool_calls): func_name, func_args = tc["function"]["name"], tc["function"]["arguments"] self.trajectory.append(("action", f"{func_name}({func_args})")) observation, status = self._execute_tool(func_name, func_args) if len(observation) > self.max_observation_length: observation = observation[:self.max_observation_length] + "..." self.trajectory.append(("observation", observation)) if func_name == "Finish": try: args = json.loads(func_args) if isinstance(func_args, str) else func_args except: args = {} if args.get("return_type") == "give_answer": self.success, self.final_answer = True, args.get("final_answer", "") elif args.get("return_type") == "give_up_and_restart": give_up = True break if status == 1: tc["function"]["name"] = "invalid_hallucination_function_name" if tool_calls: messages.append(response) for i, tc in enumerate(tool_calls): obs_idx = len(self.trajectory) - (len(tool_calls) - i) * 2 + 1 obs = self.trajectory[obs_idx][1] if 0 <= obs_idx < len(self.trajectory) else "" messages.append({"role": "tool", "name": tc["function"]["name"], "content": obs, "tool_call_id": tc["id"]}) else: messages.append(response) if self.success or give_up: break return {"success": self.success, "final_answer": self.final_answer, "trajectory": self.trajectory, "give_up": give_up, "total_tokens": self.total_tokens, "query_count": self.query_count, "steps": step + 1, "messages": messages} def _execute_tool(self, action_name, action_input): if action_name == "Finish": try: json_data = json.loads(action_input) if isinstance(action_input, str) else action_input except: json_data = {} if '"return_type": "give_answer"' in str(action_input): json_data["return_type"] = "give_answer" elif '"return_type": "give_up_and_restart"' in str(action_input): json_data["return_type"] = "give_up_and_restart" if '"final_answer": "' in str(action_input): start = str(action_input).find('"final_answer": "') + len('"final_answer": "') json_data["final_answer"] = str(action_input)[start:].rstrip('"} ') if "return_type" not in json_data: return '{"error":"must have return_type"}', 2 if json_data["return_type"] == "give_up_and_restart": return '{"response":"chose to give up and restart"}', 4 elif json_data["return_type"] == "give_answer": if "final_answer" not in json_data: return '{"error":"must have final_answer"}', 2 return '{"response":"successfully giving the final answer."}', 3 else: return '{"error":"return_type is not a valid choice"}', 2 for k, func_dict in enumerate(self.functions): func = func_dict["function"] if func["name"].endswith(action_name) or func["name"] == action_name: pure_api_name = self.api_name_reflect.get(func["name"], action_name) payload = {"category": self.cate_names[k] if k < len(self.cate_names) else "", "tool_name": self.tool_names[k] if k < len(self.tool_names) else "", "api_name": pure_api_name, "tool_input": action_input, "strip": "", "toolbench_key": ""} try: resp = requests.post(self.service_url, json=payload, timeout=30) if resp.status_code != 200: return json.dumps({"error": f"Server error: {resp.status_code}", "response": ""}), 12 response = resp.json() error = response.get("error", "") status_map = {"API not working error...": 6, "Unauthorized error...": 7, "Unsubscribed error...": 8, "Too many requests error...": 9, "Message error...": 11} return json.dumps(response), status_map.get(error, 0) except requests.exceptions.Timeout: return json.dumps({"error": "Timeout error...", "response": ""}), 5 except Exception as e: return json.dumps({"error": f"Request error: {str(e)}", "response": ""}), 12 return json.dumps({"error": f"No such function name: {action_name}", "response": ""}), 1