chomera / chimera /training /datasets.py
Lgr54HFi's picture
Upload folder using huggingface_hub
11c11f8 verified
from __future__ import annotations
import os
import torch
from torch.utils.data import Dataset
class SequenceTokenDataset(Dataset):
def __init__(self, chunks: torch.Tensor):
self.chunks = chunks
def __len__(self) -> int:
return self.chunks.size(0)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
chunk = self.chunks[idx]
return {"input_ids": chunk, "labels": chunk}
class PreTokenizedDataset(Dataset):
def __init__(self, ids: torch.Tensor, seq_len: int):
n = ids.numel() // (seq_len + 1)
self.chunks = ids[: n * (seq_len + 1)].view(n, seq_len + 1)
self.seq_len = seq_len
def __len__(self) -> int:
return self.chunks.size(0)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
chunk = self.chunks[idx]
return {"input_ids": chunk[:-1], "labels": chunk[1:]}
class GrowLengthDataset(Dataset):
def __init__(self, all_ids: torch.Tensor, seq_len: int = 16):
self.all_ids = all_ids
self._seq_len = 0
self._n = 0
self.set_seq_len(seq_len)
def set_seq_len(self, seq_len: int) -> None:
self._seq_len = int(seq_len)
self._n = self.all_ids.numel() // (self._seq_len + 1)
@property
def seq_len(self) -> int:
return self._seq_len
def __len__(self) -> int:
return self._n
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
start = idx * (self._seq_len + 1)
chunk = self.all_ids[start : start + self._seq_len + 1]
return {"input_ids": chunk[:-1], "labels": chunk[1:]}
def matches_category_filter(example: dict, filters: list[str]) -> bool:
category = example.get("category", "") or ""
if not category:
return False
category_lower = category.lower()
return any(f.lower() in category_lower for f in filters)
def format_dataset_example(ex: dict, tok, text_column: str = "auto", include_reasoning: bool = False) -> str:
if text_column == "auto":
for candidate in ("messages", "text", "content", "conversation"):
if candidate in ex:
text_column = candidate
break
else:
text_column = ""
if text_column == "messages" and "messages" in ex:
messages = ex["messages"]
if include_reasoning and isinstance(messages, list):
rewritten = []
for message in messages:
if isinstance(message, dict) and message.get("role") == "assistant" and "reasoning" in message:
rewritten.append(
{
"role": "assistant",
"content": (
f"<|thinking|>\n{message['reasoning']}\n<|/thinking|>\n"
f"{message.get('content', '')}"
),
}
)
else:
rewritten.append(message)
messages = rewritten
return tok.apply_chat_template(messages)
if text_column and text_column in ex:
value = ex[text_column]
if isinstance(value, str):
return value
if isinstance(value, list) and value and isinstance(value[0], dict):
return tok.apply_chat_template(value)
return str(value)
return str(ex)
def build_token_buffer(
dataset_name: str,
split: str,
text_column: str,
max_tokens: int,
cache_dir: str,
*,
dataset_config: str | None = None,
category_filter: str | None = None,
include_reasoning: bool = False,
):
from datasets import load_dataset
from chimera import ChimeraTokenizer
cache_name = f"{dataset_name.replace('/', '_')}_{split}_{max_tokens}.pt"
cache_path = os.path.join(cache_dir, cache_name)
os.makedirs(cache_dir, exist_ok=True)
if os.path.exists(cache_path):
print(f"[DATA] Cache hit: {cache_path}")
return torch.load(cache_path, weights_only=True)
print(f"[DATA] Streaming {dataset_name} ({split})...")
load_kwargs = {"split": split, "streaming": True}
if dataset_config:
load_kwargs["name"] = dataset_config
ds = load_dataset(dataset_name, **load_kwargs)
tok = ChimeraTokenizer(pretrained="o200k_base")
filters = [c.strip() for c in category_filter.split(",") if c.strip()] if category_filter else None
if filters:
print(f"[DATA] Filtering categories: {filters}")
buf = torch.empty(max_tokens, dtype=torch.long)
idx = processed = skipped = 0
for ex in ds:
if filters and not matches_category_filter(ex, filters):
skipped += 1
continue
text = format_dataset_example(ex, tok, text_column, include_reasoning)
if not text or not text.strip():
skipped += 1
continue
ids = tok.encode(text, add_special_tokens=False)
ids.append(tok.eos_token_id)
n = min(len(ids), max_tokens - idx)
if n <= 0:
break
buf[idx : idx + n] = torch.tensor(ids[:n], dtype=torch.long)
idx += n
processed += 1
if processed % 5000 == 0:
print(f" {processed:,} docs {idx:,}/{max_tokens} tokens")
token_buf = buf[:idx].contiguous()
torch.save(token_buf, cache_path)
print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,}.")
print(f"[DATA] {idx:,} tokens -> {cache_path}")
return token_buf
def build_sequence_dataset(
seq_len: int,
*,
max_samples=None,
max_tokens=None,
split: str = "train",
dataset_name: str = "roneneldan/TinyStories",
dataset_config: str | None = None,
text_column: str = "auto",
category_filter: str | None = None,
include_reasoning: bool = False,
cache_dir: str = "./cache",
):
token_budget = int(max_tokens) if max_tokens is not None else None
if token_budget is None and max_samples is not None:
token_budget = int(max_samples) * (seq_len + 1)
if token_budget is None or token_budget <= 0:
token_budget = max(500_000, (int(max_samples) if max_samples else 10000) * (seq_len + 1))
token_buffer = build_token_buffer(
dataset_name,
split,
text_column,
token_budget,
cache_dir,
dataset_config=dataset_config,
category_filter=category_filter,
include_reasoning=include_reasoning,
)
if token_buffer.numel() == 0:
raise ValueError("No data matched filters.")
n = token_buffer.numel() // (seq_len + 1)
if max_samples:
n = min(n, max_samples)
chunks = token_buffer[: n * (seq_len + 1)].view(n, seq_len + 1)
print(f"[DATA] {n:,} chunks × {seq_len} tokens = {n * seq_len:,} total")
return SequenceTokenDataset(chunks)