| |
| import pytest |
| from inference import chat_completion, stream_chat_completion |
|
|
| class DummyStream: |
| def __init__(self, chunks): |
| self._chunks = chunks |
| def __iter__(self): |
| return iter(self._chunks) |
|
|
| class DummyClient: |
| def __init__(self, response): |
| self.response = response |
| self.chat = self |
| def completions(self, **kwargs): |
| return self |
| def create(self, **kwargs): |
| |
| if kwargs.get("stream"): |
| from types import SimpleNamespace |
| chunks = [ |
| SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="h"))]), |
| SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="i"))]) |
| ] |
| return DummyStream(chunks) |
| |
| from types import SimpleNamespace |
| return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self.response))]) |
|
|
| @pytest.fixture(autouse=True) |
| def patch_client(monkeypatch): |
| from hf_client import get_inference_client |
| def fake(model_id, provider): |
| return DummyClient("hello") |
| monkeypatch.setattr('hf_client.get_inference_client', fake) |
|
|
| def test_chat_completion(): |
| out = chat_completion("any-model", [{"role":"user","content":"hi"}]) |
| assert out == "hello" |
|
|
| def test_stream_chat_completion(): |
| chunks = list(stream_chat_completion("any-model", [{"role":"user","content":"stream"}])) |
| assert "".join(chunks) == "hi" |
|
|