physix / tests /test_providers_hf.py
Pratyush-01's picture
Upload folder using huggingface_hub
0e24aff verified
"""Hugging Face Router-focused provider tests.
The judges will overwhelmingly run the demo through HF Router, so this
module is the deepest coverage of any single provider. We:
1. Pin the HF Router base URL constant the frontend depends on.
2. Verify the OpenAI client is constructed with the right base URL,
api key, and headers.
3. Confirm the `response_format` quirk — providers that reject
``json_object`` cause us to retry without it, transparently.
4. Exercise every error class the SDK can raise (401 / 404 /
connection / timeout) and pin the human-readable hint copy
surfaced to the UI.
5. Drive a full episode end-to-end through ``/llm-step`` with a
stubbed OpenAI client to prove the request shape that lands on
the wire matches what the visitor configured in the panel.
The real-network smoke test lives in ``scripts/verify_hf_router.py``;
that one needs an HF_TOKEN and is intentionally not part of CI.
"""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import MagicMock, patch
import openai
import pytest
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
from openenv.core.env_server import create_fastapi_app
from physix.models import PhysiXAction, PhysiXObservation
from physix.server.environment import PhysiXEnvironment
from physix.server.interactive import build_interactive_router
from physix.server.providers import (
HF_ROUTER_BASE_URL,
LlmStepRequest,
default_openai_compat_policy_factory,
)
# --- Sanity: the URL constant must be exactly what HF documents -----------
def test_hf_router_base_url_is_canonical() -> None:
"""Pin the URL the frontend, README, and provider all share. If HF
ever migrates this off router.huggingface.co, the canonical fix is
here, not scattered across the codebase."""
assert HF_ROUTER_BASE_URL == "https://router.huggingface.co/v1"
# --- Helpers ---------------------------------------------------------------
def _make_completion(content: str) -> Any:
"""Build a stand-in for an `openai.types.chat.ChatCompletion` that
only exposes the bits our code reads. Simpler and faster than the
real Pydantic schema and avoids depending on its internals."""
completion = MagicMock()
completion.choices = [MagicMock(message=MagicMock(content=content))]
return completion
def _request(api_key: str | None = "hf_test_token", model: str = "Qwen/Qwen2.5-3B-Instruct") -> LlmStepRequest:
return LlmStepRequest(
base_url=HF_ROUTER_BASE_URL,
model=model,
api_key=api_key,
)
# --- Client construction ---------------------------------------------------
def test_hf_router_client_is_constructed_with_visitor_token() -> None:
"""The exact bytes the OpenAI SDK is initialised with are
load-bearing — they're what HF authenticates against."""
with patch("openai.OpenAI") as MockOpenAI:
client = MagicMock()
client.chat.completions.create.return_value = _make_completion("{}")
MockOpenAI.return_value = client
policy = default_openai_compat_policy_factory(
_request(api_key="hf_visitor_secret")
)
policy([{"role": "user", "content": "hello"}])
MockOpenAI.assert_called_once()
kwargs = MockOpenAI.call_args.kwargs
assert kwargs["base_url"] == HF_ROUTER_BASE_URL
assert kwargs["api_key"] == "hf_visitor_secret"
# User-Agent helps HF rate-limit us cleanly. Don't drop it without
# reading the comment in providers.py.
assert kwargs["default_headers"]["User-Agent"].startswith("physix-live-demo/")
def test_hf_router_client_uses_hf_token_env_when_visitor_omits_key(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""The Space's `HF_TOKEN` secret is the safety net — verifies the
fallback rule without hitting the network."""
monkeypatch.setenv("HF_TOKEN", "hf_space_secret")
monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
with patch("openai.OpenAI") as MockOpenAI:
client = MagicMock()
client.chat.completions.create.return_value = _make_completion("{}")
MockOpenAI.return_value = client
policy = default_openai_compat_policy_factory(_request(api_key=None))
policy([{"role": "user", "content": "hello"}])
assert MockOpenAI.call_args.kwargs["api_key"] == "hf_space_secret"
# --- response_format retry quirk ------------------------------------------
def test_hf_router_retries_without_response_format_on_bad_request() -> None:
"""Featherless / Together / others sometimes 400 on
`response_format={'type':'json_object'}`. We must transparently
retry without it, otherwise every visitor pointed at those
providers gets a one-turn 502 episode."""
captured_calls: list[dict[str, Any]] = []
def fake_create(**kwargs: Any) -> Any:
captured_calls.append(kwargs)
if "response_format" in kwargs:
# Simulate a real provider 400. We construct a real
# BadRequestError because that's what the policy catches.
err_response = MagicMock()
err_response.status_code = 400
raise openai.BadRequestError(
message="response_format is not supported by provider",
response=err_response,
body=None,
)
return _make_completion(
json.dumps({"equation": "d2y/dt2 = -9.81", "rationale": "g"})
)
with patch("openai.OpenAI") as MockOpenAI:
client = MagicMock()
client.chat.completions.create.side_effect = fake_create
MockOpenAI.return_value = client
policy = default_openai_compat_policy_factory(_request())
out = policy([{"role": "user", "content": "hello"}])
assert json.loads(out)["equation"] == "d2y/dt2 = -9.81"
assert len(captured_calls) == 2, "must have retried"
assert "response_format" in captured_calls[0]
assert "response_format" not in captured_calls[1]
def test_hf_router_succeeds_on_first_try_when_provider_supports_json() -> None:
"""When the provider is happy with `response_format`, we should
*not* be making a second call. Wasted latency = ugly demo."""
with patch("openai.OpenAI") as MockOpenAI:
client = MagicMock()
client.chat.completions.create.return_value = _make_completion(
json.dumps({"equation": "d2y/dt2 = -9.81"})
)
MockOpenAI.return_value = client
policy = default_openai_compat_policy_factory(_request())
policy([{"role": "user", "content": "hello"}])
assert client.chat.completions.create.call_count == 1
# --- Error mapping ---------------------------------------------------------
def _make_err_response(status: int) -> MagicMock:
response = MagicMock()
response.status_code = status
return response
def test_hf_router_401_surfaces_inference_providers_permission_hint() -> None:
"""The single most common visitor failure: HF token without the
'Make calls to Inference Providers' fine-grained scope. The hint
must point at the exact remediation."""
with patch("openai.OpenAI") as MockOpenAI:
client = MagicMock()
client.chat.completions.create.side_effect = openai.AuthenticationError(
message="Invalid credentials in Authorization header",
response=_make_err_response(401),
body=None,
)
MockOpenAI.return_value = client
policy = default_openai_compat_policy_factory(_request())
with pytest.raises(Exception) as excinfo:
policy([{"role": "user", "content": "hello"}])
detail = str(getattr(excinfo.value, "detail", excinfo.value))
assert "Make calls to Inference Providers" in detail
assert "huggingface.co/settings/tokens" in detail
def test_hf_router_404_surfaces_warm_provider_hint() -> None:
"""When a model isn't currently served by any provider on the
router, the visitor needs to know that's a model-card-config
issue, not their token."""
with patch("openai.OpenAI") as MockOpenAI:
client = MagicMock()
client.chat.completions.create.side_effect = openai.NotFoundError(
message="Model Pratyush-01/physix-3b-rl is not currently served",
response=_make_err_response(404),
body=None,
)
MockOpenAI.return_value = client
policy = default_openai_compat_policy_factory(
_request(model="Pratyush-01/physix-3b-rl")
)
with pytest.raises(Exception) as excinfo:
policy([{"role": "user", "content": "hello"}])
detail = str(getattr(excinfo.value, "detail", excinfo.value))
assert "Pratyush-01/physix-3b-rl" in detail
assert ":fastest" in detail
assert "huggingface.co/" in detail
def test_hf_router_connection_failure_surfaces_network_hint() -> None:
"""Network blips happen — judges deserve a clear hint instead of a
raw stack trace."""
with patch("openai.OpenAI") as MockOpenAI:
client = MagicMock()
client.chat.completions.create.side_effect = openai.APIConnectionError(
request=MagicMock()
)
MockOpenAI.return_value = client
policy = default_openai_compat_policy_factory(_request())
with pytest.raises(Exception) as excinfo:
policy([{"role": "user", "content": "hello"}])
detail = str(getattr(excinfo.value, "detail", excinfo.value))
assert HF_ROUTER_BASE_URL in detail
assert "Check the URL" in detail
def test_hf_router_timeout_surfaces_network_hint() -> None:
"""First-call latency on cold providers can hit our default
120s timeout. Make the failure self-diagnosing."""
with patch("openai.OpenAI") as MockOpenAI:
client = MagicMock()
client.chat.completions.create.side_effect = openai.APITimeoutError(
request=MagicMock()
)
MockOpenAI.return_value = client
policy = default_openai_compat_policy_factory(_request())
with pytest.raises(Exception) as excinfo:
policy([{"role": "user", "content": "hello"}])
detail = str(getattr(excinfo.value, "detail", excinfo.value))
assert HF_ROUTER_BASE_URL in detail
# --- End-to-end through /interactive/sessions/{id}/llm-step ---------------
def test_full_episode_flows_visitor_config_to_openai_client_unchanged() -> None:
"""The whole point of the abstraction: whatever the visitor types
in the panel must arrive at the OpenAI SDK byte-for-byte. This
drives a real /interactive/sessions and /llm-step round-trip and
asserts the final OpenAI() call args are exactly right."""
captured_init: dict[str, Any] = {}
captured_calls: list[dict[str, Any]] = []
def fake_init(self: Any, **kwargs: Any) -> None: # noqa: ANN001
captured_init.update(kwargs)
# We still need a real-ish object so the calling code finds
# `.chat.completions.create`. We monkey-patch a minimal stub.
self.chat = MagicMock()
self.chat.completions = MagicMock()
def _create(**ckwargs: Any) -> Any:
captured_calls.append(ckwargs)
return _make_completion(
json.dumps({"equation": "d2y/dt2 = -9.81", "rationale": "gravity"})
)
self.chat.completions.create = _create
app = create_fastapi_app(
env=PhysiXEnvironment,
action_cls=PhysiXAction,
observation_cls=PhysiXObservation,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(build_interactive_router())
with patch.object(openai.OpenAI, "__init__", fake_init):
with TestClient(app) as client:
session_id = client.post(
"/interactive/sessions",
json={"system_id": "free_fall", "seed": 7, "max_turns": 4},
).json()["session_id"]
response = client.post(
f"/interactive/sessions/{session_id}/llm-step",
json={
"base_url": HF_ROUTER_BASE_URL,
"model": "Pratyush-01/physix-3b-rl",
"api_key": "hf_visitor_token",
"temperature": 0.3,
"max_tokens": 512,
},
)
assert response.status_code == 200, response.text
body = response.json()
assert body["action"]["equation"] == "d2y/dt2 = -9.81"
assert body["model"] == "Pratyush-01/physix-3b-rl"
# 1) OpenAI client was built with HF Router config.
assert captured_init["base_url"] == HF_ROUTER_BASE_URL
assert captured_init["api_key"] == "hf_visitor_token"
assert "User-Agent" in captured_init["default_headers"]
# 2) The chat completion call carried the visitor's model + the
# full system+user prompt the verifier expects.
assert len(captured_calls) == 1
call = captured_calls[0]
assert call["model"] == "Pratyush-01/physix-3b-rl"
assert call["temperature"] == 0.3
assert call["max_tokens"] == 512
messages = call["messages"]
assert isinstance(messages, list)
assert messages[0]["role"] == "system"
assert any(m["role"] == "user" for m in messages)
# Prompt must contain the trajectory the env emitted, so the
# judge sees the same scoring pipeline the training run used.
user_msg = next(m for m in messages if m["role"] == "user")
assert "TRAJECTORY" in user_msg["content"].upper() or "t=" in user_msg["content"]
def test_full_episode_recovers_from_first_call_response_format_400() -> None:
"""End-to-end version of the response_format retry test —
confirms the recovery doesn't break the /llm-step contract."""
call_count = {"n": 0}
def fake_init(self: Any, **kwargs: Any) -> None: # noqa: ANN001
self.chat = MagicMock()
self.chat.completions = MagicMock()
def _create(**ckwargs: Any) -> Any:
call_count["n"] += 1
if "response_format" in ckwargs:
err_response = MagicMock()
err_response.status_code = 400
raise openai.BadRequestError(
message="provider does not support response_format",
response=err_response,
body=None,
)
return _make_completion(
json.dumps({"equation": "d2y/dt2 = -9.81"})
)
self.chat.completions.create = _create
app = create_fastapi_app(
env=PhysiXEnvironment,
action_cls=PhysiXAction,
observation_cls=PhysiXObservation,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(build_interactive_router())
with patch.object(openai.OpenAI, "__init__", fake_init):
with TestClient(app) as client:
session_id = client.post(
"/interactive/sessions",
json={"system_id": "free_fall", "seed": 0, "max_turns": 2},
).json()["session_id"]
resp = client.post(
f"/interactive/sessions/{session_id}/llm-step",
json={
"base_url": HF_ROUTER_BASE_URL,
"model": "Qwen/Qwen2.5-3B-Instruct",
"api_key": "hf_x",
},
)
assert resp.status_code == 200, resp.text
assert call_count["n"] == 2 # first call rejected, second accepted