Spaces:
Running
Running
File size: 18,854 Bytes
dc71cad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 | """
tests/test_phase3_localisation.py
ββββββββββββββββββββββββββββββββββ
Unit tests for Phase 3: BM25, RRF fusion, DeBERTa ranker, and pipeline.
All tests work without OpenAI API key or GPU β components degrade gracefully.
Run with: pytest tests/test_phase3_localisation.py -v
"""
from __future__ import annotations
import textwrap
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def make_file_symbols(file_path: str, summary: str = ""):
"""Create a minimal FileSymbols for testing."""
from ast_parser.python_parser import FileSymbols, FunctionInfo
fs = FileSymbols(file_path=file_path, file_hash="aaa111")
fs.module_docstring = summary
# Also add a fake function whose name contains the summary words
# so summary_text is fully populated
if summary:
fs.functions = [
FunctionInfo(
name=summary.split()[0] if summary.split() else "placeholder",
qualified_name=summary.split()[0] if summary.split() else "placeholder",
args=[], decorators=[], docstring=summary,
start_line=1, end_line=5,
)
]
return fs
# ββ BM25 tokeniser ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestTokeniser:
def test_lowercase(self):
from localisation.bm25_retriever import _tokenise
result = _tokenise("Hello World")
assert all(t == t.lower() for t in result)
def test_camel_case_split(self):
from localisation.bm25_retriever import _tokenise
result = _tokenise("QuerySet")
assert "query" in result
assert "set" in result
def test_snake_case_split(self):
from localisation.bm25_retriever import _tokenise
result = _tokenise("get_queryset")
assert "get" in result
assert "queryset" in result
def test_short_tokens_filtered(self):
from localisation.bm25_retriever import _tokenise
result = _tokenise("a b c def")
assert "a" not in result
assert "b" not in result
assert "def" in result
def test_path_tokenisation(self):
from localisation.bm25_retriever import _tokenise
result = _tokenise("django/db/models/query.py")
assert "django" in result
assert "models" in result
assert "query" in result
# ββ BM25 Retriever ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestBM25Retriever:
def test_index_and_query_basic(self):
from localisation.bm25_retriever import BM25Retriever
retriever = BM25Retriever()
symbols = [
make_file_symbols("models/user.py", "User authentication login password"),
make_file_symbols("views/dashboard.py", "Dashboard render template"),
make_file_symbols("utils/email.py", "Email sending SMTP"),
]
retriever.index(symbols)
hits = retriever.query("user authentication login", top_k=3)
assert len(hits) >= 1
assert hits[0].file_path == "models/user.py"
def test_query_returns_positive_scores_only(self):
from localisation.bm25_retriever import BM25Retriever
retriever = BM25Retriever()
symbols = [make_file_symbols(f"file{i}.py", f"content {i}") for i in range(5)]
retriever.index(symbols)
hits = retriever.query("content 3", top_k=5)
assert all(h.score > 0 for h in hits)
def test_ranks_are_sequential(self):
from localisation.bm25_retriever import BM25Retriever
retriever = BM25Retriever()
symbols = [make_file_symbols(f"f{i}.py", f"word{i} text") for i in range(3)]
retriever.index(symbols)
hits = retriever.query("word0 text", top_k=3)
assert [h.rank for h in hits] == list(range(1, len(hits) + 1))
def test_empty_query_returns_empty(self):
from localisation.bm25_retriever import BM25Retriever
retriever = BM25Retriever()
symbols = [make_file_symbols("a.py", "content")]
retriever.index(symbols)
hits = retriever.query("", top_k=5)
assert hits == []
def test_corpus_size(self):
from localisation.bm25_retriever import BM25Retriever
retriever = BM25Retriever()
symbols = [make_file_symbols(f"f{i}.py", "text") for i in range(7)]
retriever.index(symbols)
assert retriever.corpus_size == 7
def test_file_path_tokens_boost(self):
from localisation.bm25_retriever import BM25Retriever
# Both files have 'models' in content. But models.py ALSO has it in
# path (doubled) β with a larger corpus that gives positive BM25 scores.
retriever = BM25Retriever()
symbols = [
make_file_symbols("django/db/models.py", "handles database records"),
make_file_symbols("utils/helper.py", "general utilities helper"),
make_file_symbols("views/base.py", "base view rendering"),
make_file_symbols("core/app.py", "application entry point"),
make_file_symbols("api/serializers.py", "rest framework serializers"),
]
retriever.index(symbols)
hits = retriever.query("models", top_k=5)
# models.py has 'models' in path (2x weight) β must appear in results
paths = [h.file_path for h in hits]
assert "django/db/models.py" in paths
def test_not_indexed_raises(self):
from localisation.bm25_retriever import BM25Retriever
retriever = BM25Retriever()
with pytest.raises(RuntimeError, match="not indexed"):
retriever.query("test", top_k=5)
def test_skips_parse_error_files(self):
from localisation.bm25_retriever import BM25Retriever
from ast_parser.python_parser import FileSymbols
retriever = BM25Retriever()
good = make_file_symbols("good.py", "good content")
bad = FileSymbols(file_path="bad.py", file_hash="bbb", parse_error="SyntaxError")
retriever.index([good, bad])
assert retriever.corpus_size == 1
# ββ RRF Fusion ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestRRFFusion:
def test_basic_fusion(self):
from localisation.rrf_fusion import reciprocal_rank_fusion
bm25 = [("a.py", 1.0, 1), ("b.py", 0.8, 2), ("c.py", 0.5, 3)]
embed = [("b.py", 0.9, 1), ("a.py", 0.7, 2), ("d.py", 0.6, 3)]
ppr = {"a.py": 0.5, "b.py": 0.3}
result = reciprocal_rank_fusion(bm25, embed, ppr, top_k=4)
assert len(result) <= 4
# a.py appears in all three β should rank high
top_paths = [h.file_path for h in result]
assert "a.py" in top_paths[:2]
def test_top_k_respected(self):
from localisation.rrf_fusion import reciprocal_rank_fusion
bm25 = [(f"f{i}.py", 1.0, i + 1) for i in range(10)]
result = reciprocal_rank_fusion(bm25, [], {}, top_k=3)
assert len(result) == 3
def test_empty_inputs(self):
from localisation.rrf_fusion import reciprocal_rank_fusion
result = reciprocal_rank_fusion([], [], {}, top_k=5)
assert result == []
def test_ranks_sequential(self):
from localisation.rrf_fusion import reciprocal_rank_fusion
bm25 = [("a.py", 1.0, 1), ("b.py", 0.5, 2)]
result = reciprocal_rank_fusion(bm25, [], {}, top_k=5)
assert [h.rank for h in result] == list(range(1, len(result) + 1))
def test_all_sources_tracked(self):
from localisation.rrf_fusion import reciprocal_rank_fusion
bm25 = [("a.py", 1.0, 1)]
embed = [("a.py", 0.9, 1)]
ppr = {"a.py": 0.5}
result = reciprocal_rank_fusion(bm25, embed, ppr, top_k=1)
hit = result[0]
assert hit.bm25_rank == 1
assert hit.embed_rank == 1
assert hit.ppr_rank == 1
def test_ablation_no_ppr(self):
from localisation.rrf_fusion import ablate
bm25 = [("a.py", 1.0, 1)]
ppr = {"b.py": 99.0} # b.py has huge PPR score
# With PPR zeroed out, b.py should NOT appear
result = ablate(bm25, [], ppr, use_ppr=False, top_k=5)
paths = [h.file_path for h in result]
assert "b.py" not in paths
def test_scores_descending(self):
from localisation.rrf_fusion import reciprocal_rank_fusion
bm25 = [("a.py", 1.0, 1), ("b.py", 0.5, 2), ("c.py", 0.1, 3)]
result = reciprocal_rank_fusion(bm25, [], {}, top_k=3)
scores = [h.fused_score for h in result]
assert scores == sorted(scores, reverse=True)
def test_union_of_all_lists(self):
"""File appearing only in PPR should still be in results."""
from localisation.rrf_fusion import reciprocal_rank_fusion
bm25 = [("a.py", 1.0, 1)]
ppr = {"z.py": 1.0} # only in PPR
result = reciprocal_rank_fusion(bm25, [], ppr, top_k=10)
paths = [h.file_path for h in result]
assert "z.py" in paths
# ββ DeBERTa Ranker β without GPU ββββββββββββββββββββββββββββββββββββββββββββββ
class TestDeBERTaRankerFallback:
"""Tests for graceful fallback when model is not loaded."""
def test_rerank_fallback_returns_stage1_order(self):
from localisation.deberta_ranker import DeBERTaRanker
# Don't actually load the model
ranker = DeBERTaRanker.__new__(DeBERTaRanker)
ranker._available = False
ranker._model = None
ranker._tokenizer = None
candidates = [("a.py", "summary a"), ("b.py", "summary b"), ("c.py", "summary c")]
result = ranker.rerank("fix the bug", candidates, top_k=3)
assert len(result) == 3
assert result[0].file_path == "a.py"
assert result[0].rank == 1
def test_rerank_empty_candidates(self):
from localisation.deberta_ranker import DeBERTaRanker
ranker = DeBERTaRanker.__new__(DeBERTaRanker)
ranker._available = False
result = ranker.rerank("query", [], top_k=5)
assert result == []
def test_ranked_file_scores_are_positive(self):
from localisation.deberta_ranker import DeBERTaRanker
ranker = DeBERTaRanker.__new__(DeBERTaRanker)
ranker._available = False
candidates = [(f"f{i}.py", f"text {i}") for i in range(5)]
result = ranker.rerank("test query", candidates, top_k=5)
assert all(r.relevance_score > 0 for r in result)
# ββ Recall metric βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestRecallMetric:
def test_perfect_recall(self):
from localisation.deberta_ranker import recall_at_k
preds = ["a.py", "b.py", "c.py"]
gold = ["a.py", "b.py"]
assert recall_at_k(preds, gold, k=5) == 1.0
def test_zero_recall(self):
from localisation.deberta_ranker import recall_at_k
preds = ["x.py", "y.py"]
gold = ["a.py"]
assert recall_at_k(preds, gold, k=5) == 0.0
def test_partial_recall(self):
from localisation.deberta_ranker import recall_at_k
preds = ["a.py", "b.py", "c.py"]
gold = ["a.py", "z.py"]
assert recall_at_k(preds, gold, k=5) == 0.5
def test_recall_at_k_respects_k(self):
from localisation.deberta_ranker import recall_at_k
preds = ["x.py", "a.py"] # a.py is at position 2
gold = ["a.py"]
assert recall_at_k(preds, gold, k=1) == 0.0 # only looking at top-1
assert recall_at_k(preds, gold, k=2) == 1.0
def test_empty_gold(self):
from localisation.deberta_ranker import recall_at_k
assert recall_at_k(["a.py"], [], k=5) == 0.0
# ββ Patch file extraction βββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestExtractFilesFromPatch:
def test_basic_unified_diff(self):
from localisation.deberta_ranker import _extract_files_from_patch
patch = textwrap.dedent("""
diff --git a/django/db/models/query.py b/django/db/models/query.py
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1 +1 @@
-old
+new
""")
files = _extract_files_from_patch(patch)
assert "django/db/models/query.py" in files
def test_multiple_files(self):
from localisation.deberta_ranker import _extract_files_from_patch
patch = textwrap.dedent("""
--- a/foo.py
+++ b/foo.py
@@ -1 +1 @@ fix
--- a/bar.py
+++ b/bar.py
@@ -1 +1 @@ fix
""")
files = _extract_files_from_patch(patch)
assert "foo.py" in files
assert "bar.py" in files
def test_dev_null_excluded(self):
from localisation.deberta_ranker import _extract_files_from_patch
patch = "--- /dev/null\n+++ b/new_file.py\n"
files = _extract_files_from_patch(patch)
assert "/dev/null" not in files
def test_empty_patch(self):
from localisation.deberta_ranker import _extract_files_from_patch
assert _extract_files_from_patch("") == []
# ββ Failure categorisation ββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestFailureCategorisation:
def test_success(self):
from localisation.pipeline import categorise_localisation_failure
result = categorise_localisation_failure(["a.py", "b.py"], ["a.py"], "good long detailed issue text here")
assert result == "success"
def test_wrong_file(self):
from localisation.pipeline import categorise_localisation_failure
# Long issue text (>10 words) + no gold file found β wrong_file
long_issue = "there is a null pointer exception in the query filter method"
result = categorise_localisation_failure(["x.py", "y.py"], ["z.py"], long_issue)
assert result == "wrong_file"
def test_partial_file(self):
from localisation.pipeline import categorise_localisation_failure
result = categorise_localisation_failure(["a.py"], ["a.py", "b.py"], "long enough issue text to be valid")
assert result == "partial_file"
def test_ambiguous_issue(self):
from localisation.pipeline import categorise_localisation_failure
result = categorise_localisation_failure(["x.py"], ["z.py"], "fix bug") # very short
assert result == "ambiguous_issue"
# ββ Pipeline integration (no API required) ββββββββββββββββββββββββββββββββββββ
class TestLocalisationPipeline:
def test_pipeline_bm25_only(self):
from localisation.pipeline import LocalisationPipeline
pipeline = LocalisationPipeline(
use_embeddings=False,
use_deberta=False,
use_ppr=False,
)
symbols = [
make_file_symbols("auth/models.py", "User model authentication password hash"),
make_file_symbols("views/login.py", "Login view render form"),
make_file_symbols("utils/email.py", "Email SMTP send message"),
]
pipeline.index_repo(symbols)
result = pipeline.localise("user authentication login password", top_k=3)
assert len(result.hits) >= 1
assert result.hits[0].file_path == "auth/models.py"
def test_pipeline_empty_query(self):
from localisation.pipeline import LocalisationPipeline
pipeline = LocalisationPipeline(use_embeddings=False, use_deberta=False)
symbols = [make_file_symbols("a.py", "content")]
pipeline.index_repo(symbols)
result = pipeline.localise("")
assert result.failure_category == "empty_query"
assert result.hits == []
def test_pipeline_with_gold_files_computes_recall(self):
from localisation.pipeline import LocalisationPipeline
pipeline = LocalisationPipeline(use_embeddings=False, use_deberta=False, use_ppr=False)
# Use a larger corpus so BM25 gives positive scores
# 'queryset' appears in path AND content of target.py β guaranteed top-1
symbols = [
make_file_symbols("db/queryset.py", "queryset filter method database orm"),
make_file_symbols("views/generic.py", "generic view rendering template"),
make_file_symbols("utils/helper.py", "utility functions general purpose"),
make_file_symbols("api/serializer.py", "rest framework serializer fields"),
make_file_symbols("forms/widget.py", "html form widget rendering input"),
]
pipeline.index_repo(symbols)
result = pipeline.localise(
"fix null pointer exception in queryset filter", top_k=5,
gold_files=["db/queryset.py"]
)
assert result.recall_at_5 is not None
assert result.recall_at_10 is not None
assert result.recall_at_5 == 1.0 # queryset in path+content guarantees top rank
def test_top_k_paths_property(self):
from localisation.pipeline import LocalisationPipeline
pipeline = LocalisationPipeline(use_embeddings=False, use_deberta=False, use_ppr=False)
symbols = [make_file_symbols(f"f{i}.py", f"content {i}") for i in range(5)]
pipeline.index_repo(symbols)
result = pipeline.localise("content 1", top_k=3)
assert len(result.top_k_paths) == len(result.hits)
def test_hit_diagnostic_flags(self):
from localisation.pipeline import LocalisationPipeline
pipeline = LocalisationPipeline(use_embeddings=False, use_deberta=False, use_ppr=False)
symbols = [make_file_symbols("a.py", "special word")]
pipeline.index_repo(symbols)
result = pipeline.localise("special word", top_k=1)
if result.hits:
hit = result.hits[0]
assert hit.in_bm25 is True
|