TranTruongMMCII's picture
add rerank
65a2a7e
import re
import gc
import html
import json
import hashlib
import traceback
from pathlib import Path
from typing import Dict, Any, Tuple, List
import torch
import gradio as gr
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModel
from model_utils import load_model_and_tokenizer, generate_completion
# ============================================================
# CONFIG
# ============================================================
REMOTE_MODEL_REPO = "TranTruongMMCII/UIT.CS2229.Generator"
# One checkpoint can be exposed in multiple inference modes.
MODEL_VARIANTS = {
"Generator - Baseline": {
"checkpoint": "baseline",
"use_online_retriever": False,
"description": "Baseline generator, context only.",
},
"Generator - EOL": {
"checkpoint": "eol",
"use_online_retriever": False,
"description": "EOL-trained generator, context only.",
},
"Generator - Retriever-trained + EOL (Retrieval OFF)": {
"checkpoint": "retriever_eol",
"use_online_retriever": False,
"description": "Ablation: retriever-trained checkpoint, but retrieved input is empty.",
},
"Generator - Online Retriever + EOL": {
"checkpoint": "retriever_eol",
"use_online_retriever": True,
"description": "Full demo mode: online dense retrieval top-5 + rerank + retriever-trained EOL generator.",
},
}
DEFAULT_MODEL_NAME = "Generator - Baseline"
# Remote retriever artifacts in the same HF model repo.
RETRIEVER_INDEX_IN_REPO = "retriever/py150_train_index.pt"
RETRIEVER_CHUNKS_IN_REPO = "retriever/py150_train_chunked.jsonl"
RETRIEVER_MODEL_NAME = "microsoft/graphcodebert-base"
RETRIEVER_BLOCK_SIZE = 512
# Retrieve top-5 dense candidates, then rerank by simple keyword overlap.
RETRIEVER_TOP_K = 5
RERANK_KEYWORD_WEIGHT = 0.03
RETRIEVED_MAX_TOKENS = 180
# Safer for HF Spaces deploy:
# - False: app starts fast; model/retriever loads lazily on first use.
# - True: download/load at startup; can fail deploy if remote artifacts have issue.
PRE_DOWNLOAD_MODELS = False
WARMUP_DEFAULT_MODEL = False
# Keep one generator model in memory at a time.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ============================================================
# GLOBAL CACHE
# ============================================================
_model_paths_cache: Dict[str, Path] = {}
_current_checkpoint_name = None
_current_tokenizer = None
_current_model = None
_current_model_path = None
_retriever = None
# ============================================================
# MODEL UTILS
# ============================================================
def file_fingerprint(path: Path) -> str:
"""Return short sha256 fingerprint for debugging model identity."""
if not path.exists():
return "missing"
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()[:16]
def get_variant_config(model_name: str) -> Dict[str, Any]:
if model_name not in MODEL_VARIANTS:
raise ValueError(f"Unknown model option: {model_name}")
return MODEL_VARIANTS[model_name]
def resolve_remote_model_path(model_name: str) -> Path:
"""Download selected generator checkpoint folder from remote HF model repo."""
variant = get_variant_config(model_name)
checkpoint_name = variant["checkpoint"]
if checkpoint_name in _model_paths_cache:
return _model_paths_cache[checkpoint_name]
remote_subdir = f"checkpoint-best/{checkpoint_name}"
local_repo_dir = snapshot_download(
repo_id=REMOTE_MODEL_REPO,
repo_type="model",
allow_patterns=[f"{remote_subdir}/*"],
)
model_path = Path(local_repo_dir) / remote_subdir
required_files = [
"config.json",
"generation_config.json",
"model.safetensors",
"tokenizer.json",
"tokenizer_config.json",
]
missing = [f for f in required_files if not (model_path / f).exists()]
if missing:
raise FileNotFoundError(
f"Missing required files in {model_path}: {missing}")
_model_paths_cache[checkpoint_name] = model_path
return model_path
def unload_current_model():
global _current_checkpoint_name, _current_tokenizer, _current_model, _current_model_path
if _current_model is not None:
del _current_model
del _current_tokenizer
_current_checkpoint_name = None
_current_tokenizer = None
_current_model = None
_current_model_path = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_model(model_name: str):
"""Load selected generator checkpoint into memory. Only one generator model is kept in RAM."""
global _current_checkpoint_name, _current_tokenizer, _current_model, _current_model_path
variant = get_variant_config(model_name)
checkpoint_name = variant["checkpoint"]
if _current_checkpoint_name == checkpoint_name and _current_model is not None:
return _current_tokenizer, _current_model, _current_model_path
unload_current_model()
model_path = resolve_remote_model_path(model_name)
print(f"Loading generator checkpoint: {checkpoint_name}")
print(f"Selected mode: {model_name}")
print(f"Path: {model_path}")
print(f"SHA: {file_fingerprint(model_path / 'model.safetensors')}")
tokenizer, model = load_model_and_tokenizer(str(model_path))
model.to(device)
model.eval()
_current_checkpoint_name = checkpoint_name
_current_tokenizer = tokenizer
_current_model = model
_current_model_path = model_path
return tokenizer, model, model_path
def preload_model_folders():
"""Download all generator model folders to Hugging Face cache. Does not load models into RAM."""
print("Pre-downloading generator model folders...")
for name in MODEL_VARIANTS:
try:
path = resolve_remote_model_path(name)
print(f"Cached {name}: {path}")
except Exception as e:
print(f"[WARN] Failed to preload {name}: {e}")
# ============================================================
# INPUT NORMALIZATION
# ============================================================
def normalize_line(line: str) -> str:
"""Soft-normalize one line to be closer to train-time token style."""
line = re.sub(r"([()\[\]{}:,.=+\-*/<>])", r" \1 ", line)
line = re.sub(r"\s+", " ", line)
return line.strip()
def context_to_tokens(code: str) -> str:
"""
Convert normal-looking code into training-style token text.
If code is already tokenized with <EOL>, keep it as-is.
"""
code = str(code or "").strip()
if "<EOL>" in code:
return code
code = code.replace("\t", " ")
lines = code.splitlines()
tokens = [normalize_line(line) for line in lines if line.strip()]
return " <EOL> ".join(tokens).strip()
def trim_token_text(text: str, max_tokens: int) -> str:
"""Trim tokenized text to a maximum number of whitespace-separated tokens."""
toks = str(text or "").split()
if len(toks) <= max_tokens:
return str(text or "").strip()
return " ".join(toks[:max_tokens]).strip()
# ============================================================
# RETRIEVER
# ============================================================
LIT_PATTERN = re.compile(r"<(STR|NUM|CHAR)_LIT:(.*?)>", re.S)
KEYWORD_PATTERN = re.compile(r"[A-Za-z_]\w+")
STOPWORDS_FOR_RERANK = {
"from", "import", "class", "def", "return", "self", "true", "false", "none",
"if", "else", "elif", "for", "while", "try", "except", "with", "as", "in",
"and", "or", "not", "is", "str", "int", "list", "dict", "object",
}
def convert_cxg_format_to_normal(code: str) -> str:
"""Convert tokenized CodeXGLUE/ReACC code to Python-like text for GraphCodeBERT."""
code = str(code or "").strip()
code = code.replace("<s>", "").replace("</s>", "")
code = code.replace("<EOL>", "\n")
code = code.replace("<NUM_LIT>", "0")
code = code.replace("<STR_LIT>", '"str"')
code = code.replace("<CHAR_LIT>", '"c"')
for lit_type, lit_value in LIT_PATTERN.findall(code):
code = code.replace(f"<{lit_type}_LIT:{lit_value}>", lit_value)
return code
def keyword_set(text: str) -> set:
"""Extract lightweight keywords for reranking dense retrieval candidates."""
text = html.unescape(str(text or "")).replace("<EOL>", " ")
kws = set()
for tok in KEYWORD_PATTERN.findall(text):
low = tok.lower()
if len(low) < 3 or low in STOPWORDS_FOR_RERANK:
continue
kws.add(low)
return kws
def rerank_retrieval_results(token_context: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Rerank dense top-k results by adding a small keyword-overlap bonus.
Dense score still dominates; keyword overlap helps avoid top-1 chunks with wrong local pattern.
"""
query_keywords = keyword_set(token_context)
for r in results:
content_keywords = keyword_set(r.get("content", ""))
overlap = query_keywords & content_keywords
denom = max(1, min(len(query_keywords), 20))
overlap_score = len(overlap) / denom
r["keyword_overlap_count"] = len(overlap)
r["keyword_overlap_preview"] = ", ".join(sorted(list(overlap))[:20])
r["keyword_overlap_score"] = float(overlap_score)
r["rerank_score"] = float(
r["score"] + RERANK_KEYWORD_WEIGHT * overlap_score)
return sorted(results, key=lambda x: x["rerank_score"], reverse=True)
class OnlineDenseRetriever:
"""
Online dense retriever.
JSONL chunks are accessed via byte offsets, so the 463MB JSONL file is not loaded
into Python dictionaries all at once.
"""
def __init__(self):
self.device = device
self.index_path, self.chunks_path = self._download_retriever_artifacts()
print("Loading retriever index:", self.index_path)
index_tensor = torch.load(self.index_path, map_location="cpu")
if not isinstance(index_tensor, torch.Tensor):
index_tensor = torch.tensor(index_tensor)
index_tensor = index_tensor.float().contiguous()
index_tensor = torch.nn.functional.normalize(index_tensor, p=2, dim=1)
self.index = index_tensor
self.vector_dim = int(self.index.shape[1])
self.num_chunks = int(self.index.shape[0])
print("Retriever index shape:", tuple(self.index.shape))
print("Building JSONL byte offsets:", self.chunks_path)
self.offsets = self._build_offsets(self.chunks_path)
if len(self.offsets) != self.num_chunks:
raise ValueError(
f"Index/chunk mismatch: index={self.num_chunks}, chunks={len(self.offsets)}"
)
print("Chunk offsets:", len(self.offsets))
print("Loading retriever tokenizer:", RETRIEVER_MODEL_NAME)
self.tokenizer = AutoTokenizer.from_pretrained(
RETRIEVER_MODEL_NAME, use_fast=True)
print("Loading retriever encoder:", RETRIEVER_MODEL_NAME)
self.encoder = AutoModel.from_pretrained(
RETRIEVER_MODEL_NAME).to(self.device)
self.encoder.eval()
print("Online retriever ready.")
@staticmethod
def _download_retriever_artifacts() -> Tuple[Path, Path]:
local_repo_dir = snapshot_download(
repo_id=REMOTE_MODEL_REPO,
repo_type="model",
allow_patterns=[RETRIEVER_INDEX_IN_REPO, RETRIEVER_CHUNKS_IN_REPO],
)
root = Path(local_repo_dir)
index_path = root / RETRIEVER_INDEX_IN_REPO
chunks_path = root / RETRIEVER_CHUNKS_IN_REPO
missing = [str(p) for p in [index_path, chunks_path] if not p.exists()]
if missing:
raise FileNotFoundError(f"Missing retriever artifacts: {missing}")
return index_path, chunks_path
@staticmethod
def _build_offsets(path: Path) -> List[int]:
offsets = []
with open(path, "rb") as f:
while True:
pos = f.tell()
line = f.readline()
if not line:
break
if line.strip():
offsets.append(pos)
return offsets
def _read_chunk(self, idx: int) -> Dict[str, Any]:
with open(self.chunks_path, "rb") as f:
f.seek(self.offsets[idx])
line = f.readline().decode("utf-8")
return json.loads(line)
@torch.no_grad()
def encode_query(self, token_context: str) -> torch.Tensor:
clean_code = convert_cxg_format_to_normal(token_context)
text = f"{self.tokenizer.cls_token} {clean_code} {self.tokenizer.sep_token}"
encoded = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=RETRIEVER_BLOCK_SIZE,
return_tensors="pt",
).to(self.device)
outputs = self.encoder(**encoded)
query_vec = outputs.last_hidden_state[:, 0, :]
query_vec = torch.nn.functional.normalize(query_vec, p=2, dim=1)
return query_vec.squeeze(0).detach().cpu().float()
def search(self, token_context: str, top_k: int = RETRIEVER_TOP_K) -> List[Dict[str, Any]]:
query_vec = self.encode_query(token_context)
scores = torch.mv(self.index, query_vec)
top_scores, top_indices = torch.topk(scores, k=top_k)
results = []
for score, idx in zip(top_scores.tolist(), top_indices.tolist()):
row = self._read_chunk(int(idx))
content = row.get("content", "")
results.append({
"score": float(score),
"index": int(idx),
"original_file_id": row.get("original_file_id"),
"fragment_sequence_id": row.get("fragment_sequence_id"),
"content": content,
})
return rerank_retrieval_results(token_context, results)
def get_retriever() -> OnlineDenseRetriever:
global _retriever
if _retriever is None:
_retriever = OnlineDenseRetriever()
return _retriever
# ============================================================
# OUTPUT CLEANUP
# ============================================================
STOP_MARKERS = ["<EOL>", "</s>", "<s>",
"<pad>", "<unk>", "<mask>", "<|endoftext|>"]
SPECIAL_TOKEN_PATTERNS = [
"<EOL>", "</s>", "<s>", "<pad>", "<unk>", "<mask>", "<|endoftext|>",
"<STR_LIT>", "<STR_LIT:...>", "<NUM_LIT>", "<NUM_LIT:...>",
"<CHAR_LIT>", "<CHAR_LIT:...>", "<BOOL_LIT>", "<BOOL_LIT:...>",
"<NULL_LIT>", "<INDENT>", "<DEDENT>",
]
TOKEN_PATTERN = re.compile(
r"\"[^\"\n]*\"|'[^'\n]*'|[A-Za-z_]\w*|\d+(?:\.\d+)?|==|!=|<=|>=|\+=|-=|\*=|/=|//|->|[(){}\[\]:,.;=+\-*/<>]"
)
def normalize_special_spacing(text: str) -> str:
"""Normalize weird spaced special tokens that may appear after decoding."""
text = html.unescape(str(text))
text = re.sub(r"<\s*/\s*s\s*>", "</s>", text)
text = re.sub(r"<\s*s\s*>", "<s>", text)
text = re.sub(r"<\s*pad\s*>", "<pad>", text)
text = re.sub(r"<\s*unk\s*>", "<unk>", text)
text = re.sub(r"<\s*mask\s*>", "<mask>", text)
text = re.sub(r"<\s*\|\s*endoftext\s*\|\s*>", "<|endoftext|>", text)
for name in ["EOL", "STR_LIT", "NUM_LIT", "CHAR_LIT", "BOOL_LIT", "NULL_LIT", "INDENT", "DEDENT"]:
text = re.sub(rf"<\s*{name}\s*>", f"<{name}>", text)
for name in ["STR_LIT", "NUM_LIT", "CHAR_LIT", "BOOL_LIT"]:
text = re.sub(
rf"<\s*{name}\s*:\s*([^>]+?)\s*>",
lambda m, n=name: f"<{n}:{m.group(1).strip()}>",
text,
)
return text
def cut_at_stop_marker(text: str):
"""Cut text at earliest stop marker. Returns: cleaned_prefix, detected_marker."""
earliest = None
detected = None
for marker in STOP_MARKERS:
pos = text.find(marker)
if pos >= 0 and (earliest is None or pos < earliest):
earliest = pos
detected = marker
if earliest is None:
return text, None
return text[:earliest], detected
def replace_dataset_placeholders(text: str) -> str:
"""Convert train-time placeholders to readable Python-ish code."""
def repl_str_payload(m):
value = m.group(1).strip()
return json.dumps(value)
text = re.sub(r"<STR_LIT:([^>]+)>", repl_str_payload, text)
text = text.replace("<STR_LIT>", json.dumps("str"))
text = re.sub(r"<NUM_LIT:([^>]+)>", lambda m: m.group(1).strip(), text)
text = text.replace("<NUM_LIT>", "0")
text = re.sub(r"<CHAR_LIT:([^>]+)>",
lambda m: json.dumps(m.group(1).strip()), text)
text = text.replace("<CHAR_LIT>", json.dumps("c"))
text = re.sub(r"<BOOL_LIT:True>", "True", text)
text = re.sub(r"<BOOL_LIT:False>", "False", text)
text = text.replace("<BOOL_LIT>", "True")
text = text.replace("<NULL_LIT>", "None")
text = text.replace("<INDENT>", "\n ")
text = text.replace("<DEDENT>", "\n")
return text
def cleanup_prediction(raw_text: str):
"""
Clean raw generated token text for UI prediction.
Returns: prediction_text, detected_stop_marker, normalized_raw_text
"""
normalized = normalize_special_spacing(raw_text)
cut_text, stop_marker = cut_at_stop_marker(normalized)
for marker in STOP_MARKERS:
cut_text = cut_text.replace(marker, "")
cut_text = replace_dataset_placeholders(cut_text)
cut_text = cut_text.replace("<EOL>", "\n")
cut_text = re.sub(r"\s+([)\]\}:,])", r"\1", cut_text)
cut_text = re.sub(r"([(\[{])\s+", r"\1", cut_text)
cut_text = re.sub(r"\s*=\s*", " = ", cut_text)
cut_text = re.sub(r"\s*\+\s*", " + ", cut_text)
cut_text = re.sub(r"\s*-\s*", " - ", cut_text)
cut_text = re.sub(r"\s*\*\s*", " * ", cut_text)
cut_text = re.sub(r"\s*/\s*", " / ", cut_text)
cut_text = re.sub(r"\s*<\s*", " < ", cut_text)
cut_text = re.sub(r"\s*>\s*", " > ", cut_text)
cut_text = re.sub(r"[ \t]+", " ", cut_text)
cut_text = re.sub(r"\n\s+", "\n ", cut_text)
return cut_text.strip(), stop_marker, normalized
def token_spans(text: str):
"""Return normalized tokens and char spans for overlap trimming."""
text = str(text or "")
text = html.unescape(text)
text = text.replace("<EOL>", "\n")
toks = []
spans = []
for m in TOKEN_PATTERN.finditer(text):
toks.append(m.group(0))
spans.append(m.span())
return toks, spans
def trim_overlapping_prefix(context_text: str, prediction: str, token_context: str = ""):
"""
Remove duplicated prefix from prediction when prediction begins with tokens that already
appear at the end of the user context.
"""
pred = str(prediction or "").strip()
if not pred:
return pred, "No prediction to align."
ctx_source = token_context if token_context else context_text
ctx_tokens, _ = token_spans(ctx_source)
pred_tokens, pred_spans = token_spans(pred)
if not ctx_tokens or not pred_tokens:
return pred, "No token overlap check applied."
max_k = min(len(ctx_tokens), len(pred_tokens), 24)
best_k = 0
for k in range(1, max_k + 1):
if ctx_tokens[-k:] == pred_tokens[:k]:
best_k = k
if best_k <= 0:
return pred, "No duplicated prefix found."
cut_char = pred_spans[best_k - 1][1]
aligned = pred[cut_char:].lstrip(" \t\n,.;:")
if not aligned:
return pred, f"Detected {best_k} duplicated prefix token(s), but kept prediction to avoid empty output."
removed = pred[:cut_char].strip()
return aligned, f"Trimmed duplicated prefix from prediction: {removed!r} ({best_k} token(s))."
# ============================================================
# INFERENCE
# ============================================================
def run_demo(model_name: str, context: str):
try:
tokenizer, model, model_path = get_model(model_name)
variant = get_variant_config(model_name)
token_context = context_to_tokens(context)
retriever_mode = "Disabled"
retrieval_results = []
token_retrieved = ""
retrieved_raw = ""
selected_retrieval = None
if variant["use_online_retriever"]:
retriever_mode = "Online dense retrieval top-5 + keyword rerank"
retriever = get_retriever()
retrieval_results = retriever.search(
token_context, top_k=RETRIEVER_TOP_K)
if retrieval_results:
selected_retrieval = retrieval_results[0]
retrieved_raw = selected_retrieval.get("content", "")
token_retrieved = trim_token_text(
retrieved_raw, RETRIEVED_MAX_TOKENS)
max_length = 384 if variant["use_online_retriever"] else 256
raw_token_output = generate_completion(
model=model,
tokenizer=tokenizer,
retrieved=token_retrieved,
context=token_context,
device=device,
max_length=max_length,
max_new_tokens=16,
do_sample=False,
stop_strings=None,
)
prediction_before_align, stop_marker, normalized_output = cleanup_prediction(
raw_token_output)
prediction, align_note = trim_overlapping_prefix(
context_text=context,
prediction=prediction_before_align,
token_context=token_context,
)
if variant["use_online_retriever"]:
retriever_note = (
"Online retriever retrieved dense top-5 candidates, then reranked them using a small "
"keyword-overlap bonus. The selected retrieved chunk is injected before the context."
)
elif variant["checkpoint"] == "retriever_eol":
retriever_note = (
"Ablation mode: this checkpoint was trained with retrieved code, but retrieval is OFF. "
"The model receives typed context only."
)
else:
retriever_note = "Retriever is not used for this model."
if retrieval_results:
retrieval_log = ""
for rank, r in enumerate(retrieval_results, start=1):
selected_flag = " <-- SELECTED" if r is selected_retrieval else ""
retrieval_log += (
f"Rank {rank}{selected_flag}\n"
f"dense_score: {r['score']:.6f}\n"
f"keyword_overlap_count: {r.get('keyword_overlap_count', 0)}\n"
f"keyword_overlap_score: {r.get('keyword_overlap_score', 0.0):.6f}\n"
f"rerank_score: {r.get('rerank_score', r['score']):.6f}\n"
f"overlap_keywords: {r.get('keyword_overlap_preview', '')}\n"
f"index: {r['index']}\n"
f"original_file_id: {r.get('original_file_id')}\n"
f"fragment_sequence_id: {r.get('fragment_sequence_id')}\n"
f"content preview: {r.get('content', '')[:1200]}\n\n"
)
else:
retrieval_log = "No retrieval result."
logs = (
"=== DEMO LOGS ===\n\n"
f"[Selected model]\n{model_name}\n\n"
f"[Mode description]\n{variant['description']}\n\n"
f"[Model repo]\n{REMOTE_MODEL_REPO}\n\n"
f"[Local cache path]\n{model_path}\n\n"
f"[Model fingerprint]\n{file_fingerprint(model_path / 'model.safetensors')}\n\n"
f"[Device]\n{device}\n\n"
f"[Retriever mode]\n{retriever_mode}\n\n"
f"[Retriever note]\n{retriever_note}\n\n"
f"[Retriever model]\n{RETRIEVER_MODEL_NAME if variant['use_online_retriever'] else 'N/A'}\n\n"
f"[Retriever artifacts]\n{RETRIEVER_INDEX_IN_REPO}\n{RETRIEVER_CHUNKS_IN_REPO}\n\n"
f"[Retriever top_k]\n{RETRIEVER_TOP_K}\n\n"
f"[Rerank keyword weight]\n{RERANK_KEYWORD_WEIGHT}\n\n"
"[Retriever results]\n"
f"{retrieval_log}\n"
"[Known token patterns cleaned in Prediction]\n"
+ "\n".join(f"- {p}" for p in SPECIAL_TOKEN_PATTERNS)
+ "\n\n"
"[Raw Context]\n"
f"{context}\n\n"
"[Context → Tokens]\n"
f"{token_context}\n\n"
"[Selected retrieved raw]\n"
f"{retrieved_raw}\n\n"
"[Selected retrieved → Tokens used by generator]\n"
f"{token_retrieved}\n\n"
"[Raw Generator Output → Tokens]\n"
f"{raw_token_output}\n\n"
"[Normalized Generator Output → Tokens]\n"
f"{normalized_output}\n\n"
f"[Detected stop marker]\n{stop_marker}\n\n"
"[Prediction before overlap trim]\n"
f"{prediction_before_align}\n\n"
"[Overlap trim note]\n"
f"{align_note}\n\n"
"[Prediction]\n"
f"{prediction}\n"
)
return prediction, logs
except Exception:
err = traceback.format_exc()
return (
"ERROR: failed to load/generate.",
"=== ERROR LOGS ===\n\n" + err,
)
# ============================================================
# GRADIO UI
# ============================================================
demo = gr.Interface(
fn=run_demo,
inputs=[
gr.Dropdown(
choices=list(MODEL_VARIANTS.keys()),
value=DEFAULT_MODEL_NAME,
label="Model",
),
gr.Textbox(
lines=10,
label="Context",
placeholder="def sum(a, b):\n return",
),
],
outputs=[
gr.Textbox(lines=6, label="Prediction"),
gr.Textbox(lines=34, label="Logs"),
],
title="ReACC Code Completion Demo",
description=(
"Type Python code and compare Baseline, EOL, Retriever-trained EOL with retrieval OFF, "
"and Online Retriever + EOL. Online retriever mode retrieves dense top-5 candidates, "
"reranks them with keyword overlap, and injects the selected chunk into the generator."
),
)
# ============================================================
# STARTUP
# ============================================================
if PRE_DOWNLOAD_MODELS:
preload_model_folders()
if WARMUP_DEFAULT_MODEL:
print(f"Warming up default model: {DEFAULT_MODEL_NAME}")
get_model(DEFAULT_MODEL_NAME)
if __name__ == "__main__":
demo.launch()