| """ |
| Data loading and preprocessing. |
| |
| Supported datasets: |
| - WikiText-2 (char-level and word-level) |
| - WikiText-103 |
| - Custom text files |
| - Synthetic random data (debugging) |
| |
| Tokenization: character-level by default. Simple, deterministic, no external deps. |
| """ |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from typing import Optional, Tuple, Dict |
| from collections import Counter |
|
|
|
|
| class CharTokenizer: |
| """Character-level tokenizer. Vocabulary built from data.""" |
|
|
| def __init__(self, min_freq: int = 1): |
| self.min_freq = min_freq |
| self.char_to_idx: Dict[str, int] = {} |
| self.idx_to_char: Dict[int, str] = {} |
| self.vocab_size = 0 |
| self.special_tokens = { |
| "<pad>": 0, |
| "<bos>": 1, |
| "<eos>": 2, |
| "<unk>": 3, |
| } |
|
|
| def fit(self, texts: list[str]): |
| """Build vocabulary from texts.""" |
| char_counts = Counter() |
| for text in texts: |
| char_counts.update(text) |
|
|
| |
| self.char_to_idx = dict(self.special_tokens) |
| |
| idx = len(self.special_tokens) |
| for char, count in char_counts.most_common(): |
| if count >= self.min_freq: |
| self.char_to_idx[char] = idx |
| idx += 1 |
|
|
| self.idx_to_char = {v: k for k, v in self.char_to_idx.items()} |
| self.vocab_size = len(self.char_to_idx) |
|
|
| def encode(self, text: str, add_bos: bool = True, |
| add_eos: bool = True, max_len: int = None) -> list[int]: |
| """Convert text to token indices.""" |
| tokens = [] |
| if add_bos: |
| tokens.append(self.special_tokens["<bos>"]) |
| for ch in text: |
| tokens.append(self.char_to_idx.get(ch, self.special_tokens["<unk>"])) |
| if add_eos: |
| tokens.append(self.special_tokens["<eos>"]) |
| if max_len is not None: |
| if len(tokens) > max_len: |
| tokens = tokens[:max_len] |
| else: |
| tokens.extend([self.special_tokens["<pad>"]] * (max_len - len(tokens))) |
| return tokens |
|
|
| def decode(self, indices: list[int], skip_special: bool = True) -> str: |
| """Convert indices back to text.""" |
| chars = [] |
| for idx in indices: |
| ch = self.idx_to_char.get(idx, "?") |
| if skip_special and idx in self.special_tokens.values(): |
| continue |
| chars.append(ch) |
| return "".join(chars) |
|
|
| def save(self, path: str): |
| torch.save({ |
| "char_to_idx": self.char_to_idx, |
| "idx_to_char": self.idx_to_char, |
| "vocab_size": self.vocab_size, |
| "special_tokens": self.special_tokens, |
| }, path) |
|
|
| @classmethod |
| def load(cls, path: str) -> "CharTokenizer": |
| data = torch.load(path) |
| tok = cls() |
| tok.char_to_idx = data["char_to_idx"] |
| tok.idx_to_char = data["idx_to_char"] |
| tok.vocab_size = data["vocab_size"] |
| tok.special_tokens = data["special_tokens"] |
| return tok |
|
|
|
|
| class TextDataset(Dataset): |
| """ |
| Causal language modeling dataset. |
| |
| Splits text into overlapping sequences of length seq_len. |
| Target = input shifted by 1 (next-token prediction). |
| """ |
|
|
| def __init__(self, texts: list[str], tokenizer: CharTokenizer, |
| seq_len: int = 128, stride: int = None): |
| self.seq_len = seq_len |
| self.stride = stride or seq_len // 2 |
|
|
| |
| all_tokens = [] |
| for text in texts: |
| all_tokens.extend(tokenizer.encode(text, add_bos=False, add_eos=True)) |
| self.tokens = torch.tensor(all_tokens, dtype=torch.long) |
|
|
| |
| self.n_samples = max(0, (len(self.tokens) - seq_len - 1) // self.stride + 1) |
|
|
| def __len__(self): |
| return self.n_samples |
|
|
| def __getitem__(self, idx): |
| start = idx * self.stride |
| end = start + self.seq_len |
| x = self.tokens[start:end] |
| y = self.tokens[start + 1:end + 1] |
| assert len(x) == len(y) == self.seq_len, f"len={len(x)} at idx={idx}" |
| return x, y |
|
|
|
|
| def load_wikitext2(tokenizer: CharTokenizer = None, |
| seq_len: int = 128, |
| batch_size: int = 16) -> Tuple[DataLoader, DataLoader, DataLoader, CharTokenizer]: |
| """ |
| Load WikiText-2 with char-level tokenization. |
| |
| Returns: |
| train_loader, val_loader, test_loader, tokenizer |
| """ |
| try: |
| from datasets import load_dataset |
| except ImportError: |
| raise ImportError("pip install datasets") |
|
|
| ds = load_dataset("wikitext", "wikitext-2-raw-v1") |
|
|
| |
| train_texts = [t for t in ds["train"]["text"] if t.strip()] |
| val_texts = [t for t in ds["validation"]["text"] if t.strip()] |
| test_texts = [t for t in ds["test"]["text"] if t.strip()] |
|
|
| if tokenizer is None: |
| tokenizer = CharTokenizer() |
| tokenizer.fit(train_texts) |
|
|
| train_ds = TextDataset(train_texts, tokenizer, seq_len) |
| val_ds = TextDataset(val_texts, tokenizer, seq_len) |
| test_ds = TextDataset(test_texts, tokenizer, seq_len) |
|
|
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, |
| num_workers=0, drop_last=True) |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0) |
| test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0) |
|
|
| return train_loader, val_loader, test_loader, tokenizer |
|
|
|
|
| def load_synthetic_data(vocab_size: int = 5000, seq_len: int = 128, |
| n_samples: int = 2000, batch_size: int = 16): |
| """Synthetic random data for debugging.""" |
| class _SynthDataset(Dataset): |
| def __init__(self, n, vocab, slen): |
| self.data = torch.randint(1, vocab, (n, slen + 1)) |
| def __len__(self): |
| return len(self.data) |
| def __getitem__(self, i): |
| return self.data[i, :-1], self.data[i, 1:] |
| ds = _SynthDataset(n_samples, vocab_size, seq_len) |
| return DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0) |
|
|