Spaces:
Running
Running
| """ | |
| localisation/pipeline.py | |
| βββββββββββββββββββββββββ | |
| Full two-stage localisation pipeline. | |
| Stage 1: BM25 + Embeddings (coarse ranking) β RRF fusion | |
| Stage 2: DeBERTa cross-encoder (precision re-ranking) | |
| Also handles: | |
| - Failure categorisation (wrong-file, partial-file, missing-dependency, ambiguous-issue) | |
| - MLflow cost tracking per retrieval call | |
| - Context budget enforcement (top-K files only) | |
| Usage: | |
| pipeline = LocalisationPipeline(cache_dir=Path(".cache")) | |
| pipeline.index_repo(file_symbols, dependency_graph) | |
| result = pipeline.localise( | |
| issue_text="Fix null pointer in QuerySet.filter()", | |
| top_k=5, | |
| ) | |
| for hit in result.hits: | |
| print(hit.file_path, hit.relevance_score) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Literal, Optional | |
| logger = logging.getLogger(__name__) | |
| # ββ Result types ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class LocalisationHit: | |
| file_path: str | |
| relevance_score: float | |
| rank: int | |
| # Diagnostic: which stages contributed | |
| in_bm25: bool = False | |
| in_embed: bool = False | |
| in_ppr: bool = False | |
| bm25_rank: Optional[int] = None | |
| embed_rank: Optional[int] = None | |
| ppr_rank: Optional[int] = None | |
| class LocalisationResult: | |
| hits: list[LocalisationHit] | |
| elapsed_seconds: float | |
| failure_category: Literal[ | |
| "success", | |
| "wrong_file", | |
| "partial_file", | |
| "missing_dependency", | |
| "ambiguous_issue", | |
| "empty_query", | |
| "index_error", | |
| ] = "success" | |
| # For evaluation | |
| recall_at_5: Optional[float] = None | |
| recall_at_10: Optional[float] = None | |
| def top_k_paths(self) -> list[str]: | |
| return [h.file_path for h in self.hits] | |
| # ββ Failure categorisation ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def categorise_localisation_failure( | |
| predicted_files: list[str], | |
| gold_files: list[str], | |
| issue_text: str, | |
| ) -> Literal["wrong_file", "partial_file", "missing_dependency", "ambiguous_issue", "success"]: | |
| """ | |
| Classify WHY localisation failed β generates signal for fine-tuning. | |
| Categories (from the roadmap): | |
| wrong_file: Gold file not in predicted top-K at all | |
| partial_file: Some gold files found but not all | |
| missing_dependency: Gold file has no BM25/embed match (needs graph) | |
| ambiguous_issue: Issue text is very short / vague | |
| success: All gold files found in predictions | |
| """ | |
| gold_set = set(gold_files) | |
| pred_set = set(predicted_files) | |
| hits = gold_set & pred_set | |
| if len(hits) == len(gold_set): | |
| return "success" | |
| if not hits: | |
| # No gold files found at all | |
| if len(issue_text.strip().split()) < 10: | |
| return "ambiguous_issue" | |
| return "wrong_file" | |
| if len(hits) < len(gold_set): | |
| return "partial_file" | |
| return "missing_dependency" | |
| # ββ Main pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class LocalisationPipeline: | |
| """ | |
| End-to-end file localisation pipeline: | |
| BM25 + Embeddings β RRF fusion β PPR graph propagation β DeBERTa re-rank | |
| The pipeline is stateful: index_repo() must be called before localise(). | |
| """ | |
| def __init__( | |
| self, | |
| cache_dir: Path = Path(".cache"), | |
| embedding_model: str = "text-embedding-3-small", | |
| deberta_model: str = "microsoft/deberta-v3-small", | |
| alpha_bm25: float = 0.4, | |
| alpha_embed: float = 0.4, | |
| alpha_ppr: float = 0.2, | |
| bm25_top_k: int = 20, | |
| embed_top_k: int = 20, | |
| ppr_top_k: int = 20, | |
| final_top_k: int = 10, | |
| use_deberta: bool = True, | |
| use_ppr: bool = True, | |
| use_embeddings: bool = True, | |
| track_mlflow: bool = False, | |
| ): | |
| self.alpha_bm25 = alpha_bm25 | |
| self.alpha_embed = alpha_embed | |
| self.alpha_ppr = alpha_ppr | |
| self.bm25_top_k = bm25_top_k | |
| self.embed_top_k = embed_top_k | |
| self.ppr_top_k = ppr_top_k | |
| self.final_top_k = final_top_k | |
| self.use_ppr = use_ppr | |
| self.use_embeddings = use_embeddings | |
| self.track_mlflow = track_mlflow | |
| # Lazy-init components | |
| self._bm25: Optional[object] = None | |
| self._embed: Optional[object] = None | |
| self._graph: Optional[object] = None | |
| self._ranker: Optional[object] = None | |
| self._file_symbols: list = [] | |
| # Build components | |
| from localisation.bm25_retriever import BM25Retriever | |
| self._bm25 = BM25Retriever() | |
| if use_embeddings: | |
| from localisation.embedding_retriever import EmbeddingRetriever | |
| self._embed = EmbeddingRetriever( | |
| model=embedding_model, | |
| cache_dir=cache_dir / "embeddings", | |
| ) | |
| if use_deberta: | |
| from localisation.deberta_ranker import DeBERTaRanker | |
| self._ranker = DeBERTaRanker(model_name_or_path=deberta_model) | |
| def index_repo( | |
| self, | |
| file_symbols: list, | |
| dependency_graph=None, | |
| show_progress: bool = False, | |
| ) -> dict: | |
| """ | |
| Index a repository for retrieval. | |
| Args: | |
| file_symbols: list of FileSymbols from ast_parser | |
| dependency_graph: RepoDependencyGraph (optional, enables PPR) | |
| show_progress: log embedding progress | |
| Returns: | |
| stats dict with timing and cache info | |
| """ | |
| self._file_symbols = file_symbols | |
| self._graph = dependency_graph | |
| start = time.monotonic() | |
| # BM25 index (fast β always runs) | |
| self._bm25.index(file_symbols) | |
| # Embedding index (slower, but cached) | |
| embed_stats = {} | |
| if self._embed: | |
| embed_stats = self._embed.index(file_symbols, show_progress=show_progress) | |
| elapsed = time.monotonic() - start | |
| logger.info( | |
| "Repo indexed in %.1fs β BM25: %d docs | Embed: %s", | |
| elapsed, self._bm25.corpus_size, embed_stats | |
| ) | |
| return {"elapsed": elapsed, "bm25_docs": self._bm25.corpus_size, **embed_stats} | |
| def localise( | |
| self, | |
| issue_text: str, | |
| top_k: Optional[int] = None, | |
| gold_files: Optional[list[str]] = None, # for evaluation only | |
| ) -> LocalisationResult: | |
| """ | |
| Localise relevant files for a given issue. | |
| Args: | |
| issue_text: the GitHub issue description | |
| top_k: override final top-k (default: self.final_top_k) | |
| gold_files: if provided, compute recall metrics | |
| Returns: | |
| LocalisationResult with ranked hits | |
| """ | |
| if not issue_text.strip(): | |
| return LocalisationResult(hits=[], elapsed_seconds=0.0, failure_category="empty_query") | |
| top_k = top_k or self.final_top_k | |
| start = time.monotonic() | |
| # ββ Stage 1a: BM25 ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| bm25_results = self._bm25.query(issue_text, top_k=self.bm25_top_k) | |
| bm25_hits_for_rrf = [(h.file_path, h.score, h.rank) for h in bm25_results] | |
| # ββ Stage 1b: Embeddings ββββββββββββββββββββββββββββββββββββββββββ | |
| embed_hits_for_rrf = [] | |
| if self._embed: | |
| embed_hits_for_rrf = self._embed.query(issue_text, top_k=self.embed_top_k) | |
| # ββ Stage 1c: PPR graph propagation ββββββββββββββββββββββββββββββ | |
| ppr_scores = {} | |
| if self.use_ppr and self._graph: | |
| seed_scores = {h.file_path: 1.0 / h.rank for h in bm25_results[:10]} | |
| ppr_scores = self._graph.personalized_pagerank( | |
| seed_scores, top_k=self.ppr_top_k | |
| ) | |
| # ββ RRF fusion ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from localisation.rrf_fusion import reciprocal_rank_fusion | |
| fused = reciprocal_rank_fusion( | |
| bm25_hits=bm25_hits_for_rrf, | |
| embed_hits=embed_hits_for_rrf, | |
| ppr_scores=ppr_scores, | |
| alpha_bm25=self.alpha_bm25, | |
| alpha_embed=self.alpha_embed, | |
| alpha_ppr=self.alpha_ppr, | |
| top_k=top_k * 2, # overshoot for Stage 2 input | |
| ) | |
| # ββ Stage 2: DeBERTa re-ranking βββββββββββββββββββββββββββββββββββ | |
| fs_summary_map = {fs.file_path: fs.summary_text for fs in self._file_symbols} | |
| stage2_candidates = [ | |
| (hit.file_path, fs_summary_map.get(hit.file_path, "")) | |
| for hit in fused | |
| ] | |
| if self._ranker and stage2_candidates: | |
| ranked_files = self._ranker.rerank( | |
| issue_text, stage2_candidates, top_k=top_k | |
| ) | |
| hits = [ | |
| LocalisationHit( | |
| file_path=r.file_path, | |
| relevance_score=r.relevance_score, | |
| rank=r.rank, | |
| in_bm25=any(h.file_path == r.file_path for h in bm25_results), | |
| in_embed=any(h[0] == r.file_path for h in embed_hits_for_rrf), | |
| in_ppr=r.file_path in ppr_scores, | |
| bm25_rank=next( | |
| (h.rank for h in bm25_results if h.file_path == r.file_path), None | |
| ), | |
| ppr_rank=next( | |
| (i + 1 for i, (fp, _) in enumerate( | |
| sorted(ppr_scores.items(), key=lambda x: -x[1]) | |
| ) if fp == r.file_path), None | |
| ), | |
| ) | |
| for r in ranked_files | |
| ] | |
| else: | |
| # Stage 1 output (no DeBERTa re-ranking) | |
| hits = [ | |
| LocalisationHit( | |
| file_path=h.file_path, | |
| relevance_score=h.fused_score, | |
| rank=h.rank, | |
| in_bm25=h.bm25_rank is not None, | |
| in_embed=h.embed_rank is not None, | |
| in_ppr=h.ppr_rank is not None, | |
| bm25_rank=h.bm25_rank, | |
| embed_rank=h.embed_rank, | |
| ppr_rank=h.ppr_rank, | |
| ) | |
| for h in fused[:top_k] | |
| ] | |
| elapsed = time.monotonic() - start | |
| # ββ Evaluation metrics ββββββββββββββββββββββββββββββββββββββββββββ | |
| result = LocalisationResult(hits=hits, elapsed_seconds=elapsed) | |
| if gold_files: | |
| from localisation.deberta_ranker import recall_at_k | |
| result.recall_at_5 = recall_at_k(result.top_k_paths, gold_files, k=5) | |
| result.recall_at_10 = recall_at_k(result.top_k_paths, gold_files, k=10) | |
| result.failure_category = categorise_localisation_failure( | |
| result.top_k_paths[:5], gold_files, issue_text | |
| ) | |
| # ββ MLflow tracking ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if self.track_mlflow: | |
| self._log_to_mlflow(result) | |
| logger.debug( | |
| "Localised in %.2fs | top-%d files | recall@5=%.2f", | |
| elapsed, len(hits), result.recall_at_5 or 0.0 | |
| ) | |
| return result | |
| def _log_to_mlflow(self, result: LocalisationResult) -> None: | |
| try: | |
| import mlflow | |
| mlflow.log_metrics({ | |
| "localisation_elapsed": result.elapsed_seconds, | |
| "recall_at_5": result.recall_at_5 or 0.0, | |
| "recall_at_10": result.recall_at_10 or 0.0, | |
| }) | |
| except Exception: | |
| pass | |