Spaces:
Running
Running
Add OllamaRelayClient to hf_inference_client.py for intelligent router
Browse files
nexus_os_v2/hf_inference_client.py
CHANGED
|
@@ -147,3 +147,54 @@ class MockInferenceClient:
|
|
| 147 |
|
| 148 |
def list_models(self) -> List[str]:
|
| 149 |
return ["mock-model"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def list_models(self) -> List[str]:
|
| 149 |
return ["mock-model"]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class OllamaRelayClient:
|
| 153 |
+
"""
|
| 154 |
+
Connects to user's local Ollama via relay URL.
|
| 155 |
+
The user exposes their local Ollama via ngrok, localtunnel, or Cloudflare Tunnel.
|
| 156 |
+
Set OLLAMA_RELAY_URL env var to the public tunnel endpoint.
|
| 157 |
+
"""
|
| 158 |
+
def __init__(self, relay_url: Optional[str] = None):
|
| 159 |
+
self.relay_url = relay_url or os.environ.get("OLLAMA_RELAY_URL", "")
|
| 160 |
+
if not self.relay_url:
|
| 161 |
+
self.relay_url = "http://localhost:11434"
|
| 162 |
+
self.relay_url = self.relay_url.rstrip("/")
|
| 163 |
+
self._available_models: List[str] = []
|
| 164 |
+
|
| 165 |
+
def is_connected(self) -> bool:
|
| 166 |
+
try:
|
| 167 |
+
import urllib.request
|
| 168 |
+
req = urllib.request.Request(
|
| 169 |
+
f"{self.relay_url}/api/tags",
|
| 170 |
+
headers={"Content-Type": "application/json"},
|
| 171 |
+
method="GET",
|
| 172 |
+
)
|
| 173 |
+
with urllib.request.urlopen(req, timeout=10) as resp:
|
| 174 |
+
data = json.loads(resp.read().decode("utf-8"))
|
| 175 |
+
self._available_models = [m.get("name", m.get("model", "")) for m in data.get("models", [])]
|
| 176 |
+
return True
|
| 177 |
+
except Exception:
|
| 178 |
+
return False
|
| 179 |
+
|
| 180 |
+
def generate(self, model_tag: str, prompt: str, system: Optional[str] = None,
|
| 181 |
+
temperature: float = 0.7, max_tokens: int = 2048, stream: bool = False):
|
| 182 |
+
messages = []
|
| 183 |
+
if system:
|
| 184 |
+
messages.append({"role": "system", "content": system})
|
| 185 |
+
messages.append({"role": "user", "content": prompt})
|
| 186 |
+
payload = json.dumps({"model": model_tag, "messages": messages, "stream": stream,
|
| 187 |
+
"options": {"temperature": temperature, "num_predict": max_tokens}}).encode("utf-8")
|
| 188 |
+
req = urllib.request.Request(f"{self.relay_url}/api/chat", data=payload,
|
| 189 |
+
headers={"Content-Type": "application/json"}, method="POST")
|
| 190 |
+
t0 = time.time()
|
| 191 |
+
with urllib.request.urlopen(req, timeout=300) as resp:
|
| 192 |
+
data = json.loads(resp.read().decode("utf-8"))
|
| 193 |
+
elapsed = (time.time() - t0) * 1000
|
| 194 |
+
text = data.get("message", {}).get("content", "") if "message" in data else data.get("response", "")
|
| 195 |
+
return text, {"model": data.get("model", model_tag), "latency_ms": elapsed}
|
| 196 |
+
|
| 197 |
+
def list_models(self) -> List[str]:
|
| 198 |
+
if not self._available_models:
|
| 199 |
+
self.is_connected()
|
| 200 |
+
return self._available_models
|