File size: 4,475 Bytes
dc71cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
localisation/bm25_retriever.py
───────────────────────────────
Stage 1a β€” BM25 retrieval over repo file corpus.

Indexes per file:
  - File path tokens  (e.g. 'django/db/models/query.py' β†’ ['django','db','models','query'])
  - Docstrings        (module + function + class docstrings)
  - Function names    (tokenised by snake_case and CamelCase splitting)
  - Class names
  - Import targets

All text is lowercased and tokenised. BM25 (Okapi BM25 via rank-bm25)
scores each file given the issue query text.

Outputs: list of (file_path, bm25_score) sorted descending.
"""
from __future__ import annotations

import logging
import re
from dataclasses import dataclass
from typing import Sequence

logger = logging.getLogger(__name__)


@dataclass
class BM25Hit:
    file_path: str
    score: float
    rank: int   # 1-indexed rank in BM25 ordering


def _tokenise(text: str) -> list[str]:
    """
    Tokenise text for BM25 indexing.
    - Lowercases
    - Splits on non-alphanumeric chars
    - Splits CamelCase: 'QuerySet' β†’ ['query', 'set']
    - Splits snake_case: 'get_queryset' β†’ ['get', 'queryset']
    - Removes tokens shorter than 2 chars
    """
    # Insert space before capital letters in CamelCase
    text = re.sub(r"(?<=[a-z0-9])(?=[A-Z])", " ", text)
    # Split on non-alphanumeric
    tokens = re.split(r"[^a-zA-Z0-9]+", text.lower())
    return [t for t in tokens if len(t) >= 2]


def _build_document(file_path: str, summary_text: str) -> list[str]:
    """
    Build the BM25 document token list for one file.
    File path tokens are added with 2x weight (repeated).
    """
    path_tokens = _tokenise(file_path.replace("/", " ").replace("_", " ").replace(".", " "))
    content_tokens = _tokenise(summary_text)
    # Double-weight file path tokens β€” path relevance is strong signal
    return path_tokens + path_tokens + content_tokens


class BM25Retriever:
    """
    BM25 retriever over a corpus of Python files.

    Usage:
        retriever = BM25Retriever()
        retriever.index(file_symbols_list)
        hits = retriever.query("fix null pointer in QuerySet filter", top_k=20)
    """

    def __init__(self):
        self._bm25 = None
        self._file_paths: list[str] = []
        self._corpus: list[list[str]] = []

    def index(self, file_symbols_list) -> None:
        """
        Build BM25 index from a list of FileSymbols.

        Args:
            file_symbols_list: list of FileSymbols from ast_parser
        """
        try:
            from rank_bm25 import BM25Okapi
        except ImportError as e:
            raise ImportError("Install rank-bm25: pip install rank-bm25") from e

        self._file_paths = []
        self._corpus = []

        for fs in file_symbols_list:
            if fs.parse_error:
                continue
            doc_tokens = _build_document(fs.file_path, fs.summary_text)
            if doc_tokens:
                self._file_paths.append(fs.file_path)
                self._corpus.append(doc_tokens)

        self._bm25 = BM25Okapi(self._corpus)
        logger.info("BM25 index built: %d documents", len(self._file_paths))

    def query(self, query_text: str, top_k: int = 20) -> list[BM25Hit]:
        """
        Retrieve top-k files most relevant to query_text.

        Args:
            query_text: raw issue text or preprocessed query
            top_k: number of results to return

        Returns:
            List of BM25Hit sorted by score descending
        """
        if self._bm25 is None:
            raise RuntimeError("BM25Retriever is not indexed. Call .index() first.")

        query_tokens = _tokenise(query_text)
        if not query_tokens:
            logger.warning("Empty query tokens after tokenisation")
            return []

        scores = self._bm25.get_scores(query_tokens)

        # Pair with file paths and sort
        ranked = sorted(
            zip(self._file_paths, scores),
            key=lambda x: -x[1],
        )

        return [
            BM25Hit(file_path=fp, score=float(score), rank=i + 1)
            for i, (fp, score) in enumerate(ranked[:top_k])
            if score > 0
        ]

    def query_batch(self, queries: list[str], top_k: int = 20) -> list[list[BM25Hit]]:
        """Query multiple issues at once."""
        return [self.query(q, top_k) for q in queries]

    @property
    def corpus_size(self) -> int:
        return len(self._file_paths)