TranTruongMMCII commited on
Commit
080fc68
·
verified ·
1 Parent(s): d6fd696

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -47
app.py CHANGED
@@ -1,90 +1,102 @@
1
- from retriever_stub import retrieve_code_stub
2
- from model_utils import load_model_and_tokenizer, generate_completion
3
- import sys
4
- from pathlib import Path
5
  import torch
6
  import gradio as gr
7
 
8
- # ===== Path setup =====
9
- APP_DIR = Path(__file__).parent
10
- GENERATOR_DIR = APP_DIR.parent
11
- CHECKPOINT_DIR = GENERATOR_DIR / "reacc_generator"
12
-
13
- sys.path.insert(0, str(GENERATOR_DIR))
14
-
15
-
16
- # ===== Load model once (important for demo speed) =====
17
- # MODEL_PATH = CHECKPOINT_DIR / "checkpoint-best"
18
- # if not MODEL_PATH.exists():
19
- # MODEL_PATH = CHECKPOINT_DIR / "checkpoint-last"
20
 
21
- import os
22
 
 
 
 
23
  if os.path.exists("reacc_generator/checkpoint-best"):
24
  MODEL_PATH = "reacc_generator/checkpoint-best"
25
  else:
26
  MODEL_PATH = "reacc_generator/checkpoint-last"
27
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- tokenizer, model = load_model_and_tokenizer(str(MODEL_PATH))
30
  model.to(device)
31
  model.eval()
32
 
33
- print("Model loaded from:", MODEL_PATH)
34
- print("Device:", device)
 
 
 
 
 
35
 
36
 
37
- # ===== Core inference logic =====
38
- def run_demo(context: str, use_retriever: bool):
39
- """
40
- context: unfinished code (tokenized style)
41
- use_retriever: toggle ON/OFF
42
- """
43
 
44
- retrieved = ""
45
 
46
- if use_retriever:
47
- retrieved = retrieve_code_stub(context)
 
 
48
 
49
- prediction = generate_completion(
 
 
 
 
 
 
 
 
50
  model=model,
51
  tokenizer=tokenizer,
52
- retrieved=retrieved,
53
- context=context,
54
  device=device,
55
  max_length=256,
56
- max_new_tokens=16, # keep short for demo
57
  do_sample=False,
58
- stop_strings=["<EOL>"]
59
  )
60
 
61
- mode = "Retriever + Generator (ReACC)" if use_retriever else "Generator-only baseline"
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- return prediction, retrieved, mode
 
64
 
65
 
66
- # ===== Gradio UI =====
 
 
67
  demo = gr.Interface(
68
  fn=run_demo,
69
  inputs=[
70
- gr.Textbox(
71
- lines=12,
72
- label="Context (unfinished code)",
73
- placeholder="Paste code context here (tokenized style: <EOL>, <STR_LIT>, <NUM_LIT>)"
74
- ),
75
  gr.Checkbox(label="Use Retriever (ReACC mode)", value=False),
76
  ],
77
  outputs=[
78
- gr.Textbox(lines=6, label="Prediction"),
79
  gr.Textbox(lines=6, label="Retrieved code"),
 
80
  gr.Textbox(lines=1, label="Mode"),
81
  ],
82
  title="ReACC Code Completion Demo",
83
  description=(
84
- "Toggle Retriever ON/OFF to compare:\n"
85
- "- Generator-only baseline\n"
86
- "- Retriever-augmented generation (ReACC)\n\n"
87
- "This demo runs on a fine-tuned CodeGPT generator."
88
  ),
89
  )
90
 
 
1
+ import os
 
 
 
2
  import torch
3
  import gradio as gr
4
 
5
+ from model_utils import load_model_and_tokenizer, generate_completion
6
+ from retriever_stub import retrieve_code_stub
 
 
 
 
 
 
 
 
 
 
7
 
 
8
 
9
+ # =========================
10
+ # Model loading
11
+ # =========================
12
  if os.path.exists("reacc_generator/checkpoint-best"):
13
  MODEL_PATH = "reacc_generator/checkpoint-best"
14
  else:
15
  MODEL_PATH = "reacc_generator/checkpoint-last"
16
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ tokenizer, model = load_model_and_tokenizer(MODEL_PATH)
19
  model.to(device)
20
  model.eval()
21
 
22
+
23
+ # =========================
24
+ # Demo-level token adapter
25
+ # =========================
26
+ def python_to_demo_tokens(code: str) -> str:
27
+ """Convert Python code to a simplified token format for the model."""
28
+ return code.replace("\n", " <EOL> ")
29
 
30
 
31
+ def demo_tokens_to_python(code: str) -> str:
32
+ """Convert model token output back to readable Python code."""
33
+ return code.replace("<EOL>", "\n")
 
 
 
34
 
 
35
 
36
+ # =========================
37
+ # Inference function
38
+ # =========================
39
+ def run_demo(python_context: str, use_retriever: bool):
40
 
41
+ # 1. Python → tokenized (hidden from user)
42
+ token_context = python_to_demo_tokens(python_context)
43
+
44
+ # 2. Retriever (optional)
45
+ retrieved = retrieve_code_stub(python_context) if use_retriever else ""
46
+ token_retrieved = python_to_demo_tokens(retrieved) if retrieved else ""
47
+
48
+ # 3. Generator
49
+ token_output = generate_completion(
50
  model=model,
51
  tokenizer=tokenizer,
52
+ retrieved=token_retrieved,
53
+ context=token_context,
54
  device=device,
55
  max_length=256,
56
+ max_new_tokens=16,
57
  do_sample=False,
58
+ stop_strings=["<EOL>"],
59
  )
60
 
61
+ # 4. Token Python
62
+ python_output = demo_tokens_to_python(token_output)
63
+
64
+ # 5. Logs
65
+ logs = (
66
+ "=== TOKENIZATION LOGS ===\n\n"
67
+ "[Input → Tokens]\n"
68
+ f"{token_context}\n\n"
69
+ "[Retrieved → Tokens]\n"
70
+ f"{token_retrieved}\n\n"
71
+ "[Generator Output → Tokens]\n"
72
+ f"{token_output}\n"
73
+ )
74
 
75
+ mode = "ReACC (Retriever + Generator)" if use_retriever else "Generator-only baseline"
76
+ return python_output.strip(), retrieved, logs, mode
77
 
78
 
79
+ # =========================
80
+ # Gradio UI
81
+ # =========================
82
  demo = gr.Interface(
83
  fn=run_demo,
84
  inputs=[
85
+ gr.Textbox(lines=12, label="Python context (unfinished code)",
86
+ placeholder="def sum(a, b):\n "),
 
 
 
87
  gr.Checkbox(label="Use Retriever (ReACC mode)", value=False),
88
  ],
89
  outputs=[
90
+ gr.Textbox(lines=8, label="Prediction (Python code)"),
91
  gr.Textbox(lines=6, label="Retrieved code"),
92
+ gr.Textbox(lines=12, label="Logs (tokenization & generation)"),
93
  gr.Textbox(lines=1, label="Mode"),
94
  ],
95
  title="ReACC Code Completion Demo",
96
  description=(
97
+ "Enter normal Python code.\n"
98
+ "The system will internally tokenize it for the generator.\n"
99
+ "You can view tokenization and generation details in the Logs section."
 
100
  ),
101
  )
102