Spaces:
Running
Running
| """ | |
| localisation/deberta_ranker.py | |
| βββββββββββββββββββββββββββββββ | |
| Stage 2 β DeBERTa-v3-small cross-encoder ranker. | |
| Given a set of candidate files from Stage 1 (RRF fusion), this | |
| re-ranks them using a fine-tuned DeBERTa-v3-small cross-encoder that | |
| classifies (issue_text, file_summary) β relevant/not_relevant. | |
| Cross-encoders are much more precise than bi-encoders because they see | |
| both the query AND the document together β allowing full attention | |
| across both. The trade-off is they can't be pre-indexed (must run at | |
| query time), so we only apply them to the top-20 candidates from Stage 1. | |
| Training data (for fine-tuning): | |
| - Positive: (issue_text, gold_file_summary) β label=1 | |
| - Negative: (issue_text, random_file_summary) β label=0 | |
| - Hard negatives: BM25 top-20 files that are NOT the gold file β label=0 | |
| - Dataset built from SWE-bench Lite instances | |
| This module has two modes: | |
| 1. inference_only: loads a pre-trained checkpoint and scores candidates | |
| 2. training: fine-tunes DeBERTa-v3-small on the SWE-bench training set | |
| For Phase 3 we implement the inference path + training scaffold. | |
| Fine-tuning happens in Phase 7 (after trajectory data is collected). | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| # Default model β can be swapped for a fine-tuned checkpoint | |
| DEFAULT_MODEL = "microsoft/deberta-v3-small" | |
| # Max token lengths for cross-encoder input | |
| MAX_QUERY_LEN = 256 # issue text tokens | |
| MAX_DOC_LEN = 256 # file summary tokens | |
| MAX_TOTAL_LEN = 512 # total cross-encoder input length | |
| class RankedFile: | |
| file_path: str | |
| relevance_score: float # 0β1 probability of relevance | |
| rank: int # final rank (1-indexed) | |
| stage1_rank: int # rank before re-ranking | |
| class DeBERTaRanker: | |
| """ | |
| Cross-encoder re-ranker using DeBERTa-v3-small. | |
| Scores each (issue, file_summary) pair and re-orders Stage 1 candidates. | |
| Falls back gracefully to Stage 1 ordering if model unavailable. | |
| """ | |
| def __init__( | |
| self, | |
| model_name_or_path: str = DEFAULT_MODEL, | |
| device: str = "auto", | |
| max_length: int = MAX_TOTAL_LEN, | |
| ): | |
| self.model_name_or_path = model_name_or_path | |
| self.max_length = max_length | |
| self._model = None | |
| self._tokenizer = None | |
| self._device = self._resolve_device(device) | |
| self._available = False | |
| self._try_load() | |
| def _resolve_device(self, device: str) -> str: | |
| if device != "auto": | |
| return device | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| except ImportError: | |
| pass | |
| return "cpu" | |
| def _try_load(self) -> None: | |
| """Attempt to load the model β log a warning if unavailable.""" | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| logger.info( | |
| "Loading DeBERTa ranker: %s on %s", self.model_name_or_path, self._device | |
| ) | |
| self._tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
| self._model = AutoModelForSequenceClassification.from_pretrained( | |
| self.model_name_or_path, num_labels=2 | |
| ) | |
| self._model.to(self._device) | |
| self._model.eval() | |
| self._available = True | |
| logger.info("DeBERTa ranker loaded successfully") | |
| except Exception as e: | |
| logger.warning( | |
| "DeBERTa ranker not available (%s) β will use Stage 1 ordering as-is", e | |
| ) | |
| def rerank( | |
| self, | |
| issue_text: str, | |
| candidates: list[tuple[str, str]], # list of (file_path, file_summary) | |
| top_k: int = 10, | |
| batch_size: int = 16, | |
| ) -> list[RankedFile]: | |
| """ | |
| Re-rank candidates by relevance to issue_text. | |
| Args: | |
| issue_text: the GitHub issue description | |
| candidates: list of (file_path, file_summary) from Stage 1 | |
| top_k: number of results to return | |
| batch_size: inference batch size | |
| Returns: | |
| List of RankedFile sorted by relevance_score descending | |
| """ | |
| if not candidates: | |
| return [] | |
| if not self._available: | |
| logger.debug("DeBERTa unavailable β returning Stage 1 ordering") | |
| return [ | |
| RankedFile( | |
| file_path=fp, | |
| relevance_score=1.0 / (i + 1), # inverse rank as score | |
| rank=i + 1, | |
| stage1_rank=i + 1, | |
| ) | |
| for i, (fp, _) in enumerate(candidates[:top_k]) | |
| ] | |
| # Score all candidates | |
| scores = self._score_batch(issue_text, candidates, batch_size) | |
| # Sort by score descending | |
| ranked = sorted( | |
| zip(candidates, scores), | |
| key=lambda x: -x[1], | |
| ) | |
| return [ | |
| RankedFile( | |
| file_path=fp, | |
| relevance_score=float(score), | |
| rank=i + 1, | |
| stage1_rank=next( | |
| (j + 1 for j, (p, _) in enumerate(candidates) if p == fp), -1 | |
| ), | |
| ) | |
| for i, ((fp, _), score) in enumerate(ranked[:top_k]) | |
| ] | |
| def _score_batch( | |
| self, | |
| issue_text: str, | |
| candidates: list[tuple[str, str]], | |
| batch_size: int, | |
| ) -> list[float]: | |
| """Run cross-encoder inference on all candidates in batches.""" | |
| import torch | |
| import torch.nn.functional as F | |
| truncated_query = issue_text[:500] # characters (tokenizer handles tokens) | |
| scores = [] | |
| for i in range(0, len(candidates), batch_size): | |
| batch = candidates[i: i + batch_size] | |
| texts_a = [truncated_query] * len(batch) | |
| texts_b = [summary[:500] for _, summary in batch] | |
| encoded = self._tokenizer( | |
| texts_a, | |
| texts_b, | |
| max_length=self.max_length, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| encoded = {k: v.to(self._device) for k, v in encoded.items()} | |
| with torch.no_grad(): | |
| logits = self._model(**encoded).logits | |
| probs = F.softmax(logits, dim=-1) | |
| # Class 1 = relevant | |
| batch_scores = probs[:, 1].cpu().tolist() | |
| scores.extend(batch_scores) | |
| return scores | |
| # ββ Training scaffold βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DeBERTaTrainer: | |
| """ | |
| Fine-tunes DeBERTa-v3-small on (issue, file_summary) pairs. | |
| Training data format (JSONL): | |
| {"query": "<issue text>", "document": "<file summary>", "label": 0|1} | |
| Called in Phase 7 after collecting trajectory data from SWE-bench runs. | |
| """ | |
| def __init__( | |
| self, | |
| base_model: str = DEFAULT_MODEL, | |
| output_dir: Path = Path("models/deberta_ranker"), | |
| num_epochs: int = 3, | |
| learning_rate: float = 2e-5, | |
| batch_size: int = 16, | |
| ): | |
| self.base_model = base_model | |
| self.output_dir = Path(output_dir) | |
| self.num_epochs = num_epochs | |
| self.learning_rate = learning_rate | |
| self.batch_size = batch_size | |
| def prepare_training_data( | |
| self, | |
| swe_instances, # list of SWEInstance | |
| file_symbols_map, # {instance_id: list[FileSymbols]} | |
| hard_negatives_k: int = 5, # BM25 top-k non-gold as hard negatives | |
| ) -> list[dict]: | |
| """ | |
| Build training pairs from SWE-bench instances. | |
| Strategy: | |
| Positive: (issue, gold_file_summary) β label=1 | |
| Hard-neg: BM25 top-5 files that are NOT in the gold patch β label=0 | |
| Random-neg: random repo file β label=0 (1:2 pos:neg ratio) | |
| """ | |
| from localisation.bm25_retriever import BM25Retriever | |
| import random | |
| training_pairs = [] | |
| for inst in swe_instances: | |
| file_symbols = file_symbols_map.get(inst.instance_id, []) | |
| if not file_symbols: | |
| continue | |
| # Extract gold file paths from the patch | |
| gold_files = _extract_files_from_patch(inst.patch) | |
| # Build BM25 index for this repo | |
| retriever = BM25Retriever() | |
| retriever.index(file_symbols) | |
| bm25_hits = retriever.query(inst.problem_statement, top_k=hard_negatives_k + 5) | |
| fs_map = {fs.file_path: fs for fs in file_symbols} | |
| for gold_fp in gold_files: | |
| if gold_fp not in fs_map: | |
| continue | |
| # Positive pair | |
| training_pairs.append({ | |
| "query": inst.problem_statement[:500], | |
| "document": fs_map[gold_fp].summary_text[:500], | |
| "label": 1, | |
| "instance_id": inst.instance_id, | |
| }) | |
| # Hard negatives | |
| for hit in bm25_hits[:hard_negatives_k]: | |
| if hit.file_path not in gold_files and hit.file_path in fs_map: | |
| training_pairs.append({ | |
| "query": inst.problem_statement[:500], | |
| "document": fs_map[hit.file_path].summary_text[:500], | |
| "label": 0, | |
| "instance_id": inst.instance_id, | |
| }) | |
| logger.info( | |
| "Training data: %d pairs (%d positive, %d negative)", | |
| len(training_pairs), | |
| sum(1 for p in training_pairs if p["label"] == 1), | |
| sum(1 for p in training_pairs if p["label"] == 0), | |
| ) | |
| return training_pairs | |
| def train(self, training_data: list[dict]) -> None: | |
| """Fine-tune DeBERTa on the prepared training data.""" | |
| try: | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForSequenceClassification, | |
| TrainingArguments, Trainer | |
| ) | |
| import torch | |
| from torch.utils.data import Dataset | |
| except ImportError as e: | |
| raise ImportError("Install transformers + torch for fine-tuning") from e | |
| class PairDataset(Dataset): | |
| def __init__(self, data, tokenizer, max_length): | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): return len(self.data) | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| enc = self.tokenizer( | |
| item["query"], item["document"], | |
| max_length=self.max_length, | |
| padding="max_length", truncation=True, | |
| return_tensors="pt", | |
| ) | |
| return { | |
| "input_ids": enc["input_ids"].squeeze(), | |
| "attention_mask": enc["attention_mask"].squeeze(), | |
| "labels": torch.tensor(item["label"], dtype=torch.long), | |
| } | |
| tokenizer = AutoTokenizer.from_pretrained(self.base_model) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| self.base_model, num_labels=2 | |
| ) | |
| dataset = PairDataset(training_data, tokenizer, MAX_TOTAL_LEN) | |
| train_size = int(0.9 * len(dataset)) | |
| val_size = len(dataset) - train_size | |
| train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size]) | |
| args = TrainingArguments( | |
| output_dir=str(self.output_dir), | |
| num_train_epochs=self.num_epochs, | |
| per_device_train_batch_size=self.batch_size, | |
| per_device_eval_batch_size=self.batch_size, | |
| learning_rate=self.learning_rate, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| logging_steps=10, | |
| warmup_ratio=0.1, | |
| ) | |
| trainer = Trainer( | |
| model=model, args=args, | |
| train_dataset=train_ds, eval_dataset=val_ds, | |
| ) | |
| trainer.train() | |
| trainer.save_model(str(self.output_dir)) | |
| tokenizer.save_pretrained(str(self.output_dir)) | |
| logger.info("DeBERTa ranker saved to %s", self.output_dir) | |
| # ββ Metric helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def recall_at_k( | |
| predictions: list[str], | |
| gold_files: list[str], | |
| k: int, | |
| ) -> float: | |
| """Compute recall@k: fraction of gold files in top-k predictions.""" | |
| if not gold_files: | |
| return 0.0 | |
| top_k_set = set(predictions[:k]) | |
| hits = sum(1 for gf in gold_files if gf in top_k_set) | |
| return hits / len(gold_files) | |
| def _extract_files_from_patch(patch: str) -> list[str]: | |
| """Extract list of files modified in a unified diff.""" | |
| import re | |
| # Match '--- a/path/to/file.py' or '+++ b/path/to/file.py' | |
| pattern = re.compile(r"^(?:\+\+\+|---)\s+(?:a/|b/)(.+?)(?:\s|$)", re.MULTILINE) | |
| files = list(set(pattern.findall(patch))) | |
| return [f for f in files if f and f != "/dev/null"] | |