| |
| """Inspect SFT training data — stats from JSON, or render the actual |
| model input using the tokenizer (requires transformers). |
| |
| Usage: |
| uv run python scripts/inspect_sft_data.py # stats |
| uv run python scripts/inspect_sft_data.py --render # render + save |
| uv run python scripts/inspect_sft_data.py --render -n 5 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from collections import Counter |
| from pathlib import Path |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| DEFAULT_PATH = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json" |
| RENDER_PATH = PROJECT_ROOT / "data" / "sft" / "sft_rendered.txt" |
|
|
|
|
| def compute_stats(data: list[dict]) -> str: |
| """Dataset statistics from raw JSON (no tokenizer needed).""" |
| lines = [] |
|
|
| tool_counts: Counter[str] = Counter() |
| msg_counts: list[int] = [] |
| tables_per_q: list[int] = [] |
|
|
| for ex in data: |
| msgs = ex["messages"] |
| msg_counts.append(len(msgs)) |
| n_describe = 0 |
| for m in msgs: |
| if m["role"] == "assistant" and "tool_calls" in m: |
| for tc in m["tool_calls"]: |
| fn = tc.get("function", tc) |
| tool_counts[fn["name"]] += 1 |
| if fn["name"] == "describe": |
| n_describe += 1 |
| tables_per_q.append(n_describe) |
|
|
| lines.append(f"Trajectories: {len(data)}") |
| lines.append( |
| f"Messages per trajectory: min={min(msg_counts)}, " |
| f"max={max(msg_counts)}, avg={sum(msg_counts) / len(msg_counts):.1f}" |
| ) |
| lines.append("") |
| lines.append("Assistant tool calls:") |
| for name in ["describe", "query", "answer", "sample"]: |
| if tool_counts[name]: |
| lines.append(f" {name}: {tool_counts[name]}") |
| lines.append(f" total: {sum(tool_counts.values())}") |
| lines.append("") |
|
|
| tbl_dist = Counter(tables_per_q) |
| lines.append("Tables described per question:") |
| for k in sorted(tbl_dist): |
| lines.append(f" {k} table(s): {tbl_dist[k]} questions") |
|
|
| n_with_query = sum( |
| 1 |
| for ex in data |
| if any( |
| m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "query" |
| for m in ex["messages"] |
| if m["role"] == "assistant" |
| ) |
| ) |
| n_with_answer = sum( |
| 1 |
| for ex in data |
| if any( |
| m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "answer" |
| for m in ex["messages"] |
| if m["role"] == "assistant" |
| ) |
| ) |
| lines.append("") |
| lines.append(f"Trajectories with query: {n_with_query}/{len(data)}") |
| lines.append(f"Trajectories with answer: {n_with_answer}/{len(data)}") |
|
|
| return "\n".join(lines) |
|
|
|
|
| def render_examples( |
| data: list[dict], |
| model_name: str, |
| n: int | None = None, |
| output_path: Path = RENDER_PATH, |
| ) -> None: |
| """Render SFT examples through the actual tokenizer and save to file. |
| |
| This produces the exact text the model will train on — |
| same apply_chat_template call, same template patch, same tool |
| definitions. The output file is the ground truth for inspection. |
| """ |
| from transformers import AutoTokenizer |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
| |
| |
| |
| tmpl = tokenizer.chat_template |
| if "{% generation %}" not in tmpl: |
| _ASST_START = '{%- elif message.role == "assistant" %}' |
| _ASST_END = "{{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}" |
| patched = tmpl.replace( |
| _ASST_START, |
| _ASST_START + "\n {% generation %}", |
| ).replace( |
| _ASST_END, |
| "{% endgeneration %}" + _ASST_END, |
| ) |
| if "{% generation %}" in patched: |
| tokenizer.chat_template = patched |
| print("Template patched with {% generation %} tags") |
|
|
| examples = data[:n] if n else data |
| rendered_parts: list[str] = [] |
| total_tokens = 0 |
| total_asst_tokens = 0 |
|
|
| for i, ex in enumerate(examples): |
| msgs = ex["messages"] |
| tools = ex.get("tools") |
|
|
| |
| text = tokenizer.apply_chat_template( |
| msgs, |
| tools=tools, |
| tokenize=False, |
| ) |
|
|
| |
| tokenized = tokenizer.apply_chat_template( |
| msgs, |
| tools=tools, |
| tokenize=True, |
| return_dict=True, |
| return_assistant_tokens_mask=True, |
| ) |
| n_tokens = len(tokenized["input_ids"]) |
| mask = tokenized.get("assistant_masks", []) |
| n_asst = sum(mask) if mask else 0 |
| total_tokens += n_tokens |
| total_asst_tokens += n_asst |
|
|
| header = ( |
| f"{'=' * 70}\n" |
| f"Example {i} | {n_tokens} tokens | " |
| f"{n_asst} assistant tokens ({n_asst / n_tokens:.0%} of sequence)\n" |
| f"{'=' * 70}" |
| ) |
| rendered_parts.append(f"{header}\n{text}") |
|
|
| |
| summary = ( |
| f"SFT Training Data Preview\n" |
| f"Model: {model_name}\n" |
| f"Examples: {len(examples)}\n" |
| f"Total tokens: {total_tokens} | " |
| f"Assistant tokens: {total_asst_tokens} " |
| f"({total_asst_tokens / total_tokens:.0%})\n" |
| f"Avg tokens/example: {total_tokens / len(examples):.0f}\n" |
| ) |
|
|
| full_output = summary + "\n" + "\n\n".join(rendered_parts) + "\n" |
|
|
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(full_output) |
| print(f"Rendered {len(examples)} examples to {output_path}") |
| print( |
| f"Total: {total_tokens} tokens, {total_asst_tokens} assistant tokens " |
| f"({total_asst_tokens / total_tokens:.0%})" |
| ) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Inspect SFT training data") |
| parser.add_argument( |
| "path", |
| nargs="?", |
| default=str(DEFAULT_PATH), |
| help="Path to sft_trajectories.json", |
| ) |
| parser.add_argument( |
| "--stats", |
| action="store_true", |
| help="Show stats only (default if --render not given)", |
| ) |
| parser.add_argument( |
| "--render", |
| action="store_true", |
| help="Render through tokenizer and save to file", |
| ) |
| parser.add_argument( |
| "--model", |
| default="Qwen/Qwen3-1.7B", |
| help="Model name for tokenizer (default: Qwen/Qwen3-1.7B)", |
| ) |
| parser.add_argument( |
| "-n", |
| "--num", |
| type=int, |
| default=None, |
| help="Number of examples to render (default: all)", |
| ) |
| parser.add_argument( |
| "-o", "--output", default=str(RENDER_PATH), help="Output path for rendered data" |
| ) |
| args = parser.parse_args() |
|
|
| path = Path(args.path) |
| if not path.exists(): |
| print(f"File not found: {path}") |
| print("Run: uv run python scripts/generate_sft_data.py") |
| sys.exit(1) |
|
|
| with open(path) as f: |
| data = json.load(f) |
|
|
| print(compute_stats(data)) |
|
|
| if args.render: |
| print() |
| render_examples(data, args.model, args.num, Path(args.output)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|