# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Tests for LLMClient abstraction, OpenAIClient, AnthropicClient, and helpers.""" import json from unittest.mock import AsyncMock, MagicMock, patch import pytest from openenv.core.llm_client import ( _clean_mcp_schema, _mcp_tools_to_anthropic, _mcp_tools_to_openai, _openai_msgs_to_anthropic, AnthropicClient, create_llm_client, LLMClient, LLMResponse, OpenAIClient, ToolCall, ) class TestLLMClientABC: """Test the abstract base class.""" def test_cannot_instantiate_directly(self): """LLMClient is abstract and cannot be instantiated.""" with pytest.raises(TypeError): LLMClient("http://localhost", 8000) def test_concrete_subclass(self): """A concrete subclass can be instantiated.""" class StubClient(LLMClient): async def complete(self, prompt: str, **kwargs) -> str: return "stub" client = StubClient("http://localhost", 8000) assert client.endpoint == "http://localhost" assert client.port == 8000 def test_base_url_property(self): """base_url combines endpoint and port.""" class StubClient(LLMClient): async def complete(self, prompt: str, **kwargs) -> str: return "stub" client = StubClient("http://localhost", 8000) assert client.base_url == "http://localhost:8000" def test_base_url_custom_endpoint(self): """base_url works with custom endpoints.""" class StubClient(LLMClient): async def complete(self, prompt: str, **kwargs) -> str: return "stub" client = StubClient("https://api.example.com", 443) assert client.base_url == "https://api.example.com:443" @pytest.mark.asyncio async def test_complete_with_tools_not_implemented(self): """Default complete_with_tools raises NotImplementedError.""" class StubClient(LLMClient): async def complete(self, prompt: str, **kwargs) -> str: return "stub" client = StubClient("http://localhost", 8000) with pytest.raises(NotImplementedError, match="StubClient"): await client.complete_with_tools([], []) class TestOpenAIClientConstruction: """Test OpenAIClient initialization.""" @patch("openenv.core.llm_client.AsyncOpenAI") def test_basic_construction(self, mock_openai_cls): """OpenAIClient stores params and creates AsyncOpenAI.""" client = OpenAIClient("http://localhost", 8000, model="gpt-4") assert client.endpoint == "http://localhost" assert client.port == 8000 assert client.model == "gpt-4" assert client.temperature == 0.0 assert client.max_tokens == 256 assert client.system_prompt is None mock_openai_cls.assert_called_once_with( base_url="http://localhost:8000/v1", api_key="not-needed", ) @patch("openenv.core.llm_client.AsyncOpenAI") def test_custom_api_key(self, mock_openai_cls): """API key is passed through to AsyncOpenAI.""" OpenAIClient("http://localhost", 8000, model="gpt-4", api_key="sk-test-123") mock_openai_cls.assert_called_once_with( base_url="http://localhost:8000/v1", api_key="sk-test-123", ) @patch("openenv.core.llm_client.AsyncOpenAI") def test_default_api_key_when_none(self, mock_openai_cls): """api_key=None defaults to 'not-needed'.""" OpenAIClient("http://localhost", 8000, model="gpt-4", api_key=None) mock_openai_cls.assert_called_once_with( base_url="http://localhost:8000/v1", api_key="not-needed", ) @patch("openenv.core.llm_client.AsyncOpenAI") def test_system_prompt_stored(self, mock_openai_cls): """System prompt is stored for use in complete().""" client = OpenAIClient( "http://localhost", 8000, model="gpt-4", system_prompt="You are a judge.", ) assert client.system_prompt == "You are a judge." @patch("openenv.core.llm_client.AsyncOpenAI") def test_custom_temperature_and_max_tokens(self, mock_openai_cls): """Custom temperature and max_tokens are stored.""" client = OpenAIClient( "http://localhost", 8000, model="gpt-4", temperature=0.7, max_tokens=512, ) assert client.temperature == 0.7 assert client.max_tokens == 512 class TestOpenAIClientComplete: """Test the complete() method.""" @pytest.mark.asyncio async def test_complete_without_system_prompt(self): """complete() sends user message only when no system prompt.""" mock_openai = MagicMock() mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "42" mock_openai.chat.completions.create = AsyncMock(return_value=mock_response) with patch("openenv.core.llm_client.AsyncOpenAI", return_value=mock_openai): client = OpenAIClient("http://localhost", 8000, model="gpt-4") result = await client.complete("What is 2+2?") assert result == "42" mock_openai.chat.completions.create.assert_called_once_with( model="gpt-4", messages=[{"role": "user", "content": "What is 2+2?"}], temperature=0.0, max_tokens=256, ) @pytest.mark.asyncio async def test_complete_with_system_prompt(self): """complete() includes system message when system_prompt is set.""" mock_openai = MagicMock() mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "0.8" mock_openai.chat.completions.create = AsyncMock(return_value=mock_response) with patch("openenv.core.llm_client.AsyncOpenAI", return_value=mock_openai): client = OpenAIClient( "http://localhost", 8000, model="gpt-4", system_prompt="You are a judge.", ) result = await client.complete("Rate this code.") assert result == "0.8" mock_openai.chat.completions.create.assert_called_once_with( model="gpt-4", messages=[ {"role": "system", "content": "You are a judge."}, {"role": "user", "content": "Rate this code."}, ], temperature=0.0, max_tokens=256, ) @pytest.mark.asyncio async def test_complete_kwargs_override(self): """Keyword arguments override default temperature and max_tokens.""" mock_openai = MagicMock() mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "ok" mock_openai.chat.completions.create = AsyncMock(return_value=mock_response) with patch("openenv.core.llm_client.AsyncOpenAI", return_value=mock_openai): client = OpenAIClient("http://localhost", 8000, model="gpt-4") await client.complete("hi", temperature=0.9, max_tokens=100) mock_openai.chat.completions.create.assert_called_once_with( model="gpt-4", messages=[{"role": "user", "content": "hi"}], temperature=0.9, max_tokens=100, ) class TestOpenAIClientCompleteWithTools: """Test complete_with_tools() on OpenAIClient.""" @pytest.mark.asyncio async def test_no_tool_calls(self): """Response without tool calls returns empty tool_calls list.""" mock_openai = MagicMock() mock_msg = MagicMock() mock_msg.content = "Hello there" mock_msg.tool_calls = None mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message = mock_msg mock_openai.chat.completions.create = AsyncMock(return_value=mock_response) with patch("openenv.core.llm_client.AsyncOpenAI", return_value=mock_openai): client = OpenAIClient("http://localhost", 8000, model="gpt-4") result = await client.complete_with_tools( [{"role": "user", "content": "hi"}], [] ) assert isinstance(result, LLMResponse) assert result.content == "Hello there" assert result.tool_calls == [] @pytest.mark.asyncio async def test_with_tool_calls(self): """Response with tool calls are parsed into ToolCall objects.""" mock_openai = MagicMock() mock_tc = MagicMock() mock_tc.id = "call_123" mock_tc.function.name = "get_weather" mock_tc.function.arguments = '{"city": "SF"}' mock_msg = MagicMock() mock_msg.content = "" mock_msg.tool_calls = [mock_tc] mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message = mock_msg mock_openai.chat.completions.create = AsyncMock(return_value=mock_response) with patch("openenv.core.llm_client.AsyncOpenAI", return_value=mock_openai): client = OpenAIClient("http://localhost", 8000, model="gpt-4") result = await client.complete_with_tools( [{"role": "user", "content": "weather?"}], [ { "name": "get_weather", "description": "Get weather", "inputSchema": { "type": "object", "properties": {"city": {"type": "string"}}, }, } ], ) assert len(result.tool_calls) == 1 assert result.tool_calls[0].id == "call_123" assert result.tool_calls[0].name == "get_weather" assert result.tool_calls[0].args == {"city": "SF"} class TestAnthropicClientConstruction: """Test AnthropicClient initialization.""" def test_missing_anthropic_package(self): """Raises ImportError with helpful message when anthropic is missing.""" with patch.dict("sys.modules", {"anthropic": None}): with pytest.raises(ImportError, match="anthropic"): AnthropicClient( "https://api.anthropic.com", 443, model="claude-sonnet-4-20250514" ) @patch("openenv.core.llm_client.AnthropicClient.__init__", return_value=None) def test_is_llm_client_subclass(self, mock_init): """AnthropicClient is a proper LLMClient subclass.""" assert issubclass(AnthropicClient, LLMClient) class TestAnthropicClientComplete: """Test the complete() method on AnthropicClient.""" @pytest.mark.asyncio async def test_complete_basic(self): """complete() calls the Anthropic messages API and returns text.""" mock_anthropic = MagicMock() mock_text_block = MagicMock() mock_text_block.type = "text" mock_text_block.text = "4" mock_response = MagicMock() mock_response.content = [mock_text_block] mock_anthropic.messages.create = AsyncMock(return_value=mock_response) with patch( "openenv.core.llm_client.AnthropicClient.__init__", return_value=None ): client = AnthropicClient.__new__(AnthropicClient) client.endpoint = "https://api.anthropic.com" client.port = 443 client.model = "claude-sonnet-4-20250514" client.system_prompt = None client.temperature = 0.0 client.max_tokens = 256 client._client = mock_anthropic result = await client.complete("What is 2+2?") assert result == "4" mock_anthropic.messages.create.assert_called_once() class TestAnthropicClientCompleteWithTools: """Test complete_with_tools() on AnthropicClient.""" @pytest.mark.asyncio async def test_with_tool_use_response(self): """Tool use blocks are parsed into ToolCall objects.""" mock_anthropic = MagicMock() mock_text_block = MagicMock() mock_text_block.type = "text" mock_text_block.text = "Let me check" mock_tool_block = MagicMock() mock_tool_block.type = "tool_use" mock_tool_block.id = "toolu_abc" mock_tool_block.name = "get_weather" mock_tool_block.input = {"city": "SF"} mock_response = MagicMock() mock_response.content = [mock_text_block, mock_tool_block] mock_anthropic.messages.create = AsyncMock(return_value=mock_response) client = AnthropicClient.__new__(AnthropicClient) client.endpoint = "https://api.anthropic.com" client.port = 443 client.model = "claude-sonnet-4-20250514" client.system_prompt = None client.temperature = 0.0 client.max_tokens = 256 client._client = mock_anthropic result = await client.complete_with_tools( [{"role": "user", "content": "weather?"}], [ { "name": "get_weather", "description": "Get weather", "inputSchema": { "type": "object", "properties": {"city": {"type": "string"}}, }, } ], ) assert result.content == "Let me check" assert len(result.tool_calls) == 1 assert result.tool_calls[0].id == "toolu_abc" assert result.tool_calls[0].name == "get_weather" assert result.tool_calls[0].args == {"city": "SF"} # --------------------------------------------------------------------------- # LLMResponse / ToolCall # --------------------------------------------------------------------------- class TestLLMResponse: """Test LLMResponse dataclass.""" def test_to_message_dict_no_tools(self): """to_message_dict without tool calls is a plain assistant message.""" resp = LLMResponse(content="hello") msg = resp.to_message_dict() assert msg == {"role": "assistant", "content": "hello"} assert "tool_calls" not in msg def test_to_message_dict_with_tools(self): """to_message_dict includes tool_calls in OpenAI format.""" resp = LLMResponse( content="", tool_calls=[ToolCall(id="c1", name="foo", args={"x": 1})], ) msg = resp.to_message_dict() assert msg["role"] == "assistant" assert len(msg["tool_calls"]) == 1 tc = msg["tool_calls"][0] assert tc["id"] == "c1" assert tc["type"] == "function" assert tc["function"]["name"] == "foo" assert json.loads(tc["function"]["arguments"]) == {"x": 1} # --------------------------------------------------------------------------- # Factory # --------------------------------------------------------------------------- class TestCreateLLMClient: """Test the create_llm_client factory.""" @patch("openenv.core.llm_client.AsyncOpenAI") def test_openai_provider(self, mock_openai_cls): """'openai' creates an OpenAIClient.""" client = create_llm_client("openai", "gpt-4", "sk-key") assert isinstance(client, OpenAIClient) assert client.model == "gpt-4" def test_anthropic_provider(self): """'anthropic' creates an AnthropicClient.""" mock_async_anthropic = MagicMock() mock_module = MagicMock() mock_module.AsyncAnthropic = MagicMock(return_value=mock_async_anthropic) with patch.dict("sys.modules", {"anthropic": mock_module}): client = create_llm_client( "anthropic", "claude-sonnet-4-20250514", "sk-ant" ) assert isinstance(client, AnthropicClient) assert client.model == "claude-sonnet-4-20250514" def test_unsupported_provider(self): """Unsupported provider raises ValueError.""" with pytest.raises(ValueError, match="google"): create_llm_client("google", "gemini-pro", "key") @patch("openenv.core.llm_client.AsyncOpenAI") def test_case_insensitive(self, mock_openai_cls): """Provider name is case-insensitive.""" client = create_llm_client("OpenAI", "gpt-4", "sk-key") assert isinstance(client, OpenAIClient) @patch("openenv.core.llm_client.AsyncOpenAI") def test_custom_params(self, mock_openai_cls): """Temperature and max_tokens are forwarded.""" client = create_llm_client( "openai", "gpt-4", "sk-key", temperature=0.5, max_tokens=1024 ) assert client.temperature == 0.5 assert client.max_tokens == 1024 @patch("openenv.core.llm_client.AsyncOpenAI") def test_system_prompt_forwarded(self, mock_openai_cls): """system_prompt is forwarded to the client.""" client = create_llm_client( "openai", "gpt-4", "sk-key", system_prompt="You are a judge." ) assert client.system_prompt == "You are a judge." # --------------------------------------------------------------------------- # MCP schema helpers # --------------------------------------------------------------------------- class TestCleanMCPSchema: """Test _clean_mcp_schema helper.""" def test_non_dict_returns_empty(self): assert _clean_mcp_schema("not a dict") == { "type": "object", "properties": {}, "required": [], } def test_passthrough_simple_object(self): schema = {"type": "object", "properties": {"x": {"type": "string"}}} result = _clean_mcp_schema(schema) assert result["properties"]["x"]["type"] == "string" def test_oneOf_selects_object(self): schema = { "oneOf": [ {"type": "string"}, {"type": "object", "properties": {"a": {"type": "int"}}}, ] } result = _clean_mcp_schema(schema) assert "a" in result["properties"] def test_allOf_merges(self): schema = { "allOf": [ {"properties": {"a": {"type": "string"}}, "required": ["a"]}, {"properties": {"b": {"type": "int"}}, "required": ["b"]}, ] } result = _clean_mcp_schema(schema) assert "a" in result["properties"] assert "b" in result["properties"] assert result["required"] == ["a", "b"] def test_anyOf_selects_object(self): schema = { "anyOf": [ {"type": "null"}, {"type": "object", "properties": {"x": {"type": "string"}}}, ] } result = _clean_mcp_schema(schema) assert "x" in result["properties"] def test_sets_default_type(self): result = _clean_mcp_schema({"properties": {"a": {"type": "string"}}}) assert result["type"] == "object" def test_does_not_mutate_input(self): """_clean_mcp_schema must not modify the caller's dict.""" original = {"type": "object"} _clean_mcp_schema(original) # Should not have added "properties" to the original dict. assert "properties" not in original class TestMCPToolsToOpenAI: """Test _mcp_tools_to_openai conversion.""" def test_basic_conversion(self): mcp_tools = [ { "name": "get_weather", "description": "Get weather", "inputSchema": { "type": "object", "properties": {"city": {"type": "string"}}, }, } ] result = _mcp_tools_to_openai(mcp_tools) assert len(result) == 1 assert result[0]["type"] == "function" assert result[0]["function"]["name"] == "get_weather" assert "city" in result[0]["function"]["parameters"]["properties"] def test_empty_list(self): assert _mcp_tools_to_openai([]) == [] class TestMCPToolsToAnthropic: """Test _mcp_tools_to_anthropic conversion.""" def test_basic_conversion(self): mcp_tools = [ { "name": "get_weather", "description": "Get weather", "inputSchema": { "type": "object", "properties": {"city": {"type": "string"}}, }, } ] result = _mcp_tools_to_anthropic(mcp_tools) assert len(result) == 1 assert result[0]["name"] == "get_weather" assert "input_schema" in result[0] def test_empty_list(self): assert _mcp_tools_to_anthropic([]) == [] # --------------------------------------------------------------------------- # Message format conversion # --------------------------------------------------------------------------- class TestOpenAIMsgsToAnthropic: """Test _openai_msgs_to_anthropic conversion.""" def test_system_extracted(self): msgs = [ {"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, ] system, result = _openai_msgs_to_anthropic(msgs) assert system == "You are helpful." assert len(result) == 1 assert result[0]["role"] == "user" def test_tool_calls_converted(self): msgs = [ {"role": "user", "content": "weather?"}, { "role": "assistant", "content": "Let me check", "tool_calls": [ { "id": "c1", "type": "function", "function": { "name": "get_weather", "arguments": '{"city": "SF"}', }, } ], }, ] system, result = _openai_msgs_to_anthropic(msgs) assert system == "" assert len(result) == 2 assistant_msg = result[1] assert assistant_msg["role"] == "assistant" assert isinstance(assistant_msg["content"], list) assert assistant_msg["content"][0]["type"] == "text" assert assistant_msg["content"][1]["type"] == "tool_use" assert assistant_msg["content"][1]["name"] == "get_weather" assert assistant_msg["content"][1]["input"] == {"city": "SF"} def test_tool_result_becomes_user_turn(self): msgs = [ {"role": "user", "content": "hi"}, { "role": "assistant", "content": "", "tool_calls": [ { "id": "c1", "type": "function", "function": { "name": "foo", "arguments": "{}", }, } ], }, {"role": "tool", "tool_call_id": "c1", "content": '{"result": 42}'}, ] _, result = _openai_msgs_to_anthropic(msgs) # tool result should be a user message with tool_result content block tool_turn = result[2] assert tool_turn["role"] == "user" assert isinstance(tool_turn["content"], list) assert tool_turn["content"][0]["type"] == "tool_result" assert tool_turn["content"][0]["tool_use_id"] == "c1" def test_multiple_system_messages_concatenated(self): msgs = [ {"role": "system", "content": "Rule 1."}, {"role": "system", "content": "Rule 2."}, {"role": "user", "content": "Hi"}, ] system, _ = _openai_msgs_to_anthropic(msgs) assert system == "Rule 1.\n\nRule 2."