| from types import SimpleNamespace |
|
|
| import pytest |
|
|
| from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial |
|
|
|
|
| class _ErrorWithBody(Exception): |
| def __init__(self, message: str, body: dict): |
| super().__init__(message) |
| self.body = body |
|
|
|
|
| class _ErrorWithResponse(Exception): |
| def __init__(self, message: str, response_text: str): |
| super().__init__(message) |
| self.response = SimpleNamespace(text=response_text) |
|
|
|
|
| def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial: |
| provider_config = { |
| "id": "test-openai", |
| "type": "openai_chat_completion", |
| "model": "gpt-4o-mini", |
| "key": ["test-key"], |
| } |
| if overrides: |
| provider_config.update(overrides) |
| return ProviderOpenAIOfficial( |
| provider_config=provider_config, |
| provider_settings={}, |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_content_moderated_removes_images(): |
| provider = _make_provider( |
| {"image_moderation_error_patterns": ["file:content-moderated"]} |
| ) |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "hello"}, |
| { |
| "type": "image_url", |
| "image_url": {"url": "data:image/jpeg;base64,abcd"}, |
| }, |
| ], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
|
|
| success, *_rest = await provider._handle_api_error( |
| Exception("Content is moderated [WKE=file:content-moderated]"), |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=0, |
| max_retries=10, |
| ) |
|
|
| assert success is False |
| updated_context = payloads["messages"] |
| assert isinstance(updated_context, list) |
| assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}] |
| finally: |
| await provider.terminate() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_model_not_vlm_removes_images_and_retries_text_only(): |
| provider = _make_provider() |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "hello"}, |
| { |
| "type": "image_url", |
| "image_url": {"url": "data:image/jpeg;base64,abcd"}, |
| }, |
| ], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
|
|
| success, *_rest = await provider._handle_api_error( |
| Exception("The model is not a VLM and cannot process images"), |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=0, |
| max_retries=10, |
| ) |
|
|
| assert success is False |
| updated_context = payloads["messages"] |
| assert isinstance(updated_context, list) |
| assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}] |
| finally: |
| await provider.terminate() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_model_not_vlm_after_fallback_raises(): |
| provider = _make_provider() |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "hello"}, |
| { |
| "type": "image_url", |
| "image_url": {"url": "data:image/jpeg;base64,abcd"}, |
| }, |
| ], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
|
|
| with pytest.raises(Exception, match="not a VLM"): |
| await provider._handle_api_error( |
| Exception("The model is not a VLM and cannot process images"), |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=1, |
| max_retries=10, |
| image_fallback_used=True, |
| ) |
| finally: |
| await provider.terminate() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_content_moderated_with_unserializable_body(): |
| provider = _make_provider({"image_moderation_error_patterns": ["blocked"]}) |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "hello"}, |
| { |
| "type": "image_url", |
| "image_url": {"url": "data:image/jpeg;base64,abcd"}, |
| }, |
| ], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
| err = _ErrorWithBody( |
| "upstream error", |
| {"error": {"message": "blocked"}, "raw": object()}, |
| ) |
|
|
| success, *_rest = await provider._handle_api_error( |
| err, |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=0, |
| max_retries=10, |
| ) |
| assert success is False |
| assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}] |
| finally: |
| await provider.terminate() |
|
|
|
|
| def test_extract_error_text_candidates_truncates_long_response_text(): |
| long_text = "x" * 20000 |
| err = _ErrorWithResponse("upstream error", long_text) |
| candidates = ProviderOpenAIOfficial._extract_error_text_candidates(err) |
| assert candidates |
| assert max(len(candidate) for candidate in candidates) <= ( |
| ProviderOpenAIOfficial._ERROR_TEXT_CANDIDATE_MAX_CHARS |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_content_moderated_without_images_raises(): |
| provider = _make_provider( |
| {"image_moderation_error_patterns": ["file:content-moderated"]} |
| ) |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [{"type": "text", "text": "hello"}], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
| err = Exception("Content is moderated [WKE=file:content-moderated]") |
|
|
| with pytest.raises(Exception, match="content-moderated"): |
| await provider._handle_api_error( |
| err, |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=0, |
| max_retries=10, |
| ) |
| finally: |
| await provider.terminate() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_content_moderated_detects_structured_body(): |
| provider = _make_provider( |
| {"image_moderation_error_patterns": ["content_moderated"]} |
| ) |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "hello"}, |
| { |
| "type": "image_url", |
| "image_url": {"url": "data:image/jpeg;base64,abcd"}, |
| }, |
| ], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
| err = _ErrorWithBody( |
| "upstream error", |
| {"error": {"code": "content_moderated", "message": "blocked"}}, |
| ) |
|
|
| success, *_rest = await provider._handle_api_error( |
| err, |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=0, |
| max_retries=10, |
| ) |
| assert success is False |
| assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}] |
| finally: |
| await provider.terminate() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_content_moderated_supports_custom_patterns(): |
| provider = _make_provider( |
| {"image_moderation_error_patterns": ["blocked_by_policy_code_123"]} |
| ) |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "hello"}, |
| { |
| "type": "image_url", |
| "image_url": {"url": "data:image/jpeg;base64,abcd"}, |
| }, |
| ], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
| err = Exception("upstream: blocked_by_policy_code_123") |
|
|
| success, *_rest = await provider._handle_api_error( |
| err, |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=0, |
| max_retries=10, |
| ) |
| assert success is False |
| assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}] |
| finally: |
| await provider.terminate() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_content_moderated_without_patterns_raises(): |
| provider = _make_provider() |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "hello"}, |
| { |
| "type": "image_url", |
| "image_url": {"url": "data:image/jpeg;base64,abcd"}, |
| }, |
| ], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
| err = Exception("Content is moderated [WKE=file:content-moderated]") |
|
|
| with pytest.raises(Exception, match="content-moderated"): |
| await provider._handle_api_error( |
| err, |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=0, |
| max_retries=10, |
| ) |
| finally: |
| await provider.terminate() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_handle_api_error_unknown_image_error_raises(): |
| provider = _make_provider() |
| try: |
| payloads = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "hello"}, |
| { |
| "type": "image_url", |
| "image_url": {"url": "data:image/jpeg;base64,abcd"}, |
| }, |
| ], |
| } |
| ] |
| } |
| context_query = payloads["messages"] |
|
|
| with pytest.raises(Exception, match="unknown provider image upload error"): |
| await provider._handle_api_error( |
| Exception("some unknown provider image upload error"), |
| payloads=payloads, |
| context_query=context_query, |
| func_tool=None, |
| chosen_key="test-key", |
| available_api_keys=["test-key"], |
| retry_cnt=0, |
| max_retries=10, |
| ) |
| finally: |
| await provider.terminate() |
|
|