astrbbbb / tests /test_openai_source.py
qa1145's picture
Upload 1245 files
8ede856 verified
from types import SimpleNamespace
import pytest
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
class _ErrorWithBody(Exception):
def __init__(self, message: str, body: dict):
super().__init__(message)
self.body = body
class _ErrorWithResponse(Exception):
def __init__(self, message: str, response_text: str):
super().__init__(message)
self.response = SimpleNamespace(text=response_text)
def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial:
provider_config = {
"id": "test-openai",
"type": "openai_chat_completion",
"model": "gpt-4o-mini",
"key": ["test-key"],
}
if overrides:
provider_config.update(overrides)
return ProviderOpenAIOfficial(
provider_config=provider_config,
provider_settings={},
)
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_removes_images():
provider = _make_provider(
{"image_moderation_error_patterns": ["file:content-moderated"]}
)
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
success, *_rest = await provider._handle_api_error(
Exception("Content is moderated [WKE=file:content-moderated]"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
updated_context = payloads["messages"]
assert isinstance(updated_context, list)
assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_model_not_vlm_removes_images_and_retries_text_only():
provider = _make_provider()
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
success, *_rest = await provider._handle_api_error(
Exception("The model is not a VLM and cannot process images"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
updated_context = payloads["messages"]
assert isinstance(updated_context, list)
assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_model_not_vlm_after_fallback_raises():
provider = _make_provider()
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
with pytest.raises(Exception, match="not a VLM"):
await provider._handle_api_error(
Exception("The model is not a VLM and cannot process images"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=1,
max_retries=10,
image_fallback_used=True,
)
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_with_unserializable_body():
provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
err = _ErrorWithBody(
"upstream error",
{"error": {"message": "blocked"}, "raw": object()},
)
success, *_rest = await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
def test_extract_error_text_candidates_truncates_long_response_text():
long_text = "x" * 20000
err = _ErrorWithResponse("upstream error", long_text)
candidates = ProviderOpenAIOfficial._extract_error_text_candidates(err)
assert candidates
assert max(len(candidate) for candidate in candidates) <= (
ProviderOpenAIOfficial._ERROR_TEXT_CANDIDATE_MAX_CHARS
)
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_without_images_raises():
provider = _make_provider(
{"image_moderation_error_patterns": ["file:content-moderated"]}
)
try:
payloads = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "hello"}],
}
]
}
context_query = payloads["messages"]
err = Exception("Content is moderated [WKE=file:content-moderated]")
with pytest.raises(Exception, match="content-moderated"):
await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_detects_structured_body():
provider = _make_provider(
{"image_moderation_error_patterns": ["content_moderated"]}
)
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
err = _ErrorWithBody(
"upstream error",
{"error": {"code": "content_moderated", "message": "blocked"}},
)
success, *_rest = await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_supports_custom_patterns():
provider = _make_provider(
{"image_moderation_error_patterns": ["blocked_by_policy_code_123"]}
)
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
err = Exception("upstream: blocked_by_policy_code_123")
success, *_rest = await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_without_patterns_raises():
provider = _make_provider()
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
err = Exception("Content is moderated [WKE=file:content-moderated]")
with pytest.raises(Exception, match="content-moderated"):
await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_unknown_image_error_raises():
provider = _make_provider()
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
with pytest.raises(Exception, match="unknown provider image upload error"):
await provider._handle_api_error(
Exception("some unknown provider image upload error"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
finally:
await provider.terminate()