Don Rishabh commited on
Commit
7d8d47c
·
1 Parent(s): 7ca042f

demo: apply chat template to target (fix rambling completion-mode outputs)

Browse files
Files changed (2) hide show
  1. space-demo/app.py +19 -1
  2. ui/demo_app.py +190 -7
space-demo/app.py CHANGED
@@ -132,6 +132,24 @@ def load_target() -> None:
132
  flush=True)
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  @torch.inference_mode()
136
  def run_target_batch(prompts: List[str], test_input: str) -> List[str]:
137
  load_target()
@@ -139,7 +157,7 @@ def run_target_batch(prompts: List[str], test_input: str) -> List[str]:
139
  keep_idx = []
140
  for i, p in enumerate(prompts):
141
  if p and p.strip():
142
- full_texts.append(f"{p}\n\n{test_input}".strip())
143
  keep_idx.append(i)
144
  if not full_texts:
145
  return ["" for _ in prompts]
 
132
  flush=True)
133
 
134
 
135
+ def _build_target_chat(prompt: str, test_input: str) -> str:
136
+ """Apply the target's chat template: prompt as system, test_input as user.
137
+
138
+ Llama-3.2-3B-Instruct (and any chat-tuned target) needs this — feeding
139
+ raw `prompt\\n\\ntest_input` makes it ramble in completion mode.
140
+ """
141
+ messages = [
142
+ {"role": "system", "content": prompt},
143
+ {"role": "user", "content": test_input},
144
+ ]
145
+ if getattr(_TOK, "chat_template", None):
146
+ return _TOK.apply_chat_template(
147
+ messages, tokenize=False, add_generation_prompt=True,
148
+ )
149
+ # Fallback for non-chat tokenizers
150
+ return f"{prompt}\n\n{test_input}\n\nAssistant:"
151
+
152
+
153
  @torch.inference_mode()
154
  def run_target_batch(prompts: List[str], test_input: str) -> List[str]:
155
  load_target()
 
157
  keep_idx = []
158
  for i, p in enumerate(prompts):
159
  if p and p.strip():
160
+ full_texts.append(_build_target_chat(p, test_input))
161
  keep_idx.append(i)
162
  if not full_texts:
163
  return ["" for _ in prompts]
ui/demo_app.py CHANGED
@@ -47,17 +47,31 @@ DEFAULTS = {
47
  "target_model": os.environ.get(
48
  "DEMO_TARGET_MODEL", "meta-llama/Llama-3.2-3B-Instruct"
49
  ),
 
 
 
 
 
 
 
 
 
 
50
  # CSV produced by training/build_before_after_csv.py
51
  "demo_csv": os.environ.get(
52
- "DEMO_CSV", str(_REPO_ROOT / "outputs" / "qwen_to_qwen_demo.csv")
 
53
  ),
54
  # If the CSV isn't local, you can pull it from the hub:
55
  "fallback_csv_url": (
56
- "https://huggingface.co/rishabh16196/prompt-golf-grpo-1.5b/"
57
- "resolve/main/evals/qwen_to_qwen_demo.csv"
58
  ),
59
- "max_new_tokens": 64,
 
60
  "temperature": 0.0,
 
 
61
  }
62
 
63
 
@@ -104,6 +118,11 @@ _TOK = None
104
  _MODEL = None
105
  _DEVICE = None
106
 
 
 
 
 
 
107
 
108
  def _device() -> str:
109
  if torch.cuda.is_available():
@@ -141,6 +160,24 @@ def load_target() -> None:
141
  flush=True)
142
 
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  @torch.inference_mode()
145
  def run_target_batch(prompts: List[str], test_input: str) -> List[str]:
146
  """Run the target on (prompt[i] + test_input) for all i, in one batched
@@ -150,12 +187,11 @@ def run_target_batch(prompts: List[str], test_input: str) -> List[str]:
150
  incurring inference cost.
151
  """
152
  load_target()
153
- # Build the full prompts; track which positions are non-empty.
154
  full_texts = []
155
  keep_idx = []
156
  for i, p in enumerate(prompts):
157
  if p and p.strip():
158
- full_texts.append(f"{p}\n\n{test_input}".strip())
159
  keep_idx.append(i)
160
  if not full_texts:
161
  return ["" for _ in prompts]
@@ -197,6 +233,140 @@ def count_tokens(text: str) -> int:
197
  return len(_TOK.encode(text or "", add_special_tokens=False))
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  # ---------------------------------------------------------------------------
201
  # Gradio handlers
202
  # ---------------------------------------------------------------------------
@@ -335,7 +505,15 @@ def build_app() -> gr.Blocks:
335
  "should be applied to."),
336
  )
337
 
338
- run_btn = gr.Button("Run target with all three prompts", variant="primary")
 
 
 
 
 
 
 
 
339
 
340
  with gr.Row():
341
  with gr.Column():
@@ -366,6 +544,11 @@ def build_app() -> gr.Blocks:
366
  v_tok, b_tok, t_tok, v_acc, b_acc, t_acc, test_input,
367
  ]
368
  task_dd.change(select_task, inputs=[task_dd], outputs=select_outputs)
 
 
 
 
 
369
  run_btn.click(
370
  generate_three,
371
  inputs=[verbose_box, base_box, trained_box, test_input],
 
47
  "target_model": os.environ.get(
48
  "DEMO_TARGET_MODEL", "meta-llama/Llama-3.2-3B-Instruct"
49
  ),
50
+ # Agent model the trained adapter is built on.
51
+ "agent_model": os.environ.get(
52
+ "DEMO_AGENT_MODEL", "Qwen/Qwen3-1.7B"
53
+ ),
54
+ # Trained LoRA adapter (HF repo or local path). If empty, the
55
+ # "regenerate live" feature stays disabled.
56
+ "agent_adapter": os.environ.get(
57
+ "DEMO_AGENT_ADAPTER",
58
+ "rishabh16196/prompt-golf-qwen-to-llama-nothink",
59
+ ),
60
  # CSV produced by training/build_before_after_csv.py
61
  "demo_csv": os.environ.get(
62
+ "DEMO_CSV",
63
+ str(_REPO_ROOT / "outputs" / "qwen_to_llama_demo.csv"),
64
  ),
65
  # If the CSV isn't local, you can pull it from the hub:
66
  "fallback_csv_url": (
67
+ "https://huggingface.co/rishabh16196/prompt-golf-qwen-to-llama-nothink/"
68
+ "resolve/main/evals/qwen_to_llama_demo.csv"
69
  ),
70
+ "max_new_tokens": 64, # target output cap
71
+ "agent_max_new_tokens": 256, # agent generation cap (no thinking)
72
  "temperature": 0.0,
73
+ # Match the chat template used when the adapter was trained.
74
+ "enable_thinking": False,
75
  }
76
 
77
 
 
118
  _MODEL = None
119
  _DEVICE = None
120
 
121
+ # --- Agent (untrained base + trained-adapter) singletons ---
122
+ _AGENT_TOK = None
123
+ _AGENT_BASE = None # raw Qwen3-1.7B
124
+ _AGENT_TRAINED = None # PeftModel(Qwen3-1.7B, LoRA)
125
+
126
 
127
  def _device() -> str:
128
  if torch.cuda.is_available():
 
160
  flush=True)
161
 
162
 
163
+ def _build_target_chat(prompt: str, test_input: str) -> str:
164
+ """Apply the target's chat template: prompt as system, test_input as user.
165
+
166
+ Chat-tuned targets (Llama-3.2-3B-Instruct, Qwen3-1.7B chat, etc.)
167
+ will ramble in completion mode if you feed them raw text — they try
168
+ to continue the few-shot pattern in the prompt instead of answering.
169
+ """
170
+ messages = [
171
+ {"role": "system", "content": prompt},
172
+ {"role": "user", "content": test_input},
173
+ ]
174
+ if getattr(_TOK, "chat_template", None):
175
+ return _TOK.apply_chat_template(
176
+ messages, tokenize=False, add_generation_prompt=True,
177
+ )
178
+ return f"{prompt}\n\n{test_input}\n\nAssistant:"
179
+
180
+
181
  @torch.inference_mode()
182
  def run_target_batch(prompts: List[str], test_input: str) -> List[str]:
183
  """Run the target on (prompt[i] + test_input) for all i, in one batched
 
187
  incurring inference cost.
188
  """
189
  load_target()
 
190
  full_texts = []
191
  keep_idx = []
192
  for i, p in enumerate(prompts):
193
  if p and p.strip():
194
+ full_texts.append(_build_target_chat(p, test_input))
195
  keep_idx.append(i)
196
  if not full_texts:
197
  return ["" for _ in prompts]
 
233
  return len(_TOK.encode(text or "", add_special_tokens=False))
234
 
235
 
236
+ # ---------------------------------------------------------------------------
237
+ # Agent loader (lazy — only loaded if the user clicks "Regenerate live")
238
+ # ---------------------------------------------------------------------------
239
+
240
+ def load_agents() -> bool:
241
+ """Load Qwen3-1.7B base + LoRA-adapted variant. Returns True on success."""
242
+ global _AGENT_TOK, _AGENT_BASE, _AGENT_TRAINED
243
+ if _AGENT_TRAINED is not None:
244
+ return True
245
+ if not DEFAULTS.get("agent_adapter"):
246
+ return False
247
+ name = DEFAULTS["agent_model"]
248
+ adapter = DEFAULTS["agent_adapter"]
249
+ print(f"[demo] loading agent {name} + adapter {adapter}...", flush=True)
250
+ t0 = time.time()
251
+ _AGENT_TOK = AutoTokenizer.from_pretrained(name)
252
+ _AGENT_TOK.padding_side = "left"
253
+ if _AGENT_TOK.pad_token is None:
254
+ _AGENT_TOK.pad_token = _AGENT_TOK.eos_token
255
+ dev = _device()
256
+ dtype = torch.bfloat16 if dev in ("cuda", "mps") else torch.float32
257
+ _AGENT_BASE = AutoModelForCausalLM.from_pretrained(
258
+ name, torch_dtype=dtype,
259
+ device_map="auto" if dev == "cuda" else None,
260
+ )
261
+ if dev != "cuda":
262
+ _AGENT_BASE = _AGENT_BASE.to(dev)
263
+ _AGENT_BASE.eval()
264
+
265
+ from peft import PeftModel
266
+ # Load adapter on TOP of a SECOND copy of the base (so we keep the raw
267
+ # base for "untrained" generations).
268
+ base_for_adapter = AutoModelForCausalLM.from_pretrained(
269
+ name, torch_dtype=dtype,
270
+ device_map="auto" if dev == "cuda" else None,
271
+ )
272
+ if dev != "cuda":
273
+ base_for_adapter = base_for_adapter.to(dev)
274
+ _AGENT_TRAINED = PeftModel.from_pretrained(base_for_adapter, adapter)
275
+ _AGENT_TRAINED.eval()
276
+ print(f"[demo] agents loaded in {time.time()-t0:.1f}s", flush=True)
277
+ return True
278
+
279
+
280
+ def _build_synthetic_obs(task_id: str):
281
+ """Look up task spec from the bank and return an obs-like SimpleNamespace
282
+ that build_agent_user_message can format."""
283
+ from types import SimpleNamespace
284
+ from prompt_golf_env.server.tasks import TASKS
285
+ from prompt_golf_env.server.tasks_v2 import TASKS_V2
286
+ from prompt_golf_env.server.tasks_tough import TASKS_TOUGH
287
+ from prompt_golf_env.server.tasks_policy import TASKS_POLICY
288
+ bank = {**TASKS, **TASKS_V2, **TASKS_TOUGH, **TASKS_POLICY}
289
+ spec = bank.get(task_id)
290
+ if spec is None:
291
+ return None
292
+ # Use first 3 train_examples as the visible block (matches env default)
293
+ train_ex = [
294
+ {"input": x, "expected": y} for (x, y) in spec.train_examples[:3]
295
+ ]
296
+ return SimpleNamespace(
297
+ task_id=task_id,
298
+ task_category=spec.category,
299
+ task_description=spec.description,
300
+ target_model_id=DEFAULTS["target_model"],
301
+ prompt_budget_tokens=spec.budget_tokens,
302
+ baseline_zero_shot_score=0.0, # unknown without env.reset
303
+ train_examples=train_ex,
304
+ prior_attempts=[],
305
+ )
306
+
307
+
308
+ @torch.inference_mode()
309
+ def _agent_generate(model, tok, chat_str: str, max_new_tokens: int) -> str:
310
+ enc = tok(chat_str, return_tensors="pt").to(_device())
311
+ out = model.generate(
312
+ **enc,
313
+ max_new_tokens=max_new_tokens,
314
+ do_sample=False,
315
+ temperature=1.0,
316
+ pad_token_id=tok.pad_token_id,
317
+ )
318
+ new_ids = out[0][enc["input_ids"].shape[1]:]
319
+ return tok.decode(new_ids, skip_special_tokens=True).strip()
320
+
321
+
322
+ def regenerate_live(label_with_tag: str):
323
+ """For the currently-selected task, ask both agents to write a fresh
324
+ prompt. Returns (base_prompt, trained_prompt, status_msg).
325
+ """
326
+ if not label_with_tag:
327
+ return "", "", "(no task selected)"
328
+ if not load_agents():
329
+ return "", "", ("agent loading disabled — set DEMO_AGENT_ADAPTER "
330
+ "env var to enable live regeneration")
331
+
332
+ # Lazy import — these live in the training/ subdir
333
+ from training.train_grpo import (
334
+ build_chat_prompt, extract_prompt,
335
+ )
336
+
337
+ tid = label_with_tag.split()[0]
338
+ obs = _build_synthetic_obs(tid)
339
+ if obs is None:
340
+ return "", "", f"unknown task: {tid}"
341
+
342
+ chat_str = build_chat_prompt(
343
+ _AGENT_TOK, obs,
344
+ enable_thinking=DEFAULTS["enable_thinking"],
345
+ )
346
+
347
+ t0 = time.time()
348
+ raw_base = _agent_generate(
349
+ _AGENT_BASE, _AGENT_TOK, chat_str,
350
+ max_new_tokens=DEFAULTS["agent_max_new_tokens"],
351
+ )
352
+ t1 = time.time()
353
+ raw_trained = _agent_generate(
354
+ _AGENT_TRAINED, _AGENT_TOK, chat_str,
355
+ max_new_tokens=DEFAULTS["agent_max_new_tokens"],
356
+ )
357
+ t2 = time.time()
358
+
359
+ base_prompt = extract_prompt(raw_base)
360
+ trained_prompt = extract_prompt(raw_trained)
361
+ msg = (
362
+ f"agents regenerated in {t2-t0:.1f}s "
363
+ f"(base {t1-t0:.1f}s, trained {t2-t1:.1f}s) | "
364
+ f"base: {count_tokens(base_prompt)} tok, "
365
+ f"trained: {count_tokens(trained_prompt)} tok"
366
+ )
367
+ return base_prompt, trained_prompt, msg
368
+
369
+
370
  # ---------------------------------------------------------------------------
371
  # Gradio handlers
372
  # ---------------------------------------------------------------------------
 
505
  "should be applied to."),
506
  )
507
 
508
+ with gr.Row():
509
+ regen_btn = gr.Button(
510
+ "Regenerate prompts live (loads agent + LoRA on first click)",
511
+ variant="secondary",
512
+ )
513
+ run_btn = gr.Button(
514
+ "Run target with all three prompts", variant="primary"
515
+ )
516
+ regen_status = gr.Textbox(label="agent status", interactive=False)
517
 
518
  with gr.Row():
519
  with gr.Column():
 
544
  v_tok, b_tok, t_tok, v_acc, b_acc, t_acc, test_input,
545
  ]
546
  task_dd.change(select_task, inputs=[task_dd], outputs=select_outputs)
547
+ regen_btn.click(
548
+ regenerate_live,
549
+ inputs=[task_dd],
550
+ outputs=[base_box, trained_box, regen_status],
551
+ )
552
  run_btn.click(
553
  generate_three,
554
  inputs=[verbose_box, base_box, trained_box, test_input],