File size: 6,844 Bytes
11c11f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """
Chimera 5.1 — Splintr (Rust) Tokenizer Wrapper — o200k_base (OpenAI o1/o3)
Wraps splintr's high-performance Rust tokenizer for transformers-compatible API.
Vocab: o200k_base (200,073 tokens) — OpenAI's o1/o3 tokenizer.
Optimizations:
- __slots__ for reduced memory footprint
- Cached special token set for fast skip_special_tokens filtering
- Batch encode uses list comprehension (minimizes Python overhead)
"""
import torch
from typing import List, Union, Optional
try:
from splintr import Tokenizer as _SplintrTokenizer, O200K_AGENT_TOKENS
HAS_SPLINTR = True
except ImportError:
HAS_SPLINTR = False
__all__ = ["ChimeraTokenizer"]
class ChimeraTokenizer:
"""
High-performance Rust-backed tokenizer (splintr) with HuggingFace-like interface.
Falls back to a basic tiktoken wrapper if splintr is not installed.
"""
def __init__(self, pretrained: str = "o200k_base", vocab_size: int = 200073):
if not HAS_SPLINTR:
self._tok = None
self.vocab_size = int(vocab_size)
self.eos_token_id = min(self.vocab_size - 1, 199999)
self.pad_token_id = min(self.vocab_size - 1, 200058)
self.sep_token_id = min(self.vocab_size - 1, 200060)
self.stop_token_id = min(self.vocab_size - 1, 200059)
self.user_token_id = min(self.vocab_size - 1, 200020)
self.assistant_token_id = min(self.vocab_size - 1, 200021)
self.system_token_id = min(self.vocab_size - 1, 200019)
self.endofprompt_token_id = min(self.vocab_size - 1, 200018)
self.bos_token_id = self.eos_token_id
self.eos_token = "<|endoftext|>"
self.pad_token = "<|pad|>"
self.model_max_length = 4194304
self._special_ids = frozenset({self.eos_token_id, self.pad_token_id, self.sep_token_id, self.stop_token_id, self.user_token_id, self.assistant_token_id, self.system_token_id, self.endofprompt_token_id})
self._byte_offset = 3
return
self._tok = _SplintrTokenizer.from_pretrained(pretrained)
self.vocab_size = self._tok.vocab_size
# o200k_base single-token special IDs
self.eos_token_id = 199999
self.pad_token_id = O200K_AGENT_TOKENS.PAD # 200058
self.sep_token_id = O200K_AGENT_TOKENS.SEP # 200060
self.stop_token_id = O200K_AGENT_TOKENS.STOP # 200059
self.user_token_id = O200K_AGENT_TOKENS.USER # 200020
self.assistant_token_id = O200K_AGENT_TOKENS.ASSISTANT # 200021
self.system_token_id = 200019
self.endofprompt_token_id = 200018
self.bos_token_id = self.eos_token_id
self.eos_token = "<|endoftext|>"
self.pad_token = "<|pad|>"
self.model_max_length = 4194304
# Cached set for fast filtering
self._special_ids = frozenset({
self.eos_token_id, self.pad_token_id, self.sep_token_id,
self.stop_token_id, self.user_token_id,
self.assistant_token_id, self.system_token_id,
self.endofprompt_token_id,
})
def __len__(self) -> int:
return self.vocab_size
def encode(self, text: str, add_special_tokens: bool = True,
max_length: Optional[int] = None) -> List[int]:
if self._tok is None:
ids = [self._byte_offset + b for b in text.encode("utf-8", errors="replace")]
else:
ids = self._tok.encode(text)
if add_special_tokens:
ids = ids + [self.eos_token_id]
if max_length is not None and len(ids) > max_length:
ids = ids[:max_length]
return ids
def encode_batch(self, texts: List[str], add_special_tokens: bool = True,
max_length: Optional[int] = None,
padding: bool = False,
truncation: bool = False,
return_tensors: Optional[str] = None):
all_ids = [self.encode(t, add_special_tokens=add_special_tokens,
max_length=max_length)
for t in texts]
if padding:
max_len = max(len(ids) for ids in all_ids)
all_ids = [ids + [self.pad_token_id] * (max_len - len(ids))
for ids in all_ids]
if return_tensors == "pt":
return {"input_ids": torch.tensor(all_ids, dtype=torch.long)}
return all_ids
def decode(self, token_ids, skip_special_tokens: bool = True) -> str:
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
if skip_special_tokens:
token_ids = [t for t in token_ids if t not in self._special_ids]
if self._tok is None:
data = bytes(max(0, min(255, int(t) - self._byte_offset)) for t in token_ids if int(t) >= self._byte_offset)
return data.decode("utf-8", errors="replace")
return self._tok.decode(token_ids)
def decode_batch(self, token_ids_list, skip_special_tokens: bool = True) -> List[str]:
return [self.decode(ids, skip_special_tokens=skip_special_tokens)
for ids in token_ids_list]
def __call__(self, text, **kwargs) -> dict:
return_tensors = kwargs.get("return_tensors", "pt")
padding = kwargs.get("padding", False)
max_length = kwargs.get("max_length", None)
add_special_tokens = kwargs.get("add_special_tokens", True)
if isinstance(text, str):
text = [text]
result = self.encode_batch(
text, add_special_tokens=add_special_tokens,
max_length=max_length, padding=padding,
return_tensors=return_tensors
)
if isinstance(result, list):
return {"input_ids": torch.tensor(result, dtype=torch.long)}
return result
def get_vocab(self) -> dict:
return {
self.eos_token_id: self.eos_token,
self.pad_token_id: self.pad_token,
self.user_token_id: "<|user|>",
self.assistant_token_id: "<|assistant|>",
self.system_token_id: "<|system|>",
}
def apply_chat_template(self, messages: List[dict],
add_generation_prompt: bool = False) -> str:
parts = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
parts.append(f"<|system|>\n{content}\n<|endofprompt|>")
elif role == "user":
parts.append(f"<|user|>\n{content}\n<|endofprompt|>")
elif role == "assistant":
parts.append(f"<|assistant|>\n{content}\n<|endofprompt|>")
text = "\n".join(parts)
if add_generation_prompt:
text += "\n<|assistant|>\n"
return text
|