"""Live training visualization callbacks for GRPO/SFT notebooks.""" from __future__ import annotations import html as _html import logging from typing import Any _logger = logging.getLogger(__name__) try: from transformers import TrainerCallback except ImportError: TrainerCallback = object # type: ignore[assignment,misc] class LiveVisualizationCallback(TrainerCallback): """TrainerCallback that plots reward and loss in place during training. Updates a single plot via IPython display handle without clearing the cell output. """ def __init__(self, **kwargs: Any) -> None: # Accept and ignore extra kwargs for backward compat _ = kwargs self.log_steps: list[int] = [] self.log_rewards: list[float] = [] self.log_losses: list[float] = [] self._plot_handle = None def on_train_begin(self, args, state, control, **kwargs): try: from IPython.display import HTML, display self._plot_handle = display( HTML("Waiting for first log..."), display_id="viz_plot", ) except Exception: pass def on_log(self, args, state, control, logs=None, **kwargs): if not logs: return step = state.global_step # Find reward (prefer mean) reward = None for key in sorted(logs.keys()): if "reward" in key and "mean" in key: reward = logs[key] break if reward is None: for key in ("reward", "rewards/mean"): if key in logs: reward = logs[key] break loss = logs.get("loss") has_data = False if reward is not None: self.log_rewards.append(float(reward)) has_data = True if loss is not None: self.log_losses.append(float(loss)) has_data = True if has_data: self.log_steps.append(step) self._update_plot() def _update_plot(self) -> None: if self._plot_handle is None: return try: import base64 import io import matplotlib.pyplot as plt from IPython.display import HTML fig, ax = plt.subplots(1, 1, figsize=(8, 3.5)) if self.log_rewards: ax.plot( self.log_steps[: len(self.log_rewards)], self.log_rewards, "b-o", markersize=3, label="Reward", ) ax.set_ylabel("Reward") ax.legend(loc="upper left") if self.log_losses: # SFT-only: plot loss on primary axis # GRPO: plot loss on secondary axis if self.log_rewards: ax2 = ax.twinx() ax2.plot( self.log_steps[: len(self.log_losses)], self.log_losses, "r-", alpha=0.4, label="Loss", ) ax2.set_ylabel("Loss", color="r", alpha=0.6) ax2.legend(loc="upper right") else: ax.plot( self.log_steps[: len(self.log_losses)], self.log_losses, "r-o", markersize=3, label="Loss", ) ax.set_ylabel("Loss") ax.legend(loc="upper right") ax.set_xlabel("Step") latest = self.log_steps[-1] if self.log_steps else 0 ax.set_title(f"Training Progress (step {latest})") ax.grid(True, alpha=0.3) plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=100, bbox_inches="tight") plt.close(fig) buf.seek(0) img = base64.b64encode(buf.read()).decode("utf-8") self._plot_handle.update(HTML(f'')) except Exception as exc: _logger.debug("Plot update failed: %s", exc) class SFTMonitorCallback(TrainerCallback): """Show sample completions and optional eval accuracy during SFT. Every ``eval_every_steps`` training steps the callback generates first-turn completions for a handful of prompts so the user can watch the model learn tool-calling patterns in real time. """ def __init__( self, tokenizer: Any, sample_prompts: list[list[dict[str, str]]], *, tools: list[dict] | None = None, train_dataset: Any = None, eval_every_steps: int = 50, max_new_tokens: int = 200, ) -> None: self.tokenizer = tokenizer self.sample_prompts = sample_prompts[:3] self.tools = tools self.train_dataset = train_dataset self.eval_every_steps = eval_every_steps self.max_new_tokens = max_new_tokens self._model: Any = None self._display_handle: Any = None # ------------------------------------------------------------------ def on_train_begin(self, args, state, control, model=None, **kwargs): self._model = model try: from IPython.display import HTML, display # Always use canonical tools (avoid Dataset serialization artifacts) tpl_tools = self.tools # 1) Inference prompt — what the model sees at generation time if self.sample_prompts: tpl_kwargs: dict[str, Any] = { "tokenize": False, "add_generation_prompt": True, } if tpl_tools: tpl_kwargs["tools"] = tpl_tools preview = self.tokenizer.apply_chat_template( self.sample_prompts[0], **tpl_kwargs, ) n_tok = len(self.tokenizer.encode(preview)) display( HTML( "
" f"Inference prompt ({n_tok} tok)" " — system + tools + question, " "as seen by model during GRPO generation" "" "
"
                        f"{_html.escape(preview)}
" ) ) # 2) Training example — one per-turn example from the dataset if self.train_dataset is not None and len(self.train_dataset) > 0: row = self.train_dataset[0] msgs = row.get("messages", []) if msgs: ex_kwargs: dict[str, Any] = {"tokenize": False} if tpl_tools: ex_kwargs["tools"] = tpl_tools rendered_ex = self.tokenizer.apply_chat_template( msgs, **ex_kwargs, ) n_ex_tok = len(self.tokenizer.encode(rendered_ex)) n_turns = sum(1 for m in msgs if m.get("role") == "assistant") last_role = msgs[-1].get("role", "?") display( HTML( "
" f"SFT training example" f" ({n_ex_tok} tok, {n_turns} asst turn)" " — history + one assistant tool_call, " "exactly what the model learns to predict" "" "
"
                            f"{_html.escape(rendered_ex)}
" f"

" f"Last message role: {last_role} " f"| Loss is on this turn only

" "
" ) ) self._display_handle = display( HTML("SFT samples: waiting for first checkpoint..."), display_id="sft_samples", ) except Exception: pass def on_log(self, args, state, control, logs=None, **kwargs): step = state.global_step if step == 0 or step % self.eval_every_steps != 0: return if self._model is None: return self._generate_and_display(step) def on_train_end(self, args, state, control, **kwargs): if self._model is not None: self._generate_and_display(state.global_step, final=True) # ------------------------------------------------------------------ def _generate_and_display(self, step: int, final: bool = False) -> None: try: import torch was_training = self._model.training self._model.eval() header = "SFT Final Samples" if final else f"SFT Samples (step {step})" parts = [f"

{header}

"] with torch.no_grad(): for messages in self.sample_prompts: question = messages[-1]["content"][:100] tpl_kwargs: dict[str, Any] = { "tokenize": False, "add_generation_prompt": True, } if self.tools: tpl_kwargs["tools"] = self.tools rendered = self.tokenizer.apply_chat_template( messages, **tpl_kwargs, ) inputs = self.tokenizer(rendered, return_tensors="pt") device = next(self._model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} has_tools = "" in rendered prompt_len = inputs["input_ids"].shape[1] out = self._model.generate( **inputs, max_new_tokens=self.max_new_tokens, do_sample=False, ) new_tokens = out[0][prompt_len:] raw = self.tokenizer.decode(new_tokens, skip_special_tokens=False) # Show first tool call only (stop at ) end = raw.find("") if end != -1: raw = raw[: end + len("")] text = self.tokenizer.decode( self.tokenizer.encode(raw), skip_special_tokens=True, ).strip() badge = ( "✓ tools" if has_tools else "✗ no tools" ) parts.append( "
"
                        f"Q: "
                        f"{_html.escape(question)}"
                        f" [{badge}, {prompt_len} tok]\n"
                        f" "
                        f"{_html.escape(text)}
" ) if was_training: self._model.train() if self._display_handle is not None: from IPython.display import HTML self._display_handle.update(HTML("\n".join(parts))) except Exception as exc: _logger.debug("SFT sample generation failed: %s", exc)