File size: 5,435 Bytes
dc71cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
agent/failure_categoriser.py
──────────────────────────────
Rule-based + regex failure categoriser.

After each failed attempt, the agent parses pytest output and classifies
the failure into one of these categories:

  syntax_error        β€” the patch introduced a SyntaxError
  hallucinated_api    β€” agent called a function/attribute that doesn't exist
  wrong_file_edit     β€” agent edited the wrong file (tests in different module fail)
  incomplete_patch    β€” partial fix: some tests pass but not all FAIL_TO_PASS
  flaky_test          β€” test is non-deterministic (passes on retry)
  import_error        β€” missing import or circular import introduced
  type_error          β€” wrong argument type passed
  assertion_error     β€” logic bug remains, assertion fails with unexpected value
  unknown             β€” can't categorise

The category is logged to MLflow and stored in trajectory JSONL.
This taxonomy directly drives which trajectories we select for fine-tuning
(Phase 7 filters on known-category failures).
"""
from __future__ import annotations

import re
from typing import Literal

FailureCategory = Literal[
    "syntax_error",
    "hallucinated_api",
    "wrong_file_edit",
    "incomplete_patch",
    "flaky_test",
    "import_error",
    "type_error",
    "assertion_error",
    "success",
    "unknown",
]

# ── Regex patterns ────────────────────────────────────────────────────────────

_PATTERNS: list[tuple[FailureCategory, re.Pattern]] = [
    ("syntax_error",     re.compile(r"SyntaxError|IndentationError|TabError", re.I)),
    ("import_error",     re.compile(r"ImportError|ModuleNotFoundError|cannot import name", re.I)),
    ("hallucinated_api", re.compile(
        r"AttributeError: .+ object has no attribute|"
        r"TypeError: .+ takes \d+ positional argument|"
        r"NameError: name .+ is not defined",
        re.I
    )),
    ("type_error",       re.compile(r"TypeError:", re.I)),
    ("assertion_error",  re.compile(r"AssertionError", re.I)),
]

_FLAKY_PATTERNS = re.compile(
    r"ResourceWarning|"
    r"random|"
    r"race condition|"
    r"flaky|"
    r"connection refused|"
    r"socket\.timeout",
    re.I
)


def categorise_failure(
    test_stdout: str,
    patch_apply_success: bool,
    fail_to_pass_results: dict[str, bool],
    pass_to_pass_results: dict[str, bool],
    attempt_num: int = 1,
    previous_categories: list[FailureCategory] | None = None,
) -> FailureCategory:
    """
    Classify a failed attempt into a FailureCategory.

    Decision flow:
      1. Patch didn't apply β†’ syntax_error
      2. All FAIL_TO_PASS pass β†’ success
      3. Scan error messages in stdout for pattern matches
      4. If same test failed differently across attempts β†’ flaky_test
      5. If some FTP pass but not all β†’ incomplete_patch
      6. Fallback: unknown

    Args:
        test_stdout:           raw pytest output
        patch_apply_success:   whether `git apply` succeeded
        fail_to_pass_results:  {test_id: passed} for FAIL_TO_PASS tests
        pass_to_pass_results:  {test_id: still_passing} for PASS_TO_PASS tests
        attempt_num:           current attempt number (1-indexed)
        previous_categories:   categories from earlier attempts (flaky detection)

    Returns:
        FailureCategory string
    """
    # 1. Patch apply failed β†’ likely syntax_error in diff
    if not patch_apply_success:
        return "syntax_error"

    # 2. All tests pass β†’ success
    ftp_ok = all(fail_to_pass_results.values()) if fail_to_pass_results else False
    ptp_ok = all(pass_to_pass_results.values()) if pass_to_pass_results else True
    if ftp_ok and ptp_ok:
        return "success"

    # 3. Scan pytest output for error patterns
    for category, pattern in _PATTERNS:
        if pattern.search(test_stdout):
            return category

    # 4. Flaky test detection: if we've seen different failures across attempts
    if previous_categories and len(set(previous_categories)) > 1:
        if _FLAKY_PATTERNS.search(test_stdout):
            return "flaky_test"

    # 5. Partial success β€” some FTP tests pass but not all
    ftp_passed = sum(1 for v in fail_to_pass_results.values() if v)
    ftp_total = len(fail_to_pass_results)
    if ftp_passed > 0 and ftp_passed < ftp_total:
        return "incomplete_patch"

    # 6. PASS_TO_PASS regression only (our patch broke existing tests)
    ptp_failed = sum(1 for v in pass_to_pass_results.values() if not v)
    if ptp_failed > 0 and ftp_passed == ftp_total:
        return "wrong_file_edit"

    return "unknown"


def extract_first_error_context(test_stdout: str, max_lines: int = 20) -> str:
    """
    Extract the most relevant error lines from pytest output.
    Used to build the reflection prompt β€” give the LLM targeted failure info.
    """
    lines = test_stdout.splitlines()

    # Find first FAILED line and return context around it
    for i, line in enumerate(lines):
        if "FAILED" in line or "ERROR" in line or "assert" in line.lower():
            start = max(0, i - 2)
            end = min(len(lines), i + max_lines)
            return "\n".join(lines[start:end])

    # Fallback: last N lines (pytest puts summary at end)
    return "\n".join(lines[-max_lines:])