aamrinder commited on
Commit
37818f1
·
verified ·
1 Parent(s): 7f60dea

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +2 -0
  2. pyproject.toml +10 -1
  3. train/side_by_side.py +206 -0
README.md CHANGED
@@ -185,6 +185,8 @@ GRPOTrainer(
185
 
186
  **Plots** (`train/plot_reward_decomp.py`): generates the 3-line reward decomposition chart (correctness / reasoning_length / format) from the training log. Saves to `docs/plots/reward_decomposition.png`.
187
 
 
 
188
  ---
189
 
190
  ## Results
 
185
 
186
  **Plots** (`train/plot_reward_decomp.py`): generates the 3-line reward decomposition chart (correctness / reasoning_length / format) from the training log. Saves to `docs/plots/reward_decomposition.png`.
187
 
188
+ **Side-by-side demo** (`train/side_by_side.py`): runs both the base Qwen2.5-3B and the trained checkpoint on hand-picked Pivot clips, dumps an HTML page with their reasoning traces side-by-side. This is the demo artifact judges read.
189
+
190
  ---
191
 
192
  ## Results
pyproject.toml CHANGED
@@ -36,4 +36,13 @@ server = "subtext_arena.server.app:main"
36
  [tool.setuptools]
37
  include-package-data = true
38
  packages = ["subtext_arena", "subtext_arena.server", "subtext_arena.train"]
39
- package-dir = { "subtext_arena" = ".", "subtext_arena.server" = "server", "subtext_arena.train" = "train" }
 
 
 
 
 
 
 
 
 
 
36
  [tool.setuptools]
37
  include-package-data = true
38
  packages = ["subtext_arena", "subtext_arena.server", "subtext_arena.train"]
39
+ package-dir = { "subtext_arena" = ".", "subtext_arena.server" = "server", "subtext_arena.train" = "train" }
40
+
41
+ [tool.setuptools.package-data]
42
+ subtext_arena = [
43
+ "data/sarcasm_data.json",
44
+ "data/pivot_set.json",
45
+ "data/prosody_cache/utterances/*.json",
46
+ "data/prosody_cache/context/*.json",
47
+ "openenv.yaml",
48
+ ]
train/side_by_side.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate side-by-side baseline-vs-trained reasoning for hand-picked clips.
2
+
3
+ This is the demo artifact: judges look at it and read what the model learned.
4
+ Output: an HTML table that can be embedded in the README + a JSON dump.
5
+
6
+ For each clip:
7
+ - Run the BASE Qwen2.5-3B (no LoRA) and dump <think> + <final>
8
+ - Run the TRAINED checkpoint and dump <think> + <final>
9
+ - Show gold label, both predictions, and which (if either) was right
10
+
11
+ Usage:
12
+ python -m subtext_arena.train.side_by_side \\
13
+ --trained-checkpoint ./checkpoints/run1 \\
14
+ --clip-ids 1_70 2_190 1_8826 2_236 2_300 \\
15
+ --out docs/plots/side_by_side.html
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import json
21
+ import sys
22
+ from pathlib import Path
23
+ from typing import List
24
+
25
+ try:
26
+ from subtext_arena.server.scenarios import load_scenarios
27
+ from subtext_arena.train.train_grpo import (
28
+ SYSTEM_PROMPT, build_full_observation, parse_final, reward_decomposition,
29
+ )
30
+ except ImportError:
31
+ ROOT = Path(__file__).resolve().parent.parent
32
+ if str(ROOT) not in sys.path:
33
+ sys.path.insert(0, str(ROOT))
34
+ from server.scenarios import load_scenarios # type: ignore[no-redef]
35
+ from train.train_grpo import ( # type: ignore[no-redef]
36
+ SYSTEM_PROMPT, build_full_observation, parse_final, reward_decomposition,
37
+ )
38
+
39
+
40
+ HTML_TEMPLATE = """<!DOCTYPE html>
41
+ <html><head>
42
+ <meta charset="utf-8">
43
+ <title>Subtext Arena — baseline vs trained, hand-picked clips</title>
44
+ <style>
45
+ body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
46
+ max-width: 1200px; margin: 40px auto; padding: 0 20px; color: #222; }}
47
+ h1 {{ font-size: 24px; }}
48
+ .clip {{ border: 1px solid #ddd; border-radius: 8px; padding: 16px;
49
+ margin-bottom: 24px; background: #fafafa; }}
50
+ .clip h2 {{ font-size: 18px; margin-top: 0; }}
51
+ .gold-sarcastic {{ color: #b3274d; }}
52
+ .gold-sincere {{ color: #1d7a4a; }}
53
+ .columns {{ display: grid; grid-template-columns: 1fr 1fr; gap: 16px; }}
54
+ .col {{ padding: 12px; border-radius: 6px; }}
55
+ .baseline {{ background: #fff5f5; border: 1px solid #f8c4c4; }}
56
+ .trained {{ background: #effaf3; border: 1px solid #b6e2c1; }}
57
+ .col h3 {{ margin-top: 0; font-size: 14px; text-transform: uppercase;
58
+ letter-spacing: 0.05em; color: #666; }}
59
+ .verdict-correct {{ color: #1d7a4a; font-weight: bold; }}
60
+ .verdict-wrong {{ color: #b3274d; font-weight: bold; }}
61
+ pre {{ white-space: pre-wrap; word-wrap: break-word; font-size: 13px;
62
+ line-height: 1.4; background: white; padding: 8px; border-radius: 4px;
63
+ border: 1px solid #eee; }}
64
+ .transcript {{ font-style: italic; color: #555; margin-bottom: 12px; }}
65
+ </style>
66
+ </head><body>
67
+ <h1>Subtext Arena — baseline vs trained</h1>
68
+ <p>Same prompt fed to the base Qwen2.5-3B-Instruct (left) and to the GRPO-trained
69
+ checkpoint (right). Each shows the model's reasoning trace and final answer.</p>
70
+ {clips_html}
71
+ </body></html>
72
+ """
73
+
74
+ CLIP_BLOCK = """<div class="clip">
75
+ <h2>Clip {clip_id} — speaker: {speaker}, gold: <span class="gold-{gold}">{gold}</span></h2>
76
+ <div class="transcript">"{utterance}"</div>
77
+ <div class="columns">
78
+ <div class="col baseline">
79
+ <h3>Baseline (no training)</h3>
80
+ <p>predicted: <span class="verdict-{baseline_verdict}">{baseline_label}</span> (conf {baseline_conf:.2f})</p>
81
+ <pre>{baseline_text}</pre>
82
+ </div>
83
+ <div class="col trained">
84
+ <h3>Trained checkpoint</h3>
85
+ <p>predicted: <span class="verdict-{trained_verdict}">{trained_label}</span> (conf {trained_conf:.2f})</p>
86
+ <pre>{trained_text}</pre>
87
+ </div>
88
+ </div>
89
+ </div>
90
+ """
91
+
92
+
93
+ def generate_completion(model, tokenizer, prompt_user_msg, max_tokens=600, temperature=0.7):
94
+ messages = [
95
+ {"role": "system", "content": SYSTEM_PROMPT},
96
+ {"role": "user", "content": prompt_user_msg},
97
+ ]
98
+ inputs = tokenizer.apply_chat_template(
99
+ messages, return_tensors="pt", add_generation_prompt=True
100
+ ).to(model.device)
101
+ out = model.generate(
102
+ inputs, max_new_tokens=max_tokens, do_sample=True,
103
+ temperature=temperature, pad_token_id=tokenizer.eos_token_id,
104
+ )
105
+ return tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
106
+
107
+
108
+ def main():
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("--trained-checkpoint", required=True)
111
+ parser.add_argument("--base-model", default="unsloth/Qwen2.5-3B-Instruct")
112
+ parser.add_argument("--clip-ids", nargs="+", required=True,
113
+ help="Hand-picked clip IDs for the side-by-side")
114
+ parser.add_argument("--out", required=True, help="Output HTML path")
115
+ parser.add_argument("--out-json", default=None, help="Optional JSON dump")
116
+ args = parser.parse_args()
117
+
118
+ scenarios = load_scenarios()
119
+ from unsloth import FastLanguageModel
120
+
121
+ rows = []
122
+
123
+ # Run baseline (no LoRA)
124
+ print(f"[load] base model: {args.base_model}")
125
+ base_model, base_tok = FastLanguageModel.from_pretrained(
126
+ model_name=args.base_model, max_seq_length=4096, load_in_4bit=True,
127
+ )
128
+ FastLanguageModel.for_inference(base_model)
129
+
130
+ for clip_id in args.clip_ids:
131
+ gold = "sarcastic" if scenarios[clip_id]["sarcasm"] else "sincere"
132
+ prompt_user = build_full_observation(clip_id, scenarios)
133
+ text = generate_completion(base_model, base_tok, prompt_user)
134
+ d = reward_decomposition(text, gold)
135
+ rows.append({
136
+ "clip_id": clip_id,
137
+ "speaker": scenarios[clip_id]["speaker"],
138
+ "utterance": scenarios[clip_id]["utterance"],
139
+ "gold": gold,
140
+ "baseline": {
141
+ "label": d["_predicted"], "confidence": d["_confidence"],
142
+ "correct": d["_correct"], "text": text[:1200],
143
+ },
144
+ })
145
+ print(f" baseline {clip_id}: pred={d['_predicted']} (conf={d['_confidence']:.2f}) correct={d['_correct']}")
146
+
147
+ # Free the base model to make room for the trained one
148
+ del base_model
149
+ import torch
150
+ torch.cuda.empty_cache()
151
+
152
+ # Run trained checkpoint
153
+ print(f"[load] trained checkpoint: {args.trained_checkpoint}")
154
+ trained_model, trained_tok = FastLanguageModel.from_pretrained(
155
+ model_name=args.trained_checkpoint, max_seq_length=4096, load_in_4bit=True,
156
+ )
157
+ FastLanguageModel.for_inference(trained_model)
158
+
159
+ for row in rows:
160
+ clip_id = row["clip_id"]
161
+ prompt_user = build_full_observation(clip_id, scenarios)
162
+ text = generate_completion(trained_model, trained_tok, prompt_user)
163
+ d = reward_decomposition(text, row["gold"])
164
+ row["trained"] = {
165
+ "label": d["_predicted"], "confidence": d["_confidence"],
166
+ "correct": d["_correct"], "text": text[:1200],
167
+ }
168
+ print(f" trained {clip_id}: pred={d['_predicted']} (conf={d['_confidence']:.2f}) correct={d['_correct']}")
169
+
170
+ # Render HTML
171
+ clips_html_parts = []
172
+ for row in rows:
173
+ b = row["baseline"]; t = row["trained"]
174
+ clips_html_parts.append(CLIP_BLOCK.format(
175
+ clip_id=row["clip_id"], speaker=row["speaker"],
176
+ utterance=row["utterance"].replace('"', '&quot;'),
177
+ gold=row["gold"],
178
+ baseline_label=b["label"] or "—",
179
+ baseline_conf=b["confidence"],
180
+ baseline_verdict=("correct" if b["correct"] else "wrong"),
181
+ baseline_text=(b["text"] or "(no output)").replace("<", "&lt;").replace(">", "&gt;"),
182
+ trained_label=t["label"] or "—",
183
+ trained_conf=t["confidence"],
184
+ trained_verdict=("correct" if t["correct"] else "wrong"),
185
+ trained_text=(t["text"] or "(no output)").replace("<", "&lt;").replace(">", "&gt;"),
186
+ ))
187
+ html = HTML_TEMPLATE.format(clips_html="\n".join(clips_html_parts))
188
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
189
+ Path(args.out).write_text(html)
190
+ print(f"[done] wrote {args.out}")
191
+
192
+ if args.out_json:
193
+ Path(args.out_json).write_text(json.dumps(rows, indent=2))
194
+ print(f"[done] wrote {args.out_json}")
195
+
196
+ # Tally
197
+ n_baseline_correct = sum(1 for r in rows if r["baseline"]["correct"])
198
+ n_trained_correct = sum(1 for r in rows if r["trained"]["correct"])
199
+ print()
200
+ print(f"Tally on {len(rows)} hand-picked clips:")
201
+ print(f" baseline: {n_baseline_correct}/{len(rows)} correct")
202
+ print(f" trained: {n_trained_correct}/{len(rows)} correct")
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()