from __future__ import annotations import os import torch from torch.utils.data import Dataset class SequenceTokenDataset(Dataset): def __init__(self, chunks: torch.Tensor): self.chunks = chunks def __len__(self) -> int: return self.chunks.size(0) def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: chunk = self.chunks[idx] return {"input_ids": chunk, "labels": chunk} class PreTokenizedDataset(Dataset): def __init__(self, ids: torch.Tensor, seq_len: int): n = ids.numel() // (seq_len + 1) self.chunks = ids[: n * (seq_len + 1)].view(n, seq_len + 1) self.seq_len = seq_len def __len__(self) -> int: return self.chunks.size(0) def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: chunk = self.chunks[idx] return {"input_ids": chunk[:-1], "labels": chunk[1:]} class GrowLengthDataset(Dataset): def __init__(self, all_ids: torch.Tensor, seq_len: int = 16): self.all_ids = all_ids self._seq_len = 0 self._n = 0 self.set_seq_len(seq_len) def set_seq_len(self, seq_len: int) -> None: self._seq_len = int(seq_len) self._n = self.all_ids.numel() // (self._seq_len + 1) @property def seq_len(self) -> int: return self._seq_len def __len__(self) -> int: return self._n def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: start = idx * (self._seq_len + 1) chunk = self.all_ids[start : start + self._seq_len + 1] return {"input_ids": chunk[:-1], "labels": chunk[1:]} def matches_category_filter(example: dict, filters: list[str]) -> bool: category = example.get("category", "") or "" if not category: return False category_lower = category.lower() return any(f.lower() in category_lower for f in filters) def format_dataset_example(ex: dict, tok, text_column: str = "auto", include_reasoning: bool = False) -> str: if text_column == "auto": for candidate in ("messages", "text", "content", "conversation"): if candidate in ex: text_column = candidate break else: text_column = "" if text_column == "messages" and "messages" in ex: messages = ex["messages"] if include_reasoning and isinstance(messages, list): rewritten = [] for message in messages: if isinstance(message, dict) and message.get("role") == "assistant" and "reasoning" in message: rewritten.append( { "role": "assistant", "content": ( f"<|thinking|>\n{message['reasoning']}\n<|/thinking|>\n" f"{message.get('content', '')}" ), } ) else: rewritten.append(message) messages = rewritten return tok.apply_chat_template(messages) if text_column and text_column in ex: value = ex[text_column] if isinstance(value, str): return value if isinstance(value, list) and value and isinstance(value[0], dict): return tok.apply_chat_template(value) return str(value) return str(ex) def build_token_buffer( dataset_name: str, split: str, text_column: str, max_tokens: int, cache_dir: str, *, dataset_config: str | None = None, category_filter: str | None = None, include_reasoning: bool = False, ): from datasets import load_dataset from chimera import ChimeraTokenizer cache_name = f"{dataset_name.replace('/', '_')}_{split}_{max_tokens}.pt" cache_path = os.path.join(cache_dir, cache_name) os.makedirs(cache_dir, exist_ok=True) if os.path.exists(cache_path): print(f"[DATA] Cache hit: {cache_path}") return torch.load(cache_path, weights_only=True) print(f"[DATA] Streaming {dataset_name} ({split})...") load_kwargs = {"split": split, "streaming": True} if dataset_config: load_kwargs["name"] = dataset_config ds = load_dataset(dataset_name, **load_kwargs) tok = ChimeraTokenizer(pretrained="o200k_base") filters = [c.strip() for c in category_filter.split(",") if c.strip()] if category_filter else None if filters: print(f"[DATA] Filtering categories: {filters}") buf = torch.empty(max_tokens, dtype=torch.long) idx = processed = skipped = 0 for ex in ds: if filters and not matches_category_filter(ex, filters): skipped += 1 continue text = format_dataset_example(ex, tok, text_column, include_reasoning) if not text or not text.strip(): skipped += 1 continue ids = tok.encode(text, add_special_tokens=False) ids.append(tok.eos_token_id) n = min(len(ids), max_tokens - idx) if n <= 0: break buf[idx : idx + n] = torch.tensor(ids[:n], dtype=torch.long) idx += n processed += 1 if processed % 5000 == 0: print(f" {processed:,} docs {idx:,}/{max_tokens} tokens") token_buf = buf[:idx].contiguous() torch.save(token_buf, cache_path) print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,}.") print(f"[DATA] {idx:,} tokens -> {cache_path}") return token_buf def build_sequence_dataset( seq_len: int, *, max_samples=None, max_tokens=None, split: str = "train", dataset_name: str = "roneneldan/TinyStories", dataset_config: str | None = None, text_column: str = "auto", category_filter: str | None = None, include_reasoning: bool = False, cache_dir: str = "./cache", ): token_budget = int(max_tokens) if max_tokens is not None else None if token_budget is None and max_samples is not None: token_budget = int(max_samples) * (seq_len + 1) if token_budget is None or token_budget <= 0: token_budget = max(500_000, (int(max_samples) if max_samples else 10000) * (seq_len + 1)) token_buffer = build_token_buffer( dataset_name, split, text_column, token_budget, cache_dir, dataset_config=dataset_config, category_filter=category_filter, include_reasoning=include_reasoning, ) if token_buffer.numel() == 0: raise ValueError("No data matched filters.") n = token_buffer.numel() // (seq_len + 1) if max_samples: n = min(n, max_samples) chunks = token_buffer[: n * (seq_len + 1)].view(n, seq_len + 1) print(f"[DATA] {n:,} chunks × {seq_len} tokens = {n * seq_len:,} total") return SequenceTokenDataset(chunks)