| |
| ```python |
| import pytest |
| from models import AVAILABLE_MODELS, find_model, ModelInfo |
|
|
| @pyte st.mark.parametrize("identifier, expected_id", [ |
| ("Moonshot Kimi-K2", "moonshotai/Kimi-K2-Instruct"), |
| ("moonshotai/Kimi-K2-Instruct", "moonshotai/Kimi-K2-Instruct"), |
| ("openai/gpt-4", "openai/gpt-4"), |
| ]) |
| def test_find_model(identifier, expected_id): |
| model = find_model(identifier) |
| assert isinstance(model, ModelInfo) |
| assert model.id == expected_id |
|
|
|
|
| def test_find_model_not_found(): |
| assert find_model("nonexistent-model") is None |
|
|
|
|
| def test_available_models_have_unique_ids(): |
| ids = [m.id for m in AVAILABLE_MODELS] |
| assert len(ids) == len(set(ids)) |
| ``` |
|
|
| |
| ```python |
| import pytest |
| from inference import chat_completion, stream_chat_completion |
| from models import ModelInfo |
|
|
| class DummyClient: |
| def __init__(self, response): |
| self.response = response |
| self.chat = self |
| n |
| def completions(self, **kwargs): |
| class Choice: pass |
| choice = type('C', (), {'message': type('M', (), {'content': self.response})}) |
| return type('R', (), {'choices': [choice]}) |
|
|
| @pytest.fixture(autouse=True) |
| def patch_client(monkeypatch): |
| |
| from hf_client import get_inference_client |
| def fake_client(model_id, provider): |
| client = DummyClient("hello world") |
| client.chat = client |
| client.chat.completions = client |
| return client |
| monkeypatch.setattr('hf_client.get_inference_client', fake_client) |
|
|
|
|
| def test_chat_completion_returns_text(): |
| msg = [{'role': 'user', 'content': 'test'}] |
| result = chat_completion('any-model', msg) |
| assert isinstance(result, str) |
| assert result == 'hello world' |
|
|
|
|
| def test_stream_chat_completion_yields_chunks(): |
| |
| class StreamClient(DummyClient): |
| def completions(self, **kwargs): |
| |
| chunks = [type('C', (), {'choices': [type('Ch', (), {'delta': type('D', (), {'content': 'h'})})]}), |
| type('C', (), {'choices': [type('Ch', (), {'delta': type('D', (), {'content': 'i'})})]})] |
| return iter(chunks) |
| from hf_client import get_inference_client as real_get |
| monkeypatch.setattr('hf_client.get_inference_client', lambda mid, prov: StreamClient(None)) |
|
|
| msg = [{'role': 'user', 'content': 'stream'}] |
| chunks = list(stream_chat_completion('any-model', msg)) |
| assert ''.join(chunks) == 'hi' |
|
|