aamrinder commited on
Commit
a008aa6
·
verified ·
1 Parent(s): 8f39abd

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. train/curate_pivot_set.py +29 -16
  2. train/sft_warmup.py +329 -0
train/curate_pivot_set.py CHANGED
@@ -63,12 +63,19 @@ def cmd_baseline(args: argparse.Namespace) -> None:
63
  )
64
  model.eval()
65
 
 
 
66
  PROMPT = (
67
- "You are an expert at detecting sarcasm in dialogue.\n"
68
- "Read the conversational context and the target line.\n"
69
- "Answer with EXACTLY one of: sarcastic | sincere\n"
70
- "Then on the next line, your confidence as a float 0..1.\n\n"
71
- "{ctx}\n\nTarget:\n[{spk}] {utt}\n\nAnswer:"
 
 
 
 
 
72
  )
73
 
74
  out: Dict[str, Dict] = {}
@@ -95,19 +102,25 @@ def cmd_baseline(args: argparse.Namespace) -> None:
95
  )
96
  text = tok.decode(o[0][ids.input_ids.shape[1]:], skip_special_tokens=True).strip()
97
 
98
- first_word = text.lower().split()[0] if text.strip() else ""
99
- pred_label = "sarcastic" if "sarc" in first_word else "sincere"
100
- # Try to extract confidence
101
  conf = 0.5
102
  for line in text.splitlines():
103
- line = line.strip()
104
- try:
105
- v = float(line)
106
- if 0.0 <= v <= 1.0:
107
- conf = v
108
- break
109
- except ValueError:
110
- continue
 
 
 
 
 
 
 
111
  out[clip_id] = {
112
  "predicted": pred_label,
113
  "confidence": conf,
 
63
  )
64
  model.eval()
65
 
66
+ # Neutral prompt: avoid the word "sarcasm" before the model has answered,
67
+ # list "sincere" before "sarcastic" to fight the prefix bias of Qwen2.5-3B.
68
  PROMPT = (
69
+ "You will read a line of TV dialogue with its conversational context.\n"
70
+ "Decide whether the speaker is being sincere (means what they say) "
71
+ "or sarcastic (means the opposite of what they say).\n\n"
72
+ "{ctx}\n\n"
73
+ "Target line:\n[{spk}] {utt}\n\n"
74
+ "Output exactly two lines, in this format:\n"
75
+ "Label: sincere\n"
76
+ "Confidence: 0.7\n\n"
77
+ "Now classify the target line above.\n"
78
+ "Output:\n"
79
  )
80
 
81
  out: Dict[str, Dict] = {}
 
102
  )
103
  text = tok.decode(o[0][ids.input_ids.shape[1]:], skip_special_tokens=True).strip()
104
 
105
+ # Parse "Label: X\nConfidence: Y" format
106
+ pred_label = "sincere" # default to sincere if parsing fails (less biased)
 
107
  conf = 0.5
108
  for line in text.splitlines():
109
+ stripped = line.strip().lower()
110
+ if stripped.startswith("label:"):
111
+ value = stripped[len("label:"):].strip()
112
+ if "sarc" in value:
113
+ pred_label = "sarcastic"
114
+ elif "sinc" in value:
115
+ pred_label = "sincere"
116
+ elif stripped.startswith("confidence:"):
117
+ value = stripped[len("confidence:"):].strip()
118
+ try:
119
+ v = float(value)
120
+ if 0.0 <= v <= 1.0:
121
+ conf = v
122
+ except ValueError:
123
+ pass
124
  out[clip_id] = {
125
  "predicted": pred_label,
126
  "confidence": conf,
train/sft_warmup.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SFT warmup: bootstrap the format + reasoning skeleton before GRPO.
2
+
3
+ Why: GRPO from a base model spends ~50-100 steps just learning to emit
4
+ `<think>...</think><final>{...}</final>`. SFT on ~100 ideal completions
5
+ gets the format perfect upfront so all GRPO steps focus on improving
6
+ correctness.
7
+
8
+ How: we synthesize "ideal" completions deterministically from the gold label
9
+ + actual prosody features. No API call. The reasoning text references the
10
+ real prosody numbers. The final tag uses the gold label.
11
+
12
+ This is a min-cost run that:
13
+ - Generates 100 (prompt, ideal) pairs locally
14
+ - SFTs Qwen2.5-3B + LoRA for 1 epoch (~5-10 min on L4)
15
+ - Saves the LoRA checkpoint to HF Hub
16
+ - Prints 3 before/after sample completions for visual inspection
17
+
18
+ Output: LoRA adapter pushed to HF Hub at aamrinder/subtext-arena-sft
19
+ (then GRPO Run #1 starts FROM this adapter, not from vanilla Qwen)
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import json
24
+ import os
25
+ import random
26
+ import sys
27
+ from pathlib import Path
28
+ from typing import Dict, List
29
+
30
+ ROOT = Path(__file__).resolve().parent.parent
31
+ if str(ROOT) not in sys.path:
32
+ sys.path.insert(0, str(ROOT))
33
+
34
+ try:
35
+ from subtext_arena.server.scenarios import load_scenarios
36
+ from subtext_arena.train.train_grpo import (
37
+ SYSTEM_PROMPT, build_full_observation,
38
+ )
39
+ except ImportError:
40
+ from server.scenarios import load_scenarios
41
+ from train.train_grpo import SYSTEM_PROMPT, build_full_observation
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Generate ideal completions deterministically from gold + prosody
46
+ # ---------------------------------------------------------------------------
47
+
48
+ POSITIVE_WORDS = {
49
+ "great", "wonderful", "amazing", "perfect", "love", "brilliant",
50
+ "privilege", "lovely", "fantastic", "excellent", "fabulous", "delighted",
51
+ "thrilled", "grateful", "honored", "blessed",
52
+ }
53
+
54
+
55
+ def synth_ideal_completion(scenario: Dict) -> str:
56
+ """Build a high-quality (think + final) target completion for SFT.
57
+
58
+ Reasoning is grounded in the actual prosody numbers and gold label.
59
+ """
60
+ gold = "sarcastic" if scenario["sarcasm"] else "sincere"
61
+ prosody = scenario.get("prosody") or {}
62
+ utt = scenario.get("utterance", "").lower()
63
+
64
+ pitch_var = float(prosody.get("pitch_var_hz", 0.0))
65
+ pitch_mean = float(prosody.get("pitch_mean_hz", 0.0))
66
+ pre_pause = int(prosody.get("pre_pause_ms", 0))
67
+ pauses = prosody.get("pauses", []) or []
68
+ voiced_ratio = float(prosody.get("voiced_ratio", 0.0))
69
+ n_internal = len(pauses)
70
+
71
+ parts = []
72
+
73
+ # 1) Lexical observation
74
+ pos_count = sum(1 for w in utt.split() if w.strip(".,!?'\"") in POSITIVE_WORDS)
75
+ if pos_count > 0:
76
+ parts.append(
77
+ f"The literal lexical content is positive ({pos_count} positive word"
78
+ f"{'s' if pos_count > 1 else ''}: {', '.join(w for w in utt.split() if w.strip(chr(46)+chr(44)+chr(33)+chr(63)+chr(39)+chr(34)) in POSITIVE_WORDS)[:120]})."
79
+ )
80
+ elif "?" in utt:
81
+ parts.append("The line is phrased as a question.")
82
+ else:
83
+ parts.append("The lexical content is neutral or descriptive.")
84
+
85
+ # 2) Prosody observation (only if features are reliable)
86
+ if voiced_ratio < 0.1:
87
+ parts.append(
88
+ "Prosody is unreliable for this clip (low voiced-frame ratio). "
89
+ "Lexical and contextual cues should dominate."
90
+ )
91
+ prosody_evidence = "weak"
92
+ else:
93
+ if pitch_var > 45:
94
+ parts.append(
95
+ f"Pitch variability is HIGH ({pitch_var:.0f} Hz over voiced frames), "
96
+ "suggesting exaggerated melodic delivery."
97
+ )
98
+ prosody_evidence = "exaggerated"
99
+ elif pitch_var < 25:
100
+ parts.append(
101
+ f"Pitch variability is LOW ({pitch_var:.0f} Hz), suggesting "
102
+ "flat or minimally inflected delivery."
103
+ )
104
+ prosody_evidence = "flat"
105
+ else:
106
+ parts.append(
107
+ f"Pitch variability is moderate ({pitch_var:.0f} Hz), neither "
108
+ "flat nor exaggerated."
109
+ )
110
+ prosody_evidence = "moderate"
111
+
112
+ if pre_pause >= 250:
113
+ parts.append(
114
+ f"There is a {pre_pause}ms pre-utterance pause — speakers often "
115
+ "use such setup pauses for ironic or emphatic delivery."
116
+ )
117
+ if n_internal >= 1:
118
+ parts.append(
119
+ f"There {'is' if n_internal == 1 else 'are'} {n_internal} internal pause"
120
+ f"{'' if n_internal == 1 else 's'} >150ms within the utterance."
121
+ )
122
+
123
+ # 3) Conclusion grounded in the evidence
124
+ if gold == "sarcastic":
125
+ if pos_count > 0 and prosody_evidence == "exaggerated":
126
+ parts.append(
127
+ "Positive lexical content combined with exaggerated melodic "
128
+ "delivery is the signature pattern of sarcastic delivery — "
129
+ "the words say one thing, the prosody says the opposite."
130
+ )
131
+ elif prosody_evidence == "exaggerated":
132
+ parts.append(
133
+ "Exaggerated prosodic shape on otherwise non-emphatic content "
134
+ "is consistent with mock or ironic delivery."
135
+ )
136
+ else:
137
+ parts.append(
138
+ "Subtle cues taken together (delivery, emphasis pause, "
139
+ "context) suggest the speaker is being ironic rather than literal."
140
+ )
141
+ else:
142
+ if prosody_evidence == "flat":
143
+ parts.append(
144
+ "Flat prosodic delivery on neutral or genuine content "
145
+ "indicates the speaker means what they say."
146
+ )
147
+ elif pos_count > 0 and prosody_evidence != "exaggerated":
148
+ parts.append(
149
+ "Positive lexical content paired with non-exaggerated delivery "
150
+ "indicates sincere expression."
151
+ )
152
+ else:
153
+ parts.append(
154
+ "Lacking strong markers of irony, the speaker appears to be "
155
+ "expressing genuine intent."
156
+ )
157
+
158
+ think_text = " ".join(parts)
159
+ final_json = json.dumps({"label": gold, "confidence": 0.85}, separators=(",", ":"))
160
+ return f"<think>\n{think_text}\n</think>\n<final>{final_json}</final>"
161
+
162
+
163
+ def build_sft_dataset(scenarios, n_rows: int = 100, seed: int = 0):
164
+ """Pick n_rows clips, build (prompt, ideal_completion) pairs."""
165
+ from datasets import Dataset
166
+
167
+ rng = random.Random(seed)
168
+ # Balance classes in the SFT set
169
+ sarc_ids = [k for k, v in scenarios.items() if v["sarcasm"]]
170
+ sinc_ids = [k for k, v in scenarios.items() if not v["sarcasm"]]
171
+ rng.shuffle(sarc_ids); rng.shuffle(sinc_ids)
172
+ chosen = sarc_ids[: n_rows // 2] + sinc_ids[: n_rows - n_rows // 2]
173
+ rng.shuffle(chosen)
174
+
175
+ rows = []
176
+ for cid in chosen:
177
+ sc = scenarios[cid]
178
+ user_text = build_full_observation(cid, scenarios)
179
+ ideal = synth_ideal_completion(sc)
180
+ # Use the chat-completion format — Qwen2.5-Instruct expects this
181
+ rows.append({
182
+ "messages": [
183
+ {"role": "system", "content": SYSTEM_PROMPT},
184
+ {"role": "user", "content": user_text},
185
+ {"role": "assistant", "content": ideal},
186
+ ],
187
+ "clip_id": cid,
188
+ "gold": "sarcastic" if sc["sarcasm"] else "sincere",
189
+ })
190
+ return Dataset.from_list(rows)
191
+
192
+
193
+ def sample_before_after(model, tokenizer, scenarios, sample_clip_ids, label_for_log: str):
194
+ """Generate completions on a few held-out clips for visual inspection."""
195
+ print(f"\n----- Sample completions ({label_for_log}) -----")
196
+ for cid in sample_clip_ids:
197
+ sc = scenarios[cid]
198
+ gold = "sarcastic" if sc["sarcasm"] else "sincere"
199
+ messages = [
200
+ {"role": "system", "content": SYSTEM_PROMPT},
201
+ {"role": "user", "content": build_full_observation(cid, scenarios)},
202
+ ]
203
+ inputs = tokenizer.apply_chat_template(
204
+ messages, return_tensors="pt", add_generation_prompt=True
205
+ ).to(model.device)
206
+ out = model.generate(
207
+ inputs, max_new_tokens=350, do_sample=True, temperature=0.7,
208
+ pad_token_id=tokenizer.eos_token_id,
209
+ )
210
+ text = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
211
+ print(f"\nClip {cid} (gold={gold}, speaker={sc.get('speaker')}):")
212
+ print(text[:1000])
213
+ print("---")
214
+
215
+
216
+ # ---------------------------------------------------------------------------
217
+ # Main
218
+ # ---------------------------------------------------------------------------
219
+
220
+ def main():
221
+ import argparse
222
+ parser = argparse.ArgumentParser()
223
+ parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct")
224
+ parser.add_argument("--n-rows", type=int, default=100)
225
+ parser.add_argument("--epochs", type=int, default=1)
226
+ parser.add_argument("--lora-r", type=int, default=8)
227
+ parser.add_argument("--learning-rate", type=float, default=2e-4)
228
+ parser.add_argument("--output-dir", default="/tmp/sft_warmup")
229
+ parser.add_argument("--push-to-hub", default=None,
230
+ help="If set, e.g. 'aamrinder/subtext-arena-sft', push the LoRA there")
231
+ parser.add_argument("--n-sample-clips", type=int, default=3,
232
+ help="How many clips to generate before/after samples on")
233
+ args = parser.parse_args()
234
+
235
+ print(f"[load-scenarios]")
236
+ scenarios = load_scenarios()
237
+ print(f" {len(scenarios)} clips")
238
+
239
+ print(f"[build-sft-dataset] n_rows={args.n_rows}")
240
+ ds = build_sft_dataset(scenarios, n_rows=args.n_rows)
241
+ print(f" {len(ds)} (prompt, ideal-completion) pairs")
242
+ print(f" first ideal completion preview:")
243
+ first_msgs = ds[0]["messages"]
244
+ print(" " + first_msgs[-1]["content"].replace("\n", "\n ")[:400])
245
+
246
+ # Pick held-out clips for before/after sampling
247
+ sample_ids = [k for k in scenarios.keys() if k not in {r["clip_id"] for r in ds}][: args.n_sample_clips]
248
+
249
+ # Load model
250
+ print(f"\n[load-model] {args.model}, 4-bit + LoRA")
251
+ import torch as _t
252
+ from transformers import (
253
+ AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
254
+ )
255
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
256
+ from trl import SFTTrainer, SFTConfig
257
+
258
+ bnb = BitsAndBytesConfig(
259
+ load_in_4bit=True,
260
+ bnb_4bit_compute_dtype=_t.bfloat16,
261
+ bnb_4bit_quant_type="nf4",
262
+ bnb_4bit_use_double_quant=True,
263
+ )
264
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
265
+ if tokenizer.pad_token is None:
266
+ tokenizer.pad_token = tokenizer.eos_token
267
+ base = AutoModelForCausalLM.from_pretrained(
268
+ args.model, quantization_config=bnb, dtype=_t.bfloat16, device_map="auto",
269
+ )
270
+ base = prepare_model_for_kbit_training(base, use_gradient_checkpointing=True)
271
+ peft_config = LoraConfig(
272
+ r=args.lora_r, lora_alpha=args.lora_r, lora_dropout=0.0, bias="none",
273
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
274
+ task_type="CAUSAL_LM",
275
+ )
276
+ model = get_peft_model(base, peft_config)
277
+
278
+ # Sample BEFORE training
279
+ sample_before_after(model, tokenizer, scenarios, sample_ids, "BEFORE SFT")
280
+
281
+ # SFT training
282
+ print(f"\n[sft-train] {args.epochs} epoch(s), lr={args.learning_rate}")
283
+ config = SFTConfig(
284
+ output_dir=args.output_dir,
285
+ num_train_epochs=args.epochs,
286
+ per_device_train_batch_size=2,
287
+ gradient_accumulation_steps=4,
288
+ learning_rate=args.learning_rate,
289
+ bf16=True,
290
+ gradient_checkpointing=True,
291
+ logging_steps=2,
292
+ save_strategy="no",
293
+ report_to="none",
294
+ max_length=4096,
295
+ )
296
+ trainer = SFTTrainer(
297
+ model=model,
298
+ args=config,
299
+ train_dataset=ds,
300
+ processing_class=tokenizer,
301
+ )
302
+ trainer.train()
303
+ trainer.save_model(args.output_dir)
304
+ print(f"\n[done] LoRA adapter saved to {args.output_dir}")
305
+
306
+ # Sample AFTER training
307
+ sample_before_after(model, tokenizer, scenarios, sample_ids, "AFTER SFT")
308
+
309
+ # Optional: push to HF Hub
310
+ if args.push_to_hub:
311
+ from huggingface_hub import HfApi
312
+ api = HfApi()
313
+ # Create the repo first (idempotent)
314
+ try:
315
+ api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
316
+ except Exception as e:
317
+ print(f"[warn] create_repo: {e}")
318
+ # Upload the LoRA adapter directory
319
+ api.upload_folder(
320
+ folder_path=args.output_dir,
321
+ repo_id=args.push_to_hub,
322
+ repo_type="model",
323
+ commit_message=f"SFT warmup checkpoint ({args.n_rows} examples, {args.epochs} epoch)",
324
+ )
325
+ print(f"[done] LoRA adapter pushed to https://huggingface.co/{args.push_to_hub}")
326
+
327
+
328
+ if __name__ == "__main__":
329
+ main()