Add react_loop.py
Browse files- pipeline/react_loop.py +93 -0
pipeline/react_loop.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ReAct reasoning loop for StableToolBench evaluation.
|
| 2 |
+
|
| 3 |
+
Implements the iterative Thought -> Action -> Observation loop,
|
| 4 |
+
making tool calls via the virtual API server and managing conversation state.
|
| 5 |
+
"""
|
| 6 |
+
import json, time, requests
|
| 7 |
+
from typing import Dict, List, Any, Optional, Tuple
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from llm_client import LLMClient
|
| 10 |
+
from config import MAX_STEPS, MAX_OBSERVATION_LENGTH, API_SERVER_URL, API_SERVER_PORT
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ReActRunner:
|
| 14 |
+
"""Runs a single ReAct episode for one query."""
|
| 15 |
+
def __init__(self, llm, functions, tool_descriptions, api_name_reflect, tool_names, cate_names, service_url=None, max_steps=MAX_STEPS, max_observation_length=MAX_OBSERVATION_LENGTH):
|
| 16 |
+
self.llm, self.functions = llm, functions
|
| 17 |
+
self.api_name_reflect, self.tool_names, self.cate_names = api_name_reflect, tool_names, cate_names
|
| 18 |
+
self.service_url = service_url or f"{API_SERVER_URL}:{API_SERVER_PORT}/virtual"
|
| 19 |
+
self.max_steps, self.max_observation_length = max_steps, max_observation_length
|
| 20 |
+
self.success, self.final_answer, self.trajectory = False, "", []
|
| 21 |
+
self.total_tokens, self.query_count = 0, 0
|
| 22 |
+
|
| 23 |
+
def run(self, initial_messages):
|
| 24 |
+
messages, give_up = list(initial_messages), False
|
| 25 |
+
for step in range(self.max_steps):
|
| 26 |
+
self.llm.change_messages(messages)
|
| 27 |
+
response, error_code, tokens = self.llm.parse(tools=self.functions, process_id=0)
|
| 28 |
+
self.total_tokens += tokens; self.query_count += 1
|
| 29 |
+
if error_code != 0:
|
| 30 |
+
self.trajectory.append(("error", "LLM generation failed")); break
|
| 31 |
+
content = response.get("content", "")
|
| 32 |
+
if content: self.trajectory.append(("thought", content))
|
| 33 |
+
tool_calls = response.get("tool_calls", [])
|
| 34 |
+
if not tool_calls:
|
| 35 |
+
messages.append(response)
|
| 36 |
+
if step > 0: break
|
| 37 |
+
continue
|
| 38 |
+
for i, tc in enumerate(tool_calls):
|
| 39 |
+
func_name, func_args = tc["function"]["name"], tc["function"]["arguments"]
|
| 40 |
+
self.trajectory.append(("action", f"{func_name}({func_args})"))
|
| 41 |
+
observation, status = self._execute_tool(func_name, func_args)
|
| 42 |
+
if len(observation) > self.max_observation_length:
|
| 43 |
+
observation = observation[:self.max_observation_length] + "..."
|
| 44 |
+
self.trajectory.append(("observation", observation))
|
| 45 |
+
if func_name == "Finish":
|
| 46 |
+
try: args = json.loads(func_args) if isinstance(func_args, str) else func_args
|
| 47 |
+
except: args = {}
|
| 48 |
+
if args.get("return_type") == "give_answer":
|
| 49 |
+
self.success, self.final_answer = True, args.get("final_answer", "")
|
| 50 |
+
elif args.get("return_type") == "give_up_and_restart": give_up = True
|
| 51 |
+
break
|
| 52 |
+
if status == 1: tc["function"]["name"] = "invalid_hallucination_function_name"
|
| 53 |
+
if tool_calls:
|
| 54 |
+
messages.append(response)
|
| 55 |
+
for i, tc in enumerate(tool_calls):
|
| 56 |
+
obs_idx = len(self.trajectory) - (len(tool_calls) - i) * 2 + 1
|
| 57 |
+
obs = self.trajectory[obs_idx][1] if 0 <= obs_idx < len(self.trajectory) else ""
|
| 58 |
+
messages.append({"role": "tool", "name": tc["function"]["name"], "content": obs, "tool_call_id": tc["id"]})
|
| 59 |
+
else: messages.append(response)
|
| 60 |
+
if self.success or give_up: break
|
| 61 |
+
return {"success": self.success, "final_answer": self.final_answer, "trajectory": self.trajectory, "give_up": give_up, "total_tokens": self.total_tokens, "query_count": self.query_count, "steps": step + 1, "messages": messages}
|
| 62 |
+
|
| 63 |
+
def _execute_tool(self, action_name, action_input):
|
| 64 |
+
if action_name == "Finish":
|
| 65 |
+
try: json_data = json.loads(action_input) if isinstance(action_input, str) else action_input
|
| 66 |
+
except:
|
| 67 |
+
json_data = {}
|
| 68 |
+
if '"return_type": "give_answer"' in str(action_input): json_data["return_type"] = "give_answer"
|
| 69 |
+
elif '"return_type": "give_up_and_restart"' in str(action_input): json_data["return_type"] = "give_up_and_restart"
|
| 70 |
+
if '"final_answer": "' in str(action_input):
|
| 71 |
+
start = str(action_input).find('"final_answer": "') + len('"final_answer": "')
|
| 72 |
+
json_data["final_answer"] = str(action_input)[start:].rstrip('"} ')
|
| 73 |
+
if "return_type" not in json_data: return '{"error":"must have return_type"}', 2
|
| 74 |
+
if json_data["return_type"] == "give_up_and_restart": return '{"response":"chose to give up and restart"}', 4
|
| 75 |
+
elif json_data["return_type"] == "give_answer":
|
| 76 |
+
if "final_answer" not in json_data: return '{"error":"must have final_answer"}', 2
|
| 77 |
+
return '{"response":"successfully giving the final answer."}', 3
|
| 78 |
+
else: return '{"error":"return_type is not a valid choice"}', 2
|
| 79 |
+
for k, func_dict in enumerate(self.functions):
|
| 80 |
+
func = func_dict["function"]
|
| 81 |
+
if func["name"].endswith(action_name) or func["name"] == action_name:
|
| 82 |
+
pure_api_name = self.api_name_reflect.get(func["name"], action_name)
|
| 83 |
+
payload = {"category": self.cate_names[k] if k < len(self.cate_names) else "", "tool_name": self.tool_names[k] if k < len(self.tool_names) else "", "api_name": pure_api_name, "tool_input": action_input, "strip": "", "toolbench_key": ""}
|
| 84 |
+
try:
|
| 85 |
+
resp = requests.post(self.service_url, json=payload, timeout=30)
|
| 86 |
+
if resp.status_code != 200: return json.dumps({"error": f"Server error: {resp.status_code}", "response": ""}), 12
|
| 87 |
+
response = resp.json()
|
| 88 |
+
error = response.get("error", "")
|
| 89 |
+
status_map = {"API not working error...": 6, "Unauthorized error...": 7, "Unsubscribed error...": 8, "Too many requests error...": 9, "Message error...": 11}
|
| 90 |
+
return json.dumps(response), status_map.get(error, 0)
|
| 91 |
+
except requests.exceptions.Timeout: return json.dumps({"error": "Timeout error...", "response": ""}), 5
|
| 92 |
+
except Exception as e: return json.dumps({"error": f"Request error: {str(e)}", "response": ""}), 12
|
| 93 |
+
return json.dumps({"error": f"No such function name: {action_name}", "response": ""}), 1
|