| """LLM client using vLLM for Llama-3.1-8B-Instruct inference with tool calling.""" |
| import json, os, time, traceback |
| from typing import Dict, List, Optional, Any |
| from openai import OpenAI |
|
|
|
|
| class LLMClient: |
| """OpenAI-compatible client for vLLM-served Llama model.""" |
| def __init__(self, model="meta-llama/Llama-3.1-8B-Instruct", base_url="http://localhost:8000/v1", api_key="dummy", temperature=0.001, max_tokens=1024): |
| self.model, self.temperature, self.max_tokens = model, temperature, max_tokens |
| self.client = OpenAI(base_url=base_url, api_key=api_key) |
| self.conversation_history = [] |
|
|
| def change_messages(self, messages): |
| self.conversation_history = messages |
|
|
| def parse(self, tools, process_id=0, max_retries=3): |
| use_messages = [] |
| for msg in self.conversation_history: |
| if msg.get("valid", True) != False: |
| clean_msg = {k: v for k, v in msg.items() if k != "valid"} |
| clean_msg.pop("function_call", None) |
| use_messages.append(clean_msg) |
| for attempt in range(max_retries): |
| try: |
| kwargs = {"model": self.model, "messages": use_messages, "max_tokens": self.max_tokens, "temperature": self.temperature, "frequency_penalty": 0, "presence_penalty": 0} |
| if tools and any(t["function"]["name"] != "Finish" for t in tools): |
| kwargs["tools"] = tools |
| kwargs["parallel_tool_calls"] = False |
| response = self.client.chat.completions.create(**kwargs) |
| message = response.choices[0].message |
| total_tokens = response.usage.total_tokens if response.usage else 0 |
| msg_dict = {"role": "assistant", "content": message.content} |
| if message.tool_calls: |
| msg_dict["tool_calls"] = [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in message.tool_calls] |
| if process_id == 0: print(f"[process({process_id})] tokens: {total_tokens}") |
| return msg_dict, 0, total_tokens |
| except Exception as e: |
| print(f"[process({process_id})] Attempt {attempt+1} error: {repr(e)}") |
| traceback.print_exc() |
| if attempt < max_retries - 1: time.sleep(2) |
| return {"role": "assistant", "content": "Error generating response"}, -1, 0 |
|
|