File size: 6,446 Bytes
b4cf41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""ReAct reasoning loop for StableToolBench evaluation.

Implements the iterative Thought -> Action -> Observation loop,
making tool calls via the virtual API server and managing conversation state.
"""
import json, time, requests
from typing import Dict, List, Any, Optional, Tuple
from copy import deepcopy
from llm_client import LLMClient
from config import MAX_STEPS, MAX_OBSERVATION_LENGTH, API_SERVER_URL, API_SERVER_PORT


class ReActRunner:
    """Runs a single ReAct episode for one query."""
    def __init__(self, llm, functions, tool_descriptions, api_name_reflect, tool_names, cate_names, service_url=None, max_steps=MAX_STEPS, max_observation_length=MAX_OBSERVATION_LENGTH):
        self.llm, self.functions = llm, functions
        self.api_name_reflect, self.tool_names, self.cate_names = api_name_reflect, tool_names, cate_names
        self.service_url = service_url or f"{API_SERVER_URL}:{API_SERVER_PORT}/virtual"
        self.max_steps, self.max_observation_length = max_steps, max_observation_length
        self.success, self.final_answer, self.trajectory = False, "", []
        self.total_tokens, self.query_count = 0, 0

    def run(self, initial_messages):
        messages, give_up = list(initial_messages), False
        for step in range(self.max_steps):
            self.llm.change_messages(messages)
            response, error_code, tokens = self.llm.parse(tools=self.functions, process_id=0)
            self.total_tokens += tokens; self.query_count += 1
            if error_code != 0:
                self.trajectory.append(("error", "LLM generation failed")); break
            content = response.get("content", "")
            if content: self.trajectory.append(("thought", content))
            tool_calls = response.get("tool_calls", [])
            if not tool_calls:
                messages.append(response)
                if step > 0: break
                continue
            for i, tc in enumerate(tool_calls):
                func_name, func_args = tc["function"]["name"], tc["function"]["arguments"]
                self.trajectory.append(("action", f"{func_name}({func_args})"))
                observation, status = self._execute_tool(func_name, func_args)
                if len(observation) > self.max_observation_length:
                    observation = observation[:self.max_observation_length] + "..."
                self.trajectory.append(("observation", observation))
                if func_name == "Finish":
                    try: args = json.loads(func_args) if isinstance(func_args, str) else func_args
                    except: args = {}
                    if args.get("return_type") == "give_answer":
                        self.success, self.final_answer = True, args.get("final_answer", "")
                    elif args.get("return_type") == "give_up_and_restart": give_up = True
                    break
                if status == 1: tc["function"]["name"] = "invalid_hallucination_function_name"
            if tool_calls:
                messages.append(response)
                for i, tc in enumerate(tool_calls):
                    obs_idx = len(self.trajectory) - (len(tool_calls) - i) * 2 + 1
                    obs = self.trajectory[obs_idx][1] if 0 <= obs_idx < len(self.trajectory) else ""
                    messages.append({"role": "tool", "name": tc["function"]["name"], "content": obs, "tool_call_id": tc["id"]})
            else: messages.append(response)
            if self.success or give_up: break
        return {"success": self.success, "final_answer": self.final_answer, "trajectory": self.trajectory, "give_up": give_up, "total_tokens": self.total_tokens, "query_count": self.query_count, "steps": step + 1, "messages": messages}

    def _execute_tool(self, action_name, action_input):
        if action_name == "Finish":
            try: json_data = json.loads(action_input) if isinstance(action_input, str) else action_input
            except:
                json_data = {}
                if '"return_type": "give_answer"' in str(action_input): json_data["return_type"] = "give_answer"
                elif '"return_type": "give_up_and_restart"' in str(action_input): json_data["return_type"] = "give_up_and_restart"
                if '"final_answer": "' in str(action_input):
                    start = str(action_input).find('"final_answer": "') + len('"final_answer": "')
                    json_data["final_answer"] = str(action_input)[start:].rstrip('"} ')
            if "return_type" not in json_data: return '{"error":"must have return_type"}', 2
            if json_data["return_type"] == "give_up_and_restart": return '{"response":"chose to give up and restart"}', 4
            elif json_data["return_type"] == "give_answer":
                if "final_answer" not in json_data: return '{"error":"must have final_answer"}', 2
                return '{"response":"successfully giving the final answer."}', 3
            else: return '{"error":"return_type is not a valid choice"}', 2
        for k, func_dict in enumerate(self.functions):
            func = func_dict["function"]
            if func["name"].endswith(action_name) or func["name"] == action_name:
                pure_api_name = self.api_name_reflect.get(func["name"], action_name)
                payload = {"category": self.cate_names[k] if k < len(self.cate_names) else "", "tool_name": self.tool_names[k] if k < len(self.tool_names) else "", "api_name": pure_api_name, "tool_input": action_input, "strip": "", "toolbench_key": ""}
                try:
                    resp = requests.post(self.service_url, json=payload, timeout=30)
                    if resp.status_code != 200: return json.dumps({"error": f"Server error: {resp.status_code}", "response": ""}), 12
                    response = resp.json()
                    error = response.get("error", "")
                    status_map = {"API not working error...": 6, "Unauthorized error...": 7, "Unsubscribed error...": 8, "Too many requests error...": 9, "Message error...": 11}
                    return json.dumps(response), status_map.get(error, 0)
                except requests.exceptions.Timeout: return json.dumps({"error": "Timeout error...", "response": ""}), 5
                except Exception as e: return json.dumps({"error": f"Request error: {str(e)}", "response": ""}), 12
        return json.dumps({"error": f"No such function name: {action_name}", "response": ""}), 1