sql_env / scripts /inspect_sft_data.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
#!/usr/bin/env python3
"""Inspect SFT training data — stats from JSON, or render the actual
model input using the tokenizer (requires transformers).
Usage:
uv run python scripts/inspect_sft_data.py # stats
uv run python scripts/inspect_sft_data.py --render # render + save
uv run python scripts/inspect_sft_data.py --render -n 5
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_PATH = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json"
RENDER_PATH = PROJECT_ROOT / "data" / "sft" / "sft_rendered.txt"
def compute_stats(data: list[dict]) -> str:
"""Dataset statistics from raw JSON (no tokenizer needed)."""
lines = []
tool_counts: Counter[str] = Counter()
msg_counts: list[int] = []
tables_per_q: list[int] = []
for ex in data:
msgs = ex["messages"]
msg_counts.append(len(msgs))
n_describe = 0
for m in msgs:
if m["role"] == "assistant" and "tool_calls" in m:
for tc in m["tool_calls"]:
fn = tc.get("function", tc)
tool_counts[fn["name"]] += 1
if fn["name"] == "describe":
n_describe += 1
tables_per_q.append(n_describe)
lines.append(f"Trajectories: {len(data)}")
lines.append(
f"Messages per trajectory: min={min(msg_counts)}, "
f"max={max(msg_counts)}, avg={sum(msg_counts) / len(msg_counts):.1f}"
)
lines.append("")
lines.append("Assistant tool calls:")
for name in ["describe", "query", "answer", "sample"]:
if tool_counts[name]:
lines.append(f" {name}: {tool_counts[name]}")
lines.append(f" total: {sum(tool_counts.values())}")
lines.append("")
tbl_dist = Counter(tables_per_q)
lines.append("Tables described per question:")
for k in sorted(tbl_dist):
lines.append(f" {k} table(s): {tbl_dist[k]} questions")
n_with_query = sum(
1
for ex in data
if any(
m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "query"
for m in ex["messages"]
if m["role"] == "assistant"
)
)
n_with_answer = sum(
1
for ex in data
if any(
m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "answer"
for m in ex["messages"]
if m["role"] == "assistant"
)
)
lines.append("")
lines.append(f"Trajectories with query: {n_with_query}/{len(data)}")
lines.append(f"Trajectories with answer: {n_with_answer}/{len(data)}")
return "\n".join(lines)
def render_examples(
data: list[dict],
model_name: str,
n: int | None = None,
output_path: Path = RENDER_PATH,
) -> None:
"""Render SFT examples through the actual tokenizer and save to file.
This produces the exact text the model will train on —
same apply_chat_template call, same template patch, same tool
definitions. The output file is the ground truth for inspection.
"""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Apply the same Qwen3 template patch as the training notebook.
# Without this, assistant_only_loss won't work and the rendered
# output won't match what training sees.
tmpl = tokenizer.chat_template
if "{% generation %}" not in tmpl:
_ASST_START = '{%- elif message.role == "assistant" %}'
_ASST_END = "{{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}"
patched = tmpl.replace(
_ASST_START,
_ASST_START + "\n {% generation %}",
).replace(
_ASST_END,
"{% endgeneration %}" + _ASST_END,
)
if "{% generation %}" in patched:
tokenizer.chat_template = patched
print("Template patched with {% generation %} tags")
examples = data[:n] if n else data
rendered_parts: list[str] = []
total_tokens = 0
total_asst_tokens = 0
for i, ex in enumerate(examples):
msgs = ex["messages"]
tools = ex.get("tools")
# Render text — same call as TRL's SFTTrainer.tokenize_fn
text = tokenizer.apply_chat_template(
msgs,
tools=tools,
tokenize=False,
)
# Tokenize with mask — same call as TRL with assistant_only_loss
tokenized = tokenizer.apply_chat_template(
msgs,
tools=tools,
tokenize=True,
return_dict=True,
return_assistant_tokens_mask=True,
)
n_tokens = len(tokenized["input_ids"])
mask = tokenized.get("assistant_masks", [])
n_asst = sum(mask) if mask else 0
total_tokens += n_tokens
total_asst_tokens += n_asst
header = (
f"{'=' * 70}\n"
f"Example {i} | {n_tokens} tokens | "
f"{n_asst} assistant tokens ({n_asst / n_tokens:.0%} of sequence)\n"
f"{'=' * 70}"
)
rendered_parts.append(f"{header}\n{text}")
# Summary header
summary = (
f"SFT Training Data Preview\n"
f"Model: {model_name}\n"
f"Examples: {len(examples)}\n"
f"Total tokens: {total_tokens} | "
f"Assistant tokens: {total_asst_tokens} "
f"({total_asst_tokens / total_tokens:.0%})\n"
f"Avg tokens/example: {total_tokens / len(examples):.0f}\n"
)
full_output = summary + "\n" + "\n\n".join(rendered_parts) + "\n"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(full_output)
print(f"Rendered {len(examples)} examples to {output_path}")
print(
f"Total: {total_tokens} tokens, {total_asst_tokens} assistant tokens "
f"({total_asst_tokens / total_tokens:.0%})"
)
def main() -> None:
parser = argparse.ArgumentParser(description="Inspect SFT training data")
parser.add_argument(
"path",
nargs="?",
default=str(DEFAULT_PATH),
help="Path to sft_trajectories.json",
)
parser.add_argument(
"--stats",
action="store_true",
help="Show stats only (default if --render not given)",
)
parser.add_argument(
"--render",
action="store_true",
help="Render through tokenizer and save to file",
)
parser.add_argument(
"--model",
default="Qwen/Qwen3-1.7B",
help="Model name for tokenizer (default: Qwen/Qwen3-1.7B)",
)
parser.add_argument(
"-n",
"--num",
type=int,
default=None,
help="Number of examples to render (default: all)",
)
parser.add_argument(
"-o", "--output", default=str(RENDER_PATH), help="Output path for rendered data"
)
args = parser.parse_args()
path = Path(args.path)
if not path.exists():
print(f"File not found: {path}")
print("Run: uv run python scripts/generate_sft_data.py")
sys.exit(1)
with open(path) as f:
data = json.load(f)
print(compute_stats(data))
if args.render:
print()
render_examples(data, args.model, args.num, Path(args.output))
if __name__ == "__main__":
main()