physix-live / physix /server /providers.py
Pratyush-01's picture
cleanup: strip verbose comments from physix/server/providers.py
a88dae7 verified
"""LLM provider abstraction for the interactive demo.
The demo points at any OpenAI-compatible ``/v1/chat/completions`` endpoint:
local Ollama, Hugging Face's Inference Providers router, OpenAI itself,
vLLM, OpenRouter, etc. Everything funnels through one factory so the UI
only has to learn one shape.
The browser passes ``base_url``, ``model``, and (optionally) ``api_key``
on every request. If ``api_key`` is missing we fall back to a per-provider
env var so a Hugging Face Space can ship a default working config without
hard-coding secrets in client bundles.
"""
from __future__ import annotations
import logging
import os
from collections.abc import Callable
from typing import Optional
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field
_log = logging.getLogger(__name__)
HF_ROUTER_BASE_URL = "https://router.huggingface.co/v1"
OPENAI_BASE_URL = "https://api.openai.com/v1"
OLLAMA_OPENAI_BASE_URL = "http://localhost:11434/v1"
PHYSIX_INFER_BASE_URL = "https://pratyush-01-physix-infer.hf.space/v1"
class LlmStepRequest(BaseModel):
"""Provider-agnostic step payload.
The browser names a base URL + model + (optional) key. The server
fans these into an ``openai.OpenAI`` client. ``base_url`` is required
so we never silently default to the wrong endpoint when the user
swaps providers mid-session.
"""
model_config = ConfigDict(extra="forbid")
base_url: str = Field(
description=(
"OpenAI-compatible /v1 base URL. E.g. http://localhost:11434/v1, "
"https://router.huggingface.co/v1, https://api.openai.com/v1."
),
)
model: str = Field(
description=(
"Model id understood by the chosen base URL. For HF this is the "
"repo id (optionally suffixed with :provider, e.g. ':fastest'); "
"for Ollama it's the local tag; for OpenAI it's the model name."
),
)
api_key: Optional[str] = Field(
default=None,
description=(
"Bearer token forwarded as Authorization header. Falls back to "
"HF_TOKEN / OPENAI_API_KEY / OLLAMA_API_KEY env vars on the "
"server based on `base_url` if omitted."
),
)
temperature: float = Field(default=0.4, ge=0.0, le=2.0)
max_tokens: int = Field(default=2048, ge=64, le=8192)
request_timeout_s: float = Field(default=120.0, ge=5.0, le=600.0)
# A policy is "give me prompt messages, get back the assistant content".
LlmPolicy = Callable[[list[dict[str, str]]], str]
LlmPolicyFactory = Callable[[LlmStepRequest], LlmPolicy]
def resolve_api_key(request: LlmStepRequest) -> Optional[str]:
"""Pick the bearer token to use for this request.
Browser-supplied keys win. When the browser sends nothing we fall
back to a server-side env var picked from the URL — this lets a
public Hugging Face Space ship a usable default by setting
``HF_TOKEN`` as a Space secret while still letting power users
bring their own.
"""
if request.api_key:
return request.api_key
base_url = (request.base_url or "").lower()
# The PhysiX-Infer sister Space serves Qwen + the trained 3B with no
# auth — it's open-access by design (rate-limited only by sleep).
if "physix-infer" in base_url:
return "physix-infer"
if "huggingface" in base_url:
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_API_KEY")
if "openai.com" in base_url:
return os.environ.get("OPENAI_API_KEY")
if "openrouter" in base_url:
return os.environ.get("OPENROUTER_API_KEY")
if "localhost" in base_url or "127.0.0.1" in base_url:
# Ollama doesn't require a key; the SDK still wants something
# truthy in some versions, so we hand it a placeholder.
return os.environ.get("OLLAMA_API_KEY", "ollama")
return None
def default_openai_compat_policy_factory(request: LlmStepRequest) -> LlmPolicy:
"""Build a chat policy for any OpenAI-compatible endpoint. Raises HTTPException(502) on provider errors."""
try:
from openai import ( # type: ignore[import-not-found]
APIConnectionError,
APITimeoutError,
AuthenticationError,
BadRequestError,
NotFoundError,
OpenAI,
)
except ImportError as exc: # pragma: no cover
raise HTTPException(
status_code=503,
detail=(
"The 'openai' Python package is not installed on the server. "
"Install with: pip install -e '.[demo]'"
),
) from exc
api_key = resolve_api_key(request)
client = OpenAI(
base_url=request.base_url,
api_key=api_key or "missing",
timeout=request.request_timeout_s,
# Identifies us to providers that rate-limit by UA. Cheap to
# send; helps the demo not look like a generic SDK probe.
default_headers={"User-Agent": "physix-live-demo/0.1"},
)
def _create(*, with_json: bool):
kwargs: dict[str, object] = {
"model": request.model,
"messages": prompt_holder["prompt"],
"temperature": request.temperature,
"max_tokens": request.max_tokens,
}
if with_json:
# Encourages JSON output where supported (OpenAI, vLLM,
# Ollama-OpenAI). HF router silently ignores this on
# providers that don't support it; ones that do reject
# land in the BadRequestError fallback below.
kwargs["response_format"] = {"type": "json_object"}
return client.chat.completions.create(**kwargs) # type: ignore[arg-type]
# Captures the prompt for `_create` without re-parameterising it.
prompt_holder: dict[str, list[dict[str, str]]] = {"prompt": []}
def _policy(prompt: list[dict[str, str]]) -> str:
prompt_holder["prompt"] = prompt
try:
try:
response = _create(with_json=True)
except (BadRequestError, TypeError) as exc:
# Provider rejected response_format (or the SDK shape).
# Retry without it; if it still fails, that's a real
# error and we let it bubble to the outer handler.
_log.info(
"Retrying without response_format for %s: %s",
request.base_url,
exc,
)
response = _create(with_json=False)
except AuthenticationError as exc:
raise HTTPException(
status_code=502,
detail=_format_auth_error(request, exc),
) from exc
except NotFoundError as exc:
raise HTTPException(
status_code=502,
detail=_format_not_found_error(request, exc),
) from exc
except (APIConnectionError, APITimeoutError) as exc:
raise HTTPException(
status_code=502,
detail=_format_connection_error(request, exc),
) from exc
except Exception as exc: # noqa: BLE001 — last-resort UI surface
raise HTTPException(
status_code=502,
detail=_format_provider_error(request, exc),
) from exc
choice = response.choices[0] if response.choices else None
content = (choice.message.content if choice and choice.message else "") or ""
return str(content)
return _policy
def _format_provider_error(request: LlmStepRequest, exc: Exception) -> str:
"""Last-resort error formatter for unclassified provider failures.
Most failures should land in one of the typed handlers below
(`_format_auth_error`, `_format_not_found_error`,
`_format_connection_error`). This is the catch-all when an
OpenAI-compatible endpoint returns something we don't recognise.
The string-matching here exists only because the test suite
exercises this path directly without going through the full
SDK exception hierarchy.
"""
base_msg = f"Chat completion failed via {request.base_url} for model {request.model!r}: {exc}"
text = str(exc).lower()
# HF Router 400s for unservable models land here (they're
# ``BadRequestError`` from the SDK, so they bypass NotFoundError).
# Detect both wordings the router emits.
if "model_not_supported" in text or "is not supported by any provider" in text or "not supported by provider" in text:
return _format_model_not_supported_error(request, exc)
if "401" in text or "unauthorized" in text or "invalid api key" in text:
return _format_auth_error(request, exc)
if "404" in text or "not found" in text or "no such model" in text:
return _format_not_found_error(request, exc)
if "connection" in text or "refused" in text or "timeout" in text:
return _format_connection_error(request, exc)
return base_msg
def _format_model_not_supported_error(
request: LlmStepRequest, exc: Exception
) -> str:
"""The HF Router accepted the request but no enabled provider serves
this model. The user-actionable fix depends on whether they own the
model card or not, so we offer both paths."""
base = (
f"HF Router can't serve {request.model!r}: no inference provider "
f"is enabled for this model. ({exc})"
)
return (
f"{base}\n\n"
"Hint: open the model card at "
f"https://huggingface.co/{request.model.split(':')[0]} → "
"'Inference Providers' panel. If it lists no providers, the model "
"isn't routable yet. Two ways to fix:\n"
" • Pick a model that already serves through the router — e.g. "
"Qwen/Qwen2.5-7B-Instruct, meta-llama/Llama-3.3-70B-Instruct, "
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.\n"
" • For your own models, deploy them via the model card's "
"'Deploy → Inference Endpoints' button (paid) or run the weights "
"locally with Ollama / vLLM and switch this connection to that "
"endpoint."
)
def _format_auth_error(request: LlmStepRequest, exc: Exception) -> str:
base = (
f"Authentication failed at {request.base_url} for model "
f"{request.model!r}: {exc}"
)
if "huggingface" in request.base_url.lower():
return (
f"{base}\n\n"
"Hint: HF Router needs a token with the 'Make calls to "
"Inference Providers' fine-grained permission. Re-create the "
"token at https://huggingface.co/settings/tokens with that "
"scope checked, then paste it into the API key field."
)
return (
f"{base}\n\n"
"Hint: the API key is missing or rejected. Open the connection "
"panel and paste a valid token, or set the matching env var on "
"the server (HF_TOKEN, OPENAI_API_KEY, etc.)."
)
def _format_not_found_error(request: LlmStepRequest, exc: Exception) -> str:
base = (
f"Model not found at {request.base_url}: {request.model!r} "
f"({exc})"
)
if "huggingface" in request.base_url.lower():
return (
f"{base}\n\n"
"Hint: the model isn't being served by any HF Inference "
"Provider right now. Check the model card at "
f"https://huggingface.co/{request.model.split(':')[0]} — "
"the 'Deploy → Inference API' panel must list at least one "
"provider as 'warm'. You can also append ':fastest' to the "
"model id to let HF auto-pick a provider."
)
if "11434" in request.base_url:
return (
f"{base}\n\n"
"Hint: that Ollama tag isn't pulled. Run "
f"'ollama pull {request.model}' and retry."
)
return base
def _format_connection_error(request: LlmStepRequest, exc: Exception) -> str:
base = f"Could not reach {request.base_url}: {exc}"
if "11434" in request.base_url:
return (
f"{base}\n\n"
"Hint: 'ollama serve' isn't running on this machine. Start "
"it in another terminal and retry."
)
return (
f"{base}\n\n"
"Hint: the endpoint isn't reachable from the server. Check the "
"URL, your network, and any firewall in front of it."
)
# -----------------------------------------------------------------------
# Ollama-only model lister (kept for the local-dev convenience dropdown).
# -----------------------------------------------------------------------
class LlmModelInfo(BaseModel):
"""A single locally-pulled Ollama model tag."""
model_config = ConfigDict(frozen=True)
name: str
size_bytes: Optional[int] = None
parameter_size: Optional[str] = None
family: Optional[str] = None
class LlmModelsResponse(BaseModel):
models: list[LlmModelInfo] = Field(default_factory=list)
error: Optional[str] = None
LlmModelsLister = Callable[[], LlmModelsResponse]
def default_ollama_models_lister() -> LlmModelsResponse:
"""Enumerate locally-pulled Ollama tags. Best-effort."""
try:
import ollama # type: ignore[import-not-found]
except ImportError:
return LlmModelsResponse(
models=[],
error=(
"The 'ollama' Python package is not installed on the server. "
"Install with: pip install -e '.[demo]'"
),
)
try:
response = ollama.Client().list()
except Exception as exc: # noqa: BLE001 — surfaced in the response body
return LlmModelsResponse(
models=[],
error=(
f"Could not reach the local Ollama daemon ({exc}). "
"Is 'ollama serve' running?"
),
)
raw_models = getattr(response, "models", None)
if raw_models is None and isinstance(response, dict):
raw_models = response.get("models", [])
raw_models = raw_models or []
out: list[LlmModelInfo] = []
for entry in raw_models:
name = _model_attr(entry, "model") or _model_attr(entry, "name")
if not isinstance(name, str) or not name:
continue
details = _model_attr(entry, "details")
out.append(
LlmModelInfo(
name=name,
size_bytes=_coerce_int(_model_attr(entry, "size")),
parameter_size=_model_attr(details, "parameter_size"),
family=_model_attr(details, "family"),
)
)
out.sort(key=lambda m: m.name)
return LlmModelsResponse(models=out)
def _model_attr(obj: object, key: str) -> object:
if obj is None:
return None
if isinstance(obj, dict):
return obj.get(key)
return getattr(obj, key, None)
def _coerce_int(value: object) -> Optional[int]:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None