| import unittest |
| from abc import ABC |
| from unittest import mock |
|
|
| from src.agent import BaseAgent |
| from src.exceptions.exceptions import InputErrorException |
| from src.llm import BaseLLM |
| from src.prompt import PromptTemplate |
| from src.schemas import AgentType, AgentOutput |
| from src.tools import BaseTool |
|
|
|
|
| class SampleBaseAgent(BaseAgent): |
| def run(self, *args, **kwargs) -> AgentOutput: |
| pass |
|
|
|
|
| class SampleBaseTool(BaseTool, ABC): |
| def run(self, req): |
| pass |
|
|
| async def async_run(self, req): |
| pass |
|
|
|
|
| class TestBaseAgent(unittest.TestCase): |
|
|
| def setUp(self): |
| self.mock_llm = mock.create_autospec(BaseLLM) |
| self.mock_prompt_template = mock.create_autospec(PromptTemplate) |
| self.agent = SampleBaseAgent(name='TestAgent', type=AgentType.react, version='1.0', |
| description='Test Description', prompt_template=self.mock_prompt_template) |
| self.tool = SampleBaseTool("test_tool", "test_tool") |
| self.agent.add_plugin('test_tool', self.tool) |
|
|
| def test_properties(self): |
| self.assertEqual(self.agent.name, 'TestAgent') |
| self.assertEqual(self.agent.type, AgentType.react) |
| self.assertEqual(self.agent.version, '1.0') |
| self.assertEqual(self.agent.description, 'Test Description') |
| self.assertEqual(self.agent.prompt_template, self.mock_prompt_template) |
|
|
| |
| def test_llm_setter_happy_path(self): |
| self.agent.llm = self.mock_llm |
| self.assertEqual(self.agent.llm, self.mock_llm) |
|
|
| def test_llm_setter_input_error(self): |
| with self.assertRaises(InputErrorException): |
| self.agent.llm = 'invalid_llm' |
|
|
| |
| def test_add_plugin_happy_path(self): |
| self.agent.add_plugin('test_tool', 'test_tool_instance') |
| self.assertIn('test_tool', self.agent.plugins_map) |
|
|
| def test_add_plugin_input_error(self): |
| with self.assertRaises(InputErrorException): |
| self.agent.add_plugin('', None) |
|
|
| def test_get_prompt_template_dict(self): |
| |
| with mock.patch.object(BaseAgent, '_parse_prompt_template', return_value=self.mock_prompt_template): |
| result = self.agent._get_prompt_template({'test_key': 'dict'}) |
| self.assertEqual(result, {'test_key': self.mock_prompt_template}) |
|
|
| def test_get_prompt_template_instance(self): |
| |
| prompt_instance = PromptTemplate(input_variables=["foo"], template="Say {foo}") |
| result = self.agent._get_prompt_template(prompt_instance) |
| self.assertEqual(result, prompt_instance) |
|
|
| def test_clear(self): |
| |
| self.agent.clear() |
|
|
| def test_get_plugin_tool_function(self): |
| function_map = self.agent.get_plugin_tool_function() |
| self.assertIn('test_tool', function_map) |
| self.assertEqual(function_map['test_tool'], self.tool.run) |
|
|
| def test_get_plugin_tool_async_function(self): |
| function_map = self.agent.get_plugin_tool_async_function() |
| self.assertIn('test_tool', function_map) |
| self.assertEqual(function_map['test_tool'], self.tool.async_run) |
|
|
|
|
| if __name__ == '__main__': |
| unittest.main() |
|
|