import httpx import pytest import recap.inference.mi300x_client as client class _FakeResp: def __init__(self, status_code=200, json_data=None): self.status_code = status_code self._json = json_data or {} def raise_for_status(self): if self.status_code >= 400: req = httpx.Request("POST", "http://x") raise httpx.HTTPStatusError("err", request=req, response=httpx.Response(self.status_code)) def json(self): return self._json def test_raises_when_url_unset(monkeypatch): monkeypatch.delenv("RECAP_MI300X_URL", raising=False) with pytest.raises(RuntimeError, match="RECAP_MI300X_URL"): client._post("medgemma", "sys", "user") def test_posts_to_correct_url(monkeypatch): monkeypatch.setenv("RECAP_MI300X_URL", "https://example.test") seen = {} def fake_post(url, json, timeout): seen["url"] = url seen["json"] = json return _FakeResp(200, {"text": "hello"}) monkeypatch.setattr(client.httpx, "post", fake_post) out = client._post("qwen", "sys-prompt", "user-prompt") assert out == "hello" assert seen["url"] == "https://example.test/qwen" assert seen["json"] == {"system": "sys-prompt", "user": "user-prompt"} def test_retries_on_transport_errors(monkeypatch): monkeypatch.setenv("RECAP_MI300X_URL", "https://example.test") monkeypatch.setattr(client.time, "sleep", lambda *_: None) calls = {"n": 0} def flaky_post(url, json, timeout): calls["n"] += 1 if calls["n"] < 3: raise httpx.ConnectError("boom") return _FakeResp(200, {"text": "ok"}) monkeypatch.setattr(client.httpx, "post", flaky_post) out = client._post("medgemma", "s", "u") assert out == "ok" assert calls["n"] == 3 def test_gives_up_after_three_attempts(monkeypatch): monkeypatch.setenv("RECAP_MI300X_URL", "https://example.test") monkeypatch.setattr(client.time, "sleep", lambda *_: None) calls = {"n": 0} def always_fail(url, json, timeout): calls["n"] += 1 raise httpx.ConnectError("down") monkeypatch.setattr(client.httpx, "post", always_fail) with pytest.raises(RuntimeError, match="failed after 3 attempts"): client._post("medgemma", "s", "u") assert calls["n"] == 3 def test_strips_trailing_slash_from_url(monkeypatch): monkeypatch.setenv("RECAP_MI300X_URL", "https://example.test/") seen = {} def fake_post(url, json, timeout): seen["url"] = url return _FakeResp(200, {"text": ""}) monkeypatch.setattr(client.httpx, "post", fake_post) client._post("qwen", "s", "u") assert seen["url"] == "https://example.test/qwen"