| import pytest |
| import time |
| from utils import * |
|
|
| server = ServerPreset.tinyllama2() |
|
|
|
|
| @pytest.fixture(scope="module", autouse=True) |
| def create_server(): |
| global server |
| server = ServerPreset.tinyllama2() |
|
|
| @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ |
| ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), |
| ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), |
| ]) |
| def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): |
| global server |
| server.start() |
| res = server.make_request("POST", "/completion", data={ |
| "n_predict": n_predict, |
| "prompt": prompt, |
| }) |
| assert res.status_code == 200 |
| assert res.body["timings"]["prompt_n"] == n_prompt |
| assert res.body["timings"]["predicted_n"] == n_predicted |
| assert res.body["truncated"] == truncated |
| assert match_regex(re_content, res.body["content"]) |
|
|
|
|
| @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ |
| ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), |
| ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), |
| ]) |
| def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): |
| global server |
| server.start() |
| res = server.make_stream_request("POST", "/completion", data={ |
| "n_predict": n_predict, |
| "prompt": prompt, |
| "stream": True, |
| }) |
| content = "" |
| for data in res: |
| if data["stop"]: |
| assert data["timings"]["prompt_n"] == n_prompt |
| assert data["timings"]["predicted_n"] == n_predicted |
| assert data["truncated"] == truncated |
| assert match_regex(re_content, content) |
| else: |
| content += data["content"] |
|
|
|
|
| @pytest.mark.parametrize("n_slots", [1, 2]) |
| def test_consistent_result_same_seed(n_slots: int): |
| global server |
| server.n_slots = n_slots |
| server.start() |
| last_res = None |
| for _ in range(4): |
| res = server.make_request("POST", "/completion", data={ |
| "prompt": "I believe the meaning of life is", |
| "seed": 42, |
| "temperature": 1.0, |
| "cache_prompt": False, |
| }) |
| if last_res is not None: |
| assert res.body["content"] == last_res.body["content"] |
| last_res = res |
|
|
|
|
| @pytest.mark.parametrize("n_slots", [1, 2]) |
| def test_different_result_different_seed(n_slots: int): |
| global server |
| server.n_slots = n_slots |
| server.start() |
| last_res = None |
| for seed in range(4): |
| res = server.make_request("POST", "/completion", data={ |
| "prompt": "I believe the meaning of life is", |
| "seed": seed, |
| "temperature": 1.0, |
| "cache_prompt": False, |
| }) |
| if last_res is not None: |
| assert res.body["content"] != last_res.body["content"] |
| last_res = res |
|
|
|
|
| @pytest.mark.parametrize("n_batch", [16, 32]) |
| @pytest.mark.parametrize("temperature", [0.0, 1.0]) |
| def test_consistent_result_different_batch_size(n_batch: int, temperature: float): |
| global server |
| server.n_batch = n_batch |
| server.start() |
| last_res = None |
| for _ in range(4): |
| res = server.make_request("POST", "/completion", data={ |
| "prompt": "I believe the meaning of life is", |
| "seed": 42, |
| "temperature": temperature, |
| "cache_prompt": False, |
| }) |
| if last_res is not None: |
| assert res.body["content"] == last_res.body["content"] |
| last_res = res |
|
|
|
|
| @pytest.mark.skip(reason="This test fails on linux, need to be fixed") |
| def test_cache_vs_nocache_prompt(): |
| global server |
| server.start() |
| res_cache = server.make_request("POST", "/completion", data={ |
| "prompt": "I believe the meaning of life is", |
| "seed": 42, |
| "temperature": 1.0, |
| "cache_prompt": True, |
| }) |
| res_no_cache = server.make_request("POST", "/completion", data={ |
| "prompt": "I believe the meaning of life is", |
| "seed": 42, |
| "temperature": 1.0, |
| "cache_prompt": False, |
| }) |
| assert res_cache.body["content"] == res_no_cache.body["content"] |
|
|
|
|
| def test_completion_with_tokens_input(): |
| global server |
| server.temperature = 0.0 |
| server.start() |
| prompt_str = "I believe the meaning of life is" |
| res = server.make_request("POST", "/tokenize", data={ |
| "content": prompt_str, |
| "add_special": True, |
| }) |
| assert res.status_code == 200 |
| tokens = res.body["tokens"] |
|
|
| |
| res = server.make_request("POST", "/completion", data={ |
| "prompt": tokens, |
| }) |
| assert res.status_code == 200 |
| assert type(res.body["content"]) == str |
|
|
| |
| res = server.make_request("POST", "/completion", data={ |
| "prompt": [tokens, tokens], |
| }) |
| assert res.status_code == 200 |
| assert type(res.body) == list |
| assert len(res.body) == 2 |
| assert res.body[0]["content"] == res.body[1]["content"] |
|
|
| |
| res = server.make_request("POST", "/completion", data={ |
| "prompt": [tokens, prompt_str], |
| }) |
| assert res.status_code == 200 |
| assert type(res.body) == list |
| assert len(res.body) == 2 |
| assert res.body[0]["content"] == res.body[1]["content"] |
|
|
| |
| res = server.make_request("POST", "/completion", data={ |
| "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], |
| }) |
| assert res.status_code == 200 |
| assert type(res.body["content"]) == str |
|
|
|
|
| @pytest.mark.parametrize("n_slots,n_requests", [ |
| (1, 3), |
| (2, 2), |
| (2, 4), |
| (4, 2), |
| (4, 6), |
| ]) |
| def test_completion_parallel_slots(n_slots: int, n_requests: int): |
| global server |
| server.n_slots = n_slots |
| server.temperature = 0.0 |
| server.start() |
|
|
| PROMPTS = [ |
| ("Write a very long book.", "(very|special|big)+"), |
| ("Write another a poem.", "(small|house)+"), |
| ("What is LLM?", "(Dad|said)+"), |
| ("The sky is blue and I love it.", "(climb|leaf)+"), |
| ("Write another very long music lyrics.", "(friends|step|sky)+"), |
| ("Write a very long joke.", "(cat|Whiskers)+"), |
| ] |
| def check_slots_status(): |
| should_all_slots_busy = n_requests >= n_slots |
| time.sleep(0.1) |
| res = server.make_request("GET", "/slots") |
| n_busy = sum([1 for slot in res.body if slot["is_processing"]]) |
| if should_all_slots_busy: |
| assert n_busy == n_slots |
| else: |
| assert n_busy <= n_slots |
|
|
| tasks = [] |
| for i in range(n_requests): |
| prompt, re_content = PROMPTS[i % len(PROMPTS)] |
| tasks.append((server.make_request, ("POST", "/completion", { |
| "prompt": prompt, |
| "seed": 42, |
| "temperature": 1.0, |
| }))) |
| tasks.append((check_slots_status, ())) |
| results = parallel_function_calls(tasks) |
|
|
| |
| for i in range(n_requests): |
| prompt, re_content = PROMPTS[i % len(PROMPTS)] |
| res = results[i] |
| assert res.status_code == 200 |
| assert type(res.body["content"]) == str |
| assert len(res.body["content"]) > 10 |
| |
| |
|
|