p2p-stabletoolbench / pipeline /llm_client.py
Dwootton's picture
Add llm_client.py, react_loop.py, run_eval.py, virtual_api_server.py
d2dafa6 verified
"""LLM client using vLLM for Llama-3.1-8B-Instruct inference with tool calling."""
import json, os, time, traceback
from typing import Dict, List, Optional, Any
from openai import OpenAI
class LLMClient:
"""OpenAI-compatible client for vLLM-served Llama model."""
def __init__(self, model="meta-llama/Llama-3.1-8B-Instruct", base_url="http://localhost:8000/v1", api_key="dummy", temperature=0.001, max_tokens=1024):
self.model, self.temperature, self.max_tokens = model, temperature, max_tokens
self.client = OpenAI(base_url=base_url, api_key=api_key)
self.conversation_history = []
def change_messages(self, messages):
self.conversation_history = messages
def parse(self, tools, process_id=0, max_retries=3):
use_messages = []
for msg in self.conversation_history:
if msg.get("valid", True) != False:
clean_msg = {k: v for k, v in msg.items() if k != "valid"}
clean_msg.pop("function_call", None)
use_messages.append(clean_msg)
for attempt in range(max_retries):
try:
kwargs = {"model": self.model, "messages": use_messages, "max_tokens": self.max_tokens, "temperature": self.temperature, "frequency_penalty": 0, "presence_penalty": 0}
if tools and any(t["function"]["name"] != "Finish" for t in tools):
kwargs["tools"] = tools
kwargs["parallel_tool_calls"] = False
response = self.client.chat.completions.create(**kwargs)
message = response.choices[0].message
total_tokens = response.usage.total_tokens if response.usage else 0
msg_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
msg_dict["tool_calls"] = [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in message.tool_calls]
if process_id == 0: print(f"[process({process_id})] tokens: {total_tokens}")
return msg_dict, 0, total_tokens
except Exception as e:
print(f"[process({process_id})] Attempt {attempt+1} error: {repr(e)}")
traceback.print_exc()
if attempt < max_retries - 1: time.sleep(2)
return {"role": "assistant", "content": "Error generating response"}, -1, 0