File size: 3,142 Bytes
6d9c72b
 
 
 
 
 
cf0a8ed
6d9c72b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466cc3d
 
 
 
 
6d9c72b
466cc3d
6d9c72b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466cc3d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Async HTTP client for vLLM OpenAI-compatible API."""
import logging
from typing import Any

import httpx

from apohara_context_forge.config import settings

logger = logging.getLogger(__name__)


class vLLMClient:
    """Async client for vLLM server."""

    def __init__(self, base_url: str | None = None, api_key: str | None = None):
        self._base_url = base_url or settings.vllm_base_url
        self._api_key = api_key or settings.vllm_api_key
        self._client: httpx.AsyncClient | None = None

    async def __aenter__(self):
        self._client = httpx.AsyncClient(
            base_url=self._base_url,
            headers={"Authorization": f"Bearer {self._api_key}"},
            timeout=60.0,
        )
        return self

    async def __aexit__(self, *args):
        await self.aclose()

    async def aclose(self) -> None:
        """Close the underlying httpx client. Safe to call multiple times."""
        if self._client is not None:
            await self._client.aclose()
            self._client = None

    async def complete(
        self,
        prompt: str,
        max_tokens: int = 256,
        temperature: float = 0.7,
        **kwargs,
    ) -> dict[str, Any]:
        """Send completion request to vLLM."""
        if self._client is None:
            self._client = httpx.AsyncClient(
                base_url=self._base_url,
                headers={"Authorization": f"Bearer {self._api_key}"},
                timeout=60.0,
            )

        payload = {
            "model": settings.vllm_model,
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": temperature,
            **kwargs,
        }

        try:
            response = await self._client.post("/v1/completions", json=payload)
            response.raise_for_status()
            return response.json()
        except httpx.HTTPError as e:
            logger.error(f"vLLM request failed: {e}")
            return {"error": str(e)}

    async def chat(
        self,
        messages: list[dict[str, str]],
        max_tokens: int = 256,
        temperature: float = 0.7,
        **kwargs,
    ) -> dict[str, Any]:
        """Send chat completion request."""
        if self._client is None:
            self._client = httpx.AsyncClient(
                base_url=self._base_url,
                headers={"Authorization": f"Bearer {self._api_key}"},
                timeout=60.0,
            )

        payload = {
            "model": settings.vllm_model,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            **kwargs,
        }

        try:
            response = await self._client.post("/v1/chat/completions", json=payload)
            response.raise_for_status()
            return response.json()
        except httpx.HTTPError as e:
            logger.error(f"vLLM chat request failed: {e}")
            return {"error": str(e)}


# Canonical PEP-8 alias. Tests and the MCP server import the upper-case form;
# the lower-case original stays for backward compatibility with older callers.
VLLMClient = vLLMClient