File size: 15,039 Bytes
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27caebd
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27caebd
 
 
 
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
a88dae7
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27caebd
08f8699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
"""LLM provider abstraction for the interactive demo.

The demo points at any OpenAI-compatible ``/v1/chat/completions`` endpoint:
local Ollama, Hugging Face's Inference Providers router, OpenAI itself,
vLLM, OpenRouter, etc. Everything funnels through one factory so the UI
only has to learn one shape.

The browser passes ``base_url``, ``model``, and (optionally) ``api_key``
on every request. If ``api_key`` is missing we fall back to a per-provider
env var so a Hugging Face Space can ship a default working config without
hard-coding secrets in client bundles.
"""

from __future__ import annotations

import logging
import os
from collections.abc import Callable
from typing import Optional

from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field

_log = logging.getLogger(__name__)


HF_ROUTER_BASE_URL = "https://router.huggingface.co/v1"
OPENAI_BASE_URL = "https://api.openai.com/v1"
OLLAMA_OPENAI_BASE_URL = "http://localhost:11434/v1"
PHYSIX_INFER_BASE_URL = "https://pratyush-01-physix-infer.hf.space/v1"


class LlmStepRequest(BaseModel):
    """Provider-agnostic step payload.

    The browser names a base URL + model + (optional) key. The server
    fans these into an ``openai.OpenAI`` client. ``base_url`` is required
    so we never silently default to the wrong endpoint when the user
    swaps providers mid-session.
    """

    model_config = ConfigDict(extra="forbid")

    base_url: str = Field(
        description=(
            "OpenAI-compatible /v1 base URL. E.g. http://localhost:11434/v1, "
            "https://router.huggingface.co/v1, https://api.openai.com/v1."
        ),
    )
    model: str = Field(
        description=(
            "Model id understood by the chosen base URL. For HF this is the "
            "repo id (optionally suffixed with :provider, e.g. ':fastest'); "
            "for Ollama it's the local tag; for OpenAI it's the model name."
        ),
    )
    api_key: Optional[str] = Field(
        default=None,
        description=(
            "Bearer token forwarded as Authorization header. Falls back to "
            "HF_TOKEN / OPENAI_API_KEY / OLLAMA_API_KEY env vars on the "
            "server based on `base_url` if omitted."
        ),
    )
    temperature: float = Field(default=0.4, ge=0.0, le=2.0)
    max_tokens: int = Field(default=2048, ge=64, le=8192)
    request_timeout_s: float = Field(default=120.0, ge=5.0, le=600.0)


# A policy is "give me prompt messages, get back the assistant content".
LlmPolicy = Callable[[list[dict[str, str]]], str]
LlmPolicyFactory = Callable[[LlmStepRequest], LlmPolicy]


def resolve_api_key(request: LlmStepRequest) -> Optional[str]:
    """Pick the bearer token to use for this request.

    Browser-supplied keys win. When the browser sends nothing we fall
    back to a server-side env var picked from the URL — this lets a
    public Hugging Face Space ship a usable default by setting
    ``HF_TOKEN`` as a Space secret while still letting power users
    bring their own.
    """

    if request.api_key:
        return request.api_key

    base_url = (request.base_url or "").lower()
    # The PhysiX-Infer sister Space serves Qwen + the trained 3B with no
    # auth — it's open-access by design (rate-limited only by sleep).
    if "physix-infer" in base_url:
        return "physix-infer"
    if "huggingface" in base_url:
        return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_API_KEY")
    if "openai.com" in base_url:
        return os.environ.get("OPENAI_API_KEY")
    if "openrouter" in base_url:
        return os.environ.get("OPENROUTER_API_KEY")
    if "localhost" in base_url or "127.0.0.1" in base_url:
        # Ollama doesn't require a key; the SDK still wants something
        # truthy in some versions, so we hand it a placeholder.
        return os.environ.get("OLLAMA_API_KEY", "ollama")
    return None


def default_openai_compat_policy_factory(request: LlmStepRequest) -> LlmPolicy:
    """Build a chat policy for any OpenAI-compatible endpoint. Raises HTTPException(502) on provider errors."""

    try:
        from openai import (  # type: ignore[import-not-found]
            APIConnectionError,
            APITimeoutError,
            AuthenticationError,
            BadRequestError,
            NotFoundError,
            OpenAI,
        )
    except ImportError as exc:  # pragma: no cover
        raise HTTPException(
            status_code=503,
            detail=(
                "The 'openai' Python package is not installed on the server. "
                "Install with: pip install -e '.[demo]'"
            ),
        ) from exc

    api_key = resolve_api_key(request)
    client = OpenAI(
        base_url=request.base_url,
        api_key=api_key or "missing",
        timeout=request.request_timeout_s,
        # Identifies us to providers that rate-limit by UA. Cheap to
        # send; helps the demo not look like a generic SDK probe.
        default_headers={"User-Agent": "physix-live-demo/0.1"},
    )

    def _create(*, with_json: bool):
        kwargs: dict[str, object] = {
            "model": request.model,
            "messages": prompt_holder["prompt"],
            "temperature": request.temperature,
            "max_tokens": request.max_tokens,
        }
        if with_json:
            # Encourages JSON output where supported (OpenAI, vLLM,
            # Ollama-OpenAI). HF router silently ignores this on
            # providers that don't support it; ones that do reject
            # land in the BadRequestError fallback below.
            kwargs["response_format"] = {"type": "json_object"}
        return client.chat.completions.create(**kwargs)  # type: ignore[arg-type]

    # Captures the prompt for `_create` without re-parameterising it.
    prompt_holder: dict[str, list[dict[str, str]]] = {"prompt": []}

    def _policy(prompt: list[dict[str, str]]) -> str:
        prompt_holder["prompt"] = prompt
        try:
            try:
                response = _create(with_json=True)
            except (BadRequestError, TypeError) as exc:
                # Provider rejected response_format (or the SDK shape).
                # Retry without it; if it still fails, that's a real
                # error and we let it bubble to the outer handler.
                _log.info(
                    "Retrying without response_format for %s: %s",
                    request.base_url,
                    exc,
                )
                response = _create(with_json=False)
        except AuthenticationError as exc:
            raise HTTPException(
                status_code=502,
                detail=_format_auth_error(request, exc),
            ) from exc
        except NotFoundError as exc:
            raise HTTPException(
                status_code=502,
                detail=_format_not_found_error(request, exc),
            ) from exc
        except (APIConnectionError, APITimeoutError) as exc:
            raise HTTPException(
                status_code=502,
                detail=_format_connection_error(request, exc),
            ) from exc
        except Exception as exc:  # noqa: BLE001 — last-resort UI surface
            raise HTTPException(
                status_code=502,
                detail=_format_provider_error(request, exc),
            ) from exc

        choice = response.choices[0] if response.choices else None
        content = (choice.message.content if choice and choice.message else "") or ""
        return str(content)

    return _policy


def _format_provider_error(request: LlmStepRequest, exc: Exception) -> str:
    """Last-resort error formatter for unclassified provider failures.

    Most failures should land in one of the typed handlers below
    (`_format_auth_error`, `_format_not_found_error`,
    `_format_connection_error`). This is the catch-all when an
    OpenAI-compatible endpoint returns something we don't recognise.
    The string-matching here exists only because the test suite
    exercises this path directly without going through the full
    SDK exception hierarchy.
    """

    base_msg = f"Chat completion failed via {request.base_url} for model {request.model!r}: {exc}"
    text = str(exc).lower()
    # HF Router 400s for unservable models land here (they're
    # ``BadRequestError`` from the SDK, so they bypass NotFoundError).
    # Detect both wordings the router emits.
    if "model_not_supported" in text or "is not supported by any provider" in text or "not supported by provider" in text:
        return _format_model_not_supported_error(request, exc)
    if "401" in text or "unauthorized" in text or "invalid api key" in text:
        return _format_auth_error(request, exc)
    if "404" in text or "not found" in text or "no such model" in text:
        return _format_not_found_error(request, exc)
    if "connection" in text or "refused" in text or "timeout" in text:
        return _format_connection_error(request, exc)
    return base_msg


def _format_model_not_supported_error(
    request: LlmStepRequest, exc: Exception
) -> str:
    """The HF Router accepted the request but no enabled provider serves
    this model. The user-actionable fix depends on whether they own the
    model card or not, so we offer both paths."""

    base = (
        f"HF Router can't serve {request.model!r}: no inference provider "
        f"is enabled for this model. ({exc})"
    )
    return (
        f"{base}\n\n"
        "Hint: open the model card at "
        f"https://huggingface.co/{request.model.split(':')[0]} → "
        "'Inference Providers' panel. If it lists no providers, the model "
        "isn't routable yet. Two ways to fix:\n"
        "  • Pick a model that already serves through the router — e.g. "
        "Qwen/Qwen2.5-7B-Instruct, meta-llama/Llama-3.3-70B-Instruct, "
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.\n"
        "  • For your own models, deploy them via the model card's "
        "'Deploy → Inference Endpoints' button (paid) or run the weights "
        "locally with Ollama / vLLM and switch this connection to that "
        "endpoint."
    )


def _format_auth_error(request: LlmStepRequest, exc: Exception) -> str:
    base = (
        f"Authentication failed at {request.base_url} for model "
        f"{request.model!r}: {exc}"
    )
    if "huggingface" in request.base_url.lower():
        return (
            f"{base}\n\n"
            "Hint: HF Router needs a token with the 'Make calls to "
            "Inference Providers' fine-grained permission. Re-create the "
            "token at https://huggingface.co/settings/tokens with that "
            "scope checked, then paste it into the API key field."
        )
    return (
        f"{base}\n\n"
        "Hint: the API key is missing or rejected. Open the connection "
        "panel and paste a valid token, or set the matching env var on "
        "the server (HF_TOKEN, OPENAI_API_KEY, etc.)."
    )


def _format_not_found_error(request: LlmStepRequest, exc: Exception) -> str:
    base = (
        f"Model not found at {request.base_url}: {request.model!r} "
        f"({exc})"
    )
    if "huggingface" in request.base_url.lower():
        return (
            f"{base}\n\n"
            "Hint: the model isn't being served by any HF Inference "
            "Provider right now. Check the model card at "
            f"https://huggingface.co/{request.model.split(':')[0]} — "
            "the 'Deploy → Inference API' panel must list at least one "
            "provider as 'warm'. You can also append ':fastest' to the "
            "model id to let HF auto-pick a provider."
        )
    if "11434" in request.base_url:
        return (
            f"{base}\n\n"
            "Hint: that Ollama tag isn't pulled. Run "
            f"'ollama pull {request.model}' and retry."
        )
    return base


def _format_connection_error(request: LlmStepRequest, exc: Exception) -> str:
    base = f"Could not reach {request.base_url}: {exc}"
    if "11434" in request.base_url:
        return (
            f"{base}\n\n"
            "Hint: 'ollama serve' isn't running on this machine. Start "
            "it in another terminal and retry."
        )
    return (
        f"{base}\n\n"
        "Hint: the endpoint isn't reachable from the server. Check the "
        "URL, your network, and any firewall in front of it."
    )


# -----------------------------------------------------------------------
# Ollama-only model lister (kept for the local-dev convenience dropdown).
# -----------------------------------------------------------------------


class LlmModelInfo(BaseModel):
    """A single locally-pulled Ollama model tag."""

    model_config = ConfigDict(frozen=True)

    name: str
    size_bytes: Optional[int] = None
    parameter_size: Optional[str] = None
    family: Optional[str] = None


class LlmModelsResponse(BaseModel):
    models: list[LlmModelInfo] = Field(default_factory=list)
    error: Optional[str] = None


LlmModelsLister = Callable[[], LlmModelsResponse]


def default_ollama_models_lister() -> LlmModelsResponse:
    """Enumerate locally-pulled Ollama tags. Best-effort."""

    try:
        import ollama  # type: ignore[import-not-found]
    except ImportError:
        return LlmModelsResponse(
            models=[],
            error=(
                "The 'ollama' Python package is not installed on the server. "
                "Install with: pip install -e '.[demo]'"
            ),
        )

    try:
        response = ollama.Client().list()
    except Exception as exc:  # noqa: BLE001 — surfaced in the response body
        return LlmModelsResponse(
            models=[],
            error=(
                f"Could not reach the local Ollama daemon ({exc}). "
                "Is 'ollama serve' running?"
            ),
        )

    raw_models = getattr(response, "models", None)
    if raw_models is None and isinstance(response, dict):
        raw_models = response.get("models", [])
    raw_models = raw_models or []

    out: list[LlmModelInfo] = []
    for entry in raw_models:
        name = _model_attr(entry, "model") or _model_attr(entry, "name")
        if not isinstance(name, str) or not name:
            continue
        details = _model_attr(entry, "details")
        out.append(
            LlmModelInfo(
                name=name,
                size_bytes=_coerce_int(_model_attr(entry, "size")),
                parameter_size=_model_attr(details, "parameter_size"),
                family=_model_attr(details, "family"),
            )
        )

    out.sort(key=lambda m: m.name)
    return LlmModelsResponse(models=out)


def _model_attr(obj: object, key: str) -> object:
    if obj is None:
        return None
    if isinstance(obj, dict):
        return obj.get(key)
    return getattr(obj, key, None)


def _coerce_int(value: object) -> Optional[int]:
    if value is None:
        return None
    try:
        return int(value)
    except (TypeError, ValueError):
        return None