Spaces:
Running
Running
| """ | |
| tests/test_phase4_reflection.py | |
| ──────────────────────────────── | |
| Unit tests for Phase 4: tools, failure categoriser, trajectory logger, | |
| and the reflection agent loop (mocked LLM, no real API calls). | |
| Run with: pytest tests/test_phase4_reflection.py -v | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import textwrap | |
| from pathlib import Path | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| # ── AgentTools ──────────────────────────────────────────────────────────────── | |
| class TestAgentTools: | |
| def test_read_file_success(self, tmp_path): | |
| from agent.tools import AgentTools | |
| (tmp_path / "foo.py").write_text("x = 1\ny = 2\n") | |
| tools = AgentTools(tmp_path) | |
| result = tools.read_file("foo.py") | |
| assert result.success | |
| assert "x = 1" in result.output | |
| def test_read_file_not_found(self, tmp_path): | |
| from agent.tools import AgentTools | |
| tools = AgentTools(tmp_path) | |
| result = tools.read_file("nonexistent.py") | |
| assert not result.success | |
| assert "not found" in result.error.lower() | |
| def test_read_file_path_traversal_rejected(self, tmp_path): | |
| from agent.tools import AgentTools | |
| tools = AgentTools(tmp_path) | |
| result = tools.read_file("../../etc/passwd") | |
| assert not result.success | |
| assert "traversal" in result.error.lower() | |
| def test_read_file_truncation(self, tmp_path): | |
| from agent.tools import AgentTools | |
| content = "\n".join(f"line {i}" for i in range(300)) | |
| (tmp_path / "big.py").write_text(content) | |
| tools = AgentTools(tmp_path) | |
| result = tools.read_file("big.py", max_lines=10) | |
| assert result.success | |
| assert "truncated" in result.output | |
| def test_write_patch_success(self, tmp_path): | |
| from agent.tools import AgentTools | |
| tools = AgentTools(tmp_path) | |
| diff = "--- a/foo.py\n+++ b/foo.py\n@@ -1 +1 @@\n-old\n+new\n" | |
| result = tools.write_patch(diff) | |
| assert result.success | |
| assert (tmp_path / "_agent_patch.diff").exists() | |
| def test_write_patch_empty_rejected(self, tmp_path): | |
| from agent.tools import AgentTools | |
| tools = AgentTools(tmp_path) | |
| result = tools.write_patch("") | |
| assert not result.success | |
| assert "Empty" in result.error | |
| def test_write_patch_invalid_format_rejected(self, tmp_path): | |
| from agent.tools import AgentTools | |
| tools = AgentTools(tmp_path) | |
| result = tools.write_patch("just some text without diff header") | |
| assert not result.success | |
| def test_list_files(self, tmp_path): | |
| from agent.tools import AgentTools | |
| (tmp_path / "a.py").write_text("x=1") | |
| (tmp_path / "b.py").write_text("y=2") | |
| (tmp_path / "__pycache__").mkdir() | |
| tools = AgentTools(tmp_path) | |
| result = tools.list_files("**/*.py") | |
| assert result.success | |
| assert "a.py" in result.output | |
| assert "b.py" in result.output | |
| assert "__pycache__" not in result.output | |
| def test_tool_result_to_prompt_str(self): | |
| from agent.tools import ToolResult | |
| tr = ToolResult("read_file", True, "x = 1\n") | |
| prompt = tr.to_prompt_str() | |
| assert "read_file" in prompt | |
| assert "SUCCESS" in prompt | |
| assert "x = 1" in prompt | |
| def test_tool_result_error_in_prompt(self): | |
| from agent.tools import ToolResult | |
| tr = ToolResult("run_tests", False, "", "Timeout after 60s") | |
| prompt = tr.to_prompt_str() | |
| assert "ERROR" in prompt | |
| assert "Timeout" in prompt | |
| # ── Failure Categoriser ─────────────────────────────────────────────────────── | |
| class TestFailureCategoriser: | |
| def _categorise(self, stdout, apply_ok=True, ftp=None, ptp=None, attempt=1, prev=None): | |
| from agent.failure_categoriser import categorise_failure | |
| return categorise_failure( | |
| test_stdout=stdout, | |
| patch_apply_success=apply_ok, | |
| fail_to_pass_results=ftp or {}, | |
| pass_to_pass_results=ptp or {}, | |
| attempt_num=attempt, | |
| previous_categories=prev, | |
| ) | |
| def test_success(self): | |
| cat = self._categorise( | |
| "1 passed", apply_ok=True, | |
| ftp={"t::test_x": True}, | |
| ptp={"t::test_y": True}, | |
| ) | |
| assert cat == "success" | |
| def test_patch_apply_failure_is_syntax_error(self): | |
| cat = self._categorise("", apply_ok=False) | |
| assert cat == "syntax_error" | |
| def test_syntax_error_in_output(self): | |
| cat = self._categorise("SyntaxError: invalid syntax (foo.py, line 5)") | |
| assert cat == "syntax_error" | |
| def test_import_error(self): | |
| cat = self._categorise("ModuleNotFoundError: No module named 'nonexistent'") | |
| assert cat == "import_error" | |
| def test_hallucinated_api_attribute_error(self): | |
| cat = self._categorise("AttributeError: 'QuerySet' object has no attribute 'bulk_filer'") | |
| assert cat == "hallucinated_api" | |
| def test_hallucinated_api_name_error(self): | |
| cat = self._categorise("NameError: name 'nonexistent_func' is not defined") | |
| assert cat == "hallucinated_api" | |
| def test_type_error(self): | |
| cat = self._categorise("TypeError: unsupported operand type(s) for +") | |
| assert cat == "type_error" | |
| def test_assertion_error(self): | |
| cat = self._categorise("AssertionError: expected True but got False") | |
| assert cat == "assertion_error" | |
| def test_incomplete_patch(self): | |
| cat = self._categorise( | |
| "2 failed", apply_ok=True, | |
| ftp={"t::a": True, "t::b": False}, # one passed, one failed | |
| ptp={}, | |
| ) | |
| assert cat == "incomplete_patch" | |
| def test_unknown_fallback(self): | |
| cat = self._categorise("some unexpected output with no pattern") | |
| assert cat == "unknown" | |
| def test_extract_first_error_context(self): | |
| from agent.failure_categoriser import extract_first_error_context | |
| output = textwrap.dedent(""" | |
| tests/test_foo.py::test_bar FAILED | |
| AssertionError: expected 1, got 2 | |
| tests/test_foo.py::test_baz PASSED | |
| """) | |
| context = extract_first_error_context(output) | |
| assert "FAILED" in context or "AssertionError" in context | |
| # ── Trajectory Logger ───────────────────────────────────────────────────────── | |
| class TestTrajectoryLogger: | |
| def test_log_and_load(self, tmp_path): | |
| from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry | |
| logger = TrajectoryLogger(tmp_path / "traj.jsonl") | |
| entry = TrajectoryEntry( | |
| instance_id="test__repo-1", | |
| repo="test/repo", | |
| attempt=1, | |
| patch="--- a/foo.py\n+++ b/foo.py\n", | |
| test_stdout="1 failed", | |
| fail_to_pass_results={"t::test_x": False}, | |
| pass_to_pass_results={}, | |
| resolved=False, | |
| failure_category="assertion_error", | |
| elapsed_seconds=5.2, | |
| ) | |
| logger.log(entry) | |
| loaded = logger.load_all() | |
| assert len(loaded) == 1 | |
| assert loaded[0].instance_id == "test__repo-1" | |
| assert loaded[0].failure_category == "assertion_error" | |
| def test_multiple_entries(self, tmp_path): | |
| from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry | |
| logger = TrajectoryLogger(tmp_path / "traj.jsonl") | |
| for i in range(5): | |
| entry = TrajectoryEntry( | |
| instance_id=f"inst-{i}", | |
| repo="test/repo", | |
| attempt=1, | |
| patch="", | |
| test_stdout="", | |
| fail_to_pass_results={}, | |
| pass_to_pass_results={}, | |
| resolved=(i % 2 == 0), | |
| failure_category="success" if i % 2 == 0 else "wrong_file_edit", | |
| elapsed_seconds=1.0, | |
| ) | |
| logger.log(entry) | |
| assert logger.total_logged == 5 | |
| loaded = logger.load_all() | |
| assert len(loaded) == 5 | |
| def test_stats(self, tmp_path): | |
| from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry | |
| logger = TrajectoryLogger(tmp_path / "traj.jsonl") | |
| for i in range(4): | |
| entry = TrajectoryEntry( | |
| instance_id=f"inst-{i}", | |
| repo="r", | |
| attempt=1, | |
| patch="", | |
| test_stdout="", | |
| fail_to_pass_results={}, | |
| pass_to_pass_results={}, | |
| resolved=(i < 2), | |
| failure_category="success" if i < 2 else "assertion_error", | |
| elapsed_seconds=1.0, | |
| ) | |
| logger.log(entry) | |
| stats = logger.stats() | |
| assert stats["total"] == 4 | |
| assert stats["resolved"] == 2 | |
| assert abs(stats["resolved_rate"] - 0.5) < 1e-6 | |
| def test_export_for_finetuning(self, tmp_path): | |
| from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry | |
| logger = TrajectoryLogger(tmp_path / "traj.jsonl") | |
| entry = TrajectoryEntry( | |
| instance_id="inst-1", | |
| repo="r", | |
| attempt=1, | |
| patch="--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n-bug\n+fix\n", | |
| test_stdout="", | |
| fail_to_pass_results={}, | |
| pass_to_pass_results={}, | |
| resolved=True, | |
| failure_category="success", | |
| elapsed_seconds=1.0, | |
| problem_statement="Fix the null pointer bug", | |
| ) | |
| logger.log(entry) | |
| out_path = tmp_path / "ft_data.jsonl" | |
| count = logger.export_for_finetuning(out_path) | |
| assert count == 1 | |
| line = json.loads(out_path.read_text().strip()) | |
| assert "system" in line | |
| assert "user" in line | |
| assert "assistant" in line | |
| def test_filter_by_category(self, tmp_path): | |
| from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry | |
| logger = TrajectoryLogger(tmp_path / "traj.jsonl") | |
| for cat in ["success", "assertion_error", "syntax_error", "unknown"]: | |
| entry = TrajectoryEntry( | |
| instance_id=cat, | |
| repo="r", | |
| attempt=1, | |
| patch="--- a/f.py\n+++ b/f.py\n", | |
| test_stdout="", | |
| fail_to_pass_results={}, | |
| pass_to_pass_results={}, | |
| resolved=(cat == "success"), | |
| failure_category=cat, | |
| elapsed_seconds=1.0, | |
| problem_statement="test issue", | |
| ) | |
| logger.log(entry) | |
| out = tmp_path / "filtered.jsonl" | |
| count = logger.export_for_finetuning( | |
| out, filter_categories=["assertion_error", "syntax_error"] | |
| ) | |
| assert count == 2 | |
| def test_instruction_pair_format(self, tmp_path): | |
| from agent.trajectory_logger import TrajectoryEntry | |
| entry = TrajectoryEntry( | |
| instance_id="test-1", | |
| repo="r", | |
| attempt=2, | |
| patch="--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n-x\n+y\n", | |
| test_stdout="AssertionError: expected 1, got 2", | |
| fail_to_pass_results={"t::test_x": False}, | |
| pass_to_pass_results={}, | |
| resolved=False, | |
| failure_category="assertion_error", | |
| elapsed_seconds=3.0, | |
| problem_statement="Fix the assertion in the filter method", | |
| localised_files=["models/query.py"], | |
| ) | |
| pair = entry.to_instruction_pair() | |
| assert "Fix the assertion" in pair["user"] | |
| assert "assertion_error" in pair["user"] | |
| assert pair["assistant"] == entry.patch | |
| assert pair["metadata"]["attempt"] == 2 | |
| # ── Reflection Agent (mocked LLM) ───────────────────────────────────────────── | |
| class TestReflectionAgent: | |
| """Tests for the agent loop — LLM calls are mocked.""" | |
| def _make_agent(self, tmp_path, trajectory_logger=None): | |
| from agent.reflection_agent import ReflectionAgent | |
| agent = ReflectionAgent( | |
| model="gpt-4o", | |
| max_attempts=3, | |
| sandbox=None, | |
| localisation_pipeline=None, | |
| trajectory_logger=trajectory_logger, | |
| ) | |
| return agent | |
| def _mock_llm_patch(self, monkeypatch, patch_text: str, tokens: int = 100): | |
| """Mock _call_llm to return a fixed patch without API calls.""" | |
| import agent.reflection_agent as ra | |
| monkeypatch.setattr( | |
| ra, "_call_llm", | |
| lambda *args, **kwargs: (patch_text, {"total_tokens": tokens, | |
| "prompt_tokens": 80, | |
| "completion_tokens": 20}) | |
| ) | |
| def test_agent_state_initialisation(self, tmp_path): | |
| from agent.reflection_agent import AgentState | |
| state = AgentState( | |
| instance_id="test-1", | |
| repo="test/repo", | |
| problem_statement="Fix bug", | |
| base_commit="abc123", | |
| fail_to_pass=["tests::test_x"], | |
| pass_to_pass=[], | |
| workspace_dir=tmp_path, | |
| ) | |
| assert state.current_attempt == 0 | |
| assert state.resolved is False | |
| assert state.total_tokens == 0 | |
| def test_should_retry_when_not_resolved(self): | |
| from agent.reflection_agent import AgentState, should_retry | |
| from pathlib import Path | |
| state = AgentState( | |
| instance_id="t", repo="r", problem_statement="p", | |
| base_commit="a", fail_to_pass=[], pass_to_pass=[], | |
| workspace_dir=Path("/tmp"), resolved=False, current_attempt=1 | |
| ) | |
| assert should_retry(state, max_attempts=3) == "retry" | |
| def test_should_done_when_resolved(self): | |
| from agent.reflection_agent import AgentState, should_retry | |
| from pathlib import Path | |
| state = AgentState( | |
| instance_id="t", repo="r", problem_statement="p", | |
| base_commit="a", fail_to_pass=[], pass_to_pass=[], | |
| workspace_dir=Path("/tmp"), resolved=True, current_attempt=1 | |
| ) | |
| assert should_retry(state, max_attempts=3) == "done" | |
| def test_should_done_when_max_attempts_reached(self): | |
| from agent.reflection_agent import AgentState, should_retry | |
| from pathlib import Path | |
| state = AgentState( | |
| instance_id="t", repo="r", problem_statement="p", | |
| base_commit="a", fail_to_pass=[], pass_to_pass=[], | |
| workspace_dir=Path("/tmp"), resolved=False, current_attempt=3 | |
| ) | |
| assert should_retry(state, max_attempts=3) == "done" | |
| def test_node_generate_patch_increments_attempt(self, tmp_path, monkeypatch): | |
| from agent.reflection_agent import AgentState, node_generate_patch | |
| self._mock_llm_patch(monkeypatch, "--- a/foo.py\n+++ b/foo.py\n@@ -1 +1 @@\n-x\n+y\n") | |
| state = AgentState( | |
| instance_id="t", repo="r", problem_statement="fix the bug please", | |
| base_commit="abc", fail_to_pass=[], pass_to_pass=[], | |
| workspace_dir=tmp_path, | |
| ) | |
| state = node_generate_patch(state) | |
| assert state.current_attempt == 1 | |
| assert "--- a/foo.py" in state.last_patch | |
| def test_node_generate_patch_uses_reflection_on_retry(self, tmp_path, monkeypatch): | |
| from agent.reflection_agent import AgentState, node_generate_patch | |
| prompts_seen = [] | |
| def mock_call_llm(user_prompt, *args, **kwargs): | |
| prompts_seen.append(user_prompt) | |
| return ("--- a/f.py\n+++ b/f.py\n", {"total_tokens": 50, "prompt_tokens": 40, "completion_tokens": 10}) | |
| import agent.reflection_agent as ra | |
| monkeypatch.setattr(ra, "_call_llm", mock_call_llm) | |
| state = AgentState( | |
| instance_id="t", repo="r", | |
| problem_statement="fix the long detailed issue description here", | |
| base_commit="abc", fail_to_pass=[], pass_to_pass=[], | |
| workspace_dir=tmp_path, | |
| current_attempt=1, # simulate already one attempt | |
| last_test_stdout="AssertionError: expected 1", | |
| last_failure_category="assertion_error", | |
| last_patch="--- a/wrong.py\n+++ b/wrong.py\n", | |
| attempts=[{"attempt_num": 1}], | |
| ) | |
| state = node_generate_patch(state) | |
| # Should use reflection prompt (contains "Previous Attempt") | |
| assert "Previous Attempt" in prompts_seen[-1] | |
| def test_agent_logs_trajectories(self, tmp_path, monkeypatch): | |
| from agent.reflection_agent import AgentState, node_generate_patch | |
| from agent.trajectory_logger import TrajectoryLogger | |
| traj_path = tmp_path / "traj.jsonl" | |
| traj_logger = TrajectoryLogger(traj_path) | |
| # Mock node_apply_and_test to mark as resolved immediately | |
| import agent.reflection_agent as ra | |
| def mock_apply(state, sandbox=None): | |
| state.resolved = True | |
| state.last_test_stdout = "1 passed" | |
| state.last_failure_category = "success" | |
| state.attempts.append({ | |
| "attempt_num": state.current_attempt, | |
| "patch": state.last_patch, | |
| "test_stdout": "1 passed", | |
| "fail_to_pass_results": {}, | |
| "pass_to_pass_results": {}, | |
| "resolved": True, | |
| "failure_category": "success", | |
| }) | |
| return state | |
| monkeypatch.setattr(ra, "node_apply_and_test", mock_apply) | |
| monkeypatch.setattr(ra, "_call_llm", | |
| lambda *a, **kw: ("--- a/f.py\n+++ b/f.py\n", {"total_tokens": 10, "prompt_tokens": 8, "completion_tokens": 2})) | |
| agent = self._make_agent(tmp_path, trajectory_logger=traj_logger) | |
| state = agent.run( | |
| instance_id="test-1", | |
| repo="test/repo", | |
| problem_statement="fix the bug", | |
| base_commit="abc", | |
| fail_to_pass=[], | |
| pass_to_pass=[], | |
| workspace_dir=tmp_path, | |
| ) | |
| assert state.resolved | |
| assert traj_logger.total_logged >= 1 | |
| def test_strip_code_fences(self): | |
| from agent.reflection_agent import _strip_code_fences | |
| raw = "```diff\n--- a/f.py\n+++ b/f.py\n```" | |
| cleaned = _strip_code_fences(raw) | |
| assert "```" not in cleaned | |
| assert "--- a/f.py" in cleaned | |
| def test_build_file_context(self): | |
| from agent.reflection_agent import _build_file_context | |
| contents = { | |
| "a.py": "def foo(): pass", | |
| "b.py": "class Bar: pass", | |
| } | |
| ctx = _build_file_context(contents) | |
| assert "a.py" in ctx | |
| assert "b.py" in ctx | |
| assert "def foo" in ctx | |