File size: 7,836 Bytes
7acd624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
SeqCond tokenizer — tiktoken cl100k_base with 4 additional special tokens.

Special tokens (assigned in order after the base vocab):
  <|im_start|>    — marks the start of a chat turn
  <|im_end|>      — marks the end of a chat turn (also used as EOS)
  <|think_start|> — marks the start of chain-of-thought reasoning
  <|think_end|>   — marks the end of chain-of-thought reasoning

Chat template:
  <|im_start|>user
  {prompt}
  <|im_end|><|im_start|>assistant
  <|think_start|>{thinking}<|think_end|>
  {answer}
  <|im_end|>
"""

import os
from typing import Dict, List, Optional, Tuple

from transformers import PreTrainedTokenizer

_SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>", "<|think_start|>", "<|think_end|>"]
_SPECIAL_TOKEN_IDS = {
    "<|im_start|>": 100278,
    "<|im_end|>": 100279,
    "<|think_start|>": 100280,
    "<|think_end|>": 100281,
    "<|endoftext|>": 100282,
    "<|fim_prefix|>": 100283,
    "<|fim_middle|>": 100284,
    "<|fim_suffix|>": 100285,
    "<|endofprompt|>": 100286,
}
_BASE_VOCAB_SIZE = 100256
_VOCAB_SIZE = max(_SPECIAL_TOKEN_IDS.values()) + 1


def _build_tiktoken_enc():
    """Build tiktoken encoding with SeqCond special tokens."""
    try:
        import tiktoken
    except ImportError as e:
        raise ImportError("tiktoken is required: pip install tiktoken") from e

    base = tiktoken.get_encoding("cl100k_base")
    return tiktoken.Encoding(
        name="seqcond",
        pat_str=base._pat_str,
        mergeable_ranks=base._mergeable_ranks,
        special_tokens=_SPECIAL_TOKEN_IDS,
    )


class SeqCondTokenizer(PreTrainedTokenizer):
    """
    Tokenizer for SeqCond models, backed by tiktoken cl100k_base.

    This is a slow tokenizer that wraps tiktoken. Tokens are represented
    internally as their stringified integer IDs (e.g. "42", "100256").
    This avoids building a full vocab dict while remaining compatible with
    HuggingFace's PreTrainedTokenizer interface.

    Requires: pip install tiktoken
    """

    vocab_files_names: Dict[str, str] = {}
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        eos_token: str = "<|im_end|>",
        bos_token: Optional[str] = None,
        unk_token: Optional[str] = None,
        pad_token: str = "<|im_end|>",
        add_bos_token: bool = False,
        **kwargs,
    ):
        self._enc = _build_tiktoken_enc()
        self._id_to_special: Dict[int, str] = {idx: tok for tok, idx in _SPECIAL_TOKEN_IDS.items()}
        self._special_to_id: Dict[str, int] = {v: k for k, v in self._id_to_special.items()}

        # Register special tokens before calling super().__init__
        kwargs.setdefault("additional_special_tokens", [t for t in _SPECIAL_TOKENS if t not in (eos_token, bos_token, unk_token, pad_token)])

        super().__init__(
            eos_token=eos_token,
            bos_token=bos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            add_bos_token=add_bos_token,
            **kwargs,
        )

    @property
    def vocab_size(self) -> int:
        return _VOCAB_SIZE

    # ------------------------------------------------------------------
    # Core token ↔ id mappings
    # ------------------------------------------------------------------

    def _tokenize(self, text: str, **kwargs) -> List[str]:
        """Encode text into a list of token-id strings."""
        ids = self._enc.encode(text, allowed_special="all")
        # Shift non-special BPE IDs by +1 to match convectors.Tiktokenize
        # offset used during training (ID 0 reserved).
        shifted = [i if i in self._id_to_special else i + 1 for i in ids]
        return [str(i) for i in shifted]

    def _convert_token_to_id(self, token: str) -> int:
        """Convert a token string (or id-string) to an integer id."""
        if token in self._special_to_id:
            return self._special_to_id[token]
        try:
            return int(token)
        except ValueError:
            return 0

    def _convert_id_to_token(self, index: int) -> str:
        """Convert an integer id to its token string."""
        if index in self._id_to_special:
            return self._id_to_special[index]
        return str(index)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Decode a list of token strings back to text."""
        ids = []
        for t in tokens:
            if t in self._special_to_id:
                ids.append(self._special_to_id[t])
            else:
                try:
                    ids.append(int(t))
                except ValueError:
                    pass
        # Reverse the +1 BPE shift before decoding; skip invalid/ID 0 tokens.
        real_ids = []
        for i in ids:
            if i in self._id_to_special:
                real_ids.append(i)
            elif i >= 1:
                real_ids.append(i - 1)
        return self._enc.decode(real_ids)

    def get_vocab(self) -> Dict[str, int]:
        """
        Return a vocab dict. Only special tokens are included with their names;
        regular BPE tokens are included as their id-string representation.
        (Building a full 100k-entry reverse BPE map is expensive and rarely needed.)
        """
        vocab = {str(i): i for i in range(self.vocab_size)}
        for tok, idx in self._special_to_id.items():
            vocab[tok] = idx
        return vocab

    def save_vocabulary(
        self, save_directory: str, filename_prefix: Optional[str] = None
    ) -> Tuple[str, ...]:
        """
        No vocabulary file is needed — the tiktoken encoding is fetched from
        the tiktoken package at runtime. Returns an empty tuple.
        """
        return ()

    # ------------------------------------------------------------------
    # Convenience helpers
    # ------------------------------------------------------------------

    @property
    def im_start_id(self) -> int:
        return self._special_to_id["<|im_start|>"]

    @property
    def im_end_id(self) -> int:
        return self._special_to_id["<|im_end|>"]

    @property
    def think_start_id(self) -> int:
        return self._special_to_id["<|think_start|>"]

    @property
    def think_end_id(self) -> int:
        return self._special_to_id["<|think_end|>"]

    def encode_chat(self, prompt: str, add_think_start: bool = True) -> List[int]:
        """
        Format and encode a user prompt using the standard chat template.

        Args:
            prompt: The user's message (plain text).
            add_think_start: If True (default), append <|think_start|> so the
                model begins generating its chain-of-thought immediately.

        Returns:
            List of token ids (prompt already encoded, ready for prefill).
        """
        text = f"<|im_start|>user\n{prompt}\n<|im_end|><|im_start|>assistant\n"
        if add_think_start:
            text += "<|think_start|>"
        ids = self._enc.encode(text, allowed_special="all")
        return [i if i in self._id_to_special else i + 1 for i in ids]

    def apply_chat_template(self, conversation, add_generation_prompt: bool = True, **kwargs) -> List[int]:
        """
        Minimal chat template support for HF pipeline compatibility.

        Expects conversation as a list of {"role": ..., "content": ...} dicts.
        Only the last user turn is supported for now.
        """
        text = ""
        for msg in conversation:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            text += f"<|im_start|>{role}\n{content}\n<|im_end|>"
        if add_generation_prompt:
            text += "<|im_start|>assistant\n<|think_start|>"
        ids = self._enc.encode(text, allowed_special="all")
        return [i if i in self._id_to_special else i + 1 for i in ids]