File size: 26,727 Bytes
8daa49c
91e23bb
7dd164e
 
8b5d27e
7dd164e
91e23bb
bf5ac8a
91e23bb
d6fd696
 
8b5d27e
abc57c0
d6fd696
080fc68
d6fd696
91e23bb
 
8b5d27e
91e23bb
 
8b5d27e
91e23bb
abc57c0
8b5d27e
abc57c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65a2a7e
abc57c0
91e23bb
d6fd696
8b5d27e
 
abc57c0
 
 
 
 
65a2a7e
 
 
 
abc57c0
 
7dd164e
abc57c0
 
7dd164e
 
 
abc57c0
d6fd696
 
8b5d27e
 
7dd164e
8b5d27e
 
abc57c0
 
91e23bb
 
8b5d27e
91e23bb
abc57c0
 
91e23bb
 
7dd164e
91e23bb
 
8b5d27e
7dd164e
8b5d27e
 
 
 
 
 
 
7dd164e
8b5d27e
 
 
abc57c0
 
 
 
 
 
8b5d27e
bf5ac8a
91e23bb
abc57c0
 
91e23bb
abc57c0
 
91e23bb
abc57c0
8b5d27e
 
 
 
3b555b0
8b5d27e
 
 
91e23bb
7dd164e
 
 
 
 
 
 
8b5d27e
7dd164e
 
91e23bb
bf5ac8a
91e23bb
abc57c0
8b5d27e
 
 
7dd164e
abc57c0
8b5d27e
7dd164e
 
 
 
abc57c0
7dd164e
 
 
 
 
 
 
8b5d27e
 
 
bf5ac8a
8b5d27e
abc57c0
8b5d27e
abc57c0
 
 
 
8b5d27e
91e23bb
7dd164e
91e23bb
8b5d27e
 
abc57c0
 
8b5d27e
 
91e23bb
 
 
 
 
abc57c0
91e23bb
 
8b5d27e
91e23bb
 
 
 
7dd164e
abc57c0
 
7dd164e
 
 
 
 
 
 
 
91e23bb
7dd164e
91e23bb
 
 
bf5ac8a
91e23bb
 
 
d6fd696
 
91e23bb
7dd164e
 
89a2859
3b555b0
 
 
 
 
 
89a2859
 
 
 
3b555b0
 
abc57c0
 
 
 
 
 
 
 
 
 
 
 
 
65a2a7e
 
 
 
 
 
abc57c0
 
 
bf5ac8a
abc57c0
 
 
 
 
 
 
 
 
 
 
 
 
65a2a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abc57c0
 
bf5ac8a
abc57c0
bf5ac8a
 
abc57c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf5ac8a
abc57c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65a2a7e
abc57c0
 
 
 
 
 
 
 
 
7dd164e
 
 
 
bf5ac8a
 
7dd164e
 
bf5ac8a
 
 
 
7dd164e
 
bf5ac8a
 
 
 
7dd164e
 
abc57c0
3b555b0
7dd164e
 
 
 
 
 
 
 
bf5ac8a
3b555b0
7dd164e
 
 
 
 
 
 
 
 
 
 
 
abc57c0
7dd164e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abc57c0
7dd164e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b555b0
7dd164e
 
 
 
 
 
 
 
 
 
 
3b555b0
7dd164e
 
 
 
 
 
 
 
 
 
 
 
 
8daa49c
91e23bb
bf5ac8a
 
 
 
 
 
 
 
65a2a7e
bf5ac8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91e23bb
8b5d27e
91e23bb
d6fd696
89a2859
7dd164e
 
abc57c0
7dd164e
 
 
abc57c0
 
89a2859
abc57c0
65a2a7e
abc57c0
 
65a2a7e
abc57c0
 
 
65a2a7e
 
 
 
 
abc57c0
 
89a2859
7dd164e
 
 
3b555b0
7dd164e
 
abc57c0
7dd164e
 
 
 
d6fd696
bf5ac8a
7dd164e
bf5ac8a
 
 
 
 
7dd164e
abc57c0
 
65a2a7e
 
abc57c0
 
89a2859
abc57c0
 
89a2859
 
 
3b555b0
abc57c0
bf5ac8a
abc57c0
65a2a7e
abc57c0
65a2a7e
 
 
 
 
 
abc57c0
 
 
 
 
 
 
 
7dd164e
 
 
abc57c0
7dd164e
 
 
 
abc57c0
 
 
 
65a2a7e
 
abc57c0
 
7dd164e
 
 
 
 
 
 
65a2a7e
abc57c0
65a2a7e
3b555b0
7dd164e
 
 
 
 
bf5ac8a
 
 
 
7dd164e
 
 
 
 
d6fd696
7dd164e
 
 
 
3b555b0
7dd164e
d6fd696
 
91e23bb
8b5d27e
91e23bb
 
d6fd696
 
 
91e23bb
8b5d27e
 
91e23bb
 
 
8b5d27e
91e23bb
8b5d27e
91e23bb
d6fd696
 
8b5d27e
65a2a7e
d6fd696
abc57c0
7dd164e
abc57c0
65a2a7e
 
7dd164e
d6fd696
 
8b5d27e
 
 
 
 
 
 
 
 
 
 
 
d6fd696
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
import re
import gc
import html
import json
import hashlib
import traceback
from pathlib import Path
from typing import Dict, Any, Tuple, List

import torch
import gradio as gr
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModel

from model_utils import load_model_and_tokenizer, generate_completion


# ============================================================
# CONFIG
# ============================================================

REMOTE_MODEL_REPO = "TranTruongMMCII/UIT.CS2229.Generator"

# One checkpoint can be exposed in multiple inference modes.
MODEL_VARIANTS = {
    "Generator - Baseline": {
        "checkpoint": "baseline",
        "use_online_retriever": False,
        "description": "Baseline generator, context only.",
    },
    "Generator - EOL": {
        "checkpoint": "eol",
        "use_online_retriever": False,
        "description": "EOL-trained generator, context only.",
    },
    "Generator - Retriever-trained + EOL (Retrieval OFF)": {
        "checkpoint": "retriever_eol",
        "use_online_retriever": False,
        "description": "Ablation: retriever-trained checkpoint, but retrieved input is empty.",
    },
    "Generator - Online Retriever + EOL": {
        "checkpoint": "retriever_eol",
        "use_online_retriever": True,
        "description": "Full demo mode: online dense retrieval top-5 + rerank + retriever-trained EOL generator.",
    },
}

DEFAULT_MODEL_NAME = "Generator - Baseline"

# Remote retriever artifacts in the same HF model repo.
RETRIEVER_INDEX_IN_REPO = "retriever/py150_train_index.pt"
RETRIEVER_CHUNKS_IN_REPO = "retriever/py150_train_chunked.jsonl"
RETRIEVER_MODEL_NAME = "microsoft/graphcodebert-base"
RETRIEVER_BLOCK_SIZE = 512

# Retrieve top-5 dense candidates, then rerank by simple keyword overlap.
RETRIEVER_TOP_K = 5
RERANK_KEYWORD_WEIGHT = 0.03
RETRIEVED_MAX_TOKENS = 180

# Safer for HF Spaces deploy:
# - False: app starts fast; model/retriever loads lazily on first use.
# - True: download/load at startup; can fail deploy if remote artifacts have issue.
PRE_DOWNLOAD_MODELS = False
WARMUP_DEFAULT_MODEL = False

# Keep one generator model in memory at a time.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ============================================================
# GLOBAL CACHE
# ============================================================

_model_paths_cache: Dict[str, Path] = {}
_current_checkpoint_name = None
_current_tokenizer = None
_current_model = None
_current_model_path = None

_retriever = None


# ============================================================
# MODEL UTILS
# ============================================================

def file_fingerprint(path: Path) -> str:
    """Return short sha256 fingerprint for debugging model identity."""
    if not path.exists():
        return "missing"

    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1024 * 1024), b""):
            h.update(chunk)

    return h.hexdigest()[:16]


def get_variant_config(model_name: str) -> Dict[str, Any]:
    if model_name not in MODEL_VARIANTS:
        raise ValueError(f"Unknown model option: {model_name}")
    return MODEL_VARIANTS[model_name]


def resolve_remote_model_path(model_name: str) -> Path:
    """Download selected generator checkpoint folder from remote HF model repo."""

    variant = get_variant_config(model_name)
    checkpoint_name = variant["checkpoint"]

    if checkpoint_name in _model_paths_cache:
        return _model_paths_cache[checkpoint_name]

    remote_subdir = f"checkpoint-best/{checkpoint_name}"

    local_repo_dir = snapshot_download(
        repo_id=REMOTE_MODEL_REPO,
        repo_type="model",
        allow_patterns=[f"{remote_subdir}/*"],
    )

    model_path = Path(local_repo_dir) / remote_subdir

    required_files = [
        "config.json",
        "generation_config.json",
        "model.safetensors",
        "tokenizer.json",
        "tokenizer_config.json",
    ]

    missing = [f for f in required_files if not (model_path / f).exists()]
    if missing:
        raise FileNotFoundError(
            f"Missing required files in {model_path}: {missing}")

    _model_paths_cache[checkpoint_name] = model_path
    return model_path


def unload_current_model():
    global _current_checkpoint_name, _current_tokenizer, _current_model, _current_model_path

    if _current_model is not None:
        del _current_model
        del _current_tokenizer

        _current_checkpoint_name = None
        _current_tokenizer = None
        _current_model = None
        _current_model_path = None

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


def get_model(model_name: str):
    """Load selected generator checkpoint into memory. Only one generator model is kept in RAM."""

    global _current_checkpoint_name, _current_tokenizer, _current_model, _current_model_path

    variant = get_variant_config(model_name)
    checkpoint_name = variant["checkpoint"]

    if _current_checkpoint_name == checkpoint_name and _current_model is not None:
        return _current_tokenizer, _current_model, _current_model_path

    unload_current_model()

    model_path = resolve_remote_model_path(model_name)

    print(f"Loading generator checkpoint: {checkpoint_name}")
    print(f"Selected mode: {model_name}")
    print(f"Path: {model_path}")
    print(f"SHA: {file_fingerprint(model_path / 'model.safetensors')}")

    tokenizer, model = load_model_and_tokenizer(str(model_path))
    model.to(device)
    model.eval()

    _current_checkpoint_name = checkpoint_name
    _current_tokenizer = tokenizer
    _current_model = model
    _current_model_path = model_path

    return tokenizer, model, model_path


def preload_model_folders():
    """Download all generator model folders to Hugging Face cache. Does not load models into RAM."""
    print("Pre-downloading generator model folders...")
    for name in MODEL_VARIANTS:
        try:
            path = resolve_remote_model_path(name)
            print(f"Cached {name}: {path}")
        except Exception as e:
            print(f"[WARN] Failed to preload {name}: {e}")


# ============================================================
# INPUT NORMALIZATION
# ============================================================

def normalize_line(line: str) -> str:
    """Soft-normalize one line to be closer to train-time token style."""
    line = re.sub(r"([()\[\]{}:,.=+\-*/<>])", r" \1 ", line)
    line = re.sub(r"\s+", " ", line)
    return line.strip()


def context_to_tokens(code: str) -> str:
    """
    Convert normal-looking code into training-style token text.
    If code is already tokenized with <EOL>, keep it as-is.
    """
    code = str(code or "").strip()

    if "<EOL>" in code:
        return code

    code = code.replace("\t", "    ")
    lines = code.splitlines()
    tokens = [normalize_line(line) for line in lines if line.strip()]
    return " <EOL> ".join(tokens).strip()


def trim_token_text(text: str, max_tokens: int) -> str:
    """Trim tokenized text to a maximum number of whitespace-separated tokens."""
    toks = str(text or "").split()
    if len(toks) <= max_tokens:
        return str(text or "").strip()
    return " ".join(toks[:max_tokens]).strip()


# ============================================================
# RETRIEVER
# ============================================================

LIT_PATTERN = re.compile(r"<(STR|NUM|CHAR)_LIT:(.*?)>", re.S)
KEYWORD_PATTERN = re.compile(r"[A-Za-z_]\w+")
STOPWORDS_FOR_RERANK = {
    "from", "import", "class", "def", "return", "self", "true", "false", "none",
    "if", "else", "elif", "for", "while", "try", "except", "with", "as", "in",
    "and", "or", "not", "is", "str", "int", "list", "dict", "object",
}


def convert_cxg_format_to_normal(code: str) -> str:
    """Convert tokenized CodeXGLUE/ReACC code to Python-like text for GraphCodeBERT."""
    code = str(code or "").strip()
    code = code.replace("<s>", "").replace("</s>", "")
    code = code.replace("<EOL>", "\n")
    code = code.replace("<NUM_LIT>", "0")
    code = code.replace("<STR_LIT>", '"str"')
    code = code.replace("<CHAR_LIT>", '"c"')

    for lit_type, lit_value in LIT_PATTERN.findall(code):
        code = code.replace(f"<{lit_type}_LIT:{lit_value}>", lit_value)

    return code


def keyword_set(text: str) -> set:
    """Extract lightweight keywords for reranking dense retrieval candidates."""
    text = html.unescape(str(text or "")).replace("<EOL>", " ")
    kws = set()
    for tok in KEYWORD_PATTERN.findall(text):
        low = tok.lower()
        if len(low) < 3 or low in STOPWORDS_FOR_RERANK:
            continue
        kws.add(low)
    return kws


def rerank_retrieval_results(token_context: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Rerank dense top-k results by adding a small keyword-overlap bonus.
    Dense score still dominates; keyword overlap helps avoid top-1 chunks with wrong local pattern.
    """
    query_keywords = keyword_set(token_context)

    for r in results:
        content_keywords = keyword_set(r.get("content", ""))
        overlap = query_keywords & content_keywords
        denom = max(1, min(len(query_keywords), 20))
        overlap_score = len(overlap) / denom
        r["keyword_overlap_count"] = len(overlap)
        r["keyword_overlap_preview"] = ", ".join(sorted(list(overlap))[:20])
        r["keyword_overlap_score"] = float(overlap_score)
        r["rerank_score"] = float(
            r["score"] + RERANK_KEYWORD_WEIGHT * overlap_score)

    return sorted(results, key=lambda x: x["rerank_score"], reverse=True)


class OnlineDenseRetriever:
    """
    Online dense retriever.

    JSONL chunks are accessed via byte offsets, so the 463MB JSONL file is not loaded
    into Python dictionaries all at once.
    """

    def __init__(self):
        self.device = device
        self.index_path, self.chunks_path = self._download_retriever_artifacts()

        print("Loading retriever index:", self.index_path)
        index_tensor = torch.load(self.index_path, map_location="cpu")
        if not isinstance(index_tensor, torch.Tensor):
            index_tensor = torch.tensor(index_tensor)
        index_tensor = index_tensor.float().contiguous()
        index_tensor = torch.nn.functional.normalize(index_tensor, p=2, dim=1)
        self.index = index_tensor
        self.vector_dim = int(self.index.shape[1])
        self.num_chunks = int(self.index.shape[0])
        print("Retriever index shape:", tuple(self.index.shape))

        print("Building JSONL byte offsets:", self.chunks_path)
        self.offsets = self._build_offsets(self.chunks_path)
        if len(self.offsets) != self.num_chunks:
            raise ValueError(
                f"Index/chunk mismatch: index={self.num_chunks}, chunks={len(self.offsets)}"
            )
        print("Chunk offsets:", len(self.offsets))

        print("Loading retriever tokenizer:", RETRIEVER_MODEL_NAME)
        self.tokenizer = AutoTokenizer.from_pretrained(
            RETRIEVER_MODEL_NAME, use_fast=True)

        print("Loading retriever encoder:", RETRIEVER_MODEL_NAME)
        self.encoder = AutoModel.from_pretrained(
            RETRIEVER_MODEL_NAME).to(self.device)
        self.encoder.eval()

        print("Online retriever ready.")

    @staticmethod
    def _download_retriever_artifacts() -> Tuple[Path, Path]:
        local_repo_dir = snapshot_download(
            repo_id=REMOTE_MODEL_REPO,
            repo_type="model",
            allow_patterns=[RETRIEVER_INDEX_IN_REPO, RETRIEVER_CHUNKS_IN_REPO],
        )
        root = Path(local_repo_dir)
        index_path = root / RETRIEVER_INDEX_IN_REPO
        chunks_path = root / RETRIEVER_CHUNKS_IN_REPO

        missing = [str(p) for p in [index_path, chunks_path] if not p.exists()]
        if missing:
            raise FileNotFoundError(f"Missing retriever artifacts: {missing}")

        return index_path, chunks_path

    @staticmethod
    def _build_offsets(path: Path) -> List[int]:
        offsets = []
        with open(path, "rb") as f:
            while True:
                pos = f.tell()
                line = f.readline()
                if not line:
                    break
                if line.strip():
                    offsets.append(pos)
        return offsets

    def _read_chunk(self, idx: int) -> Dict[str, Any]:
        with open(self.chunks_path, "rb") as f:
            f.seek(self.offsets[idx])
            line = f.readline().decode("utf-8")
        return json.loads(line)

    @torch.no_grad()
    def encode_query(self, token_context: str) -> torch.Tensor:
        clean_code = convert_cxg_format_to_normal(token_context)
        text = f"{self.tokenizer.cls_token} {clean_code} {self.tokenizer.sep_token}"

        encoded = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=RETRIEVER_BLOCK_SIZE,
            return_tensors="pt",
        ).to(self.device)

        outputs = self.encoder(**encoded)
        query_vec = outputs.last_hidden_state[:, 0, :]
        query_vec = torch.nn.functional.normalize(query_vec, p=2, dim=1)
        return query_vec.squeeze(0).detach().cpu().float()

    def search(self, token_context: str, top_k: int = RETRIEVER_TOP_K) -> List[Dict[str, Any]]:
        query_vec = self.encode_query(token_context)
        scores = torch.mv(self.index, query_vec)
        top_scores, top_indices = torch.topk(scores, k=top_k)

        results = []
        for score, idx in zip(top_scores.tolist(), top_indices.tolist()):
            row = self._read_chunk(int(idx))
            content = row.get("content", "")
            results.append({
                "score": float(score),
                "index": int(idx),
                "original_file_id": row.get("original_file_id"),
                "fragment_sequence_id": row.get("fragment_sequence_id"),
                "content": content,
            })

        return rerank_retrieval_results(token_context, results)


def get_retriever() -> OnlineDenseRetriever:
    global _retriever
    if _retriever is None:
        _retriever = OnlineDenseRetriever()
    return _retriever


# ============================================================
# OUTPUT CLEANUP
# ============================================================

STOP_MARKERS = ["<EOL>", "</s>", "<s>",
                "<pad>", "<unk>", "<mask>", "<|endoftext|>"]

SPECIAL_TOKEN_PATTERNS = [
    "<EOL>", "</s>", "<s>", "<pad>", "<unk>", "<mask>", "<|endoftext|>",
    "<STR_LIT>", "<STR_LIT:...>", "<NUM_LIT>", "<NUM_LIT:...>",
    "<CHAR_LIT>", "<CHAR_LIT:...>", "<BOOL_LIT>", "<BOOL_LIT:...>",
    "<NULL_LIT>", "<INDENT>", "<DEDENT>",
]

TOKEN_PATTERN = re.compile(
    r"\"[^\"\n]*\"|'[^'\n]*'|[A-Za-z_]\w*|\d+(?:\.\d+)?|==|!=|<=|>=|\+=|-=|\*=|/=|//|->|[(){}\[\]:,.;=+\-*/<>]"
)


def normalize_special_spacing(text: str) -> str:
    """Normalize weird spaced special tokens that may appear after decoding."""
    text = html.unescape(str(text))

    text = re.sub(r"<\s*/\s*s\s*>", "</s>", text)
    text = re.sub(r"<\s*s\s*>", "<s>", text)
    text = re.sub(r"<\s*pad\s*>", "<pad>", text)
    text = re.sub(r"<\s*unk\s*>", "<unk>", text)
    text = re.sub(r"<\s*mask\s*>", "<mask>", text)
    text = re.sub(r"<\s*\|\s*endoftext\s*\|\s*>", "<|endoftext|>", text)

    for name in ["EOL", "STR_LIT", "NUM_LIT", "CHAR_LIT", "BOOL_LIT", "NULL_LIT", "INDENT", "DEDENT"]:
        text = re.sub(rf"<\s*{name}\s*>", f"<{name}>", text)

    for name in ["STR_LIT", "NUM_LIT", "CHAR_LIT", "BOOL_LIT"]:
        text = re.sub(
            rf"<\s*{name}\s*:\s*([^>]+?)\s*>",
            lambda m, n=name: f"<{n}:{m.group(1).strip()}>",
            text,
        )

    return text


def cut_at_stop_marker(text: str):
    """Cut text at earliest stop marker. Returns: cleaned_prefix, detected_marker."""
    earliest = None
    detected = None

    for marker in STOP_MARKERS:
        pos = text.find(marker)
        if pos >= 0 and (earliest is None or pos < earliest):
            earliest = pos
            detected = marker

    if earliest is None:
        return text, None

    return text[:earliest], detected


def replace_dataset_placeholders(text: str) -> str:
    """Convert train-time placeholders to readable Python-ish code."""

    def repl_str_payload(m):
        value = m.group(1).strip()
        return json.dumps(value)

    text = re.sub(r"<STR_LIT:([^>]+)>", repl_str_payload, text)
    text = text.replace("<STR_LIT>", json.dumps("str"))

    text = re.sub(r"<NUM_LIT:([^>]+)>", lambda m: m.group(1).strip(), text)
    text = text.replace("<NUM_LIT>", "0")

    text = re.sub(r"<CHAR_LIT:([^>]+)>",
                  lambda m: json.dumps(m.group(1).strip()), text)
    text = text.replace("<CHAR_LIT>", json.dumps("c"))

    text = re.sub(r"<BOOL_LIT:True>", "True", text)
    text = re.sub(r"<BOOL_LIT:False>", "False", text)
    text = text.replace("<BOOL_LIT>", "True")
    text = text.replace("<NULL_LIT>", "None")

    text = text.replace("<INDENT>", "\n    ")
    text = text.replace("<DEDENT>", "\n")

    return text


def cleanup_prediction(raw_text: str):
    """
    Clean raw generated token text for UI prediction.
    Returns: prediction_text, detected_stop_marker, normalized_raw_text
    """
    normalized = normalize_special_spacing(raw_text)
    cut_text, stop_marker = cut_at_stop_marker(normalized)

    for marker in STOP_MARKERS:
        cut_text = cut_text.replace(marker, "")

    cut_text = replace_dataset_placeholders(cut_text)
    cut_text = cut_text.replace("<EOL>", "\n")

    cut_text = re.sub(r"\s+([)\]\}:,])", r"\1", cut_text)
    cut_text = re.sub(r"([(\[{])\s+", r"\1", cut_text)

    cut_text = re.sub(r"\s*=\s*", " = ", cut_text)
    cut_text = re.sub(r"\s*\+\s*", " + ", cut_text)
    cut_text = re.sub(r"\s*-\s*", " - ", cut_text)
    cut_text = re.sub(r"\s*\*\s*", " * ", cut_text)
    cut_text = re.sub(r"\s*/\s*", " / ", cut_text)
    cut_text = re.sub(r"\s*<\s*", " < ", cut_text)
    cut_text = re.sub(r"\s*>\s*", " > ", cut_text)

    cut_text = re.sub(r"[ \t]+", " ", cut_text)
    cut_text = re.sub(r"\n\s+", "\n    ", cut_text)

    return cut_text.strip(), stop_marker, normalized


def token_spans(text: str):
    """Return normalized tokens and char spans for overlap trimming."""
    text = str(text or "")
    text = html.unescape(text)
    text = text.replace("<EOL>", "\n")
    toks = []
    spans = []
    for m in TOKEN_PATTERN.finditer(text):
        toks.append(m.group(0))
        spans.append(m.span())
    return toks, spans


def trim_overlapping_prefix(context_text: str, prediction: str, token_context: str = ""):
    """
    Remove duplicated prefix from prediction when prediction begins with tokens that already
    appear at the end of the user context.
    """
    pred = str(prediction or "").strip()
    if not pred:
        return pred, "No prediction to align."

    ctx_source = token_context if token_context else context_text
    ctx_tokens, _ = token_spans(ctx_source)
    pred_tokens, pred_spans = token_spans(pred)

    if not ctx_tokens or not pred_tokens:
        return pred, "No token overlap check applied."

    max_k = min(len(ctx_tokens), len(pred_tokens), 24)
    best_k = 0

    for k in range(1, max_k + 1):
        if ctx_tokens[-k:] == pred_tokens[:k]:
            best_k = k

    if best_k <= 0:
        return pred, "No duplicated prefix found."

    cut_char = pred_spans[best_k - 1][1]
    aligned = pred[cut_char:].lstrip(" \t\n,.;:")

    if not aligned:
        return pred, f"Detected {best_k} duplicated prefix token(s), but kept prediction to avoid empty output."

    removed = pred[:cut_char].strip()
    return aligned, f"Trimmed duplicated prefix from prediction: {removed!r} ({best_k} token(s))."


# ============================================================
# INFERENCE
# ============================================================

def run_demo(model_name: str, context: str):
    try:
        tokenizer, model, model_path = get_model(model_name)
        variant = get_variant_config(model_name)

        token_context = context_to_tokens(context)

        retriever_mode = "Disabled"
        retrieval_results = []
        token_retrieved = ""
        retrieved_raw = ""
        selected_retrieval = None

        if variant["use_online_retriever"]:
            retriever_mode = "Online dense retrieval top-5 + keyword rerank"
            retriever = get_retriever()
            retrieval_results = retriever.search(
                token_context, top_k=RETRIEVER_TOP_K)
            if retrieval_results:
                selected_retrieval = retrieval_results[0]
                retrieved_raw = selected_retrieval.get("content", "")
                token_retrieved = trim_token_text(
                    retrieved_raw, RETRIEVED_MAX_TOKENS)

        max_length = 384 if variant["use_online_retriever"] else 256

        raw_token_output = generate_completion(
            model=model,
            tokenizer=tokenizer,
            retrieved=token_retrieved,
            context=token_context,
            device=device,
            max_length=max_length,
            max_new_tokens=16,
            do_sample=False,
            stop_strings=None,
        )

        prediction_before_align, stop_marker, normalized_output = cleanup_prediction(
            raw_token_output)
        prediction, align_note = trim_overlapping_prefix(
            context_text=context,
            prediction=prediction_before_align,
            token_context=token_context,
        )

        if variant["use_online_retriever"]:
            retriever_note = (
                "Online retriever retrieved dense top-5 candidates, then reranked them using a small "
                "keyword-overlap bonus. The selected retrieved chunk is injected before the context."
            )
        elif variant["checkpoint"] == "retriever_eol":
            retriever_note = (
                "Ablation mode: this checkpoint was trained with retrieved code, but retrieval is OFF. "
                "The model receives typed context only."
            )
        else:
            retriever_note = "Retriever is not used for this model."

        if retrieval_results:
            retrieval_log = ""
            for rank, r in enumerate(retrieval_results, start=1):
                selected_flag = "  <-- SELECTED" if r is selected_retrieval else ""
                retrieval_log += (
                    f"Rank {rank}{selected_flag}\n"
                    f"dense_score: {r['score']:.6f}\n"
                    f"keyword_overlap_count: {r.get('keyword_overlap_count', 0)}\n"
                    f"keyword_overlap_score: {r.get('keyword_overlap_score', 0.0):.6f}\n"
                    f"rerank_score: {r.get('rerank_score', r['score']):.6f}\n"
                    f"overlap_keywords: {r.get('keyword_overlap_preview', '')}\n"
                    f"index: {r['index']}\n"
                    f"original_file_id: {r.get('original_file_id')}\n"
                    f"fragment_sequence_id: {r.get('fragment_sequence_id')}\n"
                    f"content preview: {r.get('content', '')[:1200]}\n\n"
                )
        else:
            retrieval_log = "No retrieval result."

        logs = (
            "=== DEMO LOGS ===\n\n"
            f"[Selected model]\n{model_name}\n\n"
            f"[Mode description]\n{variant['description']}\n\n"
            f"[Model repo]\n{REMOTE_MODEL_REPO}\n\n"
            f"[Local cache path]\n{model_path}\n\n"
            f"[Model fingerprint]\n{file_fingerprint(model_path / 'model.safetensors')}\n\n"
            f"[Device]\n{device}\n\n"
            f"[Retriever mode]\n{retriever_mode}\n\n"
            f"[Retriever note]\n{retriever_note}\n\n"
            f"[Retriever model]\n{RETRIEVER_MODEL_NAME if variant['use_online_retriever'] else 'N/A'}\n\n"
            f"[Retriever artifacts]\n{RETRIEVER_INDEX_IN_REPO}\n{RETRIEVER_CHUNKS_IN_REPO}\n\n"
            f"[Retriever top_k]\n{RETRIEVER_TOP_K}\n\n"
            f"[Rerank keyword weight]\n{RERANK_KEYWORD_WEIGHT}\n\n"
            "[Retriever results]\n"
            f"{retrieval_log}\n"
            "[Known token patterns cleaned in Prediction]\n"
            + "\n".join(f"- {p}" for p in SPECIAL_TOKEN_PATTERNS)
            + "\n\n"
            "[Raw Context]\n"
            f"{context}\n\n"
            "[Context → Tokens]\n"
            f"{token_context}\n\n"
            "[Selected retrieved raw]\n"
            f"{retrieved_raw}\n\n"
            "[Selected retrieved → Tokens used by generator]\n"
            f"{token_retrieved}\n\n"
            "[Raw Generator Output → Tokens]\n"
            f"{raw_token_output}\n\n"
            "[Normalized Generator Output → Tokens]\n"
            f"{normalized_output}\n\n"
            f"[Detected stop marker]\n{stop_marker}\n\n"
            "[Prediction before overlap trim]\n"
            f"{prediction_before_align}\n\n"
            "[Overlap trim note]\n"
            f"{align_note}\n\n"
            "[Prediction]\n"
            f"{prediction}\n"
        )

        return prediction, logs

    except Exception:
        err = traceback.format_exc()
        return (
            "ERROR: failed to load/generate.",
            "=== ERROR LOGS ===\n\n" + err,
        )


# ============================================================
# GRADIO UI
# ============================================================

demo = gr.Interface(
    fn=run_demo,
    inputs=[
        gr.Dropdown(
            choices=list(MODEL_VARIANTS.keys()),
            value=DEFAULT_MODEL_NAME,
            label="Model",
        ),
        gr.Textbox(
            lines=10,
            label="Context",
            placeholder="def sum(a, b):\n    return",
        ),
    ],
    outputs=[
        gr.Textbox(lines=6, label="Prediction"),
        gr.Textbox(lines=34, label="Logs"),
    ],
    title="ReACC Code Completion Demo",
    description=(
        "Type Python code and compare Baseline, EOL, Retriever-trained EOL with retrieval OFF, "
        "and Online Retriever + EOL. Online retriever mode retrieves dense top-5 candidates, "
        "reranks them with keyword overlap, and injects the selected chunk into the generator."
    ),
)


# ============================================================
# STARTUP
# ============================================================

if PRE_DOWNLOAD_MODELS:
    preload_model_folders()

if WARMUP_DEFAULT_MODEL:
    print(f"Warming up default model: {DEFAULT_MODEL_NAME}")
    get_model(DEFAULT_MODEL_NAME)

if __name__ == "__main__":
    demo.launch()