TranTruongMMCII commited on
Commit
91e23bb
·
1 Parent(s): 8daa49c

refactor structure

Browse files
app.py CHANGED
@@ -1,61 +1,166 @@
1
  import os
2
  import re
 
 
 
3
  import torch
4
  import gradio as gr
5
 
6
  from model_utils import load_model_and_tokenizer, generate_completion
7
- from retriever_stub import retrieve_code_stub
8
 
9
- # =========================
10
- # Model loading (HF Spaces compatible)
11
- # =========================
12
- if os.path.exists("reacc_generator/checkpoint-best"):
13
- MODEL_PATH = "reacc_generator/checkpoint-best"
14
- else:
15
- MODEL_PATH = "reacc_generator/checkpoint-last"
 
 
 
 
16
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- tokenizer, model = load_model_and_tokenizer(MODEL_PATH)
19
- model.to(device)
20
- model.eval()
21
 
22
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Soft normalization adapters
24
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
- def python_to_tokens(code: str) -> str:
28
  """
29
- Soft-normalize Python code into a format closer to training distribution
30
- (spacing + <EOL>), without faking STR_LIT / NUM_LIT.
 
 
 
31
  """
 
32
  code = code.replace("\t", " ")
33
- code = re.sub(r"([():,=+*/-])", r" \1 ", code)
34
- code = re.sub(r"\s+", " ", code)
35
- code = code.replace("\n", " <EOL> ")
36
- return code.strip()
 
 
 
 
 
37
 
38
 
39
- def tokens_to_python(code: str) -> str:
40
- """Convert model token output back to readable Python."""
 
 
 
 
41
  code = code.replace("<EOL>", "\n")
42
- code = re.sub(r"\s*([():,=+*/-])\s*", r"\1", code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return code.strip()
44
 
45
- # =========================
 
46
  # Inference
47
- # =========================
48
 
 
 
49
 
50
- def run_demo(context: str, use_retriever: bool):
51
- # Raw -> normalized tokens
52
- token_context = python_to_tokens(context)
53
 
54
- # Retriever (optional)
55
- retrieved_raw = retrieve_code_stub(context) if use_retriever else ""
56
- token_retrieved = python_to_tokens(retrieved_raw) if retrieved_raw else ""
57
 
58
- # Generator
59
  token_output = generate_completion(
60
  model=model,
61
  tokenizer=tokenizer,
@@ -68,12 +173,12 @@ def run_demo(context: str, use_retriever: bool):
68
  stop_strings=["<EOL>"],
69
  )
70
 
71
- # Tokens -> Python
72
- python_output = tokens_to_python(token_output)
73
 
74
- # Logs for explanation
75
  logs = (
76
- "=== NORMALIZATION & GENERATION LOGS ===\n\n"
 
 
77
  "[Raw Context]\n"
78
  f"{context}\n\n"
79
  "[Context → Tokens]\n"
@@ -81,34 +186,40 @@ def run_demo(context: str, use_retriever: bool):
81
  "[Retrieved → Tokens]\n"
82
  f"{token_retrieved}\n\n"
83
  "[Generator Output → Tokens]\n"
84
- f"{token_output}\n"
 
 
85
  )
86
 
87
- mode = "ReACC (Retriever + Generator)" if use_retriever else "Generator-only baseline"
88
- return python_output, retrieved_raw, logs, mode
89
 
90
 
91
- # =========================
92
  # Gradio UI
93
- # =========================
 
94
  demo = gr.Interface(
95
  fn=run_demo,
96
  inputs=[
97
- gr.Textbox(lines=12, label="Context (unfinished code)",
98
- placeholder="def sum(a, b):\n "),
99
- gr.Checkbox(label="Use Retriever (ReACC mode)", value=False),
 
 
 
 
 
 
 
100
  ],
101
  outputs=[
102
  gr.Textbox(lines=8, label="Prediction"),
103
- gr.Textbox(lines=6, label="Retrieved code"),
104
- gr.Textbox(lines=12, label="Logs"),
105
- gr.Textbox(lines=1, label="Mode"),
106
  ],
107
- title="ReACC Code Completion Demo",
108
  description=(
109
- "User-friendly demo: input normal Python code.\n"
110
- "The system softly normalizes input to the training-style tokens.\n"
111
- "Logs show internal transformations for explanation."
112
  ),
113
  )
114
 
 
1
  import os
2
  import re
3
+ import gc
4
+ from pathlib import Path
5
+
6
  import torch
7
  import gradio as gr
8
 
9
  from model_utils import load_model_and_tokenizer, generate_completion
 
10
 
11
+
12
+ # ============================================================
13
+ # Path config
14
+ # ============================================================
15
+
16
+ BASE_DIR = Path(__file__).parent
17
+
18
+ MODEL_PATHS = {
19
+ "Generator - Baseline": BASE_DIR / "checkpoint-best" / "baseline",
20
+ "Generator - EOL": BASE_DIR / "checkpoint-best" / "eol",
21
+ }
22
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
24
 
25
+ _current_model_name = None
26
+ _current_tokenizer = None
27
+ _current_model = None
28
+
29
+
30
+ # ============================================================
31
+ # Model loading
32
+ # ============================================================
33
+
34
+ def get_model(model_name: str):
35
+ """
36
+ Lazily load selected model.
37
+ Only one model is kept in memory at a time.
38
+ """
39
+
40
+ global _current_model_name, _current_tokenizer, _current_model
41
+
42
+ if model_name not in MODEL_PATHS:
43
+ raise ValueError(f"Unknown model option: {model_name}")
44
+
45
+ model_path = MODEL_PATHS[model_name]
46
+
47
+ if not model_path.exists():
48
+ raise FileNotFoundError(
49
+ f"Model path not found: {model_path}\n"
50
+ f"Expected structure: checkpoint-best/baseline and checkpoint-best/eol"
51
+ )
52
+
53
+ # Reuse current loaded model
54
+ if _current_model_name == model_name and _current_model is not None:
55
+ return _current_tokenizer, _current_model, model_path
56
+
57
+ # Unload old model if switching
58
+ if _current_model is not None:
59
+ del _current_model
60
+ del _current_tokenizer
61
+ _current_model = None
62
+ _current_tokenizer = None
63
+ gc.collect()
64
+ if torch.cuda.is_available():
65
+ torch.cuda.empty_cache()
66
+
67
+ print(f"Loading model: {model_name} from {model_path}")
68
+
69
+ tokenizer, model = load_model_and_tokenizer(str(model_path))
70
+ model.to(device)
71
+ model.eval()
72
+
73
+ _current_model_name = model_name
74
+ _current_tokenizer = tokenizer
75
+ _current_model = model
76
+
77
+ return tokenizer, model, model_path
78
+
79
+
80
+ # ============================================================
81
  # Soft normalization adapters
82
+ # ============================================================
83
+
84
+ def normalize_line(line: str) -> str:
85
+ """
86
+ Soft-normalize one line to be closer to training token style.
87
+ Example:
88
+ def add(a, b):
89
+ becomes:
90
+ def add ( a , b ) :
91
+ """
92
+
93
+ # Put spaces around common Python punctuation/operators
94
+ line = re.sub(r"([()\[\]{}:,.=+\-*/<>])", r" \1 ", line)
95
+
96
+ # Collapse spaces
97
+ line = re.sub(r"\s+", " ", line)
98
+
99
+ return line.strip()
100
 
101
 
102
+ def context_to_tokens(code: str) -> str:
103
  """
104
+ Convert normal-looking code into training-style token text.
105
+
106
+ Important:
107
+ - Preserve line boundaries as <EOL>
108
+ - Do not fake <STR_LIT> / <NUM_LIT>
109
  """
110
+
111
  code = code.replace("\t", " ")
112
+ lines = code.splitlines()
113
+
114
+ normalized_lines = []
115
+ for line in lines:
116
+ norm = normalize_line(line)
117
+ if norm:
118
+ normalized_lines.append(norm)
119
+
120
+ return " <EOL> ".join(normalized_lines).strip()
121
 
122
 
123
+ def tokens_to_readable(code: str) -> str:
124
+ """
125
+ Convert generated token text back to readable form.
126
+ This is demo-level detokenization, not a perfect Python formatter.
127
+ """
128
+
129
  code = code.replace("<EOL>", "\n")
130
+
131
+ # Remove spaces before punctuation
132
+ code = re.sub(r"\s+([)\]\}:,])", r"\1", code)
133
+
134
+ # Remove spaces after opening punctuation
135
+ code = re.sub(r"([(\[\{])\s+", r"\1", code)
136
+
137
+ # Compact common binary operators mildly
138
+ code = re.sub(r"\s*=\s*", " = ", code)
139
+ code = re.sub(r"\s*\+\s*", " + ", code)
140
+ code = re.sub(r"\s*-\s*", " - ", code)
141
+ code = re.sub(r"\s*\*\s*", " * ", code)
142
+ code = re.sub(r"\s*/\s*", " / ", code)
143
+ code = re.sub(r"\s*<\s*", " < ", code)
144
+ code = re.sub(r"\s*>\s*", " > ", code)
145
+
146
+ # Clean repeated spaces
147
+ code = re.sub(r"[ \t]+", " ", code)
148
+
149
  return code.strip()
150
 
151
+
152
+ # ============================================================
153
  # Inference
154
+ # ============================================================
155
 
156
+ def run_demo(model_name: str, context: str):
157
+ tokenizer, model, model_path = get_model(model_name)
158
 
159
+ token_context = context_to_tokens(context)
 
 
160
 
161
+ # No retriever for now
162
+ token_retrieved = ""
 
163
 
 
164
  token_output = generate_completion(
165
  model=model,
166
  tokenizer=tokenizer,
 
173
  stop_strings=["<EOL>"],
174
  )
175
 
176
+ prediction = tokens_to_readable(token_output)
 
177
 
 
178
  logs = (
179
+ "=== DEMO LOGS ===\n\n"
180
+ f"[Selected model]\n{model_name}\n\n"
181
+ f"[Model path]\n{model_path}\n\n"
182
  "[Raw Context]\n"
183
  f"{context}\n\n"
184
  "[Context → Tokens]\n"
 
186
  "[Retrieved → Tokens]\n"
187
  f"{token_retrieved}\n\n"
188
  "[Generator Output → Tokens]\n"
189
+ f"{token_output}\n\n"
190
+ "[Prediction]\n"
191
+ f"{prediction}\n"
192
  )
193
 
194
+ return prediction, logs
 
195
 
196
 
197
+ # ============================================================
198
  # Gradio UI
199
+ # ============================================================
200
+
201
  demo = gr.Interface(
202
  fn=run_demo,
203
  inputs=[
204
+ gr.Dropdown(
205
+ choices=["Generator - Baseline", "Generator - EOL"],
206
+ value="Generator - Baseline",
207
+ label="Model",
208
+ ),
209
+ gr.Textbox(
210
+ lines=12,
211
+ label="Context",
212
+ placeholder="def add(a, b):\n return",
213
+ ),
214
  ],
215
  outputs=[
216
  gr.Textbox(lines=8, label="Prediction"),
217
+ gr.Textbox(lines=14, label="Logs"),
 
 
218
  ],
219
+ title="ReACC Generator Demo",
220
  description=(
221
+ "Compare Generator baseline and Generator + EOL. "
222
+ "Retriever integration will be added later."
 
223
  ),
224
  )
225
 
{reacc_generator/checkpoint-best → checkpoint-best/baseline}/config.json RENAMED
File without changes
{reacc_generator/checkpoint-best → checkpoint-best/baseline}/generation_config.json RENAMED
File without changes
{reacc_generator/checkpoint-best → checkpoint-best/baseline}/model.safetensors RENAMED
File without changes
{reacc_generator/checkpoint-best → checkpoint-best/baseline}/tokenizer.json RENAMED
File without changes
{reacc_generator/checkpoint-best → checkpoint-best/baseline}/tokenizer_config.json RENAMED
File without changes
checkpoint-best/eol/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_num_labels": 2,
3
+ "activation_function": "gelu_new",
4
+ "add_cross_attention": false,
5
+ "architectures": [
6
+ "GPT2LMHeadModel"
7
+ ],
8
+ "attn_pdrop": 0.1,
9
+ "bos_token_id": 0,
10
+ "dtype": "float32",
11
+ "embd_pdrop": 0.1,
12
+ "eos_token_id": 2,
13
+ "initializer_range": 0.02,
14
+ "layer_norm_epsilon": 1e-05,
15
+ "model_type": "gpt2",
16
+ "n_ctx": 1024,
17
+ "n_embd": 768,
18
+ "n_head": 12,
19
+ "n_inner": null,
20
+ "n_layer": 12,
21
+ "n_positions": 1024,
22
+ "output_past": true,
23
+ "pad_token_id": 1,
24
+ "reorder_and_upcast_attn": false,
25
+ "resid_pdrop": 0.1,
26
+ "scale_attn_by_inverse_layer_idx": false,
27
+ "scale_attn_weights": true,
28
+ "summary_activation": null,
29
+ "summary_first_dropout": 0.1,
30
+ "summary_proj_to_labels": true,
31
+ "summary_type": "cls_index",
32
+ "summary_use_proj": true,
33
+ "tie_word_embeddings": true,
34
+ "transformers_version": "5.0.0",
35
+ "use_cache": true,
36
+ "vocab_size": 50007
37
+ }
checkpoint-best/eol/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 1,
6
+ "transformers_version": "5.0.0"
7
+ }
checkpoint-best/eol/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a300ae726608e2d94265138f1ccce63e08122e979e90cdb728ec798f19f54c38
3
+ size 497006208
checkpoint-best/eol/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-best/eol/tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "eos_token": "</s>",
6
+ "errors": "replace",
7
+ "extra_special_tokens": [
8
+ "<RET>",
9
+ "</RET>",
10
+ "<CTX>",
11
+ "</CTX>",
12
+ "<GEN>"
13
+ ],
14
+ "full_tokenizer_file": null,
15
+ "is_local": false,
16
+ "model_max_length": 1000000000000000019884624838656,
17
+ "pad_token": "<pad>",
18
+ "sep_token": "<EOL>",
19
+ "tokenizer_class": "GPT2Tokenizer",
20
+ "unk_token": "<|UNKNOWN|>"
21
+ }
reacc_generator/checkpoint-best/.gitkeep DELETED
@@ -1 +0,0 @@
1
- .gitkeep