File size: 15,754 Bytes
0e24aff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
"""Hugging Face Router-focused provider tests.

The judges will overwhelmingly run the demo through HF Router, so this
module is the deepest coverage of any single provider. We:

  1. Pin the HF Router base URL constant the frontend depends on.
  2. Verify the OpenAI client is constructed with the right base URL,
     api key, and headers.
  3. Confirm the `response_format` quirk — providers that reject
     ``json_object`` cause us to retry without it, transparently.
  4. Exercise every error class the SDK can raise (401 / 404 /
     connection / timeout) and pin the human-readable hint copy
     surfaced to the UI.
  5. Drive a full episode end-to-end through ``/llm-step`` with a
     stubbed OpenAI client to prove the request shape that lands on
     the wire matches what the visitor configured in the panel.

The real-network smoke test lives in ``scripts/verify_hf_router.py``;
that one needs an HF_TOKEN and is intentionally not part of CI.
"""

from __future__ import annotations

import json
from typing import Any
from unittest.mock import MagicMock, patch

import openai
import pytest
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
from openenv.core.env_server import create_fastapi_app

from physix.models import PhysiXAction, PhysiXObservation
from physix.server.environment import PhysiXEnvironment
from physix.server.interactive import build_interactive_router
from physix.server.providers import (
    HF_ROUTER_BASE_URL,
    LlmStepRequest,
    default_openai_compat_policy_factory,
)


# --- Sanity: the URL constant must be exactly what HF documents -----------


def test_hf_router_base_url_is_canonical() -> None:
    """Pin the URL the frontend, README, and provider all share. If HF
    ever migrates this off router.huggingface.co, the canonical fix is
    here, not scattered across the codebase."""
    assert HF_ROUTER_BASE_URL == "https://router.huggingface.co/v1"


# --- Helpers ---------------------------------------------------------------


def _make_completion(content: str) -> Any:
    """Build a stand-in for an `openai.types.chat.ChatCompletion` that
    only exposes the bits our code reads. Simpler and faster than the
    real Pydantic schema and avoids depending on its internals."""
    completion = MagicMock()
    completion.choices = [MagicMock(message=MagicMock(content=content))]
    return completion


def _request(api_key: str | None = "hf_test_token", model: str = "Qwen/Qwen2.5-3B-Instruct") -> LlmStepRequest:
    return LlmStepRequest(
        base_url=HF_ROUTER_BASE_URL,
        model=model,
        api_key=api_key,
    )


# --- Client construction ---------------------------------------------------


def test_hf_router_client_is_constructed_with_visitor_token() -> None:
    """The exact bytes the OpenAI SDK is initialised with are
    load-bearing — they're what HF authenticates against."""
    with patch("openai.OpenAI") as MockOpenAI:
        client = MagicMock()
        client.chat.completions.create.return_value = _make_completion("{}")
        MockOpenAI.return_value = client

        policy = default_openai_compat_policy_factory(
            _request(api_key="hf_visitor_secret")
        )
        policy([{"role": "user", "content": "hello"}])

    MockOpenAI.assert_called_once()
    kwargs = MockOpenAI.call_args.kwargs
    assert kwargs["base_url"] == HF_ROUTER_BASE_URL
    assert kwargs["api_key"] == "hf_visitor_secret"
    # User-Agent helps HF rate-limit us cleanly. Don't drop it without
    # reading the comment in providers.py.
    assert kwargs["default_headers"]["User-Agent"].startswith("physix-live-demo/")


def test_hf_router_client_uses_hf_token_env_when_visitor_omits_key(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    """The Space's `HF_TOKEN` secret is the safety net — verifies the
    fallback rule without hitting the network."""
    monkeypatch.setenv("HF_TOKEN", "hf_space_secret")
    monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)

    with patch("openai.OpenAI") as MockOpenAI:
        client = MagicMock()
        client.chat.completions.create.return_value = _make_completion("{}")
        MockOpenAI.return_value = client

        policy = default_openai_compat_policy_factory(_request(api_key=None))
        policy([{"role": "user", "content": "hello"}])

    assert MockOpenAI.call_args.kwargs["api_key"] == "hf_space_secret"


# --- response_format retry quirk ------------------------------------------


def test_hf_router_retries_without_response_format_on_bad_request() -> None:
    """Featherless / Together / others sometimes 400 on
    `response_format={'type':'json_object'}`. We must transparently
    retry without it, otherwise every visitor pointed at those
    providers gets a one-turn 502 episode."""
    captured_calls: list[dict[str, Any]] = []

    def fake_create(**kwargs: Any) -> Any:
        captured_calls.append(kwargs)
        if "response_format" in kwargs:
            # Simulate a real provider 400. We construct a real
            # BadRequestError because that's what the policy catches.
            err_response = MagicMock()
            err_response.status_code = 400
            raise openai.BadRequestError(
                message="response_format is not supported by provider",
                response=err_response,
                body=None,
            )
        return _make_completion(
            json.dumps({"equation": "d2y/dt2 = -9.81", "rationale": "g"})
        )

    with patch("openai.OpenAI") as MockOpenAI:
        client = MagicMock()
        client.chat.completions.create.side_effect = fake_create
        MockOpenAI.return_value = client

        policy = default_openai_compat_policy_factory(_request())
        out = policy([{"role": "user", "content": "hello"}])

    assert json.loads(out)["equation"] == "d2y/dt2 = -9.81"
    assert len(captured_calls) == 2, "must have retried"
    assert "response_format" in captured_calls[0]
    assert "response_format" not in captured_calls[1]


def test_hf_router_succeeds_on_first_try_when_provider_supports_json() -> None:
    """When the provider is happy with `response_format`, we should
    *not* be making a second call. Wasted latency = ugly demo."""
    with patch("openai.OpenAI") as MockOpenAI:
        client = MagicMock()
        client.chat.completions.create.return_value = _make_completion(
            json.dumps({"equation": "d2y/dt2 = -9.81"})
        )
        MockOpenAI.return_value = client

        policy = default_openai_compat_policy_factory(_request())
        policy([{"role": "user", "content": "hello"}])

    assert client.chat.completions.create.call_count == 1


# --- Error mapping ---------------------------------------------------------


def _make_err_response(status: int) -> MagicMock:
    response = MagicMock()
    response.status_code = status
    return response


def test_hf_router_401_surfaces_inference_providers_permission_hint() -> None:
    """The single most common visitor failure: HF token without the
    'Make calls to Inference Providers' fine-grained scope. The hint
    must point at the exact remediation."""
    with patch("openai.OpenAI") as MockOpenAI:
        client = MagicMock()
        client.chat.completions.create.side_effect = openai.AuthenticationError(
            message="Invalid credentials in Authorization header",
            response=_make_err_response(401),
            body=None,
        )
        MockOpenAI.return_value = client

        policy = default_openai_compat_policy_factory(_request())
        with pytest.raises(Exception) as excinfo:
            policy([{"role": "user", "content": "hello"}])

    detail = str(getattr(excinfo.value, "detail", excinfo.value))
    assert "Make calls to Inference Providers" in detail
    assert "huggingface.co/settings/tokens" in detail


def test_hf_router_404_surfaces_warm_provider_hint() -> None:
    """When a model isn't currently served by any provider on the
    router, the visitor needs to know that's a model-card-config
    issue, not their token."""
    with patch("openai.OpenAI") as MockOpenAI:
        client = MagicMock()
        client.chat.completions.create.side_effect = openai.NotFoundError(
            message="Model Pratyush-01/physix-3b-rl is not currently served",
            response=_make_err_response(404),
            body=None,
        )
        MockOpenAI.return_value = client

        policy = default_openai_compat_policy_factory(
            _request(model="Pratyush-01/physix-3b-rl")
        )
        with pytest.raises(Exception) as excinfo:
            policy([{"role": "user", "content": "hello"}])

    detail = str(getattr(excinfo.value, "detail", excinfo.value))
    assert "Pratyush-01/physix-3b-rl" in detail
    assert ":fastest" in detail
    assert "huggingface.co/" in detail


def test_hf_router_connection_failure_surfaces_network_hint() -> None:
    """Network blips happen — judges deserve a clear hint instead of a
    raw stack trace."""
    with patch("openai.OpenAI") as MockOpenAI:
        client = MagicMock()
        client.chat.completions.create.side_effect = openai.APIConnectionError(
            request=MagicMock()
        )
        MockOpenAI.return_value = client

        policy = default_openai_compat_policy_factory(_request())
        with pytest.raises(Exception) as excinfo:
            policy([{"role": "user", "content": "hello"}])

    detail = str(getattr(excinfo.value, "detail", excinfo.value))
    assert HF_ROUTER_BASE_URL in detail
    assert "Check the URL" in detail


def test_hf_router_timeout_surfaces_network_hint() -> None:
    """First-call latency on cold providers can hit our default
    120s timeout. Make the failure self-diagnosing."""
    with patch("openai.OpenAI") as MockOpenAI:
        client = MagicMock()
        client.chat.completions.create.side_effect = openai.APITimeoutError(
            request=MagicMock()
        )
        MockOpenAI.return_value = client

        policy = default_openai_compat_policy_factory(_request())
        with pytest.raises(Exception) as excinfo:
            policy([{"role": "user", "content": "hello"}])

    detail = str(getattr(excinfo.value, "detail", excinfo.value))
    assert HF_ROUTER_BASE_URL in detail


# --- End-to-end through /interactive/sessions/{id}/llm-step ---------------


def test_full_episode_flows_visitor_config_to_openai_client_unchanged() -> None:
    """The whole point of the abstraction: whatever the visitor types
    in the panel must arrive at the OpenAI SDK byte-for-byte. This
    drives a real /interactive/sessions and /llm-step round-trip and
    asserts the final OpenAI() call args are exactly right."""
    captured_init: dict[str, Any] = {}
    captured_calls: list[dict[str, Any]] = []

    def fake_init(self: Any, **kwargs: Any) -> None:  # noqa: ANN001
        captured_init.update(kwargs)
        # We still need a real-ish object so the calling code finds
        # `.chat.completions.create`. We monkey-patch a minimal stub.
        self.chat = MagicMock()
        self.chat.completions = MagicMock()

        def _create(**ckwargs: Any) -> Any:
            captured_calls.append(ckwargs)
            return _make_completion(
                json.dumps({"equation": "d2y/dt2 = -9.81", "rationale": "gravity"})
            )

        self.chat.completions.create = _create

    app = create_fastapi_app(
        env=PhysiXEnvironment,
        action_cls=PhysiXAction,
        observation_cls=PhysiXObservation,
    )
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["http://localhost:5173"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    app.include_router(build_interactive_router())

    with patch.object(openai.OpenAI, "__init__", fake_init):
        with TestClient(app) as client:
            session_id = client.post(
                "/interactive/sessions",
                json={"system_id": "free_fall", "seed": 7, "max_turns": 4},
            ).json()["session_id"]

            response = client.post(
                f"/interactive/sessions/{session_id}/llm-step",
                json={
                    "base_url": HF_ROUTER_BASE_URL,
                    "model": "Pratyush-01/physix-3b-rl",
                    "api_key": "hf_visitor_token",
                    "temperature": 0.3,
                    "max_tokens": 512,
                },
            )

    assert response.status_code == 200, response.text
    body = response.json()
    assert body["action"]["equation"] == "d2y/dt2 = -9.81"
    assert body["model"] == "Pratyush-01/physix-3b-rl"

    # 1) OpenAI client was built with HF Router config.
    assert captured_init["base_url"] == HF_ROUTER_BASE_URL
    assert captured_init["api_key"] == "hf_visitor_token"
    assert "User-Agent" in captured_init["default_headers"]

    # 2) The chat completion call carried the visitor's model + the
    #    full system+user prompt the verifier expects.
    assert len(captured_calls) == 1
    call = captured_calls[0]
    assert call["model"] == "Pratyush-01/physix-3b-rl"
    assert call["temperature"] == 0.3
    assert call["max_tokens"] == 512
    messages = call["messages"]
    assert isinstance(messages, list)
    assert messages[0]["role"] == "system"
    assert any(m["role"] == "user" for m in messages)
    # Prompt must contain the trajectory the env emitted, so the
    # judge sees the same scoring pipeline the training run used.
    user_msg = next(m for m in messages if m["role"] == "user")
    assert "TRAJECTORY" in user_msg["content"].upper() or "t=" in user_msg["content"]


def test_full_episode_recovers_from_first_call_response_format_400() -> None:
    """End-to-end version of the response_format retry test —
    confirms the recovery doesn't break the /llm-step contract."""
    call_count = {"n": 0}

    def fake_init(self: Any, **kwargs: Any) -> None:  # noqa: ANN001
        self.chat = MagicMock()
        self.chat.completions = MagicMock()

        def _create(**ckwargs: Any) -> Any:
            call_count["n"] += 1
            if "response_format" in ckwargs:
                err_response = MagicMock()
                err_response.status_code = 400
                raise openai.BadRequestError(
                    message="provider does not support response_format",
                    response=err_response,
                    body=None,
                )
            return _make_completion(
                json.dumps({"equation": "d2y/dt2 = -9.81"})
            )

        self.chat.completions.create = _create

    app = create_fastapi_app(
        env=PhysiXEnvironment,
        action_cls=PhysiXAction,
        observation_cls=PhysiXObservation,
    )
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["http://localhost:5173"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    app.include_router(build_interactive_router())

    with patch.object(openai.OpenAI, "__init__", fake_init):
        with TestClient(app) as client:
            session_id = client.post(
                "/interactive/sessions",
                json={"system_id": "free_fall", "seed": 0, "max_turns": 2},
            ).json()["session_id"]
            resp = client.post(
                f"/interactive/sessions/{session_id}/llm-step",
                json={
                    "base_url": HF_ROUTER_BASE_URL,
                    "model": "Qwen/Qwen2.5-3B-Instruct",
                    "api_key": "hf_x",
                },
            )

    assert resp.status_code == 200, resp.text
    assert call_count["n"] == 2  # first call rejected, second accepted