Upload train.py
Browse files
train.py
CHANGED
|
@@ -13,6 +13,7 @@ Optimizations implemented:
|
|
| 13 |
7. Intel IPEX integration (optional) β auto-detected
|
| 14 |
8. Cosine LR with warmup
|
| 15 |
9. Standard AdamW with backprop as fallback mode
|
|
|
|
| 16 |
|
| 17 |
Usage:
|
| 18 |
# MeZO mode (recommended for CPU β no backward pass):
|
|
@@ -21,8 +22,11 @@ Usage:
|
|
| 21 |
# AdamW mode (standard backprop with gradient checkpointing + bf16):
|
| 22 |
python train.py --optimizer adamw --scale tiny --seq_len 64 --max_steps 100
|
| 23 |
|
| 24 |
-
# Full run:
|
| 25 |
-
python train.py --optimizer mezo --scale
|
|
|
|
|
|
|
|
|
|
| 26 |
"""
|
| 27 |
|
| 28 |
import os
|
|
@@ -255,25 +259,120 @@ class TokenDataset(Dataset):
|
|
| 255 |
return {"input_ids": self.chunks[idx], "labels": self.chunks[idx]}
|
| 256 |
|
| 257 |
|
| 258 |
-
def
|
| 259 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
from datasets import load_dataset
|
| 261 |
from chimera import ChimeraTokenizer
|
| 262 |
|
| 263 |
-
print(f"[DATA] Loading
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
print(f"[DATA] Loading tokenizer (splintr o200k_base)...")
|
| 266 |
tok = ChimeraTokenizer(pretrained="o200k_base")
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
all_ids = []
|
| 269 |
target = max_samples * (seq_len + 1) if max_samples else float('inf')
|
|
|
|
|
|
|
|
|
|
| 270 |
for i, ex in enumerate(ds):
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
all_ids.append(tok.eos_token_id)
|
|
|
|
|
|
|
| 273 |
if len(all_ids) >= target:
|
| 274 |
break
|
| 275 |
-
if (
|
| 276 |
-
print(f" {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
all_ids = torch.tensor(all_ids, dtype=torch.long)
|
| 279 |
n = len(all_ids) // (seq_len + 1)
|
|
@@ -387,6 +486,11 @@ def train(args):
|
|
| 387 |
print(f"Device: CPU ({torch.get_num_threads()} threads)")
|
| 388 |
print(f"IPEX: {HAS_IPEX}")
|
| 389 |
print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
# βββ Build model βββ
|
| 392 |
model = Chimera51ForCausalLM(config)
|
|
@@ -423,8 +527,16 @@ def train(args):
|
|
| 423 |
print("[OPT] Compilation deferred (will compile on first forward pass)")
|
| 424 |
|
| 425 |
# βββ Dataset βββ
|
| 426 |
-
dataset, tok = build_dataset(
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
loader = DataLoader(
|
| 429 |
dataset,
|
| 430 |
batch_size=args.batch_size,
|
|
@@ -616,7 +728,21 @@ if __name__ == "__main__":
|
|
| 616 |
action="store_false", default=True,
|
| 617 |
help="Regenerate directions instead of caching them for the step")
|
| 618 |
|
| 619 |
-
# Data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
p.add_argument("--num_workers", type=int, default=4)
|
| 621 |
p.add_argument("--log_every", type=int, default=10)
|
| 622 |
p.add_argument("--save_every", type=int, default=1000)
|
|
|
|
| 13 |
7. Intel IPEX integration (optional) β auto-detected
|
| 14 |
8. Cosine LR with warmup
|
| 15 |
9. Standard AdamW with backprop as fallback mode
|
| 16 |
+
10. Generic dataset loading β supports any HF dataset, messages/text columns, category filtering
|
| 17 |
|
| 18 |
Usage:
|
| 19 |
# MeZO mode (recommended for CPU β no backward pass):
|
|
|
|
| 22 |
# AdamW mode (standard backprop with gradient checkpointing + bf16):
|
| 23 |
python train.py --optimizer adamw --scale tiny --seq_len 64 --max_steps 100
|
| 24 |
|
| 25 |
+
# Full run with custom dataset and category filter:
|
| 26 |
+
python train.py --optimizer mezo --scale tiny --seq_len 64 --max_steps 10000 \
|
| 27 |
+
--dataset_name Roman1111111/claude-sonnet-4.6-120000x \
|
| 28 |
+
--dataset_split train --text_column messages \
|
| 29 |
+
--category_filter "C++,organic chemistry"
|
| 30 |
"""
|
| 31 |
|
| 32 |
import os
|
|
|
|
| 259 |
return {"input_ids": self.chunks[idx], "labels": self.chunks[idx]}
|
| 260 |
|
| 261 |
|
| 262 |
+
def _matches_category_filter(ex: dict, filters: list) -> bool:
|
| 263 |
+
"""Check if example matches any of the requested category substrings."""
|
| 264 |
+
cat = ex.get("category", "")
|
| 265 |
+
if not cat:
|
| 266 |
+
return False
|
| 267 |
+
cat_lower = cat.lower()
|
| 268 |
+
return any(f.lower() in cat_lower for f in filters)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _format_example(ex: dict, tok, text_column: str = "auto", include_reasoning: bool = False) -> str:
|
| 272 |
+
"""Convert an example dict to a single text string for tokenization."""
|
| 273 |
+
# Auto-detect text column
|
| 274 |
+
if text_column == "auto":
|
| 275 |
+
if "messages" in ex:
|
| 276 |
+
text_column = "messages"
|
| 277 |
+
elif "text" in ex:
|
| 278 |
+
text_column = "text"
|
| 279 |
+
elif "content" in ex:
|
| 280 |
+
text_column = "content"
|
| 281 |
+
elif "conversation" in ex:
|
| 282 |
+
text_column = "conversation"
|
| 283 |
+
else:
|
| 284 |
+
text_column = None
|
| 285 |
+
|
| 286 |
+
if text_column == "messages" and "messages" in ex:
|
| 287 |
+
msgs = ex["messages"]
|
| 288 |
+
# Inject reasoning into assistant messages if requested
|
| 289 |
+
if include_reasoning and isinstance(msgs, list):
|
| 290 |
+
msgs = []
|
| 291 |
+
for m in ex["messages"]:
|
| 292 |
+
if isinstance(m, dict) and m.get("role") == "assistant" and "reasoning" in m:
|
| 293 |
+
content = f"<|thinking|>\n{m['reasoning']}\n<|/thinking|>\n{m.get('content', '')}"
|
| 294 |
+
msgs.append({"role": "assistant", "content": content})
|
| 295 |
+
else:
|
| 296 |
+
msgs.append(m)
|
| 297 |
+
return tok.apply_chat_template(msgs)
|
| 298 |
+
|
| 299 |
+
if text_column and text_column in ex:
|
| 300 |
+
val = ex[text_column]
|
| 301 |
+
if isinstance(val, str):
|
| 302 |
+
return val
|
| 303 |
+
# Some datasets store conversation as list of dicts even in 'text' col
|
| 304 |
+
if isinstance(val, list) and len(val) > 0 and isinstance(val[0], dict):
|
| 305 |
+
return tok.apply_chat_template(val)
|
| 306 |
+
return str(val)
|
| 307 |
+
|
| 308 |
+
# Fallback: stringify the whole example
|
| 309 |
+
return str(ex)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def build_dataset(seq_len: int, max_samples=None, split: str = "train",
|
| 313 |
+
dataset_name: str = "roneneldan/TinyStories",
|
| 314 |
+
dataset_config: str = None,
|
| 315 |
+
text_column: str = "auto",
|
| 316 |
+
category_filter: str = None,
|
| 317 |
+
include_reasoning: bool = False):
|
| 318 |
+
"""Build dataset from any HuggingFace dataset with splintr tokenizer.
|
| 319 |
+
|
| 320 |
+
Supports:
|
| 321 |
+
- Generic text columns ('text', 'content', etc.)
|
| 322 |
+
- Messages/chat format (auto-detected, uses apply_chat_template)
|
| 323 |
+
- Category filtering (comma-separated substrings)
|
| 324 |
+
- Streaming for huge datasets
|
| 325 |
+
"""
|
| 326 |
from datasets import load_dataset
|
| 327 |
from chimera import ChimeraTokenizer
|
| 328 |
|
| 329 |
+
print(f"[DATA] Loading {dataset_name} ({split})...")
|
| 330 |
+
load_kwargs = {"split": split, "streaming": True}
|
| 331 |
+
if dataset_config:
|
| 332 |
+
load_kwargs["name"] = dataset_config
|
| 333 |
+
ds = load_dataset(dataset_name, **load_kwargs)
|
| 334 |
+
|
| 335 |
print(f"[DATA] Loading tokenizer (splintr o200k_base)...")
|
| 336 |
tok = ChimeraTokenizer(pretrained="o200k_base")
|
| 337 |
|
| 338 |
+
# Parse category filters
|
| 339 |
+
cat_filters = None
|
| 340 |
+
if category_filter:
|
| 341 |
+
cat_filters = [c.strip() for c in category_filter.split(",") if c.strip()]
|
| 342 |
+
print(f"[DATA] Filtering categories: {cat_filters}")
|
| 343 |
+
|
| 344 |
all_ids = []
|
| 345 |
target = max_samples * (seq_len + 1) if max_samples else float('inf')
|
| 346 |
+
processed = 0
|
| 347 |
+
skipped = 0
|
| 348 |
+
|
| 349 |
for i, ex in enumerate(ds):
|
| 350 |
+
# Category filter
|
| 351 |
+
if cat_filters and not _matches_category_filter(ex, cat_filters):
|
| 352 |
+
skipped += 1
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
text = _format_example(ex, tok, text_column, include_reasoning)
|
| 356 |
+
if not text or not text.strip():
|
| 357 |
+
skipped += 1
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
all_ids.extend(tok.encode(text, add_special_tokens=False))
|
| 361 |
all_ids.append(tok.eos_token_id)
|
| 362 |
+
processed += 1
|
| 363 |
+
|
| 364 |
if len(all_ids) >= target:
|
| 365 |
break
|
| 366 |
+
if (processed + 1) % 10000 == 0:
|
| 367 |
+
print(f" {processed:,} examples, {len(all_ids):,} tokens...")
|
| 368 |
+
|
| 369 |
+
print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,} (category/text mismatch)")
|
| 370 |
+
|
| 371 |
+
if len(all_ids) == 0:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
f"No data matched filters. dataset={dataset_name}, "
|
| 374 |
+
f"category_filter={category_filter}, text_column={text_column}"
|
| 375 |
+
)
|
| 376 |
|
| 377 |
all_ids = torch.tensor(all_ids, dtype=torch.long)
|
| 378 |
n = len(all_ids) // (seq_len + 1)
|
|
|
|
| 486 |
print(f"Device: CPU ({torch.get_num_threads()} threads)")
|
| 487 |
print(f"IPEX: {HAS_IPEX}")
|
| 488 |
print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)")
|
| 489 |
+
print(f"Dataset: {args.dataset_name} / {args.dataset_split}")
|
| 490 |
+
if args.category_filter:
|
| 491 |
+
print(f"Category filter: {args.category_filter}")
|
| 492 |
+
if args.include_reasoning:
|
| 493 |
+
print("Reasoning: INCLUDED (<|thinking|> ... <|/thinking|>)")
|
| 494 |
|
| 495 |
# βββ Build model βββ
|
| 496 |
model = Chimera51ForCausalLM(config)
|
|
|
|
| 527 |
print("[OPT] Compilation deferred (will compile on first forward pass)")
|
| 528 |
|
| 529 |
# βββ Dataset βββ
|
| 530 |
+
dataset, tok = build_dataset(
|
| 531 |
+
args.seq_len,
|
| 532 |
+
max_samples=args.max_samples,
|
| 533 |
+
split=args.dataset_split,
|
| 534 |
+
dataset_name=args.dataset_name,
|
| 535 |
+
dataset_config=args.dataset_config,
|
| 536 |
+
text_column=args.text_column,
|
| 537 |
+
category_filter=args.category_filter,
|
| 538 |
+
include_reasoning=args.include_reasoning,
|
| 539 |
+
)
|
| 540 |
loader = DataLoader(
|
| 541 |
dataset,
|
| 542 |
batch_size=args.batch_size,
|
|
|
|
| 728 |
action="store_false", default=True,
|
| 729 |
help="Regenerate directions instead of caching them for the step")
|
| 730 |
|
| 731 |
+
# Data β fully configurable
|
| 732 |
+
p.add_argument("--dataset_name", default="roneneldan/TinyStories",
|
| 733 |
+
help="HuggingFace dataset name (e.g. Roman1111111/claude-sonnet-4.6-120000x)")
|
| 734 |
+
p.add_argument("--dataset_config", default=None,
|
| 735 |
+
help="Dataset config/subset name")
|
| 736 |
+
p.add_argument("--dataset_split", default="train",
|
| 737 |
+
help="Dataset split to use")
|
| 738 |
+
p.add_argument("--text_column", default="auto",
|
| 739 |
+
help="Column containing text. 'auto' detects 'messages'/'text'/'content'/'conversation'")
|
| 740 |
+
p.add_argument("--category_filter", default=None,
|
| 741 |
+
help="Comma-separated category substrings to filter on (e.g. 'C++,python,math')")
|
| 742 |
+
p.add_argument("--include_reasoning", action="store_true", default=False,
|
| 743 |
+
help="Include reasoning/thinking content from assistant messages as <|thinking|>...<|/thinking|>")
|
| 744 |
+
|
| 745 |
+
# Logging / Output
|
| 746 |
p.add_argument("--num_workers", type=int, default=4)
|
| 747 |
p.add_argument("--log_every", type=int, default=10)
|
| 748 |
p.add_argument("--save_every", type=int, default=1000)
|