Spaces:
Sleeping
Sleeping
| """LLM provider abstraction for the interactive demo. | |
| The demo points at any OpenAI-compatible ``/v1/chat/completions`` endpoint: | |
| local Ollama, Hugging Face's Inference Providers router, OpenAI itself, | |
| vLLM, OpenRouter, etc. Everything funnels through one factory so the UI | |
| only has to learn one shape. | |
| The browser passes ``base_url``, ``model``, and (optionally) ``api_key`` | |
| on every request. If ``api_key`` is missing we fall back to a per-provider | |
| env var so a Hugging Face Space can ship a default working config without | |
| hard-coding secrets in client bundles. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from collections.abc import Callable | |
| from typing import Optional | |
| from fastapi import HTTPException | |
| from pydantic import BaseModel, ConfigDict, Field | |
| _log = logging.getLogger(__name__) | |
| HF_ROUTER_BASE_URL = "https://router.huggingface.co/v1" | |
| OPENAI_BASE_URL = "https://api.openai.com/v1" | |
| OLLAMA_OPENAI_BASE_URL = "http://localhost:11434/v1" | |
| PHYSIX_INFER_BASE_URL = "https://pratyush-01-physix-infer.hf.space/v1" | |
| class LlmStepRequest(BaseModel): | |
| """Provider-agnostic step payload. | |
| The browser names a base URL + model + (optional) key. The server | |
| fans these into an ``openai.OpenAI`` client. ``base_url`` is required | |
| so we never silently default to the wrong endpoint when the user | |
| swaps providers mid-session. | |
| """ | |
| model_config = ConfigDict(extra="forbid") | |
| base_url: str = Field( | |
| description=( | |
| "OpenAI-compatible /v1 base URL. E.g. http://localhost:11434/v1, " | |
| "https://router.huggingface.co/v1, https://api.openai.com/v1." | |
| ), | |
| ) | |
| model: str = Field( | |
| description=( | |
| "Model id understood by the chosen base URL. For HF this is the " | |
| "repo id (optionally suffixed with :provider, e.g. ':fastest'); " | |
| "for Ollama it's the local tag; for OpenAI it's the model name." | |
| ), | |
| ) | |
| api_key: Optional[str] = Field( | |
| default=None, | |
| description=( | |
| "Bearer token forwarded as Authorization header. Falls back to " | |
| "HF_TOKEN / OPENAI_API_KEY / OLLAMA_API_KEY env vars on the " | |
| "server based on `base_url` if omitted." | |
| ), | |
| ) | |
| temperature: float = Field(default=0.4, ge=0.0, le=2.0) | |
| max_tokens: int = Field(default=2048, ge=64, le=8192) | |
| request_timeout_s: float = Field(default=120.0, ge=5.0, le=600.0) | |
| # A policy is "give me prompt messages, get back the assistant content". | |
| LlmPolicy = Callable[[list[dict[str, str]]], str] | |
| LlmPolicyFactory = Callable[[LlmStepRequest], LlmPolicy] | |
| def resolve_api_key(request: LlmStepRequest) -> Optional[str]: | |
| """Pick the bearer token to use for this request. | |
| Browser-supplied keys win. When the browser sends nothing we fall | |
| back to a server-side env var picked from the URL — this lets a | |
| public Hugging Face Space ship a usable default by setting | |
| ``HF_TOKEN`` as a Space secret while still letting power users | |
| bring their own. | |
| """ | |
| if request.api_key: | |
| return request.api_key | |
| base_url = (request.base_url or "").lower() | |
| # The PhysiX-Infer sister Space serves Qwen + the trained 3B with no | |
| # auth — it's open-access by design (rate-limited only by sleep). | |
| if "physix-infer" in base_url: | |
| return "physix-infer" | |
| if "huggingface" in base_url: | |
| return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_API_KEY") | |
| if "openai.com" in base_url: | |
| return os.environ.get("OPENAI_API_KEY") | |
| if "openrouter" in base_url: | |
| return os.environ.get("OPENROUTER_API_KEY") | |
| if "localhost" in base_url or "127.0.0.1" in base_url: | |
| # Ollama doesn't require a key; the SDK still wants something | |
| # truthy in some versions, so we hand it a placeholder. | |
| return os.environ.get("OLLAMA_API_KEY", "ollama") | |
| return None | |
| def default_openai_compat_policy_factory(request: LlmStepRequest) -> LlmPolicy: | |
| """Build a chat policy for any OpenAI-compatible endpoint. Raises HTTPException(502) on provider errors.""" | |
| try: | |
| from openai import ( # type: ignore[import-not-found] | |
| APIConnectionError, | |
| APITimeoutError, | |
| AuthenticationError, | |
| BadRequestError, | |
| NotFoundError, | |
| OpenAI, | |
| ) | |
| except ImportError as exc: # pragma: no cover | |
| raise HTTPException( | |
| status_code=503, | |
| detail=( | |
| "The 'openai' Python package is not installed on the server. " | |
| "Install with: pip install -e '.[demo]'" | |
| ), | |
| ) from exc | |
| api_key = resolve_api_key(request) | |
| client = OpenAI( | |
| base_url=request.base_url, | |
| api_key=api_key or "missing", | |
| timeout=request.request_timeout_s, | |
| # Identifies us to providers that rate-limit by UA. Cheap to | |
| # send; helps the demo not look like a generic SDK probe. | |
| default_headers={"User-Agent": "physix-live-demo/0.1"}, | |
| ) | |
| def _create(*, with_json: bool): | |
| kwargs: dict[str, object] = { | |
| "model": request.model, | |
| "messages": prompt_holder["prompt"], | |
| "temperature": request.temperature, | |
| "max_tokens": request.max_tokens, | |
| } | |
| if with_json: | |
| # Encourages JSON output where supported (OpenAI, vLLM, | |
| # Ollama-OpenAI). HF router silently ignores this on | |
| # providers that don't support it; ones that do reject | |
| # land in the BadRequestError fallback below. | |
| kwargs["response_format"] = {"type": "json_object"} | |
| return client.chat.completions.create(**kwargs) # type: ignore[arg-type] | |
| # Captures the prompt for `_create` without re-parameterising it. | |
| prompt_holder: dict[str, list[dict[str, str]]] = {"prompt": []} | |
| def _policy(prompt: list[dict[str, str]]) -> str: | |
| prompt_holder["prompt"] = prompt | |
| try: | |
| try: | |
| response = _create(with_json=True) | |
| except (BadRequestError, TypeError) as exc: | |
| # Provider rejected response_format (or the SDK shape). | |
| # Retry without it; if it still fails, that's a real | |
| # error and we let it bubble to the outer handler. | |
| _log.info( | |
| "Retrying without response_format for %s: %s", | |
| request.base_url, | |
| exc, | |
| ) | |
| response = _create(with_json=False) | |
| except AuthenticationError as exc: | |
| raise HTTPException( | |
| status_code=502, | |
| detail=_format_auth_error(request, exc), | |
| ) from exc | |
| except NotFoundError as exc: | |
| raise HTTPException( | |
| status_code=502, | |
| detail=_format_not_found_error(request, exc), | |
| ) from exc | |
| except (APIConnectionError, APITimeoutError) as exc: | |
| raise HTTPException( | |
| status_code=502, | |
| detail=_format_connection_error(request, exc), | |
| ) from exc | |
| except Exception as exc: # noqa: BLE001 — last-resort UI surface | |
| raise HTTPException( | |
| status_code=502, | |
| detail=_format_provider_error(request, exc), | |
| ) from exc | |
| choice = response.choices[0] if response.choices else None | |
| content = (choice.message.content if choice and choice.message else "") or "" | |
| return str(content) | |
| return _policy | |
| def _format_provider_error(request: LlmStepRequest, exc: Exception) -> str: | |
| """Last-resort error formatter for unclassified provider failures. | |
| Most failures should land in one of the typed handlers below | |
| (`_format_auth_error`, `_format_not_found_error`, | |
| `_format_connection_error`). This is the catch-all when an | |
| OpenAI-compatible endpoint returns something we don't recognise. | |
| The string-matching here exists only because the test suite | |
| exercises this path directly without going through the full | |
| SDK exception hierarchy. | |
| """ | |
| base_msg = f"Chat completion failed via {request.base_url} for model {request.model!r}: {exc}" | |
| text = str(exc).lower() | |
| # HF Router 400s for unservable models land here (they're | |
| # ``BadRequestError`` from the SDK, so they bypass NotFoundError). | |
| # Detect both wordings the router emits. | |
| if "model_not_supported" in text or "is not supported by any provider" in text or "not supported by provider" in text: | |
| return _format_model_not_supported_error(request, exc) | |
| if "401" in text or "unauthorized" in text or "invalid api key" in text: | |
| return _format_auth_error(request, exc) | |
| if "404" in text or "not found" in text or "no such model" in text: | |
| return _format_not_found_error(request, exc) | |
| if "connection" in text or "refused" in text or "timeout" in text: | |
| return _format_connection_error(request, exc) | |
| return base_msg | |
| def _format_model_not_supported_error( | |
| request: LlmStepRequest, exc: Exception | |
| ) -> str: | |
| """The HF Router accepted the request but no enabled provider serves | |
| this model. The user-actionable fix depends on whether they own the | |
| model card or not, so we offer both paths.""" | |
| base = ( | |
| f"HF Router can't serve {request.model!r}: no inference provider " | |
| f"is enabled for this model. ({exc})" | |
| ) | |
| return ( | |
| f"{base}\n\n" | |
| "Hint: open the model card at " | |
| f"https://huggingface.co/{request.model.split(':')[0]} → " | |
| "'Inference Providers' panel. If it lists no providers, the model " | |
| "isn't routable yet. Two ways to fix:\n" | |
| " • Pick a model that already serves through the router — e.g. " | |
| "Qwen/Qwen2.5-7B-Instruct, meta-llama/Llama-3.3-70B-Instruct, " | |
| "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.\n" | |
| " • For your own models, deploy them via the model card's " | |
| "'Deploy → Inference Endpoints' button (paid) or run the weights " | |
| "locally with Ollama / vLLM and switch this connection to that " | |
| "endpoint." | |
| ) | |
| def _format_auth_error(request: LlmStepRequest, exc: Exception) -> str: | |
| base = ( | |
| f"Authentication failed at {request.base_url} for model " | |
| f"{request.model!r}: {exc}" | |
| ) | |
| if "huggingface" in request.base_url.lower(): | |
| return ( | |
| f"{base}\n\n" | |
| "Hint: HF Router needs a token with the 'Make calls to " | |
| "Inference Providers' fine-grained permission. Re-create the " | |
| "token at https://huggingface.co/settings/tokens with that " | |
| "scope checked, then paste it into the API key field." | |
| ) | |
| return ( | |
| f"{base}\n\n" | |
| "Hint: the API key is missing or rejected. Open the connection " | |
| "panel and paste a valid token, or set the matching env var on " | |
| "the server (HF_TOKEN, OPENAI_API_KEY, etc.)." | |
| ) | |
| def _format_not_found_error(request: LlmStepRequest, exc: Exception) -> str: | |
| base = ( | |
| f"Model not found at {request.base_url}: {request.model!r} " | |
| f"({exc})" | |
| ) | |
| if "huggingface" in request.base_url.lower(): | |
| return ( | |
| f"{base}\n\n" | |
| "Hint: the model isn't being served by any HF Inference " | |
| "Provider right now. Check the model card at " | |
| f"https://huggingface.co/{request.model.split(':')[0]} — " | |
| "the 'Deploy → Inference API' panel must list at least one " | |
| "provider as 'warm'. You can also append ':fastest' to the " | |
| "model id to let HF auto-pick a provider." | |
| ) | |
| if "11434" in request.base_url: | |
| return ( | |
| f"{base}\n\n" | |
| "Hint: that Ollama tag isn't pulled. Run " | |
| f"'ollama pull {request.model}' and retry." | |
| ) | |
| return base | |
| def _format_connection_error(request: LlmStepRequest, exc: Exception) -> str: | |
| base = f"Could not reach {request.base_url}: {exc}" | |
| if "11434" in request.base_url: | |
| return ( | |
| f"{base}\n\n" | |
| "Hint: 'ollama serve' isn't running on this machine. Start " | |
| "it in another terminal and retry." | |
| ) | |
| return ( | |
| f"{base}\n\n" | |
| "Hint: the endpoint isn't reachable from the server. Check the " | |
| "URL, your network, and any firewall in front of it." | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # Ollama-only model lister (kept for the local-dev convenience dropdown). | |
| # ----------------------------------------------------------------------- | |
| class LlmModelInfo(BaseModel): | |
| """A single locally-pulled Ollama model tag.""" | |
| model_config = ConfigDict(frozen=True) | |
| name: str | |
| size_bytes: Optional[int] = None | |
| parameter_size: Optional[str] = None | |
| family: Optional[str] = None | |
| class LlmModelsResponse(BaseModel): | |
| models: list[LlmModelInfo] = Field(default_factory=list) | |
| error: Optional[str] = None | |
| LlmModelsLister = Callable[[], LlmModelsResponse] | |
| def default_ollama_models_lister() -> LlmModelsResponse: | |
| """Enumerate locally-pulled Ollama tags. Best-effort.""" | |
| try: | |
| import ollama # type: ignore[import-not-found] | |
| except ImportError: | |
| return LlmModelsResponse( | |
| models=[], | |
| error=( | |
| "The 'ollama' Python package is not installed on the server. " | |
| "Install with: pip install -e '.[demo]'" | |
| ), | |
| ) | |
| try: | |
| response = ollama.Client().list() | |
| except Exception as exc: # noqa: BLE001 — surfaced in the response body | |
| return LlmModelsResponse( | |
| models=[], | |
| error=( | |
| f"Could not reach the local Ollama daemon ({exc}). " | |
| "Is 'ollama serve' running?" | |
| ), | |
| ) | |
| raw_models = getattr(response, "models", None) | |
| if raw_models is None and isinstance(response, dict): | |
| raw_models = response.get("models", []) | |
| raw_models = raw_models or [] | |
| out: list[LlmModelInfo] = [] | |
| for entry in raw_models: | |
| name = _model_attr(entry, "model") or _model_attr(entry, "name") | |
| if not isinstance(name, str) or not name: | |
| continue | |
| details = _model_attr(entry, "details") | |
| out.append( | |
| LlmModelInfo( | |
| name=name, | |
| size_bytes=_coerce_int(_model_attr(entry, "size")), | |
| parameter_size=_model_attr(details, "parameter_size"), | |
| family=_model_attr(details, "family"), | |
| ) | |
| ) | |
| out.sort(key=lambda m: m.name) | |
| return LlmModelsResponse(models=out) | |
| def _model_attr(obj: object, key: str) -> object: | |
| if obj is None: | |
| return None | |
| if isinstance(obj, dict): | |
| return obj.get(key) | |
| return getattr(obj, key, None) | |
| def _coerce_int(value: object) -> Optional[int]: | |
| if value is None: | |
| return None | |
| try: | |
| return int(value) | |
| except (TypeError, ValueError): | |
| return None | |