File size: 16,938 Bytes
cd61817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
"""Tests for Phase B3 — Failure Clustering (classifier + archetype extractor)."""

from __future__ import annotations

import json
from pathlib import Path
from unittest.mock import MagicMock, patch

from ci_triage_env.data.clustering import (
    FAMILIES,
    Archetype,
    ArchetypeExtractor,
    LLMClassifier,
    RuleBasedClassifier,
    classify_all,
)
from ci_triage_env.data.datasets._base import FailureRecord

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _record(log_text: str, record_id: str = "test-001") -> FailureRecord:
    return FailureRecord(
        record_id=record_id,
        source_dataset="deflaker",
        project="test/project",
        log_text=log_text,
    )


# ---------------------------------------------------------------------------
# RuleBasedClassifier
# ---------------------------------------------------------------------------

class TestRuleBasedClassifier:
    def setup_method(self) -> None:
        self.clf = RuleBasedClassifier()

    def test_rule_based_oom(self) -> None:
        record = _record("kernel: Out of memory: Killed process 123 (pytest)")
        family, conf = self.clf.classify(record)
        assert family == "infra_resource"
        assert conf > 0

    def test_rule_based_enospc(self) -> None:
        record = _record("write /var/lib/docker: no space left on device")
        family, conf = self.clf.classify(record)
        assert family == "infra_resource"
        assert conf > 0

    def test_rule_based_emfile(self) -> None:
        record = _record("Error: EMFILE: too many open files, open '/tmp/test'")
        family, conf = self.clf.classify(record)
        assert family == "infra_resource"
        assert conf > 0

    def test_rule_based_connection_refused(self) -> None:
        record = _record("dial tcp: connection refused localhost:5432")
        family, conf = self.clf.classify(record)
        assert family == "infra_network"
        assert conf > 0

    def test_rule_based_dns(self) -> None:
        record = _record("getaddrinfo failed: No such host is known")
        family, conf = self.clf.classify(record)
        assert family == "infra_network"
        assert conf > 0

    def test_rule_based_race(self) -> None:
        record = _record("WARNING: DATA RACE detected in goroutine 42")
        family, conf = self.clf.classify(record)
        assert family == "race_flake"
        assert conf > 0

    def test_rule_based_concurrent_map(self) -> None:
        record = _record("fatal error: concurrent map writes")
        family, conf = self.clf.classify(record)
        assert family == "race_flake"
        assert conf > 0

    def test_rule_based_timeout(self) -> None:
        record = _record("context canceled: deadline exceeded after 30s")
        family, conf = self.clf.classify(record)
        assert family == "timing_flake"
        assert conf > 0

    def test_rule_based_test_timed_out(self) -> None:
        record = _record("Test timed out after 120 seconds")
        family, conf = self.clf.classify(record)
        assert family == "timing_flake"
        assert conf > 0

    def test_rule_based_dependency_version_conflict(self) -> None:
        record = _record("npm ERR! peer dep missing: react@^17.0.0, version conflict")
        family, conf = self.clf.classify(record)
        assert family == "dependency_drift"
        assert conf > 0

    def test_rule_based_cargo_lock(self) -> None:
        record = _record("error: failed to select a version for the requirement\nCargo.lock conflict")
        family, conf = self.clf.classify(record)
        assert family == "dependency_drift"
        assert conf > 0

    def test_rule_based_unknown_returns_unknown(self) -> None:
        record = _record("generic test failure with no specific signal")
        family, conf = self.clf.classify(record)
        assert family == "unknown"
        assert conf == 0.0

    def test_rule_based_empty_log(self) -> None:
        record = _record("")
        family, conf = self.clf.classify(record)
        assert family == "unknown"
        assert conf == 0.0

    def test_ambiguous_when_multiple_families_match(self) -> None:
        # Both OOM (infra_resource) and DATA RACE (race_flake) in same log
        record = _record(
            "Out of memory: Killed process 999\nWARNING: DATA RACE in goroutine 1"
        )
        family, _conf = self.clf.classify(record)
        assert family == "ambiguous"

    def test_confidence_in_unit_range(self) -> None:
        record = _record("kernel: Out of memory: Killed process 1 (pytest)")
        _family, conf = self.clf.classify(record)
        assert 0.0 <= conf <= 1.0

    def test_real_bug_assertion_error(self) -> None:
        record = _record("AssertionError: expected 42, got 0\nassert result == expected")
        family, conf = self.clf.classify(record)
        assert family == "real_bug"
        assert conf > 0


# ---------------------------------------------------------------------------
# Archetype round-trip and extraction
# ---------------------------------------------------------------------------

class TestArchetype:
    def test_archetype_round_trip(self) -> None:
        arch = Archetype(
            archetype_id="infra_resource_001",
            family="infra_resource",
            pattern_summary="OOM-killer terminated test process",
            log_template="[{TIMESTAMP}] Out of memory: Killed process {PID}",
            slot_distributions={"TIMESTAMP": ["2024-01-01T00:00:00"], "PID": ["1234"]},
            informative_tools_hint=["read_logs:kernel"],
            minimal_evidence_hint=["cluster_metrics:5m"],
        )
        data = arch.model_dump()
        restored = Archetype.model_validate(data)
        assert restored.archetype_id == arch.archetype_id
        assert restored.family == arch.family
        assert restored.log_template == arch.log_template

    def test_archetype_json_round_trip(self) -> None:
        arch = Archetype(
            archetype_id="race_flake_001",
            family="race_flake",
            pattern_summary="Concurrent map write panic",
            log_template="fatal error: concurrent map writes at {TIMESTAMP}",
            slot_distributions={"TIMESTAMP": ["2024-06-01T12:00:00"]},
            informative_tools_hint=["read_logs:full"],
            minimal_evidence_hint=["read_logs:full"],
        )
        restored = Archetype.model_validate_json(arch.model_dump_json())
        assert restored.family == "race_flake"


class TestArchetypeExtractor:
    def setup_method(self) -> None:
        self.extractor = ArchetypeExtractor()

    def _make_records(self, log_texts: list[str]) -> list[FailureRecord]:
        return [
            _record(text, record_id=f"rec-{i:03d}")
            for i, text in enumerate(log_texts)
        ]

    def test_extract_returns_archetypes(self) -> None:
        records = self._make_records([
            "kernel: Out of memory: Killed process 123 (pytest) 2024-01-01T10:00:00",
            "kernel: Out of memory: Killed process 456 (cargo) 2024-01-02T11:00:00",
            "Out of memory: cannot allocate 4096kB",
            "OOMKilled: container exceeded memory limit 2048MB",
            "killed by OS: out of memory at 2024-01-03T09:00:00",
        ])
        archetypes = self.extractor.extract(records, "infra_resource", n_archetypes=4)
        assert len(archetypes) >= 1
        for arch in archetypes:
            assert arch.family == "infra_resource"
            assert arch.archetype_id.startswith("infra_resource_")
            assert arch.log_template
            assert arch.informative_tools_hint
            assert arch.minimal_evidence_hint

    def test_extract_with_single_record(self) -> None:
        records = self._make_records(["context canceled: deadline exceeded"])
        archetypes = self.extractor.extract(records, "timing_flake", n_archetypes=4)
        assert len(archetypes) == 1
        assert archetypes[0].family == "timing_flake"

    def test_extract_empty_records_returns_empty(self) -> None:
        archetypes = self.extractor.extract([], "real_bug", n_archetypes=4)
        assert archetypes == []

    def test_extract_slots_present_in_template(self) -> None:
        records = self._make_records([
            "2024-03-15T08:30:00 OOMKilled process 99999 after 120.5s",
        ])
        archetypes = self.extractor.extract(records, "infra_resource", n_archetypes=1)
        assert len(archetypes) == 1
        template = archetypes[0].log_template
        # At least one slot should have been extracted
        assert "{" in template

    def test_slot_distributions_are_populated(self) -> None:
        records = self._make_records([
            "2024-01-01T10:00:00 Out of memory: Killed process 1234",
            "2024-02-01T11:00:00 Out of memory: Killed process 5678",
        ])
        archetypes = self.extractor.extract(records, "infra_resource", n_archetypes=1)
        assert len(archetypes) >= 1
        # At least one slot family should have distribution values
        all_vals = []
        for arch in archetypes:
            for vals in arch.slot_distributions.values():
                all_vals.extend(vals)
        assert len(all_vals) > 0

    def test_archetype_extraction_from_fixture(self) -> None:
        """Given 5 fixture FailureRecords, extract at least 1 archetype with slots."""
        records = self._make_records([
            "fatal error: concurrent map writes at goroutine 12",
            "WARNING: DATA RACE in goroutine 42 on address 0xc000deadbeef",
            "data race detected: read at 0x00c0001234ab by goroutine 7",
            "deadlock detected: all goroutines are asleep",
            "fatal error: concurrent map iteration and map write",
        ])
        archetypes = self.extractor.extract(records, "race_flake", n_archetypes=4)
        assert len(archetypes) >= 1
        # Check that at least one archetype has slot placeholders
        has_slots = any("{" in a.log_template for a in archetypes)
        assert has_slots

    def test_n_archetypes_cap(self) -> None:
        """Never return more archetypes than records."""
        records = self._make_records(["timeout exceeded", "test timed out after 60s"])
        archetypes = self.extractor.extract(records, "timing_flake", n_archetypes=10)
        assert len(archetypes) <= len(records)


# ---------------------------------------------------------------------------
# classify_all
# ---------------------------------------------------------------------------

class TestClassifyAll:
    def test_classify_all_no_records(self) -> None:
        by_family = classify_all([])
        assert set(by_family.keys()) == set(FAMILIES)
        assert all(len(v) == 0 for v in by_family.values())

    def test_classify_all_routes_correctly(self) -> None:
        records = [
            _record("Out of memory: Killed process 1", "oom-1"),
            _record("WARNING: DATA RACE in goroutine 99", "race-1"),
            _record("context canceled: deadline exceeded", "timeout-1"),
        ]
        by_family = classify_all(records)
        assert len(by_family["infra_resource"]) >= 1
        assert len(by_family["race_flake"]) >= 1
        assert len(by_family["timing_flake"]) >= 1

    def test_classify_all_unknown_falls_into_real_bug(self) -> None:
        records = [_record("some completely generic text with no signal", "unknown-1")]
        by_family = classify_all(records, openai_api_key=None)
        # unknowns with no LLM go to real_bug
        assert len(by_family["real_bug"]) == 1

    def test_classify_all_returns_all_families(self) -> None:
        by_family = classify_all([_record("OOM error")])
        assert set(by_family.keys()) == set(FAMILIES)

    def test_classify_all_writes_per_family_files(self, tmp_path: Path) -> None:
        """classify_all + extract puts archetype files under each family dir."""
        records = [
            _record("Out of memory: Killed process 1", "oom-1"),
            _record("connection refused to db:5432", "net-1"),
            _record("deadline exceeded in test", "timeout-1"),
        ]
        by_family = classify_all(records)
        extractor = ArchetypeExtractor()
        for family, recs in by_family.items():
            if not recs:
                continue
            archetypes = extractor.extract(recs, family, n_archetypes=2)
            family_dir = tmp_path / family
            family_dir.mkdir(parents=True, exist_ok=True)
            (family_dir / "archetypes.json").write_text(
                json.dumps([a.model_dump() for a in archetypes], indent=2)
            )

        written_families = [d.name for d in tmp_path.iterdir() if d.is_dir()]
        assert len(written_families) >= 1
        for d in tmp_path.iterdir():
            arch_file = d / "archetypes.json"
            assert arch_file.exists()
            data = json.loads(arch_file.read_text())
            assert isinstance(data, list)
            assert len(data) >= 1


# ---------------------------------------------------------------------------
# LLMClassifier (mocked)
# ---------------------------------------------------------------------------

class TestLLMClassifier:
    def _make_mock_client(self, label: str = "infra_resource") -> MagicMock:
        mock_response = MagicMock()
        mock_response.choices[0].message.content = label
        mock_response.usage.prompt_tokens = 100
        mock_response.usage.completion_tokens = 5
        return mock_response

    def test_llm_classifier_respects_budget(self) -> None:
        """LLMClassifier stops calling after budget exhausted."""
        records = [
            _record(f"generic unknown failure {i}", f"rec-{i}")
            for i in range(5)
        ]

        mock_response = self._make_mock_client("real_bug")

        with patch("ci_triage_env.data.clustering.classifier.LLMClassifier.__init__", return_value=None):
            clf = LLMClassifier.__new__(LLMClassifier)
            clf.budget = 0.0  # already over budget
            clf.spent = 0.0
            clf.model = "gpt-4o-mini"

            mock_client = MagicMock()
            mock_client.chat.completions.create.return_value = mock_response
            clf.client = mock_client

            results = clf.classify_batch(records)

        # All results should be ("unknown", 0.0) since budget is exhausted
        assert all(family == "unknown" for family, _ in results)
        # Client should NOT have been called
        mock_client.chat.completions.create.assert_not_called()

    def test_llm_classifier_classifies_records(self) -> None:
        """LLMClassifier calls the API and returns valid labels."""
        records = [_record("some unknown log text", "unk-1")]

        mock_response = self._make_mock_client("real_bug")

        with patch("ci_triage_env.data.clustering.classifier.LLMClassifier.__init__", return_value=None):
            clf = LLMClassifier.__new__(LLMClassifier)
            clf.budget = 5.0
            clf.spent = 0.0
            clf.model = "gpt-4o-mini"

            mock_client = MagicMock()
            mock_client.chat.completions.create.return_value = mock_response
            clf.client = mock_client

            results = clf.classify_batch(records)

        assert len(results) == 1
        family, conf = results[0]
        assert family == "real_bug"
        assert conf == 0.7

    def test_llm_classifier_invalid_label_falls_back_to_unknown(self) -> None:
        records = [_record("something", "x-1")]
        mock_response = self._make_mock_client("not_a_valid_label")

        with patch("ci_triage_env.data.clustering.classifier.LLMClassifier.__init__", return_value=None):
            clf = LLMClassifier.__new__(LLMClassifier)
            clf.budget = 5.0
            clf.spent = 0.0
            clf.model = "gpt-4o-mini"

            mock_client = MagicMock()
            mock_client.chat.completions.create.return_value = mock_response
            clf.client = mock_client

            results = clf.classify_batch(records)

        family, _ = results[0]
        assert family == "unknown"

    def test_classify_all_uses_llm_for_unknowns(self) -> None:
        """classify_all routes unknown records through LLM when api_key provided."""
        records = [_record("no signal here at all", "unk-1")]

        with (
            patch("ci_triage_env.data.clustering.classifier.LLMClassifier") as MockLLMClass,
        ):
            mock_llm = MagicMock()
            mock_llm.classify_batch.return_value = [("timing_flake", 0.7)]
            mock_llm.spent = 0.0001
            MockLLMClass.return_value = mock_llm

            by_family = classify_all(records, openai_api_key="sk-fake-key")

        assert len(by_family["timing_flake"]) == 1
        MockLLMClass.assert_called_once_with("sk-fake-key")