File size: 7,425 Bytes
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python3
"""Inspect SFT training data — stats from JSON, or render the actual
model input using the tokenizer (requires transformers).

Usage:
    uv run python scripts/inspect_sft_data.py          # stats
    uv run python scripts/inspect_sft_data.py --render  # render + save
    uv run python scripts/inspect_sft_data.py --render -n 5
"""

from __future__ import annotations

import argparse
import json
import sys
from collections import Counter
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_PATH = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json"
RENDER_PATH = PROJECT_ROOT / "data" / "sft" / "sft_rendered.txt"


def compute_stats(data: list[dict]) -> str:
    """Dataset statistics from raw JSON (no tokenizer needed)."""
    lines = []

    tool_counts: Counter[str] = Counter()
    msg_counts: list[int] = []
    tables_per_q: list[int] = []

    for ex in data:
        msgs = ex["messages"]
        msg_counts.append(len(msgs))
        n_describe = 0
        for m in msgs:
            if m["role"] == "assistant" and "tool_calls" in m:
                for tc in m["tool_calls"]:
                    fn = tc.get("function", tc)
                    tool_counts[fn["name"]] += 1
                    if fn["name"] == "describe":
                        n_describe += 1
        tables_per_q.append(n_describe)

    lines.append(f"Trajectories: {len(data)}")
    lines.append(
        f"Messages per trajectory: min={min(msg_counts)}, "
        f"max={max(msg_counts)}, avg={sum(msg_counts) / len(msg_counts):.1f}"
    )
    lines.append("")
    lines.append("Assistant tool calls:")
    for name in ["describe", "query", "answer", "sample"]:
        if tool_counts[name]:
            lines.append(f"  {name}: {tool_counts[name]}")
    lines.append(f"  total: {sum(tool_counts.values())}")
    lines.append("")

    tbl_dist = Counter(tables_per_q)
    lines.append("Tables described per question:")
    for k in sorted(tbl_dist):
        lines.append(f"  {k} table(s): {tbl_dist[k]} questions")

    n_with_query = sum(
        1
        for ex in data
        if any(
            m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "query"
            for m in ex["messages"]
            if m["role"] == "assistant"
        )
    )
    n_with_answer = sum(
        1
        for ex in data
        if any(
            m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "answer"
            for m in ex["messages"]
            if m["role"] == "assistant"
        )
    )
    lines.append("")
    lines.append(f"Trajectories with query:  {n_with_query}/{len(data)}")
    lines.append(f"Trajectories with answer: {n_with_answer}/{len(data)}")

    return "\n".join(lines)


def render_examples(
    data: list[dict],
    model_name: str,
    n: int | None = None,
    output_path: Path = RENDER_PATH,
) -> None:
    """Render SFT examples through the actual tokenizer and save to file.

    This produces the exact text the model will train on —
    same apply_chat_template call, same template patch, same tool
    definitions. The output file is the ground truth for inspection.
    """
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # Apply the same Qwen3 template patch as the training notebook.
    # Without this, assistant_only_loss won't work and the rendered
    # output won't match what training sees.
    tmpl = tokenizer.chat_template
    if "{% generation %}" not in tmpl:
        _ASST_START = '{%- elif message.role == "assistant" %}'
        _ASST_END = "{{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}"
        patched = tmpl.replace(
            _ASST_START,
            _ASST_START + "\n        {% generation %}",
        ).replace(
            _ASST_END,
            "{% endgeneration %}" + _ASST_END,
        )
        if "{% generation %}" in patched:
            tokenizer.chat_template = patched
            print("Template patched with {% generation %} tags")

    examples = data[:n] if n else data
    rendered_parts: list[str] = []
    total_tokens = 0
    total_asst_tokens = 0

    for i, ex in enumerate(examples):
        msgs = ex["messages"]
        tools = ex.get("tools")

        # Render text — same call as TRL's SFTTrainer.tokenize_fn
        text = tokenizer.apply_chat_template(
            msgs,
            tools=tools,
            tokenize=False,
        )

        # Tokenize with mask — same call as TRL with assistant_only_loss
        tokenized = tokenizer.apply_chat_template(
            msgs,
            tools=tools,
            tokenize=True,
            return_dict=True,
            return_assistant_tokens_mask=True,
        )
        n_tokens = len(tokenized["input_ids"])
        mask = tokenized.get("assistant_masks", [])
        n_asst = sum(mask) if mask else 0
        total_tokens += n_tokens
        total_asst_tokens += n_asst

        header = (
            f"{'=' * 70}\n"
            f"Example {i} | {n_tokens} tokens | "
            f"{n_asst} assistant tokens ({n_asst / n_tokens:.0%} of sequence)\n"
            f"{'=' * 70}"
        )
        rendered_parts.append(f"{header}\n{text}")

    # Summary header
    summary = (
        f"SFT Training Data Preview\n"
        f"Model: {model_name}\n"
        f"Examples: {len(examples)}\n"
        f"Total tokens: {total_tokens} | "
        f"Assistant tokens: {total_asst_tokens} "
        f"({total_asst_tokens / total_tokens:.0%})\n"
        f"Avg tokens/example: {total_tokens / len(examples):.0f}\n"
    )

    full_output = summary + "\n" + "\n\n".join(rendered_parts) + "\n"

    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(full_output)
    print(f"Rendered {len(examples)} examples to {output_path}")
    print(
        f"Total: {total_tokens} tokens, {total_asst_tokens} assistant tokens "
        f"({total_asst_tokens / total_tokens:.0%})"
    )


def main() -> None:
    parser = argparse.ArgumentParser(description="Inspect SFT training data")
    parser.add_argument(
        "path",
        nargs="?",
        default=str(DEFAULT_PATH),
        help="Path to sft_trajectories.json",
    )
    parser.add_argument(
        "--stats",
        action="store_true",
        help="Show stats only (default if --render not given)",
    )
    parser.add_argument(
        "--render",
        action="store_true",
        help="Render through tokenizer and save to file",
    )
    parser.add_argument(
        "--model",
        default="Qwen/Qwen3-1.7B",
        help="Model name for tokenizer (default: Qwen/Qwen3-1.7B)",
    )
    parser.add_argument(
        "-n",
        "--num",
        type=int,
        default=None,
        help="Number of examples to render (default: all)",
    )
    parser.add_argument(
        "-o", "--output", default=str(RENDER_PATH), help="Output path for rendered data"
    )
    args = parser.parse_args()

    path = Path(args.path)
    if not path.exists():
        print(f"File not found: {path}")
        print("Run: uv run python scripts/generate_sft_data.py")
        sys.exit(1)

    with open(path) as f:
        data = json.load(f)

    print(compute_stats(data))

    if args.render:
        print()
        render_examples(data, args.model, args.num, Path(args.output))


if __name__ == "__main__":
    main()