| """Live training visualization callbacks for GRPO/SFT notebooks.""" |
|
|
| from __future__ import annotations |
|
|
| import html as _html |
| import logging |
| from typing import Any |
|
|
| _logger = logging.getLogger(__name__) |
|
|
| try: |
| from transformers import TrainerCallback |
| except ImportError: |
| TrainerCallback = object |
|
|
|
|
| class LiveVisualizationCallback(TrainerCallback): |
| """TrainerCallback that plots reward and loss in place during training. |
| |
| Updates a single plot via IPython display handle without clearing |
| the cell output. |
| """ |
|
|
| def __init__(self, **kwargs: Any) -> None: |
| |
| _ = kwargs |
| self.log_steps: list[int] = [] |
| self.log_rewards: list[float] = [] |
| self.log_losses: list[float] = [] |
| self._plot_handle = None |
|
|
| def on_train_begin(self, args, state, control, **kwargs): |
| try: |
| from IPython.display import HTML, display |
|
|
| self._plot_handle = display( |
| HTML("<em>Waiting for first log...</em>"), |
| display_id="viz_plot", |
| ) |
| except Exception: |
| pass |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if not logs: |
| return |
| step = state.global_step |
|
|
| |
| reward = None |
| for key in sorted(logs.keys()): |
| if "reward" in key and "mean" in key: |
| reward = logs[key] |
| break |
| if reward is None: |
| for key in ("reward", "rewards/mean"): |
| if key in logs: |
| reward = logs[key] |
| break |
|
|
| loss = logs.get("loss") |
|
|
| has_data = False |
| if reward is not None: |
| self.log_rewards.append(float(reward)) |
| has_data = True |
| if loss is not None: |
| self.log_losses.append(float(loss)) |
| has_data = True |
| if has_data: |
| self.log_steps.append(step) |
|
|
| self._update_plot() |
|
|
| def _update_plot(self) -> None: |
| if self._plot_handle is None: |
| return |
| try: |
| import base64 |
| import io |
|
|
| import matplotlib.pyplot as plt |
| from IPython.display import HTML |
|
|
| fig, ax = plt.subplots(1, 1, figsize=(8, 3.5)) |
|
|
| if self.log_rewards: |
| ax.plot( |
| self.log_steps[: len(self.log_rewards)], |
| self.log_rewards, |
| "b-o", |
| markersize=3, |
| label="Reward", |
| ) |
| ax.set_ylabel("Reward") |
| ax.legend(loc="upper left") |
|
|
| if self.log_losses: |
| |
| |
| if self.log_rewards: |
| ax2 = ax.twinx() |
| ax2.plot( |
| self.log_steps[: len(self.log_losses)], |
| self.log_losses, |
| "r-", |
| alpha=0.4, |
| label="Loss", |
| ) |
| ax2.set_ylabel("Loss", color="r", alpha=0.6) |
| ax2.legend(loc="upper right") |
| else: |
| ax.plot( |
| self.log_steps[: len(self.log_losses)], |
| self.log_losses, |
| "r-o", |
| markersize=3, |
| label="Loss", |
| ) |
| ax.set_ylabel("Loss") |
| ax.legend(loc="upper right") |
|
|
| ax.set_xlabel("Step") |
| latest = self.log_steps[-1] if self.log_steps else 0 |
| ax.set_title(f"Training Progress (step {latest})") |
| ax.grid(True, alpha=0.3) |
| plt.tight_layout() |
|
|
| buf = io.BytesIO() |
| fig.savefig(buf, format="png", dpi=100, bbox_inches="tight") |
| plt.close(fig) |
| buf.seek(0) |
| img = base64.b64encode(buf.read()).decode("utf-8") |
|
|
| self._plot_handle.update(HTML(f'<img src="data:image/png;base64,{img}">')) |
| except Exception as exc: |
| _logger.debug("Plot update failed: %s", exc) |
|
|
|
|
| class SFTMonitorCallback(TrainerCallback): |
| """Show sample completions and optional eval accuracy during SFT. |
| |
| Every ``eval_every_steps`` training steps the callback generates |
| first-turn completions for a handful of prompts so the user can |
| watch the model learn tool-calling patterns in real time. |
| """ |
|
|
| def __init__( |
| self, |
| tokenizer: Any, |
| sample_prompts: list[list[dict[str, str]]], |
| *, |
| tools: list[dict] | None = None, |
| train_dataset: Any = None, |
| eval_every_steps: int = 50, |
| max_new_tokens: int = 200, |
| ) -> None: |
| self.tokenizer = tokenizer |
| self.sample_prompts = sample_prompts[:3] |
| self.tools = tools |
| self.train_dataset = train_dataset |
| self.eval_every_steps = eval_every_steps |
| self.max_new_tokens = max_new_tokens |
| self._model: Any = None |
| self._display_handle: Any = None |
|
|
| |
| def on_train_begin(self, args, state, control, model=None, **kwargs): |
| self._model = model |
| try: |
| from IPython.display import HTML, display |
|
|
| |
| tpl_tools = self.tools |
|
|
| |
| if self.sample_prompts: |
| tpl_kwargs: dict[str, Any] = { |
| "tokenize": False, |
| "add_generation_prompt": True, |
| } |
| if tpl_tools: |
| tpl_kwargs["tools"] = tpl_tools |
| preview = self.tokenizer.apply_chat_template( |
| self.sample_prompts[0], |
| **tpl_kwargs, |
| ) |
| n_tok = len(self.tokenizer.encode(preview)) |
| display( |
| HTML( |
| "<details><summary>" |
| f"<b>Inference prompt</b> ({n_tok} tok)" |
| " — system + tools + question, " |
| "as seen by model during GRPO generation" |
| "</summary>" |
| "<pre style='background:#2d2d2d;color:#e0e0e0;" |
| "padding:8px;border-radius:4px;font-size:12px;" |
| "white-space:pre-wrap;max-height:600px;" |
| "overflow-y:auto;'>" |
| f"{_html.escape(preview)}</pre></details>" |
| ) |
| ) |
|
|
| |
| if self.train_dataset is not None and len(self.train_dataset) > 0: |
| row = self.train_dataset[0] |
| msgs = row.get("messages", []) |
| if msgs: |
| ex_kwargs: dict[str, Any] = {"tokenize": False} |
| if tpl_tools: |
| ex_kwargs["tools"] = tpl_tools |
| rendered_ex = self.tokenizer.apply_chat_template( |
| msgs, |
| **ex_kwargs, |
| ) |
| n_ex_tok = len(self.tokenizer.encode(rendered_ex)) |
| n_turns = sum(1 for m in msgs if m.get("role") == "assistant") |
| last_role = msgs[-1].get("role", "?") |
| display( |
| HTML( |
| "<details><summary>" |
| f"<b>SFT training example</b>" |
| f" ({n_ex_tok} tok, {n_turns} asst turn)" |
| " — history + one assistant tool_call, " |
| "exactly what the model learns to predict" |
| "</summary>" |
| "<pre style='background:#1a1a2e;color:#e0e0e0;" |
| "padding:8px;border-radius:4px;font-size:12px;" |
| "white-space:pre-wrap;max-height:600px;" |
| "overflow-y:auto;'>" |
| f"{_html.escape(rendered_ex)}</pre>" |
| f"<p style='color:#888;font-size:11px;'>" |
| f"Last message role: <b>{last_role}</b> " |
| f"| Loss is on this turn only</p>" |
| "</details>" |
| ) |
| ) |
|
|
| self._display_handle = display( |
| HTML("<em>SFT samples: waiting for first checkpoint...</em>"), |
| display_id="sft_samples", |
| ) |
| except Exception: |
| pass |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| step = state.global_step |
| if step == 0 or step % self.eval_every_steps != 0: |
| return |
| if self._model is None: |
| return |
| self._generate_and_display(step) |
|
|
| def on_train_end(self, args, state, control, **kwargs): |
| if self._model is not None: |
| self._generate_and_display(state.global_step, final=True) |
|
|
| |
| def _generate_and_display(self, step: int, final: bool = False) -> None: |
| try: |
| import torch |
|
|
| was_training = self._model.training |
| self._model.eval() |
|
|
| header = "SFT Final Samples" if final else f"SFT Samples (step {step})" |
| parts = [f"<h4 style='color:#e0e0e0;'>{header}</h4>"] |
|
|
| with torch.no_grad(): |
| for messages in self.sample_prompts: |
| question = messages[-1]["content"][:100] |
| tpl_kwargs: dict[str, Any] = { |
| "tokenize": False, |
| "add_generation_prompt": True, |
| } |
| if self.tools: |
| tpl_kwargs["tools"] = self.tools |
| rendered = self.tokenizer.apply_chat_template( |
| messages, |
| **tpl_kwargs, |
| ) |
| inputs = self.tokenizer(rendered, return_tensors="pt") |
| device = next(self._model.parameters()).device |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| has_tools = "<tools>" in rendered |
| prompt_len = inputs["input_ids"].shape[1] |
|
|
| out = self._model.generate( |
| **inputs, |
| max_new_tokens=self.max_new_tokens, |
| do_sample=False, |
| ) |
| new_tokens = out[0][prompt_len:] |
| raw = self.tokenizer.decode(new_tokens, skip_special_tokens=False) |
| |
| end = raw.find("</tool_call>") |
| if end != -1: |
| raw = raw[: end + len("</tool_call>")] |
| text = self.tokenizer.decode( |
| self.tokenizer.encode(raw), |
| skip_special_tokens=True, |
| ).strip() |
|
|
| badge = ( |
| "<span style='color:#4caf50;'>✓ tools</span>" |
| if has_tools |
| else "<span style='color:#f44336;'>✗ no tools</span>" |
| ) |
| parts.append( |
| "<pre style='background:#2d2d2d;color:#e0e0e0;" |
| "padding:8px;margin:4px 0;border-radius:4px;" |
| "font-size:13px;line-height:1.4;'>" |
| f"<b style='color:#82aaff;'>Q:</b> " |
| f"{_html.escape(question)}" |
| f" [{badge}, {prompt_len} tok]\n" |
| f"<b style='color:#c3e88d;'>→</b> " |
| f"{_html.escape(text)}</pre>" |
| ) |
|
|
| if was_training: |
| self._model.train() |
|
|
| if self._display_handle is not None: |
| from IPython.display import HTML |
|
|
| self._display_handle.update(HTML("\n".join(parts))) |
| except Exception as exc: |
| _logger.debug("SFT sample generation failed: %s", exc) |
|
|