| import os |
| import requests |
| from typing import List |
| from tenacity import retry, stop_after_attempt, wait_random_exponential |
| from models.Base import BaseModel |
| from openai import OpenAI |
|
|
|
|
| class VLLMModel(BaseModel): |
| def __init__(self, |
| model_id="local-vllm", |
| base_url="http://localhost:8041/v1", |
| api_key=None): |
| """ |
| model_id: 模型名称或标识 (通常不重要,vLLM 会忽略) |
| base_url: vLLM API 地址,例如 http://localhost:8000/v1 |
| api_key: 可选(vLLM 默认不验证) |
| """ |
| self.model_id = model_id |
|
|
| client = OpenAI( |
| api_key='EMPTY', |
| base_url=base_url, |
| ) |
|
|
| models = client.models.list() |
| self.model_id = models.data[0].id |
|
|
| self.base_url = base_url.rstrip("/") |
| self.api_key = api_key |
|
|
| |
| self.headers = { |
| "Content-Type": "application/json" |
| } |
| if api_key: |
| self.headers["Authorization"] = f"Bearer {api_key}" |
|
|
| |
| self.endpoint = f"{self.base_url}/chat/completions" |
|
|
| @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(3)) |
| def generate(self, |
| messages: List, |
| temperature=0, |
| presence_penalty=0, |
| frequency_penalty=0, |
| max_tokens=5000) -> str: |
| """ |
| 调用本地 vLLM 推理接口 |
| """ |
| payload = { |
| "model": self.model_id, |
| "messages": messages, |
| "temperature": temperature, |
| "max_tokens": max_tokens, |
| "presence_penalty": presence_penalty, |
| "frequency_penalty": frequency_penalty, |
| } |
|
|
| import ipdb; ipdb.set_trace() |
| response = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=3000) |
|
|
| if response.status_code != 200: |
| raise RuntimeError(f"vLLM request failed: {response.status_code}, {response.text}") |
|
|
| data = response.json() |
| if "choices" not in data or len(data["choices"]) == 0: |
| raise ValueError("No response choices returned from vLLM API.") |
|
|
| return data["choices"][0]["message"]["content"] |
|
|
|
|