| """ReAct reasoning loop for StableToolBench evaluation. |
| |
| Implements the iterative Thought -> Action -> Observation loop, |
| making tool calls via the virtual API server and managing conversation state. |
| """ |
| import json, time, requests |
| from typing import Dict, List, Any, Optional, Tuple |
| from copy import deepcopy |
| from llm_client import LLMClient |
| from config import MAX_STEPS, MAX_OBSERVATION_LENGTH, API_SERVER_URL, API_SERVER_PORT |
|
|
|
|
| class ReActRunner: |
| """Runs a single ReAct episode for one query.""" |
| def __init__(self, llm, functions, tool_descriptions, api_name_reflect, tool_names, cate_names, service_url=None, max_steps=MAX_STEPS, max_observation_length=MAX_OBSERVATION_LENGTH): |
| self.llm, self.functions = llm, functions |
| self.api_name_reflect, self.tool_names, self.cate_names = api_name_reflect, tool_names, cate_names |
| self.service_url = service_url or f"{API_SERVER_URL}:{API_SERVER_PORT}/virtual" |
| self.max_steps, self.max_observation_length = max_steps, max_observation_length |
| self.success, self.final_answer, self.trajectory = False, "", [] |
| self.total_tokens, self.query_count = 0, 0 |
|
|
| def run(self, initial_messages): |
| messages, give_up = list(initial_messages), False |
| for step in range(self.max_steps): |
| self.llm.change_messages(messages) |
| response, error_code, tokens = self.llm.parse(tools=self.functions, process_id=0) |
| self.total_tokens += tokens; self.query_count += 1 |
| if error_code != 0: |
| self.trajectory.append(("error", "LLM generation failed")); break |
| content = response.get("content", "") |
| if content: self.trajectory.append(("thought", content)) |
| tool_calls = response.get("tool_calls", []) |
| if not tool_calls: |
| messages.append(response) |
| if step > 0: break |
| continue |
| for i, tc in enumerate(tool_calls): |
| func_name, func_args = tc["function"]["name"], tc["function"]["arguments"] |
| self.trajectory.append(("action", f"{func_name}({func_args})")) |
| observation, status = self._execute_tool(func_name, func_args) |
| if len(observation) > self.max_observation_length: |
| observation = observation[:self.max_observation_length] + "..." |
| self.trajectory.append(("observation", observation)) |
| if func_name == "Finish": |
| try: args = json.loads(func_args) if isinstance(func_args, str) else func_args |
| except: args = {} |
| if args.get("return_type") == "give_answer": |
| self.success, self.final_answer = True, args.get("final_answer", "") |
| elif args.get("return_type") == "give_up_and_restart": give_up = True |
| break |
| if status == 1: tc["function"]["name"] = "invalid_hallucination_function_name" |
| if tool_calls: |
| messages.append(response) |
| for i, tc in enumerate(tool_calls): |
| obs_idx = len(self.trajectory) - (len(tool_calls) - i) * 2 + 1 |
| obs = self.trajectory[obs_idx][1] if 0 <= obs_idx < len(self.trajectory) else "" |
| messages.append({"role": "tool", "name": tc["function"]["name"], "content": obs, "tool_call_id": tc["id"]}) |
| else: messages.append(response) |
| if self.success or give_up: break |
| return {"success": self.success, "final_answer": self.final_answer, "trajectory": self.trajectory, "give_up": give_up, "total_tokens": self.total_tokens, "query_count": self.query_count, "steps": step + 1, "messages": messages} |
|
|
| def _execute_tool(self, action_name, action_input): |
| if action_name == "Finish": |
| try: json_data = json.loads(action_input) if isinstance(action_input, str) else action_input |
| except: |
| json_data = {} |
| if '"return_type": "give_answer"' in str(action_input): json_data["return_type"] = "give_answer" |
| elif '"return_type": "give_up_and_restart"' in str(action_input): json_data["return_type"] = "give_up_and_restart" |
| if '"final_answer": "' in str(action_input): |
| start = str(action_input).find('"final_answer": "') + len('"final_answer": "') |
| json_data["final_answer"] = str(action_input)[start:].rstrip('"} ') |
| if "return_type" not in json_data: return '{"error":"must have return_type"}', 2 |
| if json_data["return_type"] == "give_up_and_restart": return '{"response":"chose to give up and restart"}', 4 |
| elif json_data["return_type"] == "give_answer": |
| if "final_answer" not in json_data: return '{"error":"must have final_answer"}', 2 |
| return '{"response":"successfully giving the final answer."}', 3 |
| else: return '{"error":"return_type is not a valid choice"}', 2 |
| for k, func_dict in enumerate(self.functions): |
| func = func_dict["function"] |
| if func["name"].endswith(action_name) or func["name"] == action_name: |
| pure_api_name = self.api_name_reflect.get(func["name"], action_name) |
| payload = {"category": self.cate_names[k] if k < len(self.cate_names) else "", "tool_name": self.tool_names[k] if k < len(self.tool_names) else "", "api_name": pure_api_name, "tool_input": action_input, "strip": "", "toolbench_key": ""} |
| try: |
| resp = requests.post(self.service_url, json=payload, timeout=30) |
| if resp.status_code != 200: return json.dumps({"error": f"Server error: {resp.status_code}", "response": ""}), 12 |
| response = resp.json() |
| error = response.get("error", "") |
| status_map = {"API not working error...": 6, "Unauthorized error...": 7, "Unsubscribed error...": 8, "Too many requests error...": 9, "Message error...": 11} |
| return json.dumps(response), status_map.get(error, 0) |
| except requests.exceptions.Timeout: return json.dumps({"error": "Timeout error...", "response": ""}), 5 |
| except Exception as e: return json.dumps({"error": f"Request error: {str(e)}", "response": ""}), 12 |
| return json.dumps({"error": f"No such function name: {action_name}", "response": ""}), 1 |
|
|