File size: 12,553 Bytes
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""Live training visualization callbacks for GRPO/SFT notebooks."""

from __future__ import annotations

import html as _html
import logging
from typing import Any

_logger = logging.getLogger(__name__)

try:
    from transformers import TrainerCallback
except ImportError:
    TrainerCallback = object  # type: ignore[assignment,misc]


class LiveVisualizationCallback(TrainerCallback):
    """TrainerCallback that plots reward and loss in place during training.

    Updates a single plot via IPython display handle without clearing
    the cell output.
    """

    def __init__(self, **kwargs: Any) -> None:
        # Accept and ignore extra kwargs for backward compat
        _ = kwargs
        self.log_steps: list[int] = []
        self.log_rewards: list[float] = []
        self.log_losses: list[float] = []
        self._plot_handle = None

    def on_train_begin(self, args, state, control, **kwargs):
        try:
            from IPython.display import HTML, display

            self._plot_handle = display(
                HTML("<em>Waiting for first log...</em>"),
                display_id="viz_plot",
            )
        except Exception:
            pass

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        step = state.global_step

        # Find reward (prefer mean)
        reward = None
        for key in sorted(logs.keys()):
            if "reward" in key and "mean" in key:
                reward = logs[key]
                break
        if reward is None:
            for key in ("reward", "rewards/mean"):
                if key in logs:
                    reward = logs[key]
                    break

        loss = logs.get("loss")

        has_data = False
        if reward is not None:
            self.log_rewards.append(float(reward))
            has_data = True
        if loss is not None:
            self.log_losses.append(float(loss))
            has_data = True
        if has_data:
            self.log_steps.append(step)

        self._update_plot()

    def _update_plot(self) -> None:
        if self._plot_handle is None:
            return
        try:
            import base64
            import io

            import matplotlib.pyplot as plt
            from IPython.display import HTML

            fig, ax = plt.subplots(1, 1, figsize=(8, 3.5))

            if self.log_rewards:
                ax.plot(
                    self.log_steps[: len(self.log_rewards)],
                    self.log_rewards,
                    "b-o",
                    markersize=3,
                    label="Reward",
                )
                ax.set_ylabel("Reward")
                ax.legend(loc="upper left")

            if self.log_losses:
                # SFT-only: plot loss on primary axis
                # GRPO: plot loss on secondary axis
                if self.log_rewards:
                    ax2 = ax.twinx()
                    ax2.plot(
                        self.log_steps[: len(self.log_losses)],
                        self.log_losses,
                        "r-",
                        alpha=0.4,
                        label="Loss",
                    )
                    ax2.set_ylabel("Loss", color="r", alpha=0.6)
                    ax2.legend(loc="upper right")
                else:
                    ax.plot(
                        self.log_steps[: len(self.log_losses)],
                        self.log_losses,
                        "r-o",
                        markersize=3,
                        label="Loss",
                    )
                    ax.set_ylabel("Loss")
                    ax.legend(loc="upper right")

            ax.set_xlabel("Step")
            latest = self.log_steps[-1] if self.log_steps else 0
            ax.set_title(f"Training Progress (step {latest})")
            ax.grid(True, alpha=0.3)
            plt.tight_layout()

            buf = io.BytesIO()
            fig.savefig(buf, format="png", dpi=100, bbox_inches="tight")
            plt.close(fig)
            buf.seek(0)
            img = base64.b64encode(buf.read()).decode("utf-8")

            self._plot_handle.update(HTML(f'<img src="data:image/png;base64,{img}">'))
        except Exception as exc:
            _logger.debug("Plot update failed: %s", exc)


class SFTMonitorCallback(TrainerCallback):
    """Show sample completions and optional eval accuracy during SFT.

    Every ``eval_every_steps`` training steps the callback generates
    first-turn completions for a handful of prompts so the user can
    watch the model learn tool-calling patterns in real time.
    """

    def __init__(
        self,
        tokenizer: Any,
        sample_prompts: list[list[dict[str, str]]],
        *,
        tools: list[dict] | None = None,
        train_dataset: Any = None,
        eval_every_steps: int = 50,
        max_new_tokens: int = 200,
    ) -> None:
        self.tokenizer = tokenizer
        self.sample_prompts = sample_prompts[:3]
        self.tools = tools
        self.train_dataset = train_dataset
        self.eval_every_steps = eval_every_steps
        self.max_new_tokens = max_new_tokens
        self._model: Any = None
        self._display_handle: Any = None

    # ------------------------------------------------------------------
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        self._model = model
        try:
            from IPython.display import HTML, display

            # Always use canonical tools (avoid Dataset serialization artifacts)
            tpl_tools = self.tools

            # 1) Inference prompt β€” what the model sees at generation time
            if self.sample_prompts:
                tpl_kwargs: dict[str, Any] = {
                    "tokenize": False,
                    "add_generation_prompt": True,
                }
                if tpl_tools:
                    tpl_kwargs["tools"] = tpl_tools
                preview = self.tokenizer.apply_chat_template(
                    self.sample_prompts[0],
                    **tpl_kwargs,
                )
                n_tok = len(self.tokenizer.encode(preview))
                display(
                    HTML(
                        "<details><summary>"
                        f"<b>Inference prompt</b> ({n_tok} tok)"
                        " β€” system + tools + question, "
                        "as seen by model during GRPO generation"
                        "</summary>"
                        "<pre style='background:#2d2d2d;color:#e0e0e0;"
                        "padding:8px;border-radius:4px;font-size:12px;"
                        "white-space:pre-wrap;max-height:600px;"
                        "overflow-y:auto;'>"
                        f"{_html.escape(preview)}</pre></details>"
                    )
                )

            # 2) Training example β€” one per-turn example from the dataset
            if self.train_dataset is not None and len(self.train_dataset) > 0:
                row = self.train_dataset[0]
                msgs = row.get("messages", [])
                if msgs:
                    ex_kwargs: dict[str, Any] = {"tokenize": False}
                    if tpl_tools:
                        ex_kwargs["tools"] = tpl_tools
                    rendered_ex = self.tokenizer.apply_chat_template(
                        msgs,
                        **ex_kwargs,
                    )
                    n_ex_tok = len(self.tokenizer.encode(rendered_ex))
                    n_turns = sum(1 for m in msgs if m.get("role") == "assistant")
                    last_role = msgs[-1].get("role", "?")
                    display(
                        HTML(
                            "<details><summary>"
                            f"<b>SFT training example</b>"
                            f" ({n_ex_tok} tok, {n_turns} asst turn)"
                            " β€” history + one assistant tool_call, "
                            "exactly what the model learns to predict"
                            "</summary>"
                            "<pre style='background:#1a1a2e;color:#e0e0e0;"
                            "padding:8px;border-radius:4px;font-size:12px;"
                            "white-space:pre-wrap;max-height:600px;"
                            "overflow-y:auto;'>"
                            f"{_html.escape(rendered_ex)}</pre>"
                            f"<p style='color:#888;font-size:11px;'>"
                            f"Last message role: <b>{last_role}</b> "
                            f"| Loss is on this turn only</p>"
                            "</details>"
                        )
                    )

            self._display_handle = display(
                HTML("<em>SFT samples: waiting for first checkpoint...</em>"),
                display_id="sft_samples",
            )
        except Exception:
            pass

    def on_log(self, args, state, control, logs=None, **kwargs):
        step = state.global_step
        if step == 0 or step % self.eval_every_steps != 0:
            return
        if self._model is None:
            return
        self._generate_and_display(step)

    def on_train_end(self, args, state, control, **kwargs):
        if self._model is not None:
            self._generate_and_display(state.global_step, final=True)

    # ------------------------------------------------------------------
    def _generate_and_display(self, step: int, final: bool = False) -> None:
        try:
            import torch

            was_training = self._model.training
            self._model.eval()

            header = "SFT Final Samples" if final else f"SFT Samples (step {step})"
            parts = [f"<h4 style='color:#e0e0e0;'>{header}</h4>"]

            with torch.no_grad():
                for messages in self.sample_prompts:
                    question = messages[-1]["content"][:100]
                    tpl_kwargs: dict[str, Any] = {
                        "tokenize": False,
                        "add_generation_prompt": True,
                    }
                    if self.tools:
                        tpl_kwargs["tools"] = self.tools
                    rendered = self.tokenizer.apply_chat_template(
                        messages,
                        **tpl_kwargs,
                    )
                    inputs = self.tokenizer(rendered, return_tensors="pt")
                    device = next(self._model.parameters()).device
                    inputs = {k: v.to(device) for k, v in inputs.items()}

                    has_tools = "<tools>" in rendered
                    prompt_len = inputs["input_ids"].shape[1]

                    out = self._model.generate(
                        **inputs,
                        max_new_tokens=self.max_new_tokens,
                        do_sample=False,
                    )
                    new_tokens = out[0][prompt_len:]
                    raw = self.tokenizer.decode(new_tokens, skip_special_tokens=False)
                    # Show first tool call only (stop at </tool_call>)
                    end = raw.find("</tool_call>")
                    if end != -1:
                        raw = raw[: end + len("</tool_call>")]
                    text = self.tokenizer.decode(
                        self.tokenizer.encode(raw),
                        skip_special_tokens=True,
                    ).strip()

                    badge = (
                        "<span style='color:#4caf50;'>βœ“ tools</span>"
                        if has_tools
                        else "<span style='color:#f44336;'>βœ— no tools</span>"
                    )
                    parts.append(
                        "<pre style='background:#2d2d2d;color:#e0e0e0;"
                        "padding:8px;margin:4px 0;border-radius:4px;"
                        "font-size:13px;line-height:1.4;'>"
                        f"<b style='color:#82aaff;'>Q:</b> "
                        f"{_html.escape(question)}"
                        f" [{badge}, {prompt_len} tok]\n"
                        f"<b style='color:#c3e88d;'>β†’</b> "
                        f"{_html.escape(text)}</pre>"
                    )

            if was_training:
                self._model.train()

            if self._display_handle is not None:
                from IPython.display import HTML

                self._display_handle.update(HTML("\n".join(parts)))
        except Exception as exc:
            _logger.debug("SFT sample generation failed: %s", exc)