| import re |
| import gc |
| import html |
| import json |
| import hashlib |
| import traceback |
| from pathlib import Path |
| from typing import Dict, Any, Tuple, List |
|
|
| import torch |
| import gradio as gr |
| from huggingface_hub import snapshot_download |
| from transformers import AutoTokenizer, AutoModel |
|
|
| from model_utils import load_model_and_tokenizer, generate_completion |
|
|
|
|
| |
| |
| |
|
|
| REMOTE_MODEL_REPO = "TranTruongMMCII/UIT.CS2229.Generator" |
|
|
| |
| MODEL_VARIANTS = { |
| "Generator - Baseline": { |
| "checkpoint": "baseline", |
| "use_online_retriever": False, |
| "description": "Baseline generator, context only.", |
| }, |
| "Generator - EOL": { |
| "checkpoint": "eol", |
| "use_online_retriever": False, |
| "description": "EOL-trained generator, context only.", |
| }, |
| "Generator - Retriever-trained + EOL (Retrieval OFF)": { |
| "checkpoint": "retriever_eol", |
| "use_online_retriever": False, |
| "description": "Ablation: retriever-trained checkpoint, but retrieved input is empty.", |
| }, |
| "Generator - Online Retriever + EOL": { |
| "checkpoint": "retriever_eol", |
| "use_online_retriever": True, |
| "description": "Full demo mode: online dense retrieval top-5 + rerank + retriever-trained EOL generator.", |
| }, |
| } |
|
|
| DEFAULT_MODEL_NAME = "Generator - Baseline" |
|
|
| |
| RETRIEVER_INDEX_IN_REPO = "retriever/py150_train_index.pt" |
| RETRIEVER_CHUNKS_IN_REPO = "retriever/py150_train_chunked.jsonl" |
| RETRIEVER_MODEL_NAME = "microsoft/graphcodebert-base" |
| RETRIEVER_BLOCK_SIZE = 512 |
|
|
| |
| RETRIEVER_TOP_K = 5 |
| RERANK_KEYWORD_WEIGHT = 0.03 |
| RETRIEVED_MAX_TOKENS = 180 |
|
|
| |
| |
| |
| PRE_DOWNLOAD_MODELS = False |
| WARMUP_DEFAULT_MODEL = False |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| |
| |
|
|
| _model_paths_cache: Dict[str, Path] = {} |
| _current_checkpoint_name = None |
| _current_tokenizer = None |
| _current_model = None |
| _current_model_path = None |
|
|
| _retriever = None |
|
|
|
|
| |
| |
| |
|
|
| def file_fingerprint(path: Path) -> str: |
| """Return short sha256 fingerprint for debugging model identity.""" |
| if not path.exists(): |
| return "missing" |
|
|
| h = hashlib.sha256() |
| with open(path, "rb") as f: |
| for chunk in iter(lambda: f.read(1024 * 1024), b""): |
| h.update(chunk) |
|
|
| return h.hexdigest()[:16] |
|
|
|
|
| def get_variant_config(model_name: str) -> Dict[str, Any]: |
| if model_name not in MODEL_VARIANTS: |
| raise ValueError(f"Unknown model option: {model_name}") |
| return MODEL_VARIANTS[model_name] |
|
|
|
|
| def resolve_remote_model_path(model_name: str) -> Path: |
| """Download selected generator checkpoint folder from remote HF model repo.""" |
|
|
| variant = get_variant_config(model_name) |
| checkpoint_name = variant["checkpoint"] |
|
|
| if checkpoint_name in _model_paths_cache: |
| return _model_paths_cache[checkpoint_name] |
|
|
| remote_subdir = f"checkpoint-best/{checkpoint_name}" |
|
|
| local_repo_dir = snapshot_download( |
| repo_id=REMOTE_MODEL_REPO, |
| repo_type="model", |
| allow_patterns=[f"{remote_subdir}/*"], |
| ) |
|
|
| model_path = Path(local_repo_dir) / remote_subdir |
|
|
| required_files = [ |
| "config.json", |
| "generation_config.json", |
| "model.safetensors", |
| "tokenizer.json", |
| "tokenizer_config.json", |
| ] |
|
|
| missing = [f for f in required_files if not (model_path / f).exists()] |
| if missing: |
| raise FileNotFoundError( |
| f"Missing required files in {model_path}: {missing}") |
|
|
| _model_paths_cache[checkpoint_name] = model_path |
| return model_path |
|
|
|
|
| def unload_current_model(): |
| global _current_checkpoint_name, _current_tokenizer, _current_model, _current_model_path |
|
|
| if _current_model is not None: |
| del _current_model |
| del _current_tokenizer |
|
|
| _current_checkpoint_name = None |
| _current_tokenizer = None |
| _current_model = None |
| _current_model_path = None |
|
|
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| def get_model(model_name: str): |
| """Load selected generator checkpoint into memory. Only one generator model is kept in RAM.""" |
|
|
| global _current_checkpoint_name, _current_tokenizer, _current_model, _current_model_path |
|
|
| variant = get_variant_config(model_name) |
| checkpoint_name = variant["checkpoint"] |
|
|
| if _current_checkpoint_name == checkpoint_name and _current_model is not None: |
| return _current_tokenizer, _current_model, _current_model_path |
|
|
| unload_current_model() |
|
|
| model_path = resolve_remote_model_path(model_name) |
|
|
| print(f"Loading generator checkpoint: {checkpoint_name}") |
| print(f"Selected mode: {model_name}") |
| print(f"Path: {model_path}") |
| print(f"SHA: {file_fingerprint(model_path / 'model.safetensors')}") |
|
|
| tokenizer, model = load_model_and_tokenizer(str(model_path)) |
| model.to(device) |
| model.eval() |
|
|
| _current_checkpoint_name = checkpoint_name |
| _current_tokenizer = tokenizer |
| _current_model = model |
| _current_model_path = model_path |
|
|
| return tokenizer, model, model_path |
|
|
|
|
| def preload_model_folders(): |
| """Download all generator model folders to Hugging Face cache. Does not load models into RAM.""" |
| print("Pre-downloading generator model folders...") |
| for name in MODEL_VARIANTS: |
| try: |
| path = resolve_remote_model_path(name) |
| print(f"Cached {name}: {path}") |
| except Exception as e: |
| print(f"[WARN] Failed to preload {name}: {e}") |
|
|
|
|
| |
| |
| |
|
|
| def normalize_line(line: str) -> str: |
| """Soft-normalize one line to be closer to train-time token style.""" |
| line = re.sub(r"([()\[\]{}:,.=+\-*/<>])", r" \1 ", line) |
| line = re.sub(r"\s+", " ", line) |
| return line.strip() |
|
|
|
|
| def context_to_tokens(code: str) -> str: |
| """ |
| Convert normal-looking code into training-style token text. |
| If code is already tokenized with <EOL>, keep it as-is. |
| """ |
| code = str(code or "").strip() |
|
|
| if "<EOL>" in code: |
| return code |
|
|
| code = code.replace("\t", " ") |
| lines = code.splitlines() |
| tokens = [normalize_line(line) for line in lines if line.strip()] |
| return " <EOL> ".join(tokens).strip() |
|
|
|
|
| def trim_token_text(text: str, max_tokens: int) -> str: |
| """Trim tokenized text to a maximum number of whitespace-separated tokens.""" |
| toks = str(text or "").split() |
| if len(toks) <= max_tokens: |
| return str(text or "").strip() |
| return " ".join(toks[:max_tokens]).strip() |
|
|
|
|
| |
| |
| |
|
|
| LIT_PATTERN = re.compile(r"<(STR|NUM|CHAR)_LIT:(.*?)>", re.S) |
| KEYWORD_PATTERN = re.compile(r"[A-Za-z_]\w+") |
| STOPWORDS_FOR_RERANK = { |
| "from", "import", "class", "def", "return", "self", "true", "false", "none", |
| "if", "else", "elif", "for", "while", "try", "except", "with", "as", "in", |
| "and", "or", "not", "is", "str", "int", "list", "dict", "object", |
| } |
|
|
|
|
| def convert_cxg_format_to_normal(code: str) -> str: |
| """Convert tokenized CodeXGLUE/ReACC code to Python-like text for GraphCodeBERT.""" |
| code = str(code or "").strip() |
| code = code.replace("<s>", "").replace("</s>", "") |
| code = code.replace("<EOL>", "\n") |
| code = code.replace("<NUM_LIT>", "0") |
| code = code.replace("<STR_LIT>", '"str"') |
| code = code.replace("<CHAR_LIT>", '"c"') |
|
|
| for lit_type, lit_value in LIT_PATTERN.findall(code): |
| code = code.replace(f"<{lit_type}_LIT:{lit_value}>", lit_value) |
|
|
| return code |
|
|
|
|
| def keyword_set(text: str) -> set: |
| """Extract lightweight keywords for reranking dense retrieval candidates.""" |
| text = html.unescape(str(text or "")).replace("<EOL>", " ") |
| kws = set() |
| for tok in KEYWORD_PATTERN.findall(text): |
| low = tok.lower() |
| if len(low) < 3 or low in STOPWORDS_FOR_RERANK: |
| continue |
| kws.add(low) |
| return kws |
|
|
|
|
| def rerank_retrieval_results(token_context: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Rerank dense top-k results by adding a small keyword-overlap bonus. |
| Dense score still dominates; keyword overlap helps avoid top-1 chunks with wrong local pattern. |
| """ |
| query_keywords = keyword_set(token_context) |
|
|
| for r in results: |
| content_keywords = keyword_set(r.get("content", "")) |
| overlap = query_keywords & content_keywords |
| denom = max(1, min(len(query_keywords), 20)) |
| overlap_score = len(overlap) / denom |
| r["keyword_overlap_count"] = len(overlap) |
| r["keyword_overlap_preview"] = ", ".join(sorted(list(overlap))[:20]) |
| r["keyword_overlap_score"] = float(overlap_score) |
| r["rerank_score"] = float( |
| r["score"] + RERANK_KEYWORD_WEIGHT * overlap_score) |
|
|
| return sorted(results, key=lambda x: x["rerank_score"], reverse=True) |
|
|
|
|
| class OnlineDenseRetriever: |
| """ |
| Online dense retriever. |
| |
| JSONL chunks are accessed via byte offsets, so the 463MB JSONL file is not loaded |
| into Python dictionaries all at once. |
| """ |
|
|
| def __init__(self): |
| self.device = device |
| self.index_path, self.chunks_path = self._download_retriever_artifacts() |
|
|
| print("Loading retriever index:", self.index_path) |
| index_tensor = torch.load(self.index_path, map_location="cpu") |
| if not isinstance(index_tensor, torch.Tensor): |
| index_tensor = torch.tensor(index_tensor) |
| index_tensor = index_tensor.float().contiguous() |
| index_tensor = torch.nn.functional.normalize(index_tensor, p=2, dim=1) |
| self.index = index_tensor |
| self.vector_dim = int(self.index.shape[1]) |
| self.num_chunks = int(self.index.shape[0]) |
| print("Retriever index shape:", tuple(self.index.shape)) |
|
|
| print("Building JSONL byte offsets:", self.chunks_path) |
| self.offsets = self._build_offsets(self.chunks_path) |
| if len(self.offsets) != self.num_chunks: |
| raise ValueError( |
| f"Index/chunk mismatch: index={self.num_chunks}, chunks={len(self.offsets)}" |
| ) |
| print("Chunk offsets:", len(self.offsets)) |
|
|
| print("Loading retriever tokenizer:", RETRIEVER_MODEL_NAME) |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| RETRIEVER_MODEL_NAME, use_fast=True) |
|
|
| print("Loading retriever encoder:", RETRIEVER_MODEL_NAME) |
| self.encoder = AutoModel.from_pretrained( |
| RETRIEVER_MODEL_NAME).to(self.device) |
| self.encoder.eval() |
|
|
| print("Online retriever ready.") |
|
|
| @staticmethod |
| def _download_retriever_artifacts() -> Tuple[Path, Path]: |
| local_repo_dir = snapshot_download( |
| repo_id=REMOTE_MODEL_REPO, |
| repo_type="model", |
| allow_patterns=[RETRIEVER_INDEX_IN_REPO, RETRIEVER_CHUNKS_IN_REPO], |
| ) |
| root = Path(local_repo_dir) |
| index_path = root / RETRIEVER_INDEX_IN_REPO |
| chunks_path = root / RETRIEVER_CHUNKS_IN_REPO |
|
|
| missing = [str(p) for p in [index_path, chunks_path] if not p.exists()] |
| if missing: |
| raise FileNotFoundError(f"Missing retriever artifacts: {missing}") |
|
|
| return index_path, chunks_path |
|
|
| @staticmethod |
| def _build_offsets(path: Path) -> List[int]: |
| offsets = [] |
| with open(path, "rb") as f: |
| while True: |
| pos = f.tell() |
| line = f.readline() |
| if not line: |
| break |
| if line.strip(): |
| offsets.append(pos) |
| return offsets |
|
|
| def _read_chunk(self, idx: int) -> Dict[str, Any]: |
| with open(self.chunks_path, "rb") as f: |
| f.seek(self.offsets[idx]) |
| line = f.readline().decode("utf-8") |
| return json.loads(line) |
|
|
| @torch.no_grad() |
| def encode_query(self, token_context: str) -> torch.Tensor: |
| clean_code = convert_cxg_format_to_normal(token_context) |
| text = f"{self.tokenizer.cls_token} {clean_code} {self.tokenizer.sep_token}" |
|
|
| encoded = self.tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=RETRIEVER_BLOCK_SIZE, |
| return_tensors="pt", |
| ).to(self.device) |
|
|
| outputs = self.encoder(**encoded) |
| query_vec = outputs.last_hidden_state[:, 0, :] |
| query_vec = torch.nn.functional.normalize(query_vec, p=2, dim=1) |
| return query_vec.squeeze(0).detach().cpu().float() |
|
|
| def search(self, token_context: str, top_k: int = RETRIEVER_TOP_K) -> List[Dict[str, Any]]: |
| query_vec = self.encode_query(token_context) |
| scores = torch.mv(self.index, query_vec) |
| top_scores, top_indices = torch.topk(scores, k=top_k) |
|
|
| results = [] |
| for score, idx in zip(top_scores.tolist(), top_indices.tolist()): |
| row = self._read_chunk(int(idx)) |
| content = row.get("content", "") |
| results.append({ |
| "score": float(score), |
| "index": int(idx), |
| "original_file_id": row.get("original_file_id"), |
| "fragment_sequence_id": row.get("fragment_sequence_id"), |
| "content": content, |
| }) |
|
|
| return rerank_retrieval_results(token_context, results) |
|
|
|
|
| def get_retriever() -> OnlineDenseRetriever: |
| global _retriever |
| if _retriever is None: |
| _retriever = OnlineDenseRetriever() |
| return _retriever |
|
|
|
|
| |
| |
| |
|
|
| STOP_MARKERS = ["<EOL>", "</s>", "<s>", |
| "<pad>", "<unk>", "<mask>", "<|endoftext|>"] |
|
|
| SPECIAL_TOKEN_PATTERNS = [ |
| "<EOL>", "</s>", "<s>", "<pad>", "<unk>", "<mask>", "<|endoftext|>", |
| "<STR_LIT>", "<STR_LIT:...>", "<NUM_LIT>", "<NUM_LIT:...>", |
| "<CHAR_LIT>", "<CHAR_LIT:...>", "<BOOL_LIT>", "<BOOL_LIT:...>", |
| "<NULL_LIT>", "<INDENT>", "<DEDENT>", |
| ] |
|
|
| TOKEN_PATTERN = re.compile( |
| r"\"[^\"\n]*\"|'[^'\n]*'|[A-Za-z_]\w*|\d+(?:\.\d+)?|==|!=|<=|>=|\+=|-=|\*=|/=|//|->|[(){}\[\]:,.;=+\-*/<>]" |
| ) |
|
|
|
|
| def normalize_special_spacing(text: str) -> str: |
| """Normalize weird spaced special tokens that may appear after decoding.""" |
| text = html.unescape(str(text)) |
|
|
| text = re.sub(r"<\s*/\s*s\s*>", "</s>", text) |
| text = re.sub(r"<\s*s\s*>", "<s>", text) |
| text = re.sub(r"<\s*pad\s*>", "<pad>", text) |
| text = re.sub(r"<\s*unk\s*>", "<unk>", text) |
| text = re.sub(r"<\s*mask\s*>", "<mask>", text) |
| text = re.sub(r"<\s*\|\s*endoftext\s*\|\s*>", "<|endoftext|>", text) |
|
|
| for name in ["EOL", "STR_LIT", "NUM_LIT", "CHAR_LIT", "BOOL_LIT", "NULL_LIT", "INDENT", "DEDENT"]: |
| text = re.sub(rf"<\s*{name}\s*>", f"<{name}>", text) |
|
|
| for name in ["STR_LIT", "NUM_LIT", "CHAR_LIT", "BOOL_LIT"]: |
| text = re.sub( |
| rf"<\s*{name}\s*:\s*([^>]+?)\s*>", |
| lambda m, n=name: f"<{n}:{m.group(1).strip()}>", |
| text, |
| ) |
|
|
| return text |
|
|
|
|
| def cut_at_stop_marker(text: str): |
| """Cut text at earliest stop marker. Returns: cleaned_prefix, detected_marker.""" |
| earliest = None |
| detected = None |
|
|
| for marker in STOP_MARKERS: |
| pos = text.find(marker) |
| if pos >= 0 and (earliest is None or pos < earliest): |
| earliest = pos |
| detected = marker |
|
|
| if earliest is None: |
| return text, None |
|
|
| return text[:earliest], detected |
|
|
|
|
| def replace_dataset_placeholders(text: str) -> str: |
| """Convert train-time placeholders to readable Python-ish code.""" |
|
|
| def repl_str_payload(m): |
| value = m.group(1).strip() |
| return json.dumps(value) |
|
|
| text = re.sub(r"<STR_LIT:([^>]+)>", repl_str_payload, text) |
| text = text.replace("<STR_LIT>", json.dumps("str")) |
|
|
| text = re.sub(r"<NUM_LIT:([^>]+)>", lambda m: m.group(1).strip(), text) |
| text = text.replace("<NUM_LIT>", "0") |
|
|
| text = re.sub(r"<CHAR_LIT:([^>]+)>", |
| lambda m: json.dumps(m.group(1).strip()), text) |
| text = text.replace("<CHAR_LIT>", json.dumps("c")) |
|
|
| text = re.sub(r"<BOOL_LIT:True>", "True", text) |
| text = re.sub(r"<BOOL_LIT:False>", "False", text) |
| text = text.replace("<BOOL_LIT>", "True") |
| text = text.replace("<NULL_LIT>", "None") |
|
|
| text = text.replace("<INDENT>", "\n ") |
| text = text.replace("<DEDENT>", "\n") |
|
|
| return text |
|
|
|
|
| def cleanup_prediction(raw_text: str): |
| """ |
| Clean raw generated token text for UI prediction. |
| Returns: prediction_text, detected_stop_marker, normalized_raw_text |
| """ |
| normalized = normalize_special_spacing(raw_text) |
| cut_text, stop_marker = cut_at_stop_marker(normalized) |
|
|
| for marker in STOP_MARKERS: |
| cut_text = cut_text.replace(marker, "") |
|
|
| cut_text = replace_dataset_placeholders(cut_text) |
| cut_text = cut_text.replace("<EOL>", "\n") |
|
|
| cut_text = re.sub(r"\s+([)\]\}:,])", r"\1", cut_text) |
| cut_text = re.sub(r"([(\[{])\s+", r"\1", cut_text) |
|
|
| cut_text = re.sub(r"\s*=\s*", " = ", cut_text) |
| cut_text = re.sub(r"\s*\+\s*", " + ", cut_text) |
| cut_text = re.sub(r"\s*-\s*", " - ", cut_text) |
| cut_text = re.sub(r"\s*\*\s*", " * ", cut_text) |
| cut_text = re.sub(r"\s*/\s*", " / ", cut_text) |
| cut_text = re.sub(r"\s*<\s*", " < ", cut_text) |
| cut_text = re.sub(r"\s*>\s*", " > ", cut_text) |
|
|
| cut_text = re.sub(r"[ \t]+", " ", cut_text) |
| cut_text = re.sub(r"\n\s+", "\n ", cut_text) |
|
|
| return cut_text.strip(), stop_marker, normalized |
|
|
|
|
| def token_spans(text: str): |
| """Return normalized tokens and char spans for overlap trimming.""" |
| text = str(text or "") |
| text = html.unescape(text) |
| text = text.replace("<EOL>", "\n") |
| toks = [] |
| spans = [] |
| for m in TOKEN_PATTERN.finditer(text): |
| toks.append(m.group(0)) |
| spans.append(m.span()) |
| return toks, spans |
|
|
|
|
| def trim_overlapping_prefix(context_text: str, prediction: str, token_context: str = ""): |
| """ |
| Remove duplicated prefix from prediction when prediction begins with tokens that already |
| appear at the end of the user context. |
| """ |
| pred = str(prediction or "").strip() |
| if not pred: |
| return pred, "No prediction to align." |
|
|
| ctx_source = token_context if token_context else context_text |
| ctx_tokens, _ = token_spans(ctx_source) |
| pred_tokens, pred_spans = token_spans(pred) |
|
|
| if not ctx_tokens or not pred_tokens: |
| return pred, "No token overlap check applied." |
|
|
| max_k = min(len(ctx_tokens), len(pred_tokens), 24) |
| best_k = 0 |
|
|
| for k in range(1, max_k + 1): |
| if ctx_tokens[-k:] == pred_tokens[:k]: |
| best_k = k |
|
|
| if best_k <= 0: |
| return pred, "No duplicated prefix found." |
|
|
| cut_char = pred_spans[best_k - 1][1] |
| aligned = pred[cut_char:].lstrip(" \t\n,.;:") |
|
|
| if not aligned: |
| return pred, f"Detected {best_k} duplicated prefix token(s), but kept prediction to avoid empty output." |
|
|
| removed = pred[:cut_char].strip() |
| return aligned, f"Trimmed duplicated prefix from prediction: {removed!r} ({best_k} token(s))." |
|
|
|
|
| |
| |
| |
|
|
| def run_demo(model_name: str, context: str): |
| try: |
| tokenizer, model, model_path = get_model(model_name) |
| variant = get_variant_config(model_name) |
|
|
| token_context = context_to_tokens(context) |
|
|
| retriever_mode = "Disabled" |
| retrieval_results = [] |
| token_retrieved = "" |
| retrieved_raw = "" |
| selected_retrieval = None |
|
|
| if variant["use_online_retriever"]: |
| retriever_mode = "Online dense retrieval top-5 + keyword rerank" |
| retriever = get_retriever() |
| retrieval_results = retriever.search( |
| token_context, top_k=RETRIEVER_TOP_K) |
| if retrieval_results: |
| selected_retrieval = retrieval_results[0] |
| retrieved_raw = selected_retrieval.get("content", "") |
| token_retrieved = trim_token_text( |
| retrieved_raw, RETRIEVED_MAX_TOKENS) |
|
|
| max_length = 384 if variant["use_online_retriever"] else 256 |
|
|
| raw_token_output = generate_completion( |
| model=model, |
| tokenizer=tokenizer, |
| retrieved=token_retrieved, |
| context=token_context, |
| device=device, |
| max_length=max_length, |
| max_new_tokens=16, |
| do_sample=False, |
| stop_strings=None, |
| ) |
|
|
| prediction_before_align, stop_marker, normalized_output = cleanup_prediction( |
| raw_token_output) |
| prediction, align_note = trim_overlapping_prefix( |
| context_text=context, |
| prediction=prediction_before_align, |
| token_context=token_context, |
| ) |
|
|
| if variant["use_online_retriever"]: |
| retriever_note = ( |
| "Online retriever retrieved dense top-5 candidates, then reranked them using a small " |
| "keyword-overlap bonus. The selected retrieved chunk is injected before the context." |
| ) |
| elif variant["checkpoint"] == "retriever_eol": |
| retriever_note = ( |
| "Ablation mode: this checkpoint was trained with retrieved code, but retrieval is OFF. " |
| "The model receives typed context only." |
| ) |
| else: |
| retriever_note = "Retriever is not used for this model." |
|
|
| if retrieval_results: |
| retrieval_log = "" |
| for rank, r in enumerate(retrieval_results, start=1): |
| selected_flag = " <-- SELECTED" if r is selected_retrieval else "" |
| retrieval_log += ( |
| f"Rank {rank}{selected_flag}\n" |
| f"dense_score: {r['score']:.6f}\n" |
| f"keyword_overlap_count: {r.get('keyword_overlap_count', 0)}\n" |
| f"keyword_overlap_score: {r.get('keyword_overlap_score', 0.0):.6f}\n" |
| f"rerank_score: {r.get('rerank_score', r['score']):.6f}\n" |
| f"overlap_keywords: {r.get('keyword_overlap_preview', '')}\n" |
| f"index: {r['index']}\n" |
| f"original_file_id: {r.get('original_file_id')}\n" |
| f"fragment_sequence_id: {r.get('fragment_sequence_id')}\n" |
| f"content preview: {r.get('content', '')[:1200]}\n\n" |
| ) |
| else: |
| retrieval_log = "No retrieval result." |
|
|
| logs = ( |
| "=== DEMO LOGS ===\n\n" |
| f"[Selected model]\n{model_name}\n\n" |
| f"[Mode description]\n{variant['description']}\n\n" |
| f"[Model repo]\n{REMOTE_MODEL_REPO}\n\n" |
| f"[Local cache path]\n{model_path}\n\n" |
| f"[Model fingerprint]\n{file_fingerprint(model_path / 'model.safetensors')}\n\n" |
| f"[Device]\n{device}\n\n" |
| f"[Retriever mode]\n{retriever_mode}\n\n" |
| f"[Retriever note]\n{retriever_note}\n\n" |
| f"[Retriever model]\n{RETRIEVER_MODEL_NAME if variant['use_online_retriever'] else 'N/A'}\n\n" |
| f"[Retriever artifacts]\n{RETRIEVER_INDEX_IN_REPO}\n{RETRIEVER_CHUNKS_IN_REPO}\n\n" |
| f"[Retriever top_k]\n{RETRIEVER_TOP_K}\n\n" |
| f"[Rerank keyword weight]\n{RERANK_KEYWORD_WEIGHT}\n\n" |
| "[Retriever results]\n" |
| f"{retrieval_log}\n" |
| "[Known token patterns cleaned in Prediction]\n" |
| + "\n".join(f"- {p}" for p in SPECIAL_TOKEN_PATTERNS) |
| + "\n\n" |
| "[Raw Context]\n" |
| f"{context}\n\n" |
| "[Context → Tokens]\n" |
| f"{token_context}\n\n" |
| "[Selected retrieved raw]\n" |
| f"{retrieved_raw}\n\n" |
| "[Selected retrieved → Tokens used by generator]\n" |
| f"{token_retrieved}\n\n" |
| "[Raw Generator Output → Tokens]\n" |
| f"{raw_token_output}\n\n" |
| "[Normalized Generator Output → Tokens]\n" |
| f"{normalized_output}\n\n" |
| f"[Detected stop marker]\n{stop_marker}\n\n" |
| "[Prediction before overlap trim]\n" |
| f"{prediction_before_align}\n\n" |
| "[Overlap trim note]\n" |
| f"{align_note}\n\n" |
| "[Prediction]\n" |
| f"{prediction}\n" |
| ) |
|
|
| return prediction, logs |
|
|
| except Exception: |
| err = traceback.format_exc() |
| return ( |
| "ERROR: failed to load/generate.", |
| "=== ERROR LOGS ===\n\n" + err, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| demo = gr.Interface( |
| fn=run_demo, |
| inputs=[ |
| gr.Dropdown( |
| choices=list(MODEL_VARIANTS.keys()), |
| value=DEFAULT_MODEL_NAME, |
| label="Model", |
| ), |
| gr.Textbox( |
| lines=10, |
| label="Context", |
| placeholder="def sum(a, b):\n return", |
| ), |
| ], |
| outputs=[ |
| gr.Textbox(lines=6, label="Prediction"), |
| gr.Textbox(lines=34, label="Logs"), |
| ], |
| title="ReACC Code Completion Demo", |
| description=( |
| "Type Python code and compare Baseline, EOL, Retriever-trained EOL with retrieval OFF, " |
| "and Online Retriever + EOL. Online retriever mode retrieves dense top-5 candidates, " |
| "reranks them with keyword overlap, and injects the selected chunk into the generator." |
| ), |
| ) |
|
|
|
|
| |
| |
| |
|
|
| if PRE_DOWNLOAD_MODELS: |
| preload_model_folders() |
|
|
| if WARMUP_DEFAULT_MODEL: |
| print(f"Warming up default model: {DEFAULT_MODEL_NAME}") |
| get_model(DEFAULT_MODEL_NAME) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|