""" Data loading and preprocessing. Supported datasets: - WikiText-2 (char-level and word-level) - WikiText-103 - Custom text files - Synthetic random data (debugging) Tokenization: character-level by default. Simple, deterministic, no external deps. """ import torch from torch.utils.data import Dataset, DataLoader from typing import Optional, Tuple, Dict from collections import Counter class CharTokenizer: """Character-level tokenizer. Vocabulary built from data.""" def __init__(self, min_freq: int = 1): self.min_freq = min_freq self.char_to_idx: Dict[str, int] = {} self.idx_to_char: Dict[int, str] = {} self.vocab_size = 0 self.special_tokens = { "": 0, "": 1, "": 2, "": 3, } def fit(self, texts: list[str]): """Build vocabulary from texts.""" char_counts = Counter() for text in texts: char_counts.update(text) # Special tokens first self.char_to_idx = dict(self.special_tokens) # Freq-filtered chars idx = len(self.special_tokens) for char, count in char_counts.most_common(): if count >= self.min_freq: self.char_to_idx[char] = idx idx += 1 self.idx_to_char = {v: k for k, v in self.char_to_idx.items()} self.vocab_size = len(self.char_to_idx) def encode(self, text: str, add_bos: bool = True, add_eos: bool = True, max_len: int = None) -> list[int]: """Convert text to token indices.""" tokens = [] if add_bos: tokens.append(self.special_tokens[""]) for ch in text: tokens.append(self.char_to_idx.get(ch, self.special_tokens[""])) if add_eos: tokens.append(self.special_tokens[""]) if max_len is not None: if len(tokens) > max_len: tokens = tokens[:max_len] else: tokens.extend([self.special_tokens[""]] * (max_len - len(tokens))) return tokens def decode(self, indices: list[int], skip_special: bool = True) -> str: """Convert indices back to text.""" chars = [] for idx in indices: ch = self.idx_to_char.get(idx, "?") if skip_special and idx in self.special_tokens.values(): continue chars.append(ch) return "".join(chars) def save(self, path: str): torch.save({ "char_to_idx": self.char_to_idx, "idx_to_char": self.idx_to_char, "vocab_size": self.vocab_size, "special_tokens": self.special_tokens, }, path) @classmethod def load(cls, path: str) -> "CharTokenizer": data = torch.load(path) tok = cls() tok.char_to_idx = data["char_to_idx"] tok.idx_to_char = data["idx_to_char"] tok.vocab_size = data["vocab_size"] tok.special_tokens = data["special_tokens"] return tok class TextDataset(Dataset): """ Causal language modeling dataset. Splits text into overlapping sequences of length seq_len. Target = input shifted by 1 (next-token prediction). """ def __init__(self, texts: list[str], tokenizer: CharTokenizer, seq_len: int = 128, stride: int = None): self.seq_len = seq_len self.stride = stride or seq_len // 2 # Tokenize all texts all_tokens = [] for text in texts: all_tokens.extend(tokenizer.encode(text, add_bos=False, add_eos=True)) self.tokens = torch.tensor(all_tokens, dtype=torch.long) # Compute valid starting positions self.n_samples = max(0, (len(self.tokens) - seq_len - 1) // self.stride + 1) def __len__(self): return self.n_samples def __getitem__(self, idx): start = idx * self.stride end = start + self.seq_len x = self.tokens[start:end] y = self.tokens[start + 1:end + 1] assert len(x) == len(y) == self.seq_len, f"len={len(x)} at idx={idx}" return x, y def load_wikitext2(tokenizer: CharTokenizer = None, seq_len: int = 128, batch_size: int = 16) -> Tuple[DataLoader, DataLoader, DataLoader, CharTokenizer]: """ Load WikiText-2 with char-level tokenization. Returns: train_loader, val_loader, test_loader, tokenizer """ try: from datasets import load_dataset except ImportError: raise ImportError("pip install datasets") ds = load_dataset("wikitext", "wikitext-2-raw-v1") # Filter empty lines train_texts = [t for t in ds["train"]["text"] if t.strip()] val_texts = [t for t in ds["validation"]["text"] if t.strip()] test_texts = [t for t in ds["test"]["text"] if t.strip()] if tokenizer is None: tokenizer = CharTokenizer() tokenizer.fit(train_texts) train_ds = TextDataset(train_texts, tokenizer, seq_len) val_ds = TextDataset(val_texts, tokenizer, seq_len) test_ds = TextDataset(test_texts, tokenizer, seq_len) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0) test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0) return train_loader, val_loader, test_loader, tokenizer def load_synthetic_data(vocab_size: int = 5000, seq_len: int = 128, n_samples: int = 2000, batch_size: int = 16): """Synthetic random data for debugging.""" class _SynthDataset(Dataset): def __init__(self, n, vocab, slen): self.data = torch.randint(1, vocab, (n, slen + 1)) def __len__(self): return len(self.data) def __getitem__(self, i): return self.data[i, :-1], self.data[i, 1:] ds = _SynthDataset(n_samples, vocab_size, seq_len) return DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)