repomind-api / localisation /deberta_ranker.py
SouravNath's picture
Initial commit
dc71cad
"""
localisation/deberta_ranker.py
───────────────────────────────
Stage 2 β€” DeBERTa-v3-small cross-encoder ranker.
Given a set of candidate files from Stage 1 (RRF fusion), this
re-ranks them using a fine-tuned DeBERTa-v3-small cross-encoder that
classifies (issue_text, file_summary) β†’ relevant/not_relevant.
Cross-encoders are much more precise than bi-encoders because they see
both the query AND the document together β€” allowing full attention
across both. The trade-off is they can't be pre-indexed (must run at
query time), so we only apply them to the top-20 candidates from Stage 1.
Training data (for fine-tuning):
- Positive: (issue_text, gold_file_summary) β†’ label=1
- Negative: (issue_text, random_file_summary) β†’ label=0
- Hard negatives: BM25 top-20 files that are NOT the gold file β†’ label=0
- Dataset built from SWE-bench Lite instances
This module has two modes:
1. inference_only: loads a pre-trained checkpoint and scores candidates
2. training: fine-tunes DeBERTa-v3-small on the SWE-bench training set
For Phase 3 we implement the inference path + training scaffold.
Fine-tuning happens in Phase 7 (after trajectory data is collected).
"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
# Default model β€” can be swapped for a fine-tuned checkpoint
DEFAULT_MODEL = "microsoft/deberta-v3-small"
# Max token lengths for cross-encoder input
MAX_QUERY_LEN = 256 # issue text tokens
MAX_DOC_LEN = 256 # file summary tokens
MAX_TOTAL_LEN = 512 # total cross-encoder input length
@dataclass
class RankedFile:
file_path: str
relevance_score: float # 0–1 probability of relevance
rank: int # final rank (1-indexed)
stage1_rank: int # rank before re-ranking
class DeBERTaRanker:
"""
Cross-encoder re-ranker using DeBERTa-v3-small.
Scores each (issue, file_summary) pair and re-orders Stage 1 candidates.
Falls back gracefully to Stage 1 ordering if model unavailable.
"""
def __init__(
self,
model_name_or_path: str = DEFAULT_MODEL,
device: str = "auto",
max_length: int = MAX_TOTAL_LEN,
):
self.model_name_or_path = model_name_or_path
self.max_length = max_length
self._model = None
self._tokenizer = None
self._device = self._resolve_device(device)
self._available = False
self._try_load()
def _resolve_device(self, device: str) -> str:
if device != "auto":
return device
try:
import torch
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
except ImportError:
pass
return "cpu"
def _try_load(self) -> None:
"""Attempt to load the model β€” log a warning if unavailable."""
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
logger.info(
"Loading DeBERTa ranker: %s on %s", self.model_name_or_path, self._device
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self._model = AutoModelForSequenceClassification.from_pretrained(
self.model_name_or_path, num_labels=2
)
self._model.to(self._device)
self._model.eval()
self._available = True
logger.info("DeBERTa ranker loaded successfully")
except Exception as e:
logger.warning(
"DeBERTa ranker not available (%s) β€” will use Stage 1 ordering as-is", e
)
def rerank(
self,
issue_text: str,
candidates: list[tuple[str, str]], # list of (file_path, file_summary)
top_k: int = 10,
batch_size: int = 16,
) -> list[RankedFile]:
"""
Re-rank candidates by relevance to issue_text.
Args:
issue_text: the GitHub issue description
candidates: list of (file_path, file_summary) from Stage 1
top_k: number of results to return
batch_size: inference batch size
Returns:
List of RankedFile sorted by relevance_score descending
"""
if not candidates:
return []
if not self._available:
logger.debug("DeBERTa unavailable β€” returning Stage 1 ordering")
return [
RankedFile(
file_path=fp,
relevance_score=1.0 / (i + 1), # inverse rank as score
rank=i + 1,
stage1_rank=i + 1,
)
for i, (fp, _) in enumerate(candidates[:top_k])
]
# Score all candidates
scores = self._score_batch(issue_text, candidates, batch_size)
# Sort by score descending
ranked = sorted(
zip(candidates, scores),
key=lambda x: -x[1],
)
return [
RankedFile(
file_path=fp,
relevance_score=float(score),
rank=i + 1,
stage1_rank=next(
(j + 1 for j, (p, _) in enumerate(candidates) if p == fp), -1
),
)
for i, ((fp, _), score) in enumerate(ranked[:top_k])
]
def _score_batch(
self,
issue_text: str,
candidates: list[tuple[str, str]],
batch_size: int,
) -> list[float]:
"""Run cross-encoder inference on all candidates in batches."""
import torch
import torch.nn.functional as F
truncated_query = issue_text[:500] # characters (tokenizer handles tokens)
scores = []
for i in range(0, len(candidates), batch_size):
batch = candidates[i: i + batch_size]
texts_a = [truncated_query] * len(batch)
texts_b = [summary[:500] for _, summary in batch]
encoded = self._tokenizer(
texts_a,
texts_b,
max_length=self.max_length,
padding=True,
truncation=True,
return_tensors="pt",
)
encoded = {k: v.to(self._device) for k, v in encoded.items()}
with torch.no_grad():
logits = self._model(**encoded).logits
probs = F.softmax(logits, dim=-1)
# Class 1 = relevant
batch_scores = probs[:, 1].cpu().tolist()
scores.extend(batch_scores)
return scores
# ── Training scaffold ─────────────────────────────────────────────────────────
class DeBERTaTrainer:
"""
Fine-tunes DeBERTa-v3-small on (issue, file_summary) pairs.
Training data format (JSONL):
{"query": "<issue text>", "document": "<file summary>", "label": 0|1}
Called in Phase 7 after collecting trajectory data from SWE-bench runs.
"""
def __init__(
self,
base_model: str = DEFAULT_MODEL,
output_dir: Path = Path("models/deberta_ranker"),
num_epochs: int = 3,
learning_rate: float = 2e-5,
batch_size: int = 16,
):
self.base_model = base_model
self.output_dir = Path(output_dir)
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.batch_size = batch_size
def prepare_training_data(
self,
swe_instances, # list of SWEInstance
file_symbols_map, # {instance_id: list[FileSymbols]}
hard_negatives_k: int = 5, # BM25 top-k non-gold as hard negatives
) -> list[dict]:
"""
Build training pairs from SWE-bench instances.
Strategy:
Positive: (issue, gold_file_summary) β†’ label=1
Hard-neg: BM25 top-5 files that are NOT in the gold patch β†’ label=0
Random-neg: random repo file β†’ label=0 (1:2 pos:neg ratio)
"""
from localisation.bm25_retriever import BM25Retriever
import random
training_pairs = []
for inst in swe_instances:
file_symbols = file_symbols_map.get(inst.instance_id, [])
if not file_symbols:
continue
# Extract gold file paths from the patch
gold_files = _extract_files_from_patch(inst.patch)
# Build BM25 index for this repo
retriever = BM25Retriever()
retriever.index(file_symbols)
bm25_hits = retriever.query(inst.problem_statement, top_k=hard_negatives_k + 5)
fs_map = {fs.file_path: fs for fs in file_symbols}
for gold_fp in gold_files:
if gold_fp not in fs_map:
continue
# Positive pair
training_pairs.append({
"query": inst.problem_statement[:500],
"document": fs_map[gold_fp].summary_text[:500],
"label": 1,
"instance_id": inst.instance_id,
})
# Hard negatives
for hit in bm25_hits[:hard_negatives_k]:
if hit.file_path not in gold_files and hit.file_path in fs_map:
training_pairs.append({
"query": inst.problem_statement[:500],
"document": fs_map[hit.file_path].summary_text[:500],
"label": 0,
"instance_id": inst.instance_id,
})
logger.info(
"Training data: %d pairs (%d positive, %d negative)",
len(training_pairs),
sum(1 for p in training_pairs if p["label"] == 1),
sum(1 for p in training_pairs if p["label"] == 0),
)
return training_pairs
def train(self, training_data: list[dict]) -> None:
"""Fine-tune DeBERTa on the prepared training data."""
try:
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
import torch
from torch.utils.data import Dataset
except ImportError as e:
raise ImportError("Install transformers + torch for fine-tuning") from e
class PairDataset(Dataset):
def __init__(self, data, tokenizer, max_length):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self): return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
enc = self.tokenizer(
item["query"], item["document"],
max_length=self.max_length,
padding="max_length", truncation=True,
return_tensors="pt",
)
return {
"input_ids": enc["input_ids"].squeeze(),
"attention_mask": enc["attention_mask"].squeeze(),
"labels": torch.tensor(item["label"], dtype=torch.long),
}
tokenizer = AutoTokenizer.from_pretrained(self.base_model)
model = AutoModelForSequenceClassification.from_pretrained(
self.base_model, num_labels=2
)
dataset = PairDataset(training_data, tokenizer, MAX_TOTAL_LEN)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
args = TrainingArguments(
output_dir=str(self.output_dir),
num_train_epochs=self.num_epochs,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=self.batch_size,
learning_rate=self.learning_rate,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
logging_steps=10,
warmup_ratio=0.1,
)
trainer = Trainer(
model=model, args=args,
train_dataset=train_ds, eval_dataset=val_ds,
)
trainer.train()
trainer.save_model(str(self.output_dir))
tokenizer.save_pretrained(str(self.output_dir))
logger.info("DeBERTa ranker saved to %s", self.output_dir)
# ── Metric helpers ────────────────────────────────────────────────────────────
def recall_at_k(
predictions: list[str],
gold_files: list[str],
k: int,
) -> float:
"""Compute recall@k: fraction of gold files in top-k predictions."""
if not gold_files:
return 0.0
top_k_set = set(predictions[:k])
hits = sum(1 for gf in gold_files if gf in top_k_set)
return hits / len(gold_files)
def _extract_files_from_patch(patch: str) -> list[str]:
"""Extract list of files modified in a unified diff."""
import re
# Match '--- a/path/to/file.py' or '+++ b/path/to/file.py'
pattern = re.compile(r"^(?:\+\+\+|---)\s+(?:a/|b/)(.+?)(?:\s|$)", re.MULTILINE)
files = list(set(pattern.findall(patch)))
return [f for f in files if f and f != "/dev/null"]