Dwootton commited on
Commit
d2dafa6
·
verified ·
1 Parent(s): a5eff80

Add llm_client.py, react_loop.py, run_eval.py, virtual_api_server.py

Browse files
Files changed (1) hide show
  1. pipeline/llm_client.py +42 -0
pipeline/llm_client.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM client using vLLM for Llama-3.1-8B-Instruct inference with tool calling."""
2
+ import json, os, time, traceback
3
+ from typing import Dict, List, Optional, Any
4
+ from openai import OpenAI
5
+
6
+
7
+ class LLMClient:
8
+ """OpenAI-compatible client for vLLM-served Llama model."""
9
+ def __init__(self, model="meta-llama/Llama-3.1-8B-Instruct", base_url="http://localhost:8000/v1", api_key="dummy", temperature=0.001, max_tokens=1024):
10
+ self.model, self.temperature, self.max_tokens = model, temperature, max_tokens
11
+ self.client = OpenAI(base_url=base_url, api_key=api_key)
12
+ self.conversation_history = []
13
+
14
+ def change_messages(self, messages):
15
+ self.conversation_history = messages
16
+
17
+ def parse(self, tools, process_id=0, max_retries=3):
18
+ use_messages = []
19
+ for msg in self.conversation_history:
20
+ if msg.get("valid", True) != False:
21
+ clean_msg = {k: v for k, v in msg.items() if k != "valid"}
22
+ clean_msg.pop("function_call", None)
23
+ use_messages.append(clean_msg)
24
+ for attempt in range(max_retries):
25
+ try:
26
+ kwargs = {"model": self.model, "messages": use_messages, "max_tokens": self.max_tokens, "temperature": self.temperature, "frequency_penalty": 0, "presence_penalty": 0}
27
+ if tools and any(t["function"]["name"] != "Finish" for t in tools):
28
+ kwargs["tools"] = tools
29
+ kwargs["parallel_tool_calls"] = False
30
+ response = self.client.chat.completions.create(**kwargs)
31
+ message = response.choices[0].message
32
+ total_tokens = response.usage.total_tokens if response.usage else 0
33
+ msg_dict = {"role": "assistant", "content": message.content}
34
+ if message.tool_calls:
35
+ msg_dict["tool_calls"] = [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in message.tool_calls]
36
+ if process_id == 0: print(f"[process({process_id})] tokens: {total_tokens}")
37
+ return msg_dict, 0, total_tokens
38
+ except Exception as e:
39
+ print(f"[process({process_id})] Attempt {attempt+1} error: {repr(e)}")
40
+ traceback.print_exc()
41
+ if attempt < max_retries - 1: time.sleep(2)
42
+ return {"role": "assistant", "content": "Error generating response"}, -1, 0