specimba commited on
Commit
a07f2c0
·
verified ·
1 Parent(s): f9d9a60

Add OllamaRelayClient to hf_inference_client.py for intelligent router

Browse files
Files changed (1) hide show
  1. nexus_os_v2/hf_inference_client.py +51 -0
nexus_os_v2/hf_inference_client.py CHANGED
@@ -147,3 +147,54 @@ class MockInferenceClient:
147
 
148
  def list_models(self) -> List[str]:
149
  return ["mock-model"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  def list_models(self) -> List[str]:
149
  return ["mock-model"]
150
+
151
+
152
+ class OllamaRelayClient:
153
+ """
154
+ Connects to user's local Ollama via relay URL.
155
+ The user exposes their local Ollama via ngrok, localtunnel, or Cloudflare Tunnel.
156
+ Set OLLAMA_RELAY_URL env var to the public tunnel endpoint.
157
+ """
158
+ def __init__(self, relay_url: Optional[str] = None):
159
+ self.relay_url = relay_url or os.environ.get("OLLAMA_RELAY_URL", "")
160
+ if not self.relay_url:
161
+ self.relay_url = "http://localhost:11434"
162
+ self.relay_url = self.relay_url.rstrip("/")
163
+ self._available_models: List[str] = []
164
+
165
+ def is_connected(self) -> bool:
166
+ try:
167
+ import urllib.request
168
+ req = urllib.request.Request(
169
+ f"{self.relay_url}/api/tags",
170
+ headers={"Content-Type": "application/json"},
171
+ method="GET",
172
+ )
173
+ with urllib.request.urlopen(req, timeout=10) as resp:
174
+ data = json.loads(resp.read().decode("utf-8"))
175
+ self._available_models = [m.get("name", m.get("model", "")) for m in data.get("models", [])]
176
+ return True
177
+ except Exception:
178
+ return False
179
+
180
+ def generate(self, model_tag: str, prompt: str, system: Optional[str] = None,
181
+ temperature: float = 0.7, max_tokens: int = 2048, stream: bool = False):
182
+ messages = []
183
+ if system:
184
+ messages.append({"role": "system", "content": system})
185
+ messages.append({"role": "user", "content": prompt})
186
+ payload = json.dumps({"model": model_tag, "messages": messages, "stream": stream,
187
+ "options": {"temperature": temperature, "num_predict": max_tokens}}).encode("utf-8")
188
+ req = urllib.request.Request(f"{self.relay_url}/api/chat", data=payload,
189
+ headers={"Content-Type": "application/json"}, method="POST")
190
+ t0 = time.time()
191
+ with urllib.request.urlopen(req, timeout=300) as resp:
192
+ data = json.loads(resp.read().decode("utf-8"))
193
+ elapsed = (time.time() - t0) * 1000
194
+ text = data.get("message", {}).get("content", "") if "message" in data else data.get("response", "")
195
+ return text, {"model": data.get("model", model_tag), "latency_ms": elapsed}
196
+
197
+ def list_models(self) -> List[str]:
198
+ if not self._available_models:
199
+ self.is_connected()
200
+ return self._available_models