physix / tests /test_interactive_api.py
Pratyush-01's picture
Upload folder using huggingface_hub
0e24aff verified
"""End-to-end tests for the ``/interactive/*`` router."""
from __future__ import annotations
import json
from collections.abc import Iterable
import pytest
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
from openenv.core.env_server import create_fastapi_app
from physix.models import PhysiXAction, PhysiXObservation
from physix.server.app import build_app
from physix.server.environment import PhysiXEnvironment
from physix.server.interactive import (
LlmModelInfo,
LlmModelsResponse,
LlmStepRequest,
build_interactive_router,
)
@pytest.fixture
def client() -> TestClient:
return TestClient(build_app())
def _build_app_with_stubbed_llm(
completions: Iterable[str],
*,
models_response: LlmModelsResponse | None = None,
) -> FastAPI:
"""Build a clone of the production app whose LLM policy returns
pre-canned completion strings in order.
Each call to the policy pops the next completion off the deque, so a
test that wants three turns supplies three strings. Optionally
overrides the model lister so the ``/interactive/models`` route can
be exercised without touching the real Ollama daemon.
"""
queue = list(completions)
def _stub_policy(_payload: LlmStepRequest):
def _policy(_prompt: list[dict[str, str]]) -> str:
if not queue:
raise AssertionError("Stubbed LLM ran out of canned completions.")
return queue.pop(0)
return _policy
def _stub_lister() -> LlmModelsResponse:
return models_response or LlmModelsResponse(models=[])
app = create_fastapi_app(
env=PhysiXEnvironment,
action_cls=PhysiXAction,
observation_cls=PhysiXObservation,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(
build_interactive_router(
policy_factory=_stub_policy,
models_lister=_stub_lister,
)
)
return app
# --- Catalogue ---
def test_systems_endpoint_returns_supported_systems_in_order(
client: TestClient,
) -> None:
from physix.systems.registry import SUPPORTED_SYSTEMS
response = client.get("/interactive/systems")
assert response.status_code == 200
catalogue = response.json()
returned_ids = [row["system_id"] for row in catalogue]
assert returned_ids == list(SUPPORTED_SYSTEMS)
# The demo intentionally exposes all registered systems including
# tier-3 (``projectile_drag``, ``charged_b_field``) so visitors can
# stress-test the verifier on systems the model never trained on —
# that's the generalisation showcase.
system_ids = set(returned_ids)
assert "free_fall" in system_ids
assert "damped_spring" in system_ids
assert "projectile_drag" in system_ids
assert "charged_b_field" in system_ids
# --- Local model catalogue ---
def test_models_endpoint_returns_injected_list() -> None:
"""Frontend reads installed model tags from the server, not a hardcoded
list. The route must surface whatever the lister reports."""
canned = LlmModelsResponse(
models=[
LlmModelInfo(name="qwen2.5:7b", size_bytes=4_700_000_000, parameter_size="7.6B"),
LlmModelInfo(name="qwen2.5:1.5b-instruct", size_bytes=986_000_000),
]
)
app = _build_app_with_stubbed_llm([], models_response=canned)
with TestClient(app) as client:
response = client.get("/interactive/models")
assert response.status_code == 200, response.text
body = response.json()
assert body["error"] is None
assert [m["name"] for m in body["models"]] == [
"qwen2.5:7b",
"qwen2.5:1.5b-instruct",
]
assert body["models"][0]["parameter_size"] == "7.6B"
def test_models_endpoint_returns_empty_with_error_when_daemon_unavailable() -> None:
"""When Ollama is unreachable the route degrades to an empty list and
surfaces a human-readable hint, instead of 5xx-ing the page."""
canned = LlmModelsResponse(
models=[],
error="Could not reach the local Ollama daemon (test). Is 'ollama serve' running?",
)
app = _build_app_with_stubbed_llm([], models_response=canned)
with TestClient(app) as client:
response = client.get("/interactive/models")
assert response.status_code == 200
body = response.json()
assert body["models"] == []
assert "Ollama" in body["error"]
# --- Session lifecycle ---
def test_session_lifecycle_create_summary_delete(client: TestClient) -> None:
"""Create → reset observation → summary → delete → 404. The actual
advancing of turn counter / format scoring / predicted overlay
lives in the ``/llm-step`` tests below; this is the lifecycle
skeleton (the only flow the UI actually exercises now that the
manual ``/step`` route is gone)."""
create = client.post(
"/interactive/sessions",
json={"system_id": "free_fall", "seed": 42, "max_turns": 4},
)
assert create.status_code == 200, create.text
body = create.json()
session_id = body["session_id"]
assert isinstance(session_id, str) and session_id
assert body["system"]["system_id"] == "free_fall"
assert "tier" not in body["system"] # tier is dropped from the public schema
assert body["max_turns"] == 4
assert body["observation"]["turn"] == 0
assert body["observation"]["done"] is False
assert len(body["observation"]["trajectory"]) == 100
summary = client.get(f"/interactive/sessions/{session_id}").json()
assert summary["turn"] == 0
assert summary["max_turns"] == 4
assert summary["done"] is False
end = client.delete(f"/interactive/sessions/{session_id}")
assert end.status_code == 204
assert client.get(f"/interactive/sessions/{session_id}").status_code == 404
def test_unknown_system_id_returns_400(client: TestClient) -> None:
response = client.post(
"/interactive/sessions",
json={"system_id": "no_such_system"},
)
assert response.status_code == 400
_STUB_REQ = {"base_url": "http://stub/v1", "model": "stub"}
def test_unknown_session_id_returns_404() -> None:
"""Session-scoped routes return 404 for unknown ids, not 500."""
app = _build_app_with_stubbed_llm([])
with TestClient(app) as client:
response = client.post(
"/interactive/sessions/does-not-exist/llm-step",
json=_STUB_REQ,
)
assert response.status_code == 404
# --- LLM-step endpoint (with stubbed policy) ---
def test_llm_step_drives_a_turn_using_injected_policy() -> None:
"""The endpoint must call the policy, parse, step, and surface the raw."""
app = _build_app_with_stubbed_llm(
[json.dumps({"equation": "d2y/dt2 = -9.81", "rationale": "gravity"})]
)
with TestClient(app) as client:
create = client.post(
"/interactive/sessions",
json={"system_id": "free_fall", "seed": 0, "max_turns": 4},
).json()
session_id = create["session_id"]
response = client.post(
f"/interactive/sessions/{session_id}/llm-step",
json={
"base_url": "http://stub/v1",
"model": "stub:1.5b",
"temperature": 0.1,
"max_tokens": 64,
},
)
assert response.status_code == 200, response.text
body = response.json()
assert body["model"] == "stub:1.5b"
assert body["action"]["equation"] == "d2y/dt2 = -9.81"
assert body["action"]["rationale"] == "gravity"
assert body["observation"]["turn"] == 1
assert body["observation"]["reward_breakdown"]["match"] >= 0.9
assert body["predicted_trajectory"]
assert body["latency_s"] >= 0.0
assert "d2y/dt2" in body["raw_completion"]
def test_llm_step_runs_full_episode_with_three_canned_turns() -> None:
"""Multi-turn drive: each call pops the next completion, history grows."""
completions = [
json.dumps({"equation": "d2y/dt2 = -9.81", "rationale": "pure gravity"}),
json.dumps({
"equation": "d2y/dt2 = -9.81 + 0.1 * vy",
"rationale": "linear drag",
}),
json.dumps({
"equation": "d2y/dt2 = -9.81 + 0.05 * vy**2",
"rationale": "quadratic drag",
}),
]
app = _build_app_with_stubbed_llm(completions)
with TestClient(app) as client:
session_id = client.post(
"/interactive/sessions",
json={"system_id": "free_fall_drag", "seed": 42, "max_turns": 8},
).json()["session_id"]
bodies = []
for _ in range(3):
response = client.post(
f"/interactive/sessions/{session_id}/llm-step",
json=_STUB_REQ,
)
assert response.status_code == 200, response.text
bodies.append(response.json())
assert [b["action"]["equation"] for b in bodies] == [
"d2y/dt2 = -9.81",
"d2y/dt2 = -9.81 + 0.1 * vy",
"d2y/dt2 = -9.81 + 0.05 * vy**2",
]
assert [b["observation"]["turn"] for b in bodies] == [1, 2, 3]
# History accumulates across turns.
assert len(bodies[-1]["observation"]["history"]) == 3
def test_llm_step_handles_unparseable_completion_as_format_zero() -> None:
"""If the model emits junk, the env scores it format=0, no 500."""
app = _build_app_with_stubbed_llm(["I refuse to answer."])
with TestClient(app) as client:
session_id = client.post(
"/interactive/sessions",
json={"system_id": "simple_pendulum", "seed": 0, "max_turns": 4},
).json()["session_id"]
response = client.post(
f"/interactive/sessions/{session_id}/llm-step",
json=_STUB_REQ,
)
assert response.status_code == 200, response.text
body = response.json()
assert body["observation"]["reward_breakdown"]["format"] == 0.0
assert body["predicted_trajectory"] == []
assert body["raw_completion"] == "I refuse to answer."
def test_llm_step_after_budget_exhaustion_returns_409() -> None:
"""Once the env has consumed its budget, llm-step is rejected too."""
canned = [
json.dumps({"equation": "d2theta/dt2 = 0"}),
json.dumps({"equation": "d2theta/dt2 = 0"}),
]
app = _build_app_with_stubbed_llm(canned)
with TestClient(app) as client:
session_id = client.post(
"/interactive/sessions",
json={"system_id": "simple_pendulum", "seed": 1, "max_turns": 2},
).json()["session_id"]
for _ in range(2):
assert client.post(
f"/interactive/sessions/{session_id}/llm-step",
json=_STUB_REQ,
).status_code == 200
overflow = client.post(
f"/interactive/sessions/{session_id}/llm-step",
json=_STUB_REQ,
)
assert overflow.status_code == 409
# --- CORS ---
def test_cors_preflight_for_dev_origin(client: TestClient) -> None:
"""OPTIONS preflight from the Vite dev server is allowed."""
response = client.options(
"/interactive/sessions",
headers={
"Origin": "http://localhost:5173",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "content-type",
},
)
assert response.status_code in (200, 204), response.text
assert response.headers["access-control-allow-origin"] == "http://localhost:5173"