Update app.py
Browse files
app.py
CHANGED
|
@@ -1,90 +1,102 @@
|
|
| 1 |
-
|
| 2 |
-
from model_utils import load_model_and_tokenizer, generate_completion
|
| 3 |
-
import sys
|
| 4 |
-
from pathlib import Path
|
| 5 |
import torch
|
| 6 |
import gradio as gr
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
GENERATOR_DIR = APP_DIR.parent
|
| 11 |
-
CHECKPOINT_DIR = GENERATOR_DIR / "reacc_generator"
|
| 12 |
-
|
| 13 |
-
sys.path.insert(0, str(GENERATOR_DIR))
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# ===== Load model once (important for demo speed) =====
|
| 17 |
-
# MODEL_PATH = CHECKPOINT_DIR / "checkpoint-best"
|
| 18 |
-
# if not MODEL_PATH.exists():
|
| 19 |
-
# MODEL_PATH = CHECKPOINT_DIR / "checkpoint-last"
|
| 20 |
|
| 21 |
-
import os
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
if os.path.exists("reacc_generator/checkpoint-best"):
|
| 24 |
MODEL_PATH = "reacc_generator/checkpoint-best"
|
| 25 |
else:
|
| 26 |
MODEL_PATH = "reacc_generator/checkpoint-last"
|
| 27 |
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
-
tokenizer, model = load_model_and_tokenizer(
|
| 30 |
model.to(device)
|
| 31 |
model.eval()
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
"""
|
| 40 |
-
context: unfinished code (tokenized style)
|
| 41 |
-
use_retriever: toggle ON/OFF
|
| 42 |
-
"""
|
| 43 |
|
| 44 |
-
retrieved = ""
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
model=model,
|
| 51 |
tokenizer=tokenizer,
|
| 52 |
-
retrieved=
|
| 53 |
-
context=
|
| 54 |
device=device,
|
| 55 |
max_length=256,
|
| 56 |
-
max_new_tokens=16,
|
| 57 |
do_sample=False,
|
| 58 |
-
stop_strings=["<EOL>"]
|
| 59 |
)
|
| 60 |
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
-
# =====
|
|
|
|
|
|
|
| 67 |
demo = gr.Interface(
|
| 68 |
fn=run_demo,
|
| 69 |
inputs=[
|
| 70 |
-
gr.Textbox(
|
| 71 |
-
|
| 72 |
-
label="Context (unfinished code)",
|
| 73 |
-
placeholder="Paste code context here (tokenized style: <EOL>, <STR_LIT>, <NUM_LIT>)"
|
| 74 |
-
),
|
| 75 |
gr.Checkbox(label="Use Retriever (ReACC mode)", value=False),
|
| 76 |
],
|
| 77 |
outputs=[
|
| 78 |
-
gr.Textbox(lines=
|
| 79 |
gr.Textbox(lines=6, label="Retrieved code"),
|
|
|
|
| 80 |
gr.Textbox(lines=1, label="Mode"),
|
| 81 |
],
|
| 82 |
title="ReACC Code Completion Demo",
|
| 83 |
description=(
|
| 84 |
-
"
|
| 85 |
-
"
|
| 86 |
-
"
|
| 87 |
-
"This demo runs on a fine-tuned CodeGPT generator."
|
| 88 |
),
|
| 89 |
)
|
| 90 |
|
|
|
|
| 1 |
+
import os
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
|
| 5 |
+
from model_utils import load_model_and_tokenizer, generate_completion
|
| 6 |
+
from retriever_stub import retrieve_code_stub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
+
# =========================
|
| 10 |
+
# Model loading
|
| 11 |
+
# =========================
|
| 12 |
if os.path.exists("reacc_generator/checkpoint-best"):
|
| 13 |
MODEL_PATH = "reacc_generator/checkpoint-best"
|
| 14 |
else:
|
| 15 |
MODEL_PATH = "reacc_generator/checkpoint-last"
|
| 16 |
|
| 17 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
tokenizer, model = load_model_and_tokenizer(MODEL_PATH)
|
| 19 |
model.to(device)
|
| 20 |
model.eval()
|
| 21 |
|
| 22 |
+
|
| 23 |
+
# =========================
|
| 24 |
+
# Demo-level token adapter
|
| 25 |
+
# =========================
|
| 26 |
+
def python_to_demo_tokens(code: str) -> str:
|
| 27 |
+
"""Convert Python code to a simplified token format for the model."""
|
| 28 |
+
return code.replace("\n", " <EOL> ")
|
| 29 |
|
| 30 |
|
| 31 |
+
def demo_tokens_to_python(code: str) -> str:
|
| 32 |
+
"""Convert model token output back to readable Python code."""
|
| 33 |
+
return code.replace("<EOL>", "\n")
|
|
|
|
|
|
|
|
|
|
| 34 |
|
|
|
|
| 35 |
|
| 36 |
+
# =========================
|
| 37 |
+
# Inference function
|
| 38 |
+
# =========================
|
| 39 |
+
def run_demo(python_context: str, use_retriever: bool):
|
| 40 |
|
| 41 |
+
# 1. Python → tokenized (hidden from user)
|
| 42 |
+
token_context = python_to_demo_tokens(python_context)
|
| 43 |
+
|
| 44 |
+
# 2. Retriever (optional)
|
| 45 |
+
retrieved = retrieve_code_stub(python_context) if use_retriever else ""
|
| 46 |
+
token_retrieved = python_to_demo_tokens(retrieved) if retrieved else ""
|
| 47 |
+
|
| 48 |
+
# 3. Generator
|
| 49 |
+
token_output = generate_completion(
|
| 50 |
model=model,
|
| 51 |
tokenizer=tokenizer,
|
| 52 |
+
retrieved=token_retrieved,
|
| 53 |
+
context=token_context,
|
| 54 |
device=device,
|
| 55 |
max_length=256,
|
| 56 |
+
max_new_tokens=16,
|
| 57 |
do_sample=False,
|
| 58 |
+
stop_strings=["<EOL>"],
|
| 59 |
)
|
| 60 |
|
| 61 |
+
# 4. Token → Python
|
| 62 |
+
python_output = demo_tokens_to_python(token_output)
|
| 63 |
+
|
| 64 |
+
# 5. Logs
|
| 65 |
+
logs = (
|
| 66 |
+
"=== TOKENIZATION LOGS ===\n\n"
|
| 67 |
+
"[Input → Tokens]\n"
|
| 68 |
+
f"{token_context}\n\n"
|
| 69 |
+
"[Retrieved → Tokens]\n"
|
| 70 |
+
f"{token_retrieved}\n\n"
|
| 71 |
+
"[Generator Output → Tokens]\n"
|
| 72 |
+
f"{token_output}\n"
|
| 73 |
+
)
|
| 74 |
|
| 75 |
+
mode = "ReACC (Retriever + Generator)" if use_retriever else "Generator-only baseline"
|
| 76 |
+
return python_output.strip(), retrieved, logs, mode
|
| 77 |
|
| 78 |
|
| 79 |
+
# =========================
|
| 80 |
+
# Gradio UI
|
| 81 |
+
# =========================
|
| 82 |
demo = gr.Interface(
|
| 83 |
fn=run_demo,
|
| 84 |
inputs=[
|
| 85 |
+
gr.Textbox(lines=12, label="Python context (unfinished code)",
|
| 86 |
+
placeholder="def sum(a, b):\n "),
|
|
|
|
|
|
|
|
|
|
| 87 |
gr.Checkbox(label="Use Retriever (ReACC mode)", value=False),
|
| 88 |
],
|
| 89 |
outputs=[
|
| 90 |
+
gr.Textbox(lines=8, label="Prediction (Python code)"),
|
| 91 |
gr.Textbox(lines=6, label="Retrieved code"),
|
| 92 |
+
gr.Textbox(lines=12, label="Logs (tokenization & generation)"),
|
| 93 |
gr.Textbox(lines=1, label="Mode"),
|
| 94 |
],
|
| 95 |
title="ReACC Code Completion Demo",
|
| 96 |
description=(
|
| 97 |
+
"Enter normal Python code.\n"
|
| 98 |
+
"The system will internally tokenize it for the generator.\n"
|
| 99 |
+
"You can view tokenization and generation details in the Logs section."
|
|
|
|
| 100 |
),
|
| 101 |
)
|
| 102 |
|