| from __future__ import annotations |
|
|
| import os |
|
|
| import torch |
| from torch.utils.data import Dataset |
|
|
|
|
| class SequenceTokenDataset(Dataset): |
| def __init__(self, chunks: torch.Tensor): |
| self.chunks = chunks |
|
|
| def __len__(self) -> int: |
| return self.chunks.size(0) |
|
|
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: |
| chunk = self.chunks[idx] |
| return {"input_ids": chunk, "labels": chunk} |
|
|
|
|
| class PreTokenizedDataset(Dataset): |
| def __init__(self, ids: torch.Tensor, seq_len: int): |
| n = ids.numel() // (seq_len + 1) |
| self.chunks = ids[: n * (seq_len + 1)].view(n, seq_len + 1) |
| self.seq_len = seq_len |
|
|
| def __len__(self) -> int: |
| return self.chunks.size(0) |
|
|
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: |
| chunk = self.chunks[idx] |
| return {"input_ids": chunk[:-1], "labels": chunk[1:]} |
|
|
|
|
| class GrowLengthDataset(Dataset): |
| def __init__(self, all_ids: torch.Tensor, seq_len: int = 16): |
| self.all_ids = all_ids |
| self._seq_len = 0 |
| self._n = 0 |
| self.set_seq_len(seq_len) |
|
|
| def set_seq_len(self, seq_len: int) -> None: |
| self._seq_len = int(seq_len) |
| self._n = self.all_ids.numel() // (self._seq_len + 1) |
|
|
| @property |
| def seq_len(self) -> int: |
| return self._seq_len |
|
|
| def __len__(self) -> int: |
| return self._n |
|
|
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: |
| start = idx * (self._seq_len + 1) |
| chunk = self.all_ids[start : start + self._seq_len + 1] |
| return {"input_ids": chunk[:-1], "labels": chunk[1:]} |
|
|
|
|
| def matches_category_filter(example: dict, filters: list[str]) -> bool: |
| category = example.get("category", "") or "" |
| if not category: |
| return False |
| category_lower = category.lower() |
| return any(f.lower() in category_lower for f in filters) |
|
|
|
|
| def format_dataset_example(ex: dict, tok, text_column: str = "auto", include_reasoning: bool = False) -> str: |
| if text_column == "auto": |
| for candidate in ("messages", "text", "content", "conversation"): |
| if candidate in ex: |
| text_column = candidate |
| break |
| else: |
| text_column = "" |
|
|
| if text_column == "messages" and "messages" in ex: |
| messages = ex["messages"] |
| if include_reasoning and isinstance(messages, list): |
| rewritten = [] |
| for message in messages: |
| if isinstance(message, dict) and message.get("role") == "assistant" and "reasoning" in message: |
| rewritten.append( |
| { |
| "role": "assistant", |
| "content": ( |
| f"<|thinking|>\n{message['reasoning']}\n<|/thinking|>\n" |
| f"{message.get('content', '')}" |
| ), |
| } |
| ) |
| else: |
| rewritten.append(message) |
| messages = rewritten |
| return tok.apply_chat_template(messages) |
|
|
| if text_column and text_column in ex: |
| value = ex[text_column] |
| if isinstance(value, str): |
| return value |
| if isinstance(value, list) and value and isinstance(value[0], dict): |
| return tok.apply_chat_template(value) |
| return str(value) |
| return str(ex) |
|
|
|
|
| def build_token_buffer( |
| dataset_name: str, |
| split: str, |
| text_column: str, |
| max_tokens: int, |
| cache_dir: str, |
| *, |
| dataset_config: str | None = None, |
| category_filter: str | None = None, |
| include_reasoning: bool = False, |
| ): |
| from datasets import load_dataset |
| from chimera import ChimeraTokenizer |
|
|
| cache_name = f"{dataset_name.replace('/', '_')}_{split}_{max_tokens}.pt" |
| cache_path = os.path.join(cache_dir, cache_name) |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| if os.path.exists(cache_path): |
| print(f"[DATA] Cache hit: {cache_path}") |
| return torch.load(cache_path, weights_only=True) |
|
|
| print(f"[DATA] Streaming {dataset_name} ({split})...") |
| load_kwargs = {"split": split, "streaming": True} |
| if dataset_config: |
| load_kwargs["name"] = dataset_config |
| ds = load_dataset(dataset_name, **load_kwargs) |
| tok = ChimeraTokenizer(pretrained="o200k_base") |
|
|
| filters = [c.strip() for c in category_filter.split(",") if c.strip()] if category_filter else None |
| if filters: |
| print(f"[DATA] Filtering categories: {filters}") |
|
|
| buf = torch.empty(max_tokens, dtype=torch.long) |
| idx = processed = skipped = 0 |
| for ex in ds: |
| if filters and not matches_category_filter(ex, filters): |
| skipped += 1 |
| continue |
| text = format_dataset_example(ex, tok, text_column, include_reasoning) |
| if not text or not text.strip(): |
| skipped += 1 |
| continue |
| ids = tok.encode(text, add_special_tokens=False) |
| ids.append(tok.eos_token_id) |
| n = min(len(ids), max_tokens - idx) |
| if n <= 0: |
| break |
| buf[idx : idx + n] = torch.tensor(ids[:n], dtype=torch.long) |
| idx += n |
| processed += 1 |
| if processed % 5000 == 0: |
| print(f" {processed:,} docs {idx:,}/{max_tokens} tokens") |
|
|
| token_buf = buf[:idx].contiguous() |
| torch.save(token_buf, cache_path) |
| print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,}.") |
| print(f"[DATA] {idx:,} tokens -> {cache_path}") |
| return token_buf |
|
|
|
|
| def build_sequence_dataset( |
| seq_len: int, |
| *, |
| max_samples=None, |
| max_tokens=None, |
| split: str = "train", |
| dataset_name: str = "roneneldan/TinyStories", |
| dataset_config: str | None = None, |
| text_column: str = "auto", |
| category_filter: str | None = None, |
| include_reasoning: bool = False, |
| cache_dir: str = "./cache", |
| ): |
| token_budget = int(max_tokens) if max_tokens is not None else None |
| if token_budget is None and max_samples is not None: |
| token_budget = int(max_samples) * (seq_len + 1) |
| if token_budget is None or token_budget <= 0: |
| token_budget = max(500_000, (int(max_samples) if max_samples else 10000) * (seq_len + 1)) |
|
|
| token_buffer = build_token_buffer( |
| dataset_name, |
| split, |
| text_column, |
| token_budget, |
| cache_dir, |
| dataset_config=dataset_config, |
| category_filter=category_filter, |
| include_reasoning=include_reasoning, |
| ) |
|
|
| if token_buffer.numel() == 0: |
| raise ValueError("No data matched filters.") |
|
|
| n = token_buffer.numel() // (seq_len + 1) |
| if max_samples: |
| n = min(n, max_samples) |
| chunks = token_buffer[: n * (seq_len + 1)].view(n, seq_len + 1) |
| print(f"[DATA] {n:,} chunks × {seq_len} tokens = {n * seq_len:,} total") |
| return SequenceTokenDataset(chunks) |
|
|