File size: 6,169 Bytes
b9c4adf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """
Data loading and preprocessing.
Supported datasets:
- WikiText-2 (char-level and word-level)
- WikiText-103
- Custom text files
- Synthetic random data (debugging)
Tokenization: character-level by default. Simple, deterministic, no external deps.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Optional, Tuple, Dict
from collections import Counter
class CharTokenizer:
"""Character-level tokenizer. Vocabulary built from data."""
def __init__(self, min_freq: int = 1):
self.min_freq = min_freq
self.char_to_idx: Dict[str, int] = {}
self.idx_to_char: Dict[int, str] = {}
self.vocab_size = 0
self.special_tokens = {
"<pad>": 0,
"<bos>": 1,
"<eos>": 2,
"<unk>": 3,
}
def fit(self, texts: list[str]):
"""Build vocabulary from texts."""
char_counts = Counter()
for text in texts:
char_counts.update(text)
# Special tokens first
self.char_to_idx = dict(self.special_tokens)
# Freq-filtered chars
idx = len(self.special_tokens)
for char, count in char_counts.most_common():
if count >= self.min_freq:
self.char_to_idx[char] = idx
idx += 1
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
self.vocab_size = len(self.char_to_idx)
def encode(self, text: str, add_bos: bool = True,
add_eos: bool = True, max_len: int = None) -> list[int]:
"""Convert text to token indices."""
tokens = []
if add_bos:
tokens.append(self.special_tokens["<bos>"])
for ch in text:
tokens.append(self.char_to_idx.get(ch, self.special_tokens["<unk>"]))
if add_eos:
tokens.append(self.special_tokens["<eos>"])
if max_len is not None:
if len(tokens) > max_len:
tokens = tokens[:max_len]
else:
tokens.extend([self.special_tokens["<pad>"]] * (max_len - len(tokens)))
return tokens
def decode(self, indices: list[int], skip_special: bool = True) -> str:
"""Convert indices back to text."""
chars = []
for idx in indices:
ch = self.idx_to_char.get(idx, "?")
if skip_special and idx in self.special_tokens.values():
continue
chars.append(ch)
return "".join(chars)
def save(self, path: str):
torch.save({
"char_to_idx": self.char_to_idx,
"idx_to_char": self.idx_to_char,
"vocab_size": self.vocab_size,
"special_tokens": self.special_tokens,
}, path)
@classmethod
def load(cls, path: str) -> "CharTokenizer":
data = torch.load(path)
tok = cls()
tok.char_to_idx = data["char_to_idx"]
tok.idx_to_char = data["idx_to_char"]
tok.vocab_size = data["vocab_size"]
tok.special_tokens = data["special_tokens"]
return tok
class TextDataset(Dataset):
"""
Causal language modeling dataset.
Splits text into overlapping sequences of length seq_len.
Target = input shifted by 1 (next-token prediction).
"""
def __init__(self, texts: list[str], tokenizer: CharTokenizer,
seq_len: int = 128, stride: int = None):
self.seq_len = seq_len
self.stride = stride or seq_len // 2
# Tokenize all texts
all_tokens = []
for text in texts:
all_tokens.extend(tokenizer.encode(text, add_bos=False, add_eos=True))
self.tokens = torch.tensor(all_tokens, dtype=torch.long)
# Compute valid starting positions
self.n_samples = max(0, (len(self.tokens) - seq_len - 1) // self.stride + 1)
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
start = idx * self.stride
end = start + self.seq_len
x = self.tokens[start:end]
y = self.tokens[start + 1:end + 1]
assert len(x) == len(y) == self.seq_len, f"len={len(x)} at idx={idx}"
return x, y
def load_wikitext2(tokenizer: CharTokenizer = None,
seq_len: int = 128,
batch_size: int = 16) -> Tuple[DataLoader, DataLoader, DataLoader, CharTokenizer]:
"""
Load WikiText-2 with char-level tokenization.
Returns:
train_loader, val_loader, test_loader, tokenizer
"""
try:
from datasets import load_dataset
except ImportError:
raise ImportError("pip install datasets")
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
# Filter empty lines
train_texts = [t for t in ds["train"]["text"] if t.strip()]
val_texts = [t for t in ds["validation"]["text"] if t.strip()]
test_texts = [t for t in ds["test"]["text"] if t.strip()]
if tokenizer is None:
tokenizer = CharTokenizer()
tokenizer.fit(train_texts)
train_ds = TextDataset(train_texts, tokenizer, seq_len)
val_ds = TextDataset(val_texts, tokenizer, seq_len)
test_ds = TextDataset(test_texts, tokenizer, seq_len)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=0, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
return train_loader, val_loader, test_loader, tokenizer
def load_synthetic_data(vocab_size: int = 5000, seq_len: int = 128,
n_samples: int = 2000, batch_size: int = 16):
"""Synthetic random data for debugging."""
class _SynthDataset(Dataset):
def __init__(self, n, vocab, slen):
self.data = torch.randint(1, vocab, (n, slen + 1))
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i, :-1], self.data[i, 1:]
ds = _SynthDataset(n_samples, vocab_size, seq_len)
return DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)
|