import re import gc import html import json import hashlib import traceback from pathlib import Path from typing import Dict, Any, Tuple, List import torch import gradio as gr from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModel from model_utils import load_model_and_tokenizer, generate_completion # ============================================================ # CONFIG # ============================================================ REMOTE_MODEL_REPO = "TranTruongMMCII/UIT.CS2229.Generator" # One checkpoint can be exposed in multiple inference modes. MODEL_VARIANTS = { "Generator - Baseline": { "checkpoint": "baseline", "use_online_retriever": False, "description": "Baseline generator, context only.", }, "Generator - EOL": { "checkpoint": "eol", "use_online_retriever": False, "description": "EOL-trained generator, context only.", }, "Generator - Retriever-trained + EOL (Retrieval OFF)": { "checkpoint": "retriever_eol", "use_online_retriever": False, "description": "Ablation: retriever-trained checkpoint, but retrieved input is empty.", }, "Generator - Online Retriever + EOL": { "checkpoint": "retriever_eol", "use_online_retriever": True, "description": "Full demo mode: online dense retrieval top-5 + rerank + retriever-trained EOL generator.", }, } DEFAULT_MODEL_NAME = "Generator - Baseline" # Remote retriever artifacts in the same HF model repo. RETRIEVER_INDEX_IN_REPO = "retriever/py150_train_index.pt" RETRIEVER_CHUNKS_IN_REPO = "retriever/py150_train_chunked.jsonl" RETRIEVER_MODEL_NAME = "microsoft/graphcodebert-base" RETRIEVER_BLOCK_SIZE = 512 # Retrieve top-5 dense candidates, then rerank by simple keyword overlap. RETRIEVER_TOP_K = 5 RERANK_KEYWORD_WEIGHT = 0.03 RETRIEVED_MAX_TOKENS = 180 # Safer for HF Spaces deploy: # - False: app starts fast; model/retriever loads lazily on first use. # - True: download/load at startup; can fail deploy if remote artifacts have issue. PRE_DOWNLOAD_MODELS = False WARMUP_DEFAULT_MODEL = False # Keep one generator model in memory at a time. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ============================================================ # GLOBAL CACHE # ============================================================ _model_paths_cache: Dict[str, Path] = {} _current_checkpoint_name = None _current_tokenizer = None _current_model = None _current_model_path = None _retriever = None # ============================================================ # MODEL UTILS # ============================================================ def file_fingerprint(path: Path) -> str: """Return short sha256 fingerprint for debugging model identity.""" if not path.exists(): return "missing" h = hashlib.sha256() with open(path, "rb") as f: for chunk in iter(lambda: f.read(1024 * 1024), b""): h.update(chunk) return h.hexdigest()[:16] def get_variant_config(model_name: str) -> Dict[str, Any]: if model_name not in MODEL_VARIANTS: raise ValueError(f"Unknown model option: {model_name}") return MODEL_VARIANTS[model_name] def resolve_remote_model_path(model_name: str) -> Path: """Download selected generator checkpoint folder from remote HF model repo.""" variant = get_variant_config(model_name) checkpoint_name = variant["checkpoint"] if checkpoint_name in _model_paths_cache: return _model_paths_cache[checkpoint_name] remote_subdir = f"checkpoint-best/{checkpoint_name}" local_repo_dir = snapshot_download( repo_id=REMOTE_MODEL_REPO, repo_type="model", allow_patterns=[f"{remote_subdir}/*"], ) model_path = Path(local_repo_dir) / remote_subdir required_files = [ "config.json", "generation_config.json", "model.safetensors", "tokenizer.json", "tokenizer_config.json", ] missing = [f for f in required_files if not (model_path / f).exists()] if missing: raise FileNotFoundError( f"Missing required files in {model_path}: {missing}") _model_paths_cache[checkpoint_name] = model_path return model_path def unload_current_model(): global _current_checkpoint_name, _current_tokenizer, _current_model, _current_model_path if _current_model is not None: del _current_model del _current_tokenizer _current_checkpoint_name = None _current_tokenizer = None _current_model = None _current_model_path = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def get_model(model_name: str): """Load selected generator checkpoint into memory. Only one generator model is kept in RAM.""" global _current_checkpoint_name, _current_tokenizer, _current_model, _current_model_path variant = get_variant_config(model_name) checkpoint_name = variant["checkpoint"] if _current_checkpoint_name == checkpoint_name and _current_model is not None: return _current_tokenizer, _current_model, _current_model_path unload_current_model() model_path = resolve_remote_model_path(model_name) print(f"Loading generator checkpoint: {checkpoint_name}") print(f"Selected mode: {model_name}") print(f"Path: {model_path}") print(f"SHA: {file_fingerprint(model_path / 'model.safetensors')}") tokenizer, model = load_model_and_tokenizer(str(model_path)) model.to(device) model.eval() _current_checkpoint_name = checkpoint_name _current_tokenizer = tokenizer _current_model = model _current_model_path = model_path return tokenizer, model, model_path def preload_model_folders(): """Download all generator model folders to Hugging Face cache. Does not load models into RAM.""" print("Pre-downloading generator model folders...") for name in MODEL_VARIANTS: try: path = resolve_remote_model_path(name) print(f"Cached {name}: {path}") except Exception as e: print(f"[WARN] Failed to preload {name}: {e}") # ============================================================ # INPUT NORMALIZATION # ============================================================ def normalize_line(line: str) -> str: """Soft-normalize one line to be closer to train-time token style.""" line = re.sub(r"([()\[\]{}:,.=+\-*/<>])", r" \1 ", line) line = re.sub(r"\s+", " ", line) return line.strip() def context_to_tokens(code: str) -> str: """ Convert normal-looking code into training-style token text. If code is already tokenized with , keep it as-is. """ code = str(code or "").strip() if "" in code: return code code = code.replace("\t", " ") lines = code.splitlines() tokens = [normalize_line(line) for line in lines if line.strip()] return " ".join(tokens).strip() def trim_token_text(text: str, max_tokens: int) -> str: """Trim tokenized text to a maximum number of whitespace-separated tokens.""" toks = str(text or "").split() if len(toks) <= max_tokens: return str(text or "").strip() return " ".join(toks[:max_tokens]).strip() # ============================================================ # RETRIEVER # ============================================================ LIT_PATTERN = re.compile(r"<(STR|NUM|CHAR)_LIT:(.*?)>", re.S) KEYWORD_PATTERN = re.compile(r"[A-Za-z_]\w+") STOPWORDS_FOR_RERANK = { "from", "import", "class", "def", "return", "self", "true", "false", "none", "if", "else", "elif", "for", "while", "try", "except", "with", "as", "in", "and", "or", "not", "is", "str", "int", "list", "dict", "object", } def convert_cxg_format_to_normal(code: str) -> str: """Convert tokenized CodeXGLUE/ReACC code to Python-like text for GraphCodeBERT.""" code = str(code or "").strip() code = code.replace("", "").replace("", "") code = code.replace("", "\n") code = code.replace("", "0") code = code.replace("", '"str"') code = code.replace("", '"c"') for lit_type, lit_value in LIT_PATTERN.findall(code): code = code.replace(f"<{lit_type}_LIT:{lit_value}>", lit_value) return code def keyword_set(text: str) -> set: """Extract lightweight keywords for reranking dense retrieval candidates.""" text = html.unescape(str(text or "")).replace("", " ") kws = set() for tok in KEYWORD_PATTERN.findall(text): low = tok.lower() if len(low) < 3 or low in STOPWORDS_FOR_RERANK: continue kws.add(low) return kws def rerank_retrieval_results(token_context: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Rerank dense top-k results by adding a small keyword-overlap bonus. Dense score still dominates; keyword overlap helps avoid top-1 chunks with wrong local pattern. """ query_keywords = keyword_set(token_context) for r in results: content_keywords = keyword_set(r.get("content", "")) overlap = query_keywords & content_keywords denom = max(1, min(len(query_keywords), 20)) overlap_score = len(overlap) / denom r["keyword_overlap_count"] = len(overlap) r["keyword_overlap_preview"] = ", ".join(sorted(list(overlap))[:20]) r["keyword_overlap_score"] = float(overlap_score) r["rerank_score"] = float( r["score"] + RERANK_KEYWORD_WEIGHT * overlap_score) return sorted(results, key=lambda x: x["rerank_score"], reverse=True) class OnlineDenseRetriever: """ Online dense retriever. JSONL chunks are accessed via byte offsets, so the 463MB JSONL file is not loaded into Python dictionaries all at once. """ def __init__(self): self.device = device self.index_path, self.chunks_path = self._download_retriever_artifacts() print("Loading retriever index:", self.index_path) index_tensor = torch.load(self.index_path, map_location="cpu") if not isinstance(index_tensor, torch.Tensor): index_tensor = torch.tensor(index_tensor) index_tensor = index_tensor.float().contiguous() index_tensor = torch.nn.functional.normalize(index_tensor, p=2, dim=1) self.index = index_tensor self.vector_dim = int(self.index.shape[1]) self.num_chunks = int(self.index.shape[0]) print("Retriever index shape:", tuple(self.index.shape)) print("Building JSONL byte offsets:", self.chunks_path) self.offsets = self._build_offsets(self.chunks_path) if len(self.offsets) != self.num_chunks: raise ValueError( f"Index/chunk mismatch: index={self.num_chunks}, chunks={len(self.offsets)}" ) print("Chunk offsets:", len(self.offsets)) print("Loading retriever tokenizer:", RETRIEVER_MODEL_NAME) self.tokenizer = AutoTokenizer.from_pretrained( RETRIEVER_MODEL_NAME, use_fast=True) print("Loading retriever encoder:", RETRIEVER_MODEL_NAME) self.encoder = AutoModel.from_pretrained( RETRIEVER_MODEL_NAME).to(self.device) self.encoder.eval() print("Online retriever ready.") @staticmethod def _download_retriever_artifacts() -> Tuple[Path, Path]: local_repo_dir = snapshot_download( repo_id=REMOTE_MODEL_REPO, repo_type="model", allow_patterns=[RETRIEVER_INDEX_IN_REPO, RETRIEVER_CHUNKS_IN_REPO], ) root = Path(local_repo_dir) index_path = root / RETRIEVER_INDEX_IN_REPO chunks_path = root / RETRIEVER_CHUNKS_IN_REPO missing = [str(p) for p in [index_path, chunks_path] if not p.exists()] if missing: raise FileNotFoundError(f"Missing retriever artifacts: {missing}") return index_path, chunks_path @staticmethod def _build_offsets(path: Path) -> List[int]: offsets = [] with open(path, "rb") as f: while True: pos = f.tell() line = f.readline() if not line: break if line.strip(): offsets.append(pos) return offsets def _read_chunk(self, idx: int) -> Dict[str, Any]: with open(self.chunks_path, "rb") as f: f.seek(self.offsets[idx]) line = f.readline().decode("utf-8") return json.loads(line) @torch.no_grad() def encode_query(self, token_context: str) -> torch.Tensor: clean_code = convert_cxg_format_to_normal(token_context) text = f"{self.tokenizer.cls_token} {clean_code} {self.tokenizer.sep_token}" encoded = self.tokenizer( text, padding="max_length", truncation=True, max_length=RETRIEVER_BLOCK_SIZE, return_tensors="pt", ).to(self.device) outputs = self.encoder(**encoded) query_vec = outputs.last_hidden_state[:, 0, :] query_vec = torch.nn.functional.normalize(query_vec, p=2, dim=1) return query_vec.squeeze(0).detach().cpu().float() def search(self, token_context: str, top_k: int = RETRIEVER_TOP_K) -> List[Dict[str, Any]]: query_vec = self.encode_query(token_context) scores = torch.mv(self.index, query_vec) top_scores, top_indices = torch.topk(scores, k=top_k) results = [] for score, idx in zip(top_scores.tolist(), top_indices.tolist()): row = self._read_chunk(int(idx)) content = row.get("content", "") results.append({ "score": float(score), "index": int(idx), "original_file_id": row.get("original_file_id"), "fragment_sequence_id": row.get("fragment_sequence_id"), "content": content, }) return rerank_retrieval_results(token_context, results) def get_retriever() -> OnlineDenseRetriever: global _retriever if _retriever is None: _retriever = OnlineDenseRetriever() return _retriever # ============================================================ # OUTPUT CLEANUP # ============================================================ STOP_MARKERS = ["", "", "", "", "", "", "<|endoftext|>"] SPECIAL_TOKEN_PATTERNS = [ "", "", "", "", "", "", "<|endoftext|>", "", "", "", "", "", "", "", "", "", "", "", ] TOKEN_PATTERN = re.compile( r"\"[^\"\n]*\"|'[^'\n]*'|[A-Za-z_]\w*|\d+(?:\.\d+)?|==|!=|<=|>=|\+=|-=|\*=|/=|//|->|[(){}\[\]:,.;=+\-*/<>]" ) def normalize_special_spacing(text: str) -> str: """Normalize weird spaced special tokens that may appear after decoding.""" text = html.unescape(str(text)) text = re.sub(r"<\s*/\s*s\s*>", "", text) text = re.sub(r"<\s*s\s*>", "", text) text = re.sub(r"<\s*pad\s*>", "", text) text = re.sub(r"<\s*unk\s*>", "", text) text = re.sub(r"<\s*mask\s*>", "", text) text = re.sub(r"<\s*\|\s*endoftext\s*\|\s*>", "<|endoftext|>", text) for name in ["EOL", "STR_LIT", "NUM_LIT", "CHAR_LIT", "BOOL_LIT", "NULL_LIT", "INDENT", "DEDENT"]: text = re.sub(rf"<\s*{name}\s*>", f"<{name}>", text) for name in ["STR_LIT", "NUM_LIT", "CHAR_LIT", "BOOL_LIT"]: text = re.sub( rf"<\s*{name}\s*:\s*([^>]+?)\s*>", lambda m, n=name: f"<{n}:{m.group(1).strip()}>", text, ) return text def cut_at_stop_marker(text: str): """Cut text at earliest stop marker. Returns: cleaned_prefix, detected_marker.""" earliest = None detected = None for marker in STOP_MARKERS: pos = text.find(marker) if pos >= 0 and (earliest is None or pos < earliest): earliest = pos detected = marker if earliest is None: return text, None return text[:earliest], detected def replace_dataset_placeholders(text: str) -> str: """Convert train-time placeholders to readable Python-ish code.""" def repl_str_payload(m): value = m.group(1).strip() return json.dumps(value) text = re.sub(r"]+)>", repl_str_payload, text) text = text.replace("", json.dumps("str")) text = re.sub(r"]+)>", lambda m: m.group(1).strip(), text) text = text.replace("", "0") text = re.sub(r"]+)>", lambda m: json.dumps(m.group(1).strip()), text) text = text.replace("", json.dumps("c")) text = re.sub(r"", "True", text) text = re.sub(r"", "False", text) text = text.replace("", "True") text = text.replace("", "None") text = text.replace("", "\n ") text = text.replace("", "\n") return text def cleanup_prediction(raw_text: str): """ Clean raw generated token text for UI prediction. Returns: prediction_text, detected_stop_marker, normalized_raw_text """ normalized = normalize_special_spacing(raw_text) cut_text, stop_marker = cut_at_stop_marker(normalized) for marker in STOP_MARKERS: cut_text = cut_text.replace(marker, "") cut_text = replace_dataset_placeholders(cut_text) cut_text = cut_text.replace("", "\n") cut_text = re.sub(r"\s+([)\]\}:,])", r"\1", cut_text) cut_text = re.sub(r"([(\[{])\s+", r"\1", cut_text) cut_text = re.sub(r"\s*=\s*", " = ", cut_text) cut_text = re.sub(r"\s*\+\s*", " + ", cut_text) cut_text = re.sub(r"\s*-\s*", " - ", cut_text) cut_text = re.sub(r"\s*\*\s*", " * ", cut_text) cut_text = re.sub(r"\s*/\s*", " / ", cut_text) cut_text = re.sub(r"\s*<\s*", " < ", cut_text) cut_text = re.sub(r"\s*>\s*", " > ", cut_text) cut_text = re.sub(r"[ \t]+", " ", cut_text) cut_text = re.sub(r"\n\s+", "\n ", cut_text) return cut_text.strip(), stop_marker, normalized def token_spans(text: str): """Return normalized tokens and char spans for overlap trimming.""" text = str(text or "") text = html.unescape(text) text = text.replace("", "\n") toks = [] spans = [] for m in TOKEN_PATTERN.finditer(text): toks.append(m.group(0)) spans.append(m.span()) return toks, spans def trim_overlapping_prefix(context_text: str, prediction: str, token_context: str = ""): """ Remove duplicated prefix from prediction when prediction begins with tokens that already appear at the end of the user context. """ pred = str(prediction or "").strip() if not pred: return pred, "No prediction to align." ctx_source = token_context if token_context else context_text ctx_tokens, _ = token_spans(ctx_source) pred_tokens, pred_spans = token_spans(pred) if not ctx_tokens or not pred_tokens: return pred, "No token overlap check applied." max_k = min(len(ctx_tokens), len(pred_tokens), 24) best_k = 0 for k in range(1, max_k + 1): if ctx_tokens[-k:] == pred_tokens[:k]: best_k = k if best_k <= 0: return pred, "No duplicated prefix found." cut_char = pred_spans[best_k - 1][1] aligned = pred[cut_char:].lstrip(" \t\n,.;:") if not aligned: return pred, f"Detected {best_k} duplicated prefix token(s), but kept prediction to avoid empty output." removed = pred[:cut_char].strip() return aligned, f"Trimmed duplicated prefix from prediction: {removed!r} ({best_k} token(s))." # ============================================================ # INFERENCE # ============================================================ def run_demo(model_name: str, context: str): try: tokenizer, model, model_path = get_model(model_name) variant = get_variant_config(model_name) token_context = context_to_tokens(context) retriever_mode = "Disabled" retrieval_results = [] token_retrieved = "" retrieved_raw = "" selected_retrieval = None if variant["use_online_retriever"]: retriever_mode = "Online dense retrieval top-5 + keyword rerank" retriever = get_retriever() retrieval_results = retriever.search( token_context, top_k=RETRIEVER_TOP_K) if retrieval_results: selected_retrieval = retrieval_results[0] retrieved_raw = selected_retrieval.get("content", "") token_retrieved = trim_token_text( retrieved_raw, RETRIEVED_MAX_TOKENS) max_length = 384 if variant["use_online_retriever"] else 256 raw_token_output = generate_completion( model=model, tokenizer=tokenizer, retrieved=token_retrieved, context=token_context, device=device, max_length=max_length, max_new_tokens=16, do_sample=False, stop_strings=None, ) prediction_before_align, stop_marker, normalized_output = cleanup_prediction( raw_token_output) prediction, align_note = trim_overlapping_prefix( context_text=context, prediction=prediction_before_align, token_context=token_context, ) if variant["use_online_retriever"]: retriever_note = ( "Online retriever retrieved dense top-5 candidates, then reranked them using a small " "keyword-overlap bonus. The selected retrieved chunk is injected before the context." ) elif variant["checkpoint"] == "retriever_eol": retriever_note = ( "Ablation mode: this checkpoint was trained with retrieved code, but retrieval is OFF. " "The model receives typed context only." ) else: retriever_note = "Retriever is not used for this model." if retrieval_results: retrieval_log = "" for rank, r in enumerate(retrieval_results, start=1): selected_flag = " <-- SELECTED" if r is selected_retrieval else "" retrieval_log += ( f"Rank {rank}{selected_flag}\n" f"dense_score: {r['score']:.6f}\n" f"keyword_overlap_count: {r.get('keyword_overlap_count', 0)}\n" f"keyword_overlap_score: {r.get('keyword_overlap_score', 0.0):.6f}\n" f"rerank_score: {r.get('rerank_score', r['score']):.6f}\n" f"overlap_keywords: {r.get('keyword_overlap_preview', '')}\n" f"index: {r['index']}\n" f"original_file_id: {r.get('original_file_id')}\n" f"fragment_sequence_id: {r.get('fragment_sequence_id')}\n" f"content preview: {r.get('content', '')[:1200]}\n\n" ) else: retrieval_log = "No retrieval result." logs = ( "=== DEMO LOGS ===\n\n" f"[Selected model]\n{model_name}\n\n" f"[Mode description]\n{variant['description']}\n\n" f"[Model repo]\n{REMOTE_MODEL_REPO}\n\n" f"[Local cache path]\n{model_path}\n\n" f"[Model fingerprint]\n{file_fingerprint(model_path / 'model.safetensors')}\n\n" f"[Device]\n{device}\n\n" f"[Retriever mode]\n{retriever_mode}\n\n" f"[Retriever note]\n{retriever_note}\n\n" f"[Retriever model]\n{RETRIEVER_MODEL_NAME if variant['use_online_retriever'] else 'N/A'}\n\n" f"[Retriever artifacts]\n{RETRIEVER_INDEX_IN_REPO}\n{RETRIEVER_CHUNKS_IN_REPO}\n\n" f"[Retriever top_k]\n{RETRIEVER_TOP_K}\n\n" f"[Rerank keyword weight]\n{RERANK_KEYWORD_WEIGHT}\n\n" "[Retriever results]\n" f"{retrieval_log}\n" "[Known token patterns cleaned in Prediction]\n" + "\n".join(f"- {p}" for p in SPECIAL_TOKEN_PATTERNS) + "\n\n" "[Raw Context]\n" f"{context}\n\n" "[Context → Tokens]\n" f"{token_context}\n\n" "[Selected retrieved raw]\n" f"{retrieved_raw}\n\n" "[Selected retrieved → Tokens used by generator]\n" f"{token_retrieved}\n\n" "[Raw Generator Output → Tokens]\n" f"{raw_token_output}\n\n" "[Normalized Generator Output → Tokens]\n" f"{normalized_output}\n\n" f"[Detected stop marker]\n{stop_marker}\n\n" "[Prediction before overlap trim]\n" f"{prediction_before_align}\n\n" "[Overlap trim note]\n" f"{align_note}\n\n" "[Prediction]\n" f"{prediction}\n" ) return prediction, logs except Exception: err = traceback.format_exc() return ( "ERROR: failed to load/generate.", "=== ERROR LOGS ===\n\n" + err, ) # ============================================================ # GRADIO UI # ============================================================ demo = gr.Interface( fn=run_demo, inputs=[ gr.Dropdown( choices=list(MODEL_VARIANTS.keys()), value=DEFAULT_MODEL_NAME, label="Model", ), gr.Textbox( lines=10, label="Context", placeholder="def sum(a, b):\n return", ), ], outputs=[ gr.Textbox(lines=6, label="Prediction"), gr.Textbox(lines=34, label="Logs"), ], title="ReACC Code Completion Demo", description=( "Type Python code and compare Baseline, EOL, Retriever-trained EOL with retrieval OFF, " "and Online Retriever + EOL. Online retriever mode retrieves dense top-5 candidates, " "reranks them with keyword overlap, and injects the selected chunk into the generator." ), ) # ============================================================ # STARTUP # ============================================================ if PRE_DOWNLOAD_MODELS: preload_model_folders() if WARMUP_DEFAULT_MODEL: print(f"Warming up default model: {DEFAULT_MODEL_NAME}") get_model(DEFAULT_MODEL_NAME) if __name__ == "__main__": demo.launch()