Lgr54HFi commited on
Commit
c4fa83f
Β·
verified Β·
1 Parent(s): 092c193

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +138 -12
train.py CHANGED
@@ -13,6 +13,7 @@ Optimizations implemented:
13
  7. Intel IPEX integration (optional) β€” auto-detected
14
  8. Cosine LR with warmup
15
  9. Standard AdamW with backprop as fallback mode
 
16
 
17
  Usage:
18
  # MeZO mode (recommended for CPU β€” no backward pass):
@@ -21,8 +22,11 @@ Usage:
21
  # AdamW mode (standard backprop with gradient checkpointing + bf16):
22
  python train.py --optimizer adamw --scale tiny --seq_len 64 --max_steps 100
23
 
24
- # Full run:
25
- python train.py --optimizer mezo --scale small --seq_len 256 --max_steps 10000 --compile
 
 
 
26
  """
27
 
28
  import os
@@ -255,25 +259,120 @@ class TokenDataset(Dataset):
255
  return {"input_ids": self.chunks[idx], "labels": self.chunks[idx]}
256
 
257
 
258
- def build_dataset(seq_len: int, max_samples=None, split: str = "train"):
259
- """Build dataset from TinyStories with splintr tokenizer."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  from datasets import load_dataset
261
  from chimera import ChimeraTokenizer
262
 
263
- print(f"[DATA] Loading TinyStories ({split})...")
264
- ds = load_dataset("roneneldan/TinyStories", split=split, streaming=True)
 
 
 
 
265
  print(f"[DATA] Loading tokenizer (splintr o200k_base)...")
266
  tok = ChimeraTokenizer(pretrained="o200k_base")
267
 
 
 
 
 
 
 
268
  all_ids = []
269
  target = max_samples * (seq_len + 1) if max_samples else float('inf')
 
 
 
270
  for i, ex in enumerate(ds):
271
- all_ids.extend(tok.encode(ex["text"], add_special_tokens=False))
 
 
 
 
 
 
 
 
 
 
272
  all_ids.append(tok.eos_token_id)
 
 
273
  if len(all_ids) >= target:
274
  break
275
- if (i + 1) % 10000 == 0:
276
- print(f" {i + 1} texts, {len(all_ids):,} tokens...")
 
 
 
 
 
 
 
 
277
 
278
  all_ids = torch.tensor(all_ids, dtype=torch.long)
279
  n = len(all_ids) // (seq_len + 1)
@@ -387,6 +486,11 @@ def train(args):
387
  print(f"Device: CPU ({torch.get_num_threads()} threads)")
388
  print(f"IPEX: {HAS_IPEX}")
389
  print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)")
 
 
 
 
 
390
 
391
  # ─── Build model ───
392
  model = Chimera51ForCausalLM(config)
@@ -423,8 +527,16 @@ def train(args):
423
  print("[OPT] Compilation deferred (will compile on first forward pass)")
424
 
425
  # ─── Dataset ───
426
- dataset, tok = build_dataset(args.seq_len, max_samples=args.max_samples,
427
- split="train")
 
 
 
 
 
 
 
 
428
  loader = DataLoader(
429
  dataset,
430
  batch_size=args.batch_size,
@@ -616,7 +728,21 @@ if __name__ == "__main__":
616
  action="store_false", default=True,
617
  help="Regenerate directions instead of caching them for the step")
618
 
619
- # Data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  p.add_argument("--num_workers", type=int, default=4)
621
  p.add_argument("--log_every", type=int, default=10)
622
  p.add_argument("--save_every", type=int, default=1000)
 
13
  7. Intel IPEX integration (optional) β€” auto-detected
14
  8. Cosine LR with warmup
15
  9. Standard AdamW with backprop as fallback mode
16
+ 10. Generic dataset loading β€” supports any HF dataset, messages/text columns, category filtering
17
 
18
  Usage:
19
  # MeZO mode (recommended for CPU β€” no backward pass):
 
22
  # AdamW mode (standard backprop with gradient checkpointing + bf16):
23
  python train.py --optimizer adamw --scale tiny --seq_len 64 --max_steps 100
24
 
25
+ # Full run with custom dataset and category filter:
26
+ python train.py --optimizer mezo --scale tiny --seq_len 64 --max_steps 10000 \
27
+ --dataset_name Roman1111111/claude-sonnet-4.6-120000x \
28
+ --dataset_split train --text_column messages \
29
+ --category_filter "C++,organic chemistry"
30
  """
31
 
32
  import os
 
259
  return {"input_ids": self.chunks[idx], "labels": self.chunks[idx]}
260
 
261
 
262
+ def _matches_category_filter(ex: dict, filters: list) -> bool:
263
+ """Check if example matches any of the requested category substrings."""
264
+ cat = ex.get("category", "")
265
+ if not cat:
266
+ return False
267
+ cat_lower = cat.lower()
268
+ return any(f.lower() in cat_lower for f in filters)
269
+
270
+
271
+ def _format_example(ex: dict, tok, text_column: str = "auto", include_reasoning: bool = False) -> str:
272
+ """Convert an example dict to a single text string for tokenization."""
273
+ # Auto-detect text column
274
+ if text_column == "auto":
275
+ if "messages" in ex:
276
+ text_column = "messages"
277
+ elif "text" in ex:
278
+ text_column = "text"
279
+ elif "content" in ex:
280
+ text_column = "content"
281
+ elif "conversation" in ex:
282
+ text_column = "conversation"
283
+ else:
284
+ text_column = None
285
+
286
+ if text_column == "messages" and "messages" in ex:
287
+ msgs = ex["messages"]
288
+ # Inject reasoning into assistant messages if requested
289
+ if include_reasoning and isinstance(msgs, list):
290
+ msgs = []
291
+ for m in ex["messages"]:
292
+ if isinstance(m, dict) and m.get("role") == "assistant" and "reasoning" in m:
293
+ content = f"<|thinking|>\n{m['reasoning']}\n<|/thinking|>\n{m.get('content', '')}"
294
+ msgs.append({"role": "assistant", "content": content})
295
+ else:
296
+ msgs.append(m)
297
+ return tok.apply_chat_template(msgs)
298
+
299
+ if text_column and text_column in ex:
300
+ val = ex[text_column]
301
+ if isinstance(val, str):
302
+ return val
303
+ # Some datasets store conversation as list of dicts even in 'text' col
304
+ if isinstance(val, list) and len(val) > 0 and isinstance(val[0], dict):
305
+ return tok.apply_chat_template(val)
306
+ return str(val)
307
+
308
+ # Fallback: stringify the whole example
309
+ return str(ex)
310
+
311
+
312
+ def build_dataset(seq_len: int, max_samples=None, split: str = "train",
313
+ dataset_name: str = "roneneldan/TinyStories",
314
+ dataset_config: str = None,
315
+ text_column: str = "auto",
316
+ category_filter: str = None,
317
+ include_reasoning: bool = False):
318
+ """Build dataset from any HuggingFace dataset with splintr tokenizer.
319
+
320
+ Supports:
321
+ - Generic text columns ('text', 'content', etc.)
322
+ - Messages/chat format (auto-detected, uses apply_chat_template)
323
+ - Category filtering (comma-separated substrings)
324
+ - Streaming for huge datasets
325
+ """
326
  from datasets import load_dataset
327
  from chimera import ChimeraTokenizer
328
 
329
+ print(f"[DATA] Loading {dataset_name} ({split})...")
330
+ load_kwargs = {"split": split, "streaming": True}
331
+ if dataset_config:
332
+ load_kwargs["name"] = dataset_config
333
+ ds = load_dataset(dataset_name, **load_kwargs)
334
+
335
  print(f"[DATA] Loading tokenizer (splintr o200k_base)...")
336
  tok = ChimeraTokenizer(pretrained="o200k_base")
337
 
338
+ # Parse category filters
339
+ cat_filters = None
340
+ if category_filter:
341
+ cat_filters = [c.strip() for c in category_filter.split(",") if c.strip()]
342
+ print(f"[DATA] Filtering categories: {cat_filters}")
343
+
344
  all_ids = []
345
  target = max_samples * (seq_len + 1) if max_samples else float('inf')
346
+ processed = 0
347
+ skipped = 0
348
+
349
  for i, ex in enumerate(ds):
350
+ # Category filter
351
+ if cat_filters and not _matches_category_filter(ex, cat_filters):
352
+ skipped += 1
353
+ continue
354
+
355
+ text = _format_example(ex, tok, text_column, include_reasoning)
356
+ if not text or not text.strip():
357
+ skipped += 1
358
+ continue
359
+
360
+ all_ids.extend(tok.encode(text, add_special_tokens=False))
361
  all_ids.append(tok.eos_token_id)
362
+ processed += 1
363
+
364
  if len(all_ids) >= target:
365
  break
366
+ if (processed + 1) % 10000 == 0:
367
+ print(f" {processed:,} examples, {len(all_ids):,} tokens...")
368
+
369
+ print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,} (category/text mismatch)")
370
+
371
+ if len(all_ids) == 0:
372
+ raise ValueError(
373
+ f"No data matched filters. dataset={dataset_name}, "
374
+ f"category_filter={category_filter}, text_column={text_column}"
375
+ )
376
 
377
  all_ids = torch.tensor(all_ids, dtype=torch.long)
378
  n = len(all_ids) // (seq_len + 1)
 
486
  print(f"Device: CPU ({torch.get_num_threads()} threads)")
487
  print(f"IPEX: {HAS_IPEX}")
488
  print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)")
489
+ print(f"Dataset: {args.dataset_name} / {args.dataset_split}")
490
+ if args.category_filter:
491
+ print(f"Category filter: {args.category_filter}")
492
+ if args.include_reasoning:
493
+ print("Reasoning: INCLUDED (<|thinking|> ... <|/thinking|>)")
494
 
495
  # ─── Build model ───
496
  model = Chimera51ForCausalLM(config)
 
527
  print("[OPT] Compilation deferred (will compile on first forward pass)")
528
 
529
  # ─── Dataset ───
530
+ dataset, tok = build_dataset(
531
+ args.seq_len,
532
+ max_samples=args.max_samples,
533
+ split=args.dataset_split,
534
+ dataset_name=args.dataset_name,
535
+ dataset_config=args.dataset_config,
536
+ text_column=args.text_column,
537
+ category_filter=args.category_filter,
538
+ include_reasoning=args.include_reasoning,
539
+ )
540
  loader = DataLoader(
541
  dataset,
542
  batch_size=args.batch_size,
 
728
  action="store_false", default=True,
729
  help="Regenerate directions instead of caching them for the step")
730
 
731
+ # Data β€” fully configurable
732
+ p.add_argument("--dataset_name", default="roneneldan/TinyStories",
733
+ help="HuggingFace dataset name (e.g. Roman1111111/claude-sonnet-4.6-120000x)")
734
+ p.add_argument("--dataset_config", default=None,
735
+ help="Dataset config/subset name")
736
+ p.add_argument("--dataset_split", default="train",
737
+ help="Dataset split to use")
738
+ p.add_argument("--text_column", default="auto",
739
+ help="Column containing text. 'auto' detects 'messages'/'text'/'content'/'conversation'")
740
+ p.add_argument("--category_filter", default=None,
741
+ help="Comma-separated category substrings to filter on (e.g. 'C++,python,math')")
742
+ p.add_argument("--include_reasoning", action="store_true", default=False,
743
+ help="Include reasoning/thinking content from assistant messages as <|thinking|>...<|/thinking|>")
744
+
745
+ # Logging / Output
746
  p.add_argument("--num_workers", type=int, default=4)
747
  p.add_argument("--log_every", type=int, default=10)
748
  p.add_argument("--save_every", type=int, default=1000)