File size: 3,887 Bytes
6848cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""SFT dataset with proper assistant-only loss masking and safe packing.

Each example is a chat-formatted string with `<|system|> <|user|> <|assistant|> <|end|>`
turn delimiters.  We tokenize on the fly (corpus is small, ~25M tokens) and build a
mask=1 only on tokens that are part of an assistant response (everything between
`<|assistant|>` and the next `<|end|>`).

For pre-training-style packing without cross-example contamination we group multiple
short examples into a fixed-length window using `cu_seqlens`-style document boundaries
implemented via per-document attention reset.  Here we keep it simple: pad/truncate
each example to `block_size`. Throughput is still high (>40k tok/s on L4) for this
volume.
"""

import json
import random
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset


def _read_jsonl(path):
    out = []
    with open(path, "r", encoding="utf-8", errors="replace") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                continue
            t = obj.get("text") or ""
            if t:
                out.append({"text": t, "source": obj.get("source", Path(path).stem)})
    return out


def build_assistant_mask(token_ids, assistant_id, end_id):
    """mask[i] = 1 iff token_ids[i] is inside an `<|assistant|> ... <|end|>` span.

    We mark from the token AFTER `<|assistant|>` up to and including `<|end|>` so the
    model learns to emit the closing delimiter.
    """
    mask = np.zeros(len(token_ids), dtype=np.int64)
    inside = False
    for i, t in enumerate(token_ids):
        if t == assistant_id and not inside:
            inside = True
            continue  # don't include the assistant tag itself
        if inside:
            mask[i] = 1
            if t == end_id:
                inside = False
    return mask


class SFTDataset(Dataset):
    def __init__(self, jsonl_paths, sp, block_size, assistant_token="<|assistant|>",
                 end_token="<|end|>", pad_id=0, seed=42, mix_weights=None):
        self.sp = sp
        self.block_size = block_size
        self.pad_id = pad_id
        self.assistant_id = sp.piece_to_id(assistant_token)
        self.end_id = sp.piece_to_id(end_token)
        if self.assistant_id < 0 or self.end_id < 0:
            raise ValueError(f"missing special tokens in tokenizer: "
                             f"{assistant_token}={self.assistant_id} {end_token}={self.end_id}")

        self.examples = []
        rng = random.Random(seed)
        for p in jsonl_paths:
            recs = _read_jsonl(p)
            w = (mix_weights or {}).get(Path(p).name, 1.0)
            if w != 1.0:
                k = int(len(recs) * w)
                recs = rng.sample(recs, min(k, len(recs)))
            self.examples.extend(recs)
            print(f"  [sft] {p}: {len(recs):,} ex (w={w})")
        rng.shuffle(self.examples)
        print(f"[sft] total: {len(self.examples):,} examples")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        text = self.examples[idx]["text"]
        ids = self.sp.encode(text, out_type=int)
        ids = ids[: self.block_size + 1]
        mask = build_assistant_mask(ids, self.assistant_id, self.end_id)

        if len(ids) < self.block_size + 1:
            need = self.block_size + 1 - len(ids)
            ids = ids + [self.pad_id] * need
            mask = np.concatenate([mask, np.zeros(need, dtype=np.int64)])

        ids = np.asarray(ids, dtype=np.int64)
        x = torch.from_numpy(ids[:-1])
        y = torch.from_numpy(ids[1:].copy())
        m = torch.from_numpy(mask[1:].copy())  # mask aligned with targets
        # zero out padded targets
        y[m == 0] = -100
        return x, y, m