File size: 2,708 Bytes
ba54ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import httpx
import pytest

import recap.inference.mi300x_client as client


class _FakeResp:
    def __init__(self, status_code=200, json_data=None):
        self.status_code = status_code
        self._json = json_data or {}

    def raise_for_status(self):
        if self.status_code >= 400:
            req = httpx.Request("POST", "http://x")
            raise httpx.HTTPStatusError("err", request=req, response=httpx.Response(self.status_code))

    def json(self):
        return self._json


def test_raises_when_url_unset(monkeypatch):
    monkeypatch.delenv("RECAP_MI300X_URL", raising=False)
    with pytest.raises(RuntimeError, match="RECAP_MI300X_URL"):
        client._post("medgemma", "sys", "user")


def test_posts_to_correct_url(monkeypatch):
    monkeypatch.setenv("RECAP_MI300X_URL", "https://example.test")
    seen = {}

    def fake_post(url, json, timeout):
        seen["url"] = url
        seen["json"] = json
        return _FakeResp(200, {"text": "hello"})

    monkeypatch.setattr(client.httpx, "post", fake_post)
    out = client._post("qwen", "sys-prompt", "user-prompt")
    assert out == "hello"
    assert seen["url"] == "https://example.test/qwen"
    assert seen["json"] == {"system": "sys-prompt", "user": "user-prompt"}


def test_retries_on_transport_errors(monkeypatch):
    monkeypatch.setenv("RECAP_MI300X_URL", "https://example.test")
    monkeypatch.setattr(client.time, "sleep", lambda *_: None)
    calls = {"n": 0}

    def flaky_post(url, json, timeout):
        calls["n"] += 1
        if calls["n"] < 3:
            raise httpx.ConnectError("boom")
        return _FakeResp(200, {"text": "ok"})

    monkeypatch.setattr(client.httpx, "post", flaky_post)
    out = client._post("medgemma", "s", "u")
    assert out == "ok"
    assert calls["n"] == 3


def test_gives_up_after_three_attempts(monkeypatch):
    monkeypatch.setenv("RECAP_MI300X_URL", "https://example.test")
    monkeypatch.setattr(client.time, "sleep", lambda *_: None)
    calls = {"n": 0}

    def always_fail(url, json, timeout):
        calls["n"] += 1
        raise httpx.ConnectError("down")

    monkeypatch.setattr(client.httpx, "post", always_fail)
    with pytest.raises(RuntimeError, match="failed after 3 attempts"):
        client._post("medgemma", "s", "u")
    assert calls["n"] == 3


def test_strips_trailing_slash_from_url(monkeypatch):
    monkeypatch.setenv("RECAP_MI300X_URL", "https://example.test/")
    seen = {}

    def fake_post(url, json, timeout):
        seen["url"] = url
        return _FakeResp(200, {"text": ""})

    monkeypatch.setattr(client.httpx, "post", fake_post)
    client._post("qwen", "s", "u")
    assert seen["url"] == "https://example.test/qwen"