Add llm_client.py, react_loop.py, run_eval.py, virtual_api_server.py
Browse files- pipeline/llm_client.py +42 -0
pipeline/llm_client.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM client using vLLM for Llama-3.1-8B-Instruct inference with tool calling."""
|
| 2 |
+
import json, os, time, traceback
|
| 3 |
+
from typing import Dict, List, Optional, Any
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LLMClient:
|
| 8 |
+
"""OpenAI-compatible client for vLLM-served Llama model."""
|
| 9 |
+
def __init__(self, model="meta-llama/Llama-3.1-8B-Instruct", base_url="http://localhost:8000/v1", api_key="dummy", temperature=0.001, max_tokens=1024):
|
| 10 |
+
self.model, self.temperature, self.max_tokens = model, temperature, max_tokens
|
| 11 |
+
self.client = OpenAI(base_url=base_url, api_key=api_key)
|
| 12 |
+
self.conversation_history = []
|
| 13 |
+
|
| 14 |
+
def change_messages(self, messages):
|
| 15 |
+
self.conversation_history = messages
|
| 16 |
+
|
| 17 |
+
def parse(self, tools, process_id=0, max_retries=3):
|
| 18 |
+
use_messages = []
|
| 19 |
+
for msg in self.conversation_history:
|
| 20 |
+
if msg.get("valid", True) != False:
|
| 21 |
+
clean_msg = {k: v for k, v in msg.items() if k != "valid"}
|
| 22 |
+
clean_msg.pop("function_call", None)
|
| 23 |
+
use_messages.append(clean_msg)
|
| 24 |
+
for attempt in range(max_retries):
|
| 25 |
+
try:
|
| 26 |
+
kwargs = {"model": self.model, "messages": use_messages, "max_tokens": self.max_tokens, "temperature": self.temperature, "frequency_penalty": 0, "presence_penalty": 0}
|
| 27 |
+
if tools and any(t["function"]["name"] != "Finish" for t in tools):
|
| 28 |
+
kwargs["tools"] = tools
|
| 29 |
+
kwargs["parallel_tool_calls"] = False
|
| 30 |
+
response = self.client.chat.completions.create(**kwargs)
|
| 31 |
+
message = response.choices[0].message
|
| 32 |
+
total_tokens = response.usage.total_tokens if response.usage else 0
|
| 33 |
+
msg_dict = {"role": "assistant", "content": message.content}
|
| 34 |
+
if message.tool_calls:
|
| 35 |
+
msg_dict["tool_calls"] = [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in message.tool_calls]
|
| 36 |
+
if process_id == 0: print(f"[process({process_id})] tokens: {total_tokens}")
|
| 37 |
+
return msg_dict, 0, total_tokens
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"[process({process_id})] Attempt {attempt+1} error: {repr(e)}")
|
| 40 |
+
traceback.print_exc()
|
| 41 |
+
if attempt < max_retries - 1: time.sleep(2)
|
| 42 |
+
return {"role": "assistant", "content": "Error generating response"}, -1, 0
|