TranTruongMMCII commited on
Commit
8b5d27e
·
1 Parent(s): b11a23b
app.py CHANGED
@@ -1,70 +1,139 @@
1
- import os
2
  import re
3
  import gc
 
4
  from pathlib import Path
5
 
6
  import torch
7
  import gradio as gr
 
8
 
9
  from model_utils import load_model_and_tokenizer, generate_completion
10
 
11
 
12
  # ============================================================
13
- # Path config
14
  # ============================================================
15
 
16
- BASE_DIR = Path(__file__).parent
17
 
18
- MODEL_PATHS = {
19
- "Generator - Baseline": BASE_DIR / "checkpoint-best" / "baseline",
20
- "Generator - EOL": BASE_DIR / "checkpoint-best" / "eol",
 
21
  }
22
 
 
 
 
 
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
 
 
 
 
 
 
 
25
  _current_model_name = None
26
  _current_tokenizer = None
27
  _current_model = None
 
28
 
29
 
30
  # ============================================================
31
- # Model loading
32
  # ============================================================
33
 
34
- def get_model(model_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
35
  """
36
- Lazily load selected model.
37
- Only one model is kept in memory at a time.
38
  """
39
 
40
- global _current_model_name, _current_tokenizer, _current_model
 
41
 
42
- if model_name not in MODEL_PATHS:
43
  raise ValueError(f"Unknown model option: {model_name}")
44
 
45
- model_path = MODEL_PATHS[model_name]
 
 
 
 
 
 
 
 
 
46
 
47
  if not model_path.exists():
 
 
 
48
  raise FileNotFoundError(
49
- f"Model path not found: {model_path}\n"
50
- f"Expected structure: checkpoint-best/baseline and checkpoint-best/eol"
51
  )
52
 
53
- # Reuse current loaded model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if _current_model_name == model_name and _current_model is not None:
55
- return _current_tokenizer, _current_model, model_path
56
 
57
- # Unload old model if switching
58
  if _current_model is not None:
59
  del _current_model
60
  del _current_tokenizer
61
  _current_model = None
62
  _current_tokenizer = None
 
63
  gc.collect()
64
  if torch.cuda.is_available():
65
  torch.cuda.empty_cache()
66
 
67
- print(f"Loading model: {model_name} from {model_path}")
 
 
 
 
68
 
69
  tokenizer, model = load_model_and_tokenizer(str(model_path))
70
  model.to(device)
@@ -73,84 +142,37 @@ def get_model(model_name: str):
73
  _current_model_name = model_name
74
  _current_tokenizer = tokenizer
75
  _current_model = model
 
76
 
77
  return tokenizer, model, model_path
78
 
79
 
80
  # ============================================================
81
- # Soft normalization adapters
82
  # ============================================================
83
 
84
  def normalize_line(line: str) -> str:
85
- """
86
- Soft-normalize one line to be closer to training token style.
87
- Example:
88
- def add(a, b):
89
- becomes:
90
- def add ( a , b ) :
91
- """
92
-
93
- # Put spaces around common Python punctuation/operators
94
  line = re.sub(r"([()\[\]{}:,.=+\-*/<>])", r" \1 ", line)
95
-
96
- # Collapse spaces
97
  line = re.sub(r"\s+", " ", line)
98
-
99
  return line.strip()
100
 
101
 
102
  def context_to_tokens(code: str) -> str:
103
- """
104
- Convert normal-looking code into training-style token text.
105
-
106
- Important:
107
- - Preserve line boundaries as <EOL>
108
- - Do not fake <STR_LIT> / <NUM_LIT>
109
- """
110
-
111
- code = code.replace("\t", " ")
112
- lines = code.splitlines()
113
-
114
- normalized_lines = []
115
- for line in lines:
116
- norm = normalize_line(line)
117
- if norm:
118
- normalized_lines.append(norm)
119
-
120
- return " <EOL> ".join(normalized_lines).strip()
121
 
122
 
123
  def tokens_to_readable(code: str) -> str:
124
- """
125
- Convert generated token text back to readable form.
126
- This is demo-level detokenization, not a perfect Python formatter.
127
- """
128
-
129
  code = code.replace("<EOL>", "\n")
130
-
131
- # Remove spaces before punctuation
132
  code = re.sub(r"\s+([)\]\}:,])", r"\1", code)
133
-
134
- # Remove spaces after opening punctuation
135
  code = re.sub(r"([(\[\{])\s+", r"\1", code)
136
-
137
- # Compact common binary operators mildly
138
- code = re.sub(r"\s*=\s*", " = ", code)
139
- code = re.sub(r"\s*\+\s*", " + ", code)
140
- code = re.sub(r"\s*-\s*", " - ", code)
141
- code = re.sub(r"\s*\*\s*", " * ", code)
142
- code = re.sub(r"\s*/\s*", " / ", code)
143
- code = re.sub(r"\s*<\s*", " < ", code)
144
- code = re.sub(r"\s*>\s*", " > ", code)
145
-
146
- # Clean repeated spaces
147
- code = re.sub(r"[ \t]+", " ", code)
148
-
149
  return code.strip()
150
 
151
 
152
  # ============================================================
153
- # Inference
154
  # ============================================================
155
 
156
  def run_demo(model_name: str, context: str):
@@ -158,13 +180,10 @@ def run_demo(model_name: str, context: str):
158
 
159
  token_context = context_to_tokens(context)
160
 
161
- # No retriever for now
162
- token_retrieved = ""
163
-
164
  token_output = generate_completion(
165
  model=model,
166
  tokenizer=tokenizer,
167
- retrieved=token_retrieved,
168
  context=token_context,
169
  device=device,
170
  max_length=256,
@@ -178,50 +197,56 @@ def run_demo(model_name: str, context: str):
178
  logs = (
179
  "=== DEMO LOGS ===\n\n"
180
  f"[Selected model]\n{model_name}\n\n"
181
- f"[Model path]\n{model_path}\n\n"
182
- "[Raw Context]\n"
183
- f"{context}\n\n"
 
184
  "[Context → Tokens]\n"
185
  f"{token_context}\n\n"
186
- "[Retrieved → Tokens]\n"
187
- f"{token_retrieved}\n\n"
188
- "[Generator Output → Tokens]\n"
189
- f"{token_output}\n\n"
190
- "[Prediction]\n"
191
- f"{prediction}\n"
192
  )
193
 
194
  return prediction, logs
195
 
196
 
197
  # ============================================================
198
- # Gradio UI
199
  # ============================================================
200
 
201
  demo = gr.Interface(
202
  fn=run_demo,
203
  inputs=[
204
  gr.Dropdown(
205
- choices=["Generator - Baseline", "Generator - EOL"],
206
- value="Generator - Baseline",
207
  label="Model",
208
  ),
209
  gr.Textbox(
210
- lines=12,
211
  label="Context",
212
- placeholder="def add(a, b):\n return",
213
  ),
214
  ],
215
  outputs=[
216
- gr.Textbox(lines=8, label="Prediction"),
217
- gr.Textbox(lines=14, label="Logs"),
218
  ],
219
  title="ReACC Generator Demo",
220
- description=(
221
- "Compare Generator baseline and Generator + EOL. "
222
- "Retriever integration will be added later."
223
- ),
224
  )
225
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  if __name__ == "__main__":
227
  demo.launch()
 
 
1
  import re
2
  import gc
3
+ import hashlib
4
  from pathlib import Path
5
 
6
  import torch
7
  import gradio as gr
8
+ from huggingface_hub import snapshot_download
9
 
10
  from model_utils import load_model_and_tokenizer, generate_completion
11
 
12
 
13
  # ============================================================
14
+ # CONFIG
15
  # ============================================================
16
 
17
+ REMOTE_MODEL_REPO = "TranTruongMMCII/UIT.CS2229.Generator"
18
 
19
+ # Mapping dropdown → folder trong model repo
20
+ MODEL_VARIANTS = {
21
+ "Generator - Baseline": "baseline",
22
+ "Generator - EOL": "eol",
23
  }
24
 
25
+ # Hành vi khi start app
26
+ PRE_DOWNLOAD_MODELS = True # tải model về cache ngay khi start
27
+ WARMUP_DEFAULT_MODEL = True # load sẵn baseline vào RAM
28
+ DEFAULT_MODEL_NAME = "Generator - Baseline"
29
+
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
+
33
+ # ============================================================
34
+ # GLOBAL CACHE (SESSION-LIFETIME)
35
+ # ============================================================
36
+
37
+ _model_paths_cache = {} # cache path model đã download
38
  _current_model_name = None
39
  _current_tokenizer = None
40
  _current_model = None
41
+ _current_model_path = None
42
 
43
 
44
  # ============================================================
45
+ # UTILS
46
  # ============================================================
47
 
48
+ def file_fingerprint(path: Path) -> str:
49
+ """Short SHA256 fingerprint to verify model identity."""
50
+ if not path.exists():
51
+ return "missing"
52
+
53
+ h = hashlib.sha256()
54
+ with open(path, "rb") as f:
55
+ for chunk in iter(lambda: f.read(1024 * 1024), b""):
56
+ h.update(chunk)
57
+ return h.hexdigest()[:16]
58
+
59
+
60
+ def resolve_remote_model_path(model_name: str) -> Path:
61
  """
62
+ Download model folder from remote HF model repo.
63
+ Download happens once per runtime and is cached.
64
  """
65
 
66
+ if model_name in _model_paths_cache:
67
+ return _model_paths_cache[model_name]
68
 
69
+ if model_name not in MODEL_VARIANTS:
70
  raise ValueError(f"Unknown model option: {model_name}")
71
 
72
+ variant = MODEL_VARIANTS[model_name]
73
+ remote_subdir = f"checkpoint-best/{variant}"
74
+
75
+ local_repo_dir = snapshot_download(
76
+ repo_id=REMOTE_MODEL_REPO,
77
+ repo_type="model",
78
+ allow_patterns=[f"{remote_subdir}/*"],
79
+ )
80
+
81
+ model_path = Path(local_repo_dir) / remote_subdir
82
 
83
  if not model_path.exists():
84
+ raise FileNotFoundError(f"Missing model folder: {model_path}")
85
+
86
+ if not (model_path / "model.safetensors").exists():
87
  raise FileNotFoundError(
88
+ f"model.safetensors not found in {model_path}"
 
89
  )
90
 
91
+ _model_paths_cache[model_name] = model_path
92
+ return model_path
93
+
94
+
95
+ def preload_model_folders():
96
+ """Download all model folders into HF cache (no RAM load)."""
97
+ print("Pre-downloading model folders...")
98
+ for name in MODEL_VARIANTS:
99
+ try:
100
+ path = resolve_remote_model_path(name)
101
+ print(f"✔ Cached {name}: {path}")
102
+ except Exception as e:
103
+ print(f"⚠ Failed to preload {name}: {e}")
104
+
105
+
106
+ # ============================================================
107
+ # MODEL LOADING (RAM)
108
+ # ============================================================
109
+
110
+ def get_model(model_name: str):
111
+ """
112
+ Load selected model into RAM.
113
+ Only ONE model is kept in memory at a time.
114
+ """
115
+
116
+ global _current_model_name, _current_tokenizer, _current_model, _current_model_path
117
+
118
  if _current_model_name == model_name and _current_model is not None:
119
+ return _current_tokenizer, _current_model, _current_model_path
120
 
121
+ # unload old model
122
  if _current_model is not None:
123
  del _current_model
124
  del _current_tokenizer
125
  _current_model = None
126
  _current_tokenizer = None
127
+ _current_model_path = None
128
  gc.collect()
129
  if torch.cuda.is_available():
130
  torch.cuda.empty_cache()
131
 
132
+ model_path = resolve_remote_model_path(model_name)
133
+
134
+ print(f"Loading model: {model_name}")
135
+ print(f"Path: {model_path}")
136
+ print(f"SHA: {file_fingerprint(model_path / 'model.safetensors')}")
137
 
138
  tokenizer, model = load_model_and_tokenizer(str(model_path))
139
  model.to(device)
 
142
  _current_model_name = model_name
143
  _current_tokenizer = tokenizer
144
  _current_model = model
145
+ _current_model_path = model_path
146
 
147
  return tokenizer, model, model_path
148
 
149
 
150
  # ============================================================
151
+ # SOFT NORMALIZATION
152
  # ============================================================
153
 
154
  def normalize_line(line: str) -> str:
 
 
 
 
 
 
 
 
 
155
  line = re.sub(r"([()\[\]{}:,.=+\-*/<>])", r" \1 ", line)
 
 
156
  line = re.sub(r"\s+", " ", line)
 
157
  return line.strip()
158
 
159
 
160
  def context_to_tokens(code: str) -> str:
161
+ lines = code.replace("\t", " ").splitlines()
162
+ tokens = [normalize_line(l) for l in lines if l.strip()]
163
+ return " <EOL> ".join(tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  def tokens_to_readable(code: str) -> str:
 
 
 
 
 
167
  code = code.replace("<EOL>", "\n")
 
 
168
  code = re.sub(r"\s+([)\]\}:,])", r"\1", code)
 
 
169
  code = re.sub(r"([(\[\{])\s+", r"\1", code)
170
+ code = re.sub(r"\s+", " ", code)
 
 
 
 
 
 
 
 
 
 
 
 
171
  return code.strip()
172
 
173
 
174
  # ============================================================
175
+ # INFERENCE
176
  # ============================================================
177
 
178
  def run_demo(model_name: str, context: str):
 
180
 
181
  token_context = context_to_tokens(context)
182
 
 
 
 
183
  token_output = generate_completion(
184
  model=model,
185
  tokenizer=tokenizer,
186
+ retrieved="",
187
  context=token_context,
188
  device=device,
189
  max_length=256,
 
197
  logs = (
198
  "=== DEMO LOGS ===\n\n"
199
  f"[Selected model]\n{model_name}\n\n"
200
+ f"[Model repo]\n{REMOTE_MODEL_REPO}\n\n"
201
+ f"[Local cache path]\n{model_path}\n\n"
202
+ f"[Model fingerprint]\n{file_fingerprint(model_path / 'model.safetensors')}\n\n"
203
+ f"[Device]\n{device}\n\n"
204
  "[Context → Tokens]\n"
205
  f"{token_context}\n\n"
206
+ "[Output → Tokens]\n"
207
+ f"{token_output}\n"
 
 
 
 
208
  )
209
 
210
  return prediction, logs
211
 
212
 
213
  # ============================================================
214
+ # GRADIO UI
215
  # ============================================================
216
 
217
  demo = gr.Interface(
218
  fn=run_demo,
219
  inputs=[
220
  gr.Dropdown(
221
+ choices=list(MODEL_VARIANTS.keys()),
222
+ value=DEFAULT_MODEL_NAME,
223
  label="Model",
224
  ),
225
  gr.Textbox(
226
+ lines=10,
227
  label="Context",
228
+ placeholder="def sum(a, b):\n return",
229
  ),
230
  ],
231
  outputs=[
232
+ gr.Textbox(lines=6, label="Prediction"),
233
+ gr.Textbox(lines=16, label="Logs"),
234
  ],
235
  title="ReACC Generator Demo",
236
+ description="Compare Generator Baseline vs Generator + EOL (model loaded from external HF repo).",
 
 
 
237
  )
238
 
239
+
240
+ # ============================================================
241
+ # STARTUP
242
+ # ============================================================
243
+
244
+ if PRE_DOWNLOAD_MODELS:
245
+ preload_model_folders()
246
+
247
+ if WARMUP_DEFAULT_MODEL:
248
+ print(f"Warming up default model: {DEFAULT_MODEL_NAME}")
249
+ get_model(DEFAULT_MODEL_NAME)
250
+
251
  if __name__ == "__main__":
252
  demo.launch()
checkpoint-best/baseline/config.json DELETED
@@ -1,37 +0,0 @@
1
- {
2
- "_num_labels": 2,
3
- "activation_function": "gelu_new",
4
- "add_cross_attention": false,
5
- "architectures": [
6
- "GPT2LMHeadModel"
7
- ],
8
- "attn_pdrop": 0.1,
9
- "bos_token_id": 0,
10
- "dtype": "float32",
11
- "embd_pdrop": 0.1,
12
- "eos_token_id": 2,
13
- "initializer_range": 0.02,
14
- "layer_norm_epsilon": 1e-05,
15
- "model_type": "gpt2",
16
- "n_ctx": 1024,
17
- "n_embd": 768,
18
- "n_head": 12,
19
- "n_inner": null,
20
- "n_layer": 12,
21
- "n_positions": 1024,
22
- "output_past": true,
23
- "pad_token_id": 1,
24
- "reorder_and_upcast_attn": false,
25
- "resid_pdrop": 0.1,
26
- "scale_attn_by_inverse_layer_idx": false,
27
- "scale_attn_weights": true,
28
- "summary_activation": null,
29
- "summary_first_dropout": 0.1,
30
- "summary_proj_to_labels": true,
31
- "summary_type": "cls_index",
32
- "summary_use_proj": true,
33
- "tie_word_embeddings": true,
34
- "transformers_version": "5.0.0",
35
- "use_cache": true,
36
- "vocab_size": 50007
37
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoint-best/baseline/generation_config.json DELETED
@@ -1,7 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "bos_token_id": 0,
4
- "eos_token_id": 2,
5
- "pad_token_id": 1,
6
- "transformers_version": "5.0.0"
7
- }
 
 
 
 
 
 
 
 
checkpoint-best/baseline/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
checkpoint-best/baseline/tokenizer_config.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "add_prefix_space": false,
3
- "backend": "tokenizers",
4
- "bos_token": "<s>",
5
- "eos_token": "</s>",
6
- "errors": "replace",
7
- "extra_special_tokens": [
8
- "<RET>",
9
- "</RET>",
10
- "<CTX>",
11
- "</CTX>",
12
- "<GEN>"
13
- ],
14
- "full_tokenizer_file": null,
15
- "is_local": false,
16
- "model_max_length": 1000000000000000019884624838656,
17
- "pad_token": "<pad>",
18
- "sep_token": "<EOL>",
19
- "tokenizer_class": "GPT2Tokenizer",
20
- "unk_token": "<|UNKNOWN|>"
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoint-best/eol/config.json DELETED
@@ -1,37 +0,0 @@
1
- {
2
- "_num_labels": 2,
3
- "activation_function": "gelu_new",
4
- "add_cross_attention": false,
5
- "architectures": [
6
- "GPT2LMHeadModel"
7
- ],
8
- "attn_pdrop": 0.1,
9
- "bos_token_id": 0,
10
- "dtype": "float32",
11
- "embd_pdrop": 0.1,
12
- "eos_token_id": 2,
13
- "initializer_range": 0.02,
14
- "layer_norm_epsilon": 1e-05,
15
- "model_type": "gpt2",
16
- "n_ctx": 1024,
17
- "n_embd": 768,
18
- "n_head": 12,
19
- "n_inner": null,
20
- "n_layer": 12,
21
- "n_positions": 1024,
22
- "output_past": true,
23
- "pad_token_id": 1,
24
- "reorder_and_upcast_attn": false,
25
- "resid_pdrop": 0.1,
26
- "scale_attn_by_inverse_layer_idx": false,
27
- "scale_attn_weights": true,
28
- "summary_activation": null,
29
- "summary_first_dropout": 0.1,
30
- "summary_proj_to_labels": true,
31
- "summary_type": "cls_index",
32
- "summary_use_proj": true,
33
- "tie_word_embeddings": true,
34
- "transformers_version": "5.0.0",
35
- "use_cache": true,
36
- "vocab_size": 50007
37
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoint-best/eol/generation_config.json DELETED
@@ -1,7 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "bos_token_id": 0,
4
- "eos_token_id": 2,
5
- "pad_token_id": 1,
6
- "transformers_version": "5.0.0"
7
- }
 
 
 
 
 
 
 
 
checkpoint-best/eol/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
checkpoint-best/eol/tokenizer_config.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "add_prefix_space": false,
3
- "backend": "tokenizers",
4
- "bos_token": "<s>",
5
- "eos_token": "</s>",
6
- "errors": "replace",
7
- "extra_special_tokens": [
8
- "<RET>",
9
- "</RET>",
10
- "<CTX>",
11
- "</CTX>",
12
- "<GEN>"
13
- ],
14
- "full_tokenizer_file": null,
15
- "is_local": false,
16
- "model_max_length": 1000000000000000019884624838656,
17
- "pad_token": "<pad>",
18
- "sep_token": "<EOL>",
19
- "tokenizer_class": "GPT2Tokenizer",
20
- "unk_token": "<|UNKNOWN|>"
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- torch
2
- transformers
3
- gradio
4
- tqdm
5
- numpy
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ tqdm
5
+ numpy
6
+ huggingface_hub
retriever_stub.py DELETED
@@ -1,19 +0,0 @@
1
- def retrieve_code_stub(context: str) -> str:
2
- """
3
- Mock retriever for demo purposes.
4
- Later, replace this with real retriever logic.
5
- """
6
-
7
- # Simple heuristic demo (hardcoded or rule-based)
8
- if "pytest" in context:
9
- return (
10
- "def data(): <EOL>"
11
- " tmpdir = py.test.ensuretemp('<STR_LIT>') <EOL>"
12
- " return tmpdir"
13
- )
14
-
15
- if "def add" in context:
16
- return "def add(a, b): <EOL> return a + b"
17
-
18
- # default fallback
19
- return ""