Q-TensorFormer / src /data.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
"""
Data loading and preprocessing.
Supported datasets:
- WikiText-2 (char-level and word-level)
- WikiText-103
- Custom text files
- Synthetic random data (debugging)
Tokenization: character-level by default. Simple, deterministic, no external deps.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Optional, Tuple, Dict
from collections import Counter
class CharTokenizer:
"""Character-level tokenizer. Vocabulary built from data."""
def __init__(self, min_freq: int = 1):
self.min_freq = min_freq
self.char_to_idx: Dict[str, int] = {}
self.idx_to_char: Dict[int, str] = {}
self.vocab_size = 0
self.special_tokens = {
"<pad>": 0,
"<bos>": 1,
"<eos>": 2,
"<unk>": 3,
}
def fit(self, texts: list[str]):
"""Build vocabulary from texts."""
char_counts = Counter()
for text in texts:
char_counts.update(text)
# Special tokens first
self.char_to_idx = dict(self.special_tokens)
# Freq-filtered chars
idx = len(self.special_tokens)
for char, count in char_counts.most_common():
if count >= self.min_freq:
self.char_to_idx[char] = idx
idx += 1
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
self.vocab_size = len(self.char_to_idx)
def encode(self, text: str, add_bos: bool = True,
add_eos: bool = True, max_len: int = None) -> list[int]:
"""Convert text to token indices."""
tokens = []
if add_bos:
tokens.append(self.special_tokens["<bos>"])
for ch in text:
tokens.append(self.char_to_idx.get(ch, self.special_tokens["<unk>"]))
if add_eos:
tokens.append(self.special_tokens["<eos>"])
if max_len is not None:
if len(tokens) > max_len:
tokens = tokens[:max_len]
else:
tokens.extend([self.special_tokens["<pad>"]] * (max_len - len(tokens)))
return tokens
def decode(self, indices: list[int], skip_special: bool = True) -> str:
"""Convert indices back to text."""
chars = []
for idx in indices:
ch = self.idx_to_char.get(idx, "?")
if skip_special and idx in self.special_tokens.values():
continue
chars.append(ch)
return "".join(chars)
def save(self, path: str):
torch.save({
"char_to_idx": self.char_to_idx,
"idx_to_char": self.idx_to_char,
"vocab_size": self.vocab_size,
"special_tokens": self.special_tokens,
}, path)
@classmethod
def load(cls, path: str) -> "CharTokenizer":
data = torch.load(path)
tok = cls()
tok.char_to_idx = data["char_to_idx"]
tok.idx_to_char = data["idx_to_char"]
tok.vocab_size = data["vocab_size"]
tok.special_tokens = data["special_tokens"]
return tok
class TextDataset(Dataset):
"""
Causal language modeling dataset.
Splits text into overlapping sequences of length seq_len.
Target = input shifted by 1 (next-token prediction).
"""
def __init__(self, texts: list[str], tokenizer: CharTokenizer,
seq_len: int = 128, stride: int = None):
self.seq_len = seq_len
self.stride = stride or seq_len // 2
# Tokenize all texts
all_tokens = []
for text in texts:
all_tokens.extend(tokenizer.encode(text, add_bos=False, add_eos=True))
self.tokens = torch.tensor(all_tokens, dtype=torch.long)
# Compute valid starting positions
self.n_samples = max(0, (len(self.tokens) - seq_len - 1) // self.stride + 1)
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
start = idx * self.stride
end = start + self.seq_len
x = self.tokens[start:end]
y = self.tokens[start + 1:end + 1]
assert len(x) == len(y) == self.seq_len, f"len={len(x)} at idx={idx}"
return x, y
def load_wikitext2(tokenizer: CharTokenizer = None,
seq_len: int = 128,
batch_size: int = 16) -> Tuple[DataLoader, DataLoader, DataLoader, CharTokenizer]:
"""
Load WikiText-2 with char-level tokenization.
Returns:
train_loader, val_loader, test_loader, tokenizer
"""
try:
from datasets import load_dataset
except ImportError:
raise ImportError("pip install datasets")
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
# Filter empty lines
train_texts = [t for t in ds["train"]["text"] if t.strip()]
val_texts = [t for t in ds["validation"]["text"] if t.strip()]
test_texts = [t for t in ds["test"]["text"] if t.strip()]
if tokenizer is None:
tokenizer = CharTokenizer()
tokenizer.fit(train_texts)
train_ds = TextDataset(train_texts, tokenizer, seq_len)
val_ds = TextDataset(val_texts, tokenizer, seq_len)
test_ds = TextDataset(test_texts, tokenizer, seq_len)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=0, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
return train_loader, val_loader, test_loader, tokenizer
def load_synthetic_data(vocab_size: int = 5000, seq_len: int = 128,
n_samples: int = 2000, batch_size: int = 16):
"""Synthetic random data for debugging."""
class _SynthDataset(Dataset):
def __init__(self, n, vocab, slen):
self.data = torch.randint(1, vocab, (n, slen + 1))
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i, :-1], self.data[i, 1:]
ds = _SynthDataset(n_samples, vocab_size, seq_len)
return DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)