"""Live training visualization callbacks for GRPO/SFT notebooks."""
from __future__ import annotations
import html as _html
import logging
from typing import Any
_logger = logging.getLogger(__name__)
try:
from transformers import TrainerCallback
except ImportError:
TrainerCallback = object # type: ignore[assignment,misc]
class LiveVisualizationCallback(TrainerCallback):
"""TrainerCallback that plots reward and loss in place during training.
Updates a single plot via IPython display handle without clearing
the cell output.
"""
def __init__(self, **kwargs: Any) -> None:
# Accept and ignore extra kwargs for backward compat
_ = kwargs
self.log_steps: list[int] = []
self.log_rewards: list[float] = []
self.log_losses: list[float] = []
self._plot_handle = None
def on_train_begin(self, args, state, control, **kwargs):
try:
from IPython.display import HTML, display
self._plot_handle = display(
HTML("Waiting for first log..."),
display_id="viz_plot",
)
except Exception:
pass
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs:
return
step = state.global_step
# Find reward (prefer mean)
reward = None
for key in sorted(logs.keys()):
if "reward" in key and "mean" in key:
reward = logs[key]
break
if reward is None:
for key in ("reward", "rewards/mean"):
if key in logs:
reward = logs[key]
break
loss = logs.get("loss")
has_data = False
if reward is not None:
self.log_rewards.append(float(reward))
has_data = True
if loss is not None:
self.log_losses.append(float(loss))
has_data = True
if has_data:
self.log_steps.append(step)
self._update_plot()
def _update_plot(self) -> None:
if self._plot_handle is None:
return
try:
import base64
import io
import matplotlib.pyplot as plt
from IPython.display import HTML
fig, ax = plt.subplots(1, 1, figsize=(8, 3.5))
if self.log_rewards:
ax.plot(
self.log_steps[: len(self.log_rewards)],
self.log_rewards,
"b-o",
markersize=3,
label="Reward",
)
ax.set_ylabel("Reward")
ax.legend(loc="upper left")
if self.log_losses:
# SFT-only: plot loss on primary axis
# GRPO: plot loss on secondary axis
if self.log_rewards:
ax2 = ax.twinx()
ax2.plot(
self.log_steps[: len(self.log_losses)],
self.log_losses,
"r-",
alpha=0.4,
label="Loss",
)
ax2.set_ylabel("Loss", color="r", alpha=0.6)
ax2.legend(loc="upper right")
else:
ax.plot(
self.log_steps[: len(self.log_losses)],
self.log_losses,
"r-o",
markersize=3,
label="Loss",
)
ax.set_ylabel("Loss")
ax.legend(loc="upper right")
ax.set_xlabel("Step")
latest = self.log_steps[-1] if self.log_steps else 0
ax.set_title(f"Training Progress (step {latest})")
ax.grid(True, alpha=0.3)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=100, bbox_inches="tight")
plt.close(fig)
buf.seek(0)
img = base64.b64encode(buf.read()).decode("utf-8")
self._plot_handle.update(HTML(f' "
f"Last message role: {last_role} "
f"| Loss is on this turn only'))
except Exception as exc:
_logger.debug("Plot update failed: %s", exc)
class SFTMonitorCallback(TrainerCallback):
"""Show sample completions and optional eval accuracy during SFT.
Every ``eval_every_steps`` training steps the callback generates
first-turn completions for a handful of prompts so the user can
watch the model learn tool-calling patterns in real time.
"""
def __init__(
self,
tokenizer: Any,
sample_prompts: list[list[dict[str, str]]],
*,
tools: list[dict] | None = None,
train_dataset: Any = None,
eval_every_steps: int = 50,
max_new_tokens: int = 200,
) -> None:
self.tokenizer = tokenizer
self.sample_prompts = sample_prompts[:3]
self.tools = tools
self.train_dataset = train_dataset
self.eval_every_steps = eval_every_steps
self.max_new_tokens = max_new_tokens
self._model: Any = None
self._display_handle: Any = None
# ------------------------------------------------------------------
def on_train_begin(self, args, state, control, model=None, **kwargs):
self._model = model
try:
from IPython.display import HTML, display
# Always use canonical tools (avoid Dataset serialization artifacts)
tpl_tools = self.tools
# 1) Inference prompt — what the model sees at generation time
if self.sample_prompts:
tpl_kwargs: dict[str, Any] = {
"tokenize": False,
"add_generation_prompt": True,
}
if tpl_tools:
tpl_kwargs["tools"] = tpl_tools
preview = self.tokenizer.apply_chat_template(
self.sample_prompts[0],
**tpl_kwargs,
)
n_tok = len(self.tokenizer.encode(preview))
display(
HTML(
"
"
f"Inference prompt ({n_tok} tok)"
" — system + tools + question, "
"as seen by model during GRPO generation"
"
"
""
f"{_html.escape(preview)}"
f"SFT training example"
f" ({n_ex_tok} tok, {n_turns} asst turn)"
" — history + one assistant tool_call, "
"exactly what the model learns to predict"
"
"
""
f"{_html.escape(rendered_ex)}"
f"
"
f"Q: "
f"{_html.escape(question)}"
f" [{badge}, {prompt_len} tok]\n"
f"→ "
f"{_html.escape(text)}"
)
if was_training:
self._model.train()
if self._display_handle is not None:
from IPython.display import HTML
self._display_handle.update(HTML("\n".join(parts)))
except Exception as exc:
_logger.debug("SFT sample generation failed: %s", exc)