Dwootton commited on
Commit
b4cf41e
·
verified ·
1 Parent(s): d2dafa6

Add react_loop.py

Browse files
Files changed (1) hide show
  1. pipeline/react_loop.py +93 -0
pipeline/react_loop.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ReAct reasoning loop for StableToolBench evaluation.
2
+
3
+ Implements the iterative Thought -> Action -> Observation loop,
4
+ making tool calls via the virtual API server and managing conversation state.
5
+ """
6
+ import json, time, requests
7
+ from typing import Dict, List, Any, Optional, Tuple
8
+ from copy import deepcopy
9
+ from llm_client import LLMClient
10
+ from config import MAX_STEPS, MAX_OBSERVATION_LENGTH, API_SERVER_URL, API_SERVER_PORT
11
+
12
+
13
+ class ReActRunner:
14
+ """Runs a single ReAct episode for one query."""
15
+ def __init__(self, llm, functions, tool_descriptions, api_name_reflect, tool_names, cate_names, service_url=None, max_steps=MAX_STEPS, max_observation_length=MAX_OBSERVATION_LENGTH):
16
+ self.llm, self.functions = llm, functions
17
+ self.api_name_reflect, self.tool_names, self.cate_names = api_name_reflect, tool_names, cate_names
18
+ self.service_url = service_url or f"{API_SERVER_URL}:{API_SERVER_PORT}/virtual"
19
+ self.max_steps, self.max_observation_length = max_steps, max_observation_length
20
+ self.success, self.final_answer, self.trajectory = False, "", []
21
+ self.total_tokens, self.query_count = 0, 0
22
+
23
+ def run(self, initial_messages):
24
+ messages, give_up = list(initial_messages), False
25
+ for step in range(self.max_steps):
26
+ self.llm.change_messages(messages)
27
+ response, error_code, tokens = self.llm.parse(tools=self.functions, process_id=0)
28
+ self.total_tokens += tokens; self.query_count += 1
29
+ if error_code != 0:
30
+ self.trajectory.append(("error", "LLM generation failed")); break
31
+ content = response.get("content", "")
32
+ if content: self.trajectory.append(("thought", content))
33
+ tool_calls = response.get("tool_calls", [])
34
+ if not tool_calls:
35
+ messages.append(response)
36
+ if step > 0: break
37
+ continue
38
+ for i, tc in enumerate(tool_calls):
39
+ func_name, func_args = tc["function"]["name"], tc["function"]["arguments"]
40
+ self.trajectory.append(("action", f"{func_name}({func_args})"))
41
+ observation, status = self._execute_tool(func_name, func_args)
42
+ if len(observation) > self.max_observation_length:
43
+ observation = observation[:self.max_observation_length] + "..."
44
+ self.trajectory.append(("observation", observation))
45
+ if func_name == "Finish":
46
+ try: args = json.loads(func_args) if isinstance(func_args, str) else func_args
47
+ except: args = {}
48
+ if args.get("return_type") == "give_answer":
49
+ self.success, self.final_answer = True, args.get("final_answer", "")
50
+ elif args.get("return_type") == "give_up_and_restart": give_up = True
51
+ break
52
+ if status == 1: tc["function"]["name"] = "invalid_hallucination_function_name"
53
+ if tool_calls:
54
+ messages.append(response)
55
+ for i, tc in enumerate(tool_calls):
56
+ obs_idx = len(self.trajectory) - (len(tool_calls) - i) * 2 + 1
57
+ obs = self.trajectory[obs_idx][1] if 0 <= obs_idx < len(self.trajectory) else ""
58
+ messages.append({"role": "tool", "name": tc["function"]["name"], "content": obs, "tool_call_id": tc["id"]})
59
+ else: messages.append(response)
60
+ if self.success or give_up: break
61
+ return {"success": self.success, "final_answer": self.final_answer, "trajectory": self.trajectory, "give_up": give_up, "total_tokens": self.total_tokens, "query_count": self.query_count, "steps": step + 1, "messages": messages}
62
+
63
+ def _execute_tool(self, action_name, action_input):
64
+ if action_name == "Finish":
65
+ try: json_data = json.loads(action_input) if isinstance(action_input, str) else action_input
66
+ except:
67
+ json_data = {}
68
+ if '"return_type": "give_answer"' in str(action_input): json_data["return_type"] = "give_answer"
69
+ elif '"return_type": "give_up_and_restart"' in str(action_input): json_data["return_type"] = "give_up_and_restart"
70
+ if '"final_answer": "' in str(action_input):
71
+ start = str(action_input).find('"final_answer": "') + len('"final_answer": "')
72
+ json_data["final_answer"] = str(action_input)[start:].rstrip('"} ')
73
+ if "return_type" not in json_data: return '{"error":"must have return_type"}', 2
74
+ if json_data["return_type"] == "give_up_and_restart": return '{"response":"chose to give up and restart"}', 4
75
+ elif json_data["return_type"] == "give_answer":
76
+ if "final_answer" not in json_data: return '{"error":"must have final_answer"}', 2
77
+ return '{"response":"successfully giving the final answer."}', 3
78
+ else: return '{"error":"return_type is not a valid choice"}', 2
79
+ for k, func_dict in enumerate(self.functions):
80
+ func = func_dict["function"]
81
+ if func["name"].endswith(action_name) or func["name"] == action_name:
82
+ pure_api_name = self.api_name_reflect.get(func["name"], action_name)
83
+ payload = {"category": self.cate_names[k] if k < len(self.cate_names) else "", "tool_name": self.tool_names[k] if k < len(self.tool_names) else "", "api_name": pure_api_name, "tool_input": action_input, "strip": "", "toolbench_key": ""}
84
+ try:
85
+ resp = requests.post(self.service_url, json=payload, timeout=30)
86
+ if resp.status_code != 200: return json.dumps({"error": f"Server error: {resp.status_code}", "response": ""}), 12
87
+ response = resp.json()
88
+ error = response.get("error", "")
89
+ status_map = {"API not working error...": 6, "Unauthorized error...": 7, "Unsubscribed error...": 8, "Too many requests error...": 9, "Message error...": 11}
90
+ return json.dumps(response), status_map.get(error, 0)
91
+ except requests.exceptions.Timeout: return json.dumps({"error": "Timeout error...", "response": ""}), 5
92
+ except Exception as e: return json.dumps({"error": f"Request error: {str(e)}", "response": ""}), 12
93
+ return json.dumps({"error": f"No such function name: {action_name}", "response": ""}), 1