| |
| import pytest |
| from unittest.mock import patch, MagicMock, ANY, AsyncMock |
| from openai import OpenAIError |
| import json |
| import tenacity |
| import asyncio |
| from openai.types.chat import ChatCompletion |
| from openai.types.chat.chat_completion import Choice as ChatCompletionChoice |
| from openai.types.chat.chat_completion_message import ChatCompletionMessage |
| from openai import RateLimitError, APIConnectionError, AsyncOpenAI |
|
|
| |
| from ankigen_core.llm_interface import ( |
| OpenAIClientManager, |
| structured_output_completion, |
| process_crawled_page, |
| process_crawled_pages, |
| ) |
| from ankigen_core.utils import ( |
| ResponseCache, |
| ) |
| from ankigen_core.models import CrawledPage, AnkiCardData |
| from openai import APIError |
|
|
| |
|
|
|
|
| @pytest.mark.anyio |
| async def test_client_manager_init(): |
| """Test initial state of the client manager.""" |
| manager = OpenAIClientManager() |
| assert manager._client is None |
| assert manager._api_key is None |
|
|
|
|
| @pytest.mark.anyio |
| async def test_client_manager_initialize_success(): |
| """Test successful client initialization.""" |
| manager = OpenAIClientManager() |
| valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| |
| with patch( |
| "ankigen_core.llm_interface.AsyncOpenAI" |
| ) as mock_async_openai_constructor: |
| await manager.initialize_client(valid_key) |
| mock_async_openai_constructor.assert_called_once_with(api_key=valid_key) |
| assert manager.get_client() is not None |
|
|
|
|
| @pytest.mark.anyio |
| async def test_client_manager_initialize_invalid_key_format(): |
| """Test initialization failure with invalid API key format.""" |
| manager = OpenAIClientManager() |
| invalid_key = "invalid-key-format" |
| with pytest.raises(ValueError, match="Invalid OpenAI API key format."): |
| await manager.initialize_client(invalid_key) |
| assert manager._client is None |
| assert manager._api_key is None |
|
|
|
|
| @pytest.mark.anyio |
| async def test_client_manager_initialize_openai_error(): |
| """Test handling of OpenAIError during client initialization.""" |
| manager = OpenAIClientManager() |
| valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| error_message = "Test OpenAI Init Error" |
|
|
| with patch( |
| "ankigen_core.llm_interface.AsyncOpenAI", side_effect=OpenAIError(error_message) |
| ) as mock_async_openai_constructor: |
| with pytest.raises(OpenAIError, match=error_message): |
| await manager.initialize_client(valid_key) |
| mock_async_openai_constructor.assert_called_once_with(api_key=valid_key) |
|
|
|
|
| @pytest.mark.anyio |
| async def test_client_manager_get_client_success(): |
| """Test getting the client after successful initialization.""" |
| manager = OpenAIClientManager() |
| valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| with patch( |
| "ankigen_core.llm_interface.AsyncOpenAI" |
| ) as mock_async_openai_constructor: |
| mock_instance = mock_async_openai_constructor.return_value |
| await manager.initialize_client(valid_key) |
| assert manager.get_client() == mock_instance |
|
|
|
|
| def test_client_manager_get_client_not_initialized(): |
| """Test getting the client before initialization.""" |
| manager = OpenAIClientManager() |
| with pytest.raises(RuntimeError, match="OpenAI client is not initialized."): |
| manager.get_client() |
|
|
|
|
| |
|
|
|
|
| |
| @pytest.fixture |
| def mock_openai_client(): |
| client = MagicMock(spec=AsyncOpenAI) |
| client.chat = AsyncMock() |
| client.chat.completions = AsyncMock() |
| client.chat.completions.create = AsyncMock() |
| mock_chat_completion_response = create_mock_chat_completion( |
| json.dumps([{"data": "mocked success"}]) |
| ) |
| client.chat.completions.create.return_value = mock_chat_completion_response |
| return client |
|
|
|
|
| |
| @pytest.fixture |
| def mock_response_cache(): |
| cache = MagicMock(spec=ResponseCache) |
| return cache |
|
|
|
|
| @pytest.mark.anyio |
| async def test_structured_output_completion_cache_hit( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when the response is found in the cache.""" |
| system_prompt = "System prompt" |
| user_prompt = "User prompt" |
| model = "test-model" |
| cached_result = {"data": "cached result"} |
|
|
| |
| mock_response_cache.get.return_value = cached_result |
|
|
| result = await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
|
|
| |
| mock_response_cache.get.assert_called_once_with( |
| f"{system_prompt}:{user_prompt}", model |
| ) |
| mock_openai_client.chat.completions.create.assert_not_called() |
| mock_response_cache.set.assert_not_called() |
| assert result == cached_result |
|
|
|
|
| @pytest.mark.anyio |
| async def test_structured_output_completion_cache_miss_success( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior on cache miss with a successful API call.""" |
| system_prompt = "System prompt for success" |
| user_prompt = "User prompt for success" |
| model = "test-model-success" |
| expected_result = {"data": "successful API result"} |
|
|
| |
| mock_response_cache.get.return_value = None |
|
|
| |
| mock_completion = MagicMock() |
| mock_message = MagicMock() |
| mock_message.content = json.dumps(expected_result) |
| mock_choice = MagicMock() |
| mock_choice.message = mock_message |
| mock_completion.choices = [mock_choice] |
| mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
| result = await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
|
|
| |
| mock_response_cache.get.assert_called_once_with( |
| f"{system_prompt}:{user_prompt}", model |
| ) |
| mock_openai_client.chat.completions.create.assert_called_once_with( |
| model=model, |
| messages=[ |
| { |
| "role": "system", |
| "content": ANY, |
| }, |
| {"role": "user", "content": user_prompt}, |
| ], |
| response_format={"type": "json_object"}, |
| temperature=0.7, |
| ) |
| mock_response_cache.set.assert_called_once_with( |
| f"{system_prompt}:{user_prompt}", model, expected_result |
| ) |
| assert result == expected_result |
|
|
|
|
| @pytest.mark.anyio |
| async def test_structured_output_completion_api_error( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when the OpenAI API call raises an error.""" |
| system_prompt = "System prompt for error" |
| user_prompt = "User prompt for error" |
| model = "test-model-error" |
| error_message = "Test API Error" |
|
|
| |
| mock_response_cache.get.return_value = None |
|
|
| |
| |
| |
| mock_openai_client.chat.completions.create.side_effect = OpenAIError(error_message) |
|
|
| with pytest.raises(tenacity.RetryError): |
| await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
|
|
| |
| |
| |
|
|
| |
| |
| assert mock_response_cache.get.call_count == 3, ( |
| f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}" |
| ) |
| |
| assert mock_openai_client.chat.completions.create.call_count == 3, ( |
| f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}" |
| ) |
| mock_response_cache.set.assert_not_called() |
|
|
|
|
| @pytest.mark.anyio |
| async def test_structured_output_completion_invalid_json( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when the API returns invalid JSON.""" |
| system_prompt = "System prompt for invalid json" |
| user_prompt = "User prompt for invalid json" |
| model = "test-model-invalid-json" |
| invalid_json_content = "this is not json" |
|
|
| |
| mock_response_cache.get.return_value = None |
|
|
| |
| mock_completion = MagicMock() |
| mock_message = MagicMock() |
| mock_message.content = invalid_json_content |
| mock_choice = MagicMock() |
| mock_choice.message = mock_message |
| mock_completion.choices = [mock_choice] |
| mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
| with pytest.raises(tenacity.RetryError): |
| await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
|
|
| |
| |
| assert mock_response_cache.get.call_count == 3, ( |
| f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}" |
| ) |
| |
| assert mock_openai_client.chat.completions.create.call_count == 3, ( |
| f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}" |
| ) |
| mock_response_cache.set.assert_not_called() |
|
|
|
|
| @pytest.mark.anyio |
| async def test_structured_output_completion_no_choices( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when API completion has no choices.""" |
| system_prompt = "System prompt no choices" |
| user_prompt = "User prompt no choices" |
| model = "test-model-no-choices" |
|
|
| mock_response_cache.get.return_value = None |
| mock_completion = MagicMock() |
| mock_completion.choices = [] |
| mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
| |
| result = await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
| assert result is None |
| mock_response_cache.set.assert_not_called() |
|
|
|
|
| @pytest.mark.anyio |
| async def test_structured_output_completion_no_message_content( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when API choice has no message content.""" |
| system_prompt = "System prompt no content" |
| user_prompt = "User prompt no content" |
| model = "test-model-no-content" |
|
|
| mock_response_cache.get.return_value = None |
| mock_completion = MagicMock() |
| mock_message = MagicMock() |
| mock_message.content = None |
| mock_choice = MagicMock() |
| mock_choice.message = mock_message |
| mock_completion.choices = [mock_choice] |
| mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
| |
| result = await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
| assert result is None |
| mock_response_cache.set.assert_not_called() |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
|
|
| |
|
|
|
|
| def create_mock_chat_completion(content: str) -> ChatCompletion: |
| return ChatCompletion( |
| id="chatcmpl-test123", |
| choices=[ |
| ChatCompletionChoice( |
| finish_reason="stop", |
| index=0, |
| message=ChatCompletionMessage(content=content, role="assistant"), |
| logprobs=None, |
| ) |
| ], |
| created=1677652288, |
| model="gpt-4o", |
| object="chat.completion", |
| system_fingerprint="fp_test", |
| usage=None, |
| ) |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_success(mock_openai_client, sample_crawled_page): |
| mock_response_content = json.dumps( |
| [ |
| {"front": "Q1", "back": "A1", "tags": ["tag1"]}, |
| {"front": "Q2", "back": "A2", "tags": ["tag2", "python"]}, |
| ] |
| ) |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion(mock_response_content) |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
| assert len(result_cards) == 2 |
| assert result_cards[0].front == "Q1" |
| assert result_cards[0].source_url == sample_crawled_page.url |
| assert result_cards[1].tags == ["tag2", "python"] |
| mock_openai_client.chat.completions.create.assert_awaited_once() |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_empty_llm_response_content( |
| mock_openai_client, sample_crawled_page |
| ): |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion("") |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| assert len(result_cards) == 0 |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_llm_returns_not_a_list( |
| mock_openai_client, sample_crawled_page |
| ): |
| mock_response_content = json.dumps( |
| {"error": "not a list as expected"} |
| ) |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion(mock_response_content) |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| assert len(result_cards) == 0 |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_llm_returns_dict_with_cards_key( |
| mock_openai_client, sample_crawled_page |
| ): |
| mock_response_content = json.dumps( |
| {"cards": [{"front": "Q1", "back": "A1", "tags": []}]} |
| ) |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion(mock_response_content) |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| assert len(result_cards) == 1 |
| assert result_cards[0].front == "Q1" |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_json_decode_error( |
| mock_openai_client, sample_crawled_page |
| ): |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion("this is not valid json") |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| assert len(result_cards) == 0 |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_empty_text_content(mock_openai_client): |
| empty_content_page = CrawledPage( |
| url="http://example.com/empty", |
| html_content="", |
| text_content=" ", |
| title="Empty", |
| ) |
| result_cards = await process_crawled_page(mock_openai_client, empty_content_page) |
| assert len(result_cards) == 0 |
| mock_openai_client.chat.completions.create.assert_not_awaited() |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_openai_api_error_retry( |
| mock_openai_client, sample_crawled_page, caplog |
| ): |
| |
| errors_to_raise = [ |
| RateLimitError("rate limited", response=MagicMock(), body=None) |
| ] * 2 + [ |
| create_mock_chat_completion( |
| json.dumps([{"front": "Q1", "back": "A1", "tags": []}]) |
| ) |
| ] |
|
|
| mock_openai_client.chat.completions.create.side_effect = errors_to_raise |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
| assert len(result_cards) == 1 |
| assert result_cards[0].front == "Q1" |
| assert ( |
| mock_openai_client.chat.completions.create.await_count == 3 |
| ) |
| assert "Retrying OpenAI call (attempt 1)" in caplog.text |
| assert "Retrying OpenAI call (attempt 2)" in caplog.text |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_openai_persistent_api_error( |
| mock_openai_client, sample_crawled_page, caplog |
| ): |
| |
| mock_openai_client.chat.completions.create.side_effect = APIConnectionError( |
| request=MagicMock() |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
| assert len(result_cards) == 0 |
| assert ( |
| mock_openai_client.chat.completions.create.await_count == 3 |
| ) |
| assert "OpenAI API error after retries" in caplog.text |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_tiktoken_truncation( |
| mock_openai_client, sample_crawled_page |
| ): |
| |
| long_text = "word " * 8000 |
| sample_crawled_page.text_content = long_text |
|
|
| |
| mock_response_content = json.dumps( |
| [{"front": "TruncatedQ", "back": "TruncatedA", "tags": []}] |
| ) |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion(mock_response_content) |
| ) |
|
|
| |
| await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
| |
| |
| call_args = mock_openai_client.chat.completions.create.call_args |
| user_prompt_message_content = next( |
| m["content"] for m in call_args.kwargs["messages"] if m["role"] == "user" |
| ) |
|
|
| |
| |
| assert "CONTENT:\n" in user_prompt_message_content |
| content_part = user_prompt_message_content.split("CONTENT:\n")[1].split( |
| "\n\nReturn a JSON array" |
| )[0] |
|
|
| import tiktoken |
|
|
| encoding = tiktoken.get_encoding( |
| "cl100k_base" |
| ) |
| num_tokens = len(encoding.encode(content_part)) |
|
|
| |
| assert 5900 < num_tokens < 6100 |
|
|
|
|
| |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_pages_success(mock_openai_client, sample_crawled_page): |
| pages_to_process = [ |
| sample_crawled_page, |
| CrawledPage( |
| url="http://example.com/page2", |
| html_content="", |
| text_content="Content for page 2", |
| title="Page 2", |
| ), |
| ] |
|
|
| |
| async def mock_single_page_processor(client, page, model, max_tokens): |
| if page.url == pages_to_process[0].url: |
| return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)] |
| elif page.url == pages_to_process[1].url: |
| return [ |
| AnkiCardData(front="P2Q1", back="P2A1", source_url=page.url), |
| AnkiCardData(front="P2Q2", back="P2A2", source_url=page.url), |
| ] |
| return [] |
|
|
| with patch( |
| "ankigen_core.llm_interface.process_crawled_page", |
| side_effect=mock_single_page_processor, |
| ) as mock_processor: |
| result_cards = await process_crawled_pages( |
| mock_openai_client, pages_to_process, max_concurrent_requests=1 |
| ) |
|
|
| assert len(result_cards) == 3 |
| assert result_cards[0].front == "P1Q1" |
| assert result_cards[1].front == "P2Q1" |
| assert result_cards[2].front == "P2Q2" |
| assert mock_processor.call_count == 2 |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_pages_partial_failure( |
| mock_openai_client, sample_crawled_page |
| ): |
| pages_to_process = [ |
| sample_crawled_page, |
| CrawledPage( |
| url="http://example.com/page_fail", |
| html_content="", |
| text_content="Content for page fail", |
| title="Page Fail", |
| ), |
| CrawledPage( |
| url="http://example.com/page3", |
| html_content="", |
| text_content="Content for page 3", |
| title="Page 3", |
| ), |
| ] |
|
|
| async def mock_single_page_processor_with_failure(client, page, model, max_tokens): |
| if page.url == pages_to_process[0].url: |
| return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)] |
| elif page.url == pages_to_process[1].url: |
| raise APIConnectionError(request=MagicMock()) |
| elif page.url == pages_to_process[2].url: |
| return [AnkiCardData(front="P3Q1", back="P3A1", source_url=page.url)] |
| return [] |
|
|
| with patch( |
| "ankigen_core.llm_interface.process_crawled_page", |
| side_effect=mock_single_page_processor_with_failure, |
| ) as mock_processor: |
| result_cards = await process_crawled_pages( |
| mock_openai_client, pages_to_process, max_concurrent_requests=2 |
| ) |
|
|
| assert len(result_cards) == 2 |
| successful_urls = [card.source_url for card in result_cards] |
| assert pages_to_process[0].url in successful_urls |
| assert pages_to_process[2].url in successful_urls |
| assert pages_to_process[1].url not in successful_urls |
| assert mock_processor.call_count == 3 |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_pages_progress_callback( |
| mock_openai_client, sample_crawled_page |
| ): |
| pages_to_process = [sample_crawled_page] * 3 |
| progress_log = [] |
|
|
| def callback(completed_count, total_count): |
| progress_log.append((completed_count, total_count)) |
|
|
| async def mock_simple_processor(client, page, model, max_tokens): |
| await asyncio.sleep(0.01) |
| return [AnkiCardData(front=f"{page.url}-Q", back="A", source_url=page.url)] |
|
|
| with patch( |
| "ankigen_core.llm_interface.process_crawled_page", |
| side_effect=mock_simple_processor, |
| ): |
| await process_crawled_pages( |
| mock_openai_client, |
| pages_to_process, |
| progress_callback=callback, |
| max_concurrent_requests=1, |
| ) |
|
|
| assert len(progress_log) == 3 |
| assert progress_log[0] == (1, 3) |
| assert progress_log[1] == (2, 3) |
| assert progress_log[2] == (3, 3) |
|
|
|
|
| |
| TEST_API_KEY = "sk-testkey1234567890abcdefghijklmnopqrstuvwxyz" |
|
|
|
|
| @pytest.fixture |
| def client_manager(): |
| """Fixture for OpenAIClientManager.""" |
| return OpenAIClientManager() |
|
|
|
|
| @pytest.fixture |
| def mock_async_openai_client(): |
| """Mocks an AsyncOpenAI client instance.""" |
| mock_client = AsyncMock() |
| mock_client.chat = AsyncMock() |
| mock_client.chat.completions = AsyncMock() |
| mock_client.chat.completions.create = AsyncMock() |
|
|
| |
| mock_response = MagicMock() |
| mock_response.choices = [MagicMock()] |
| mock_response.choices[0].message = MagicMock() |
| mock_response.choices[ |
| 0 |
| ].message.content = '{"question": "Q1", "answer": "A1"}' |
| mock_response.usage = MagicMock() |
| mock_response.usage.total_tokens = 100 |
|
|
| mock_client.chat.completions.create.return_value = mock_response |
| return mock_client |
|
|
|
|
| @pytest.fixture |
| def sample_crawled_page(): |
| """Fixture for a sample CrawledPage object.""" |
| return CrawledPage( |
| url="http://example.com", |
| html_content="<html><body>This is some test content for the page.</body></html>", |
| text_content="This is some test content for the page.", |
| title="Test Page", |
| meta_description="A test page.", |
| meta_keywords=["test", "page"], |
| crawl_depth=0, |
| ) |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_success( |
| client_manager, mock_async_openai_client, sample_crawled_page |
| ): |
| """Test successful processing of a single crawled page.""" |
| with patch.object( |
| client_manager, "get_client", return_value=mock_async_openai_client |
| ): |
| result, tokens = await process_crawled_page( |
| mock_async_openai_client, |
| sample_crawled_page, |
| "gpt-4o", |
| max_prompt_content_tokens=1000, |
| ) |
| assert isinstance(result, AnkiCardData) |
| assert result.front == "Q1" |
| assert result.back == "A1" |
| assert tokens == 100 |
| mock_async_openai_client.chat.completions.create.assert_called_once() |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_json_error( |
| client_manager, mock_async_openai_client, sample_crawled_page |
| ): |
| """Test handling of invalid JSON response from LLM.""" |
| mock_async_openai_client.chat.completions.create.return_value.choices[ |
| 0 |
| ].message.content = "This is not JSON" |
|
|
| with patch.object( |
| client_manager, "get_client", return_value=mock_async_openai_client |
| ): |
| |
| mock_async_openai_client.chat.completions.create.reset_mock() |
|
|
| result, tokens = await process_crawled_page( |
| mock_async_openai_client, |
| sample_crawled_page, |
| "gpt-4o", |
| max_prompt_content_tokens=1000, |
| ) |
| assert result is None |
| assert ( |
| tokens == 100 |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| assert mock_async_openai_client.chat.completions.create.call_count >= 1 |
| |
| |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_api_error( |
| client_manager, mock_async_openai_client, sample_crawled_page |
| ): |
| """Test handling of API error during LLM call.""" |
|
|
| |
| |
| mock_request = MagicMock() |
| mock_async_openai_client.chat.completions.create.side_effect = APIError( |
| message="Test API Error", request=mock_request, body=None |
| ) |
|
|
| with patch.object( |
| client_manager, "get_client", return_value=mock_async_openai_client |
| ): |
| |
| mock_async_openai_client.chat.completions.create.reset_mock() |
|
|
| result, tokens = await process_crawled_page( |
| mock_async_openai_client, |
| sample_crawled_page, |
| "gpt-4o", |
| max_prompt_content_tokens=1000, |
| ) |
| assert result is None |
| assert tokens == 0 |
| |
| assert mock_async_openai_client.chat.completions.create.call_count > 1 |
|
|
|
|
| @pytest.mark.anyio |
| async def test_process_crawled_page_content_truncation( |
| client_manager, mock_async_openai_client, sample_crawled_page |
| ): |
| """Test content truncation based on max_prompt_content_tokens.""" |
| long_content_piece = "This is a word. " |
| repetitions = 10 |
| sample_crawled_page.content = long_content_piece * repetitions |
|
|
| with ( |
| patch.object( |
| client_manager, "get_client", return_value=mock_async_openai_client |
| ), |
| patch("tiktoken.get_encoding") as mock_get_encoding, |
| ): |
| mock_encoding = MagicMock() |
|
|
| original_tokens = [] |
| for i in range(repetitions): |
| original_tokens.extend([i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3]) |
|
|
| mock_encoding.encode.return_value = original_tokens |
|
|
| def mock_decode_side_effect(token_ids): |
| num_tokens_to_decode = len(token_ids) |
| num_full_pieces = num_tokens_to_decode // 4 |
| partial_piece_tokens = num_tokens_to_decode % 4 |
| decoded_str = long_content_piece * num_full_pieces |
| if partial_piece_tokens > 0: |
| words_in_piece = long_content_piece.strip().split(" ") |
| num_words_to_take = min(partial_piece_tokens, len(words_in_piece)) |
| decoded_str += " ".join(words_in_piece[:num_words_to_take]) |
| return decoded_str.strip() |
|
|
| mock_encoding.decode.side_effect = mock_decode_side_effect |
| mock_get_encoding.return_value = mock_encoding |
|
|
| mock_async_openai_client.chat.completions.create.reset_mock() |
|
|
| await process_crawled_page( |
| mock_async_openai_client, |
| sample_crawled_page, |
| "gpt-4o", |
| max_prompt_content_tokens=5, |
| ) |
|
|
| mock_get_encoding.assert_called_once_with("cl100k_base") |
| mock_encoding.encode.assert_called_once_with( |
| sample_crawled_page.content, disallowed_special=() |
| ) |
| mock_encoding.decode.assert_called_once_with(original_tokens[:5]) |
|
|
| call_args = mock_async_openai_client.chat.completions.create.call_args |
| assert call_args is not None |
| messages = call_args.kwargs["messages"] |
| user_prompt_content = messages[1]["content"] |
|
|
| expected_truncated_content = mock_decode_side_effect(original_tokens[:5]) |
| assert f"Content: {expected_truncated_content}" in user_prompt_content |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @pytest.mark.anyio |
| async def test_openai_client_manager_get_client( |
| client_manager, mock_async_openai_client |
| ): |
| """Test that get_client returns the AsyncOpenAI client instance and initializes it once.""" |
| with patch( |
| "openai.AsyncOpenAI", return_value=mock_async_openai_client |
| ) as mock_constructor: |
| client1 = client_manager.get_client() |
| client2 = client_manager.get_client() |
|
|
| assert client1 is mock_async_openai_client |
| assert client2 is mock_async_openai_client |
| mock_constructor.assert_called_once_with(api_key=TEST_API_KEY) |
|
|
|
|
| |
| |
| |
| |
|
|