Lgr54HFi commited on
Commit
89aac72
·
verified ·
1 Parent(s): ed37c7e

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +77 -24
train.py CHANGED
@@ -309,7 +309,7 @@ def _format_example(ex: dict, tok, text_column: str = "auto", include_reasoning:
309
  return str(ex)
310
 
311
 
312
- def build_dataset(seq_len: int, max_samples=None, split: str = "train",
313
  dataset_name: str = "roneneldan/TinyStories",
314
  dataset_config: str = None,
315
  text_column: str = "auto",
@@ -322,6 +322,7 @@ def build_dataset(seq_len: int, max_samples=None, split: str = "train",
322
  - Messages/chat format (auto-detected, uses apply_chat_template)
323
  - Category filtering (comma-separated substrings)
324
  - Streaming for huge datasets
 
325
  """
326
  from datasets import load_dataset
327
  from chimera import ChimeraTokenizer
@@ -341,30 +342,77 @@ def build_dataset(seq_len: int, max_samples=None, split: str = "train",
341
  cat_filters = [c.strip() for c in category_filter.split(",") if c.strip()]
342
  print(f"[DATA] Filtering categories: {cat_filters}")
343
 
344
- all_ids = []
345
- target = max_samples * (seq_len + 1) if max_samples else float('inf')
 
 
 
 
 
 
346
  processed = 0
347
  skipped = 0
348
 
349
- for i, ex in enumerate(ds):
350
- # Category filter
351
- if cat_filters and not _matches_category_filter(ex, cat_filters):
352
- skipped += 1
353
- continue
354
-
355
- text = _format_example(ex, tok, text_column, include_reasoning)
356
- if not text or not text.strip():
357
- skipped += 1
358
- continue
359
-
360
- all_ids.extend(tok.encode(text, add_special_tokens=False))
361
- all_ids.append(tok.eos_token_id)
362
- processed += 1
363
-
364
- if len(all_ids) >= target:
365
- break
366
- if (processed + 1) % 10000 == 0:
367
- print(f" {processed:,} examples, {len(all_ids):,} tokens...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,} (category/text mismatch)")
370
 
@@ -374,7 +422,6 @@ def build_dataset(seq_len: int, max_samples=None, split: str = "train",
374
  f"category_filter={category_filter}, text_column={text_column}"
375
  )
376
 
377
- all_ids = torch.tensor(all_ids, dtype=torch.long)
378
  n = len(all_ids) // (seq_len + 1)
379
  if max_samples:
380
  n = min(n, max_samples)
@@ -487,6 +534,8 @@ def train(args):
487
  print(f"IPEX: {HAS_IPEX}")
488
  print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)")
489
  print(f"Dataset: {args.dataset_name} / {args.dataset_split}")
 
 
490
  if args.category_filter:
491
  print(f"Category filter: {args.category_filter}")
492
  if args.include_reasoning:
@@ -530,6 +579,7 @@ def train(args):
530
  dataset, tok = build_dataset(
531
  args.seq_len,
532
  max_samples=args.max_samples,
 
533
  split=args.dataset_split,
534
  dataset_name=args.dataset_name,
535
  dataset_config=args.dataset_config,
@@ -710,7 +760,10 @@ if __name__ == "__main__":
710
  p.add_argument("--lr", type=float, default=1e-3)
711
  p.add_argument("--warmup", type=int, default=200)
712
  p.add_argument("--max_steps", type=int, default=5000)
713
- p.add_argument("--max_samples", type=int, default=None)
 
 
 
714
 
715
  # CPU Optimizations
716
  p.add_argument("--bf16", action="store_true", default=True,
 
309
  return str(ex)
310
 
311
 
312
+ def build_dataset(seq_len: int, max_samples=None, max_tokens=None, split: str = "train",
313
  dataset_name: str = "roneneldan/TinyStories",
314
  dataset_config: str = None,
315
  text_column: str = "auto",
 
322
  - Messages/chat format (auto-detected, uses apply_chat_template)
323
  - Category filtering (comma-separated substrings)
324
  - Streaming for huge datasets
325
+ - Pre-allocated token buffer to avoid OOM on billion-token datasets
326
  """
327
  from datasets import load_dataset
328
  from chimera import ChimeraTokenizer
 
342
  cat_filters = [c.strip() for c in category_filter.split(",") if c.strip()]
343
  print(f"[DATA] Filtering categories: {cat_filters}")
344
 
345
+ # Determine token budget
346
+ if max_tokens is not None:
347
+ token_budget = max_tokens
348
+ elif max_samples is not None:
349
+ token_budget = max_samples * (seq_len + 1)
350
+ else:
351
+ token_budget = None
352
+
353
  processed = 0
354
  skipped = 0
355
 
356
+ if token_budget is not None and token_budget > 0:
357
+ # Pre-allocated flat buffer — avoids Python list overhead (~28 bytes/token)
358
+ buffer = torch.empty(token_budget, dtype=torch.long)
359
+ buf_idx = 0
360
+
361
+ for i, ex in enumerate(ds):
362
+ if cat_filters and not _matches_category_filter(ex, cat_filters):
363
+ skipped += 1
364
+ continue
365
+
366
+ text = _format_example(ex, tok, text_column, include_reasoning)
367
+ if not text or not text.strip():
368
+ skipped += 1
369
+ continue
370
+
371
+ ids = tok.encode(text, add_special_tokens=False)
372
+ ids.append(tok.eos_token_id)
373
+ n_ids = len(ids)
374
+
375
+ # Truncate if we would exceed the buffer
376
+ if buf_idx + n_ids > token_budget:
377
+ n_ids = token_budget - buf_idx
378
+ if n_ids <= 0:
379
+ break
380
+ ids = ids[:n_ids]
381
+
382
+ if n_ids > 0:
383
+ buffer[buf_idx:buf_idx + n_ids] = torch.tensor(ids, dtype=torch.long)
384
+ buf_idx += n_ids
385
+ processed += 1
386
+
387
+ if buf_idx >= token_budget:
388
+ break
389
+ if (processed + 1) % 10000 == 0:
390
+ print(f" {processed:,} examples, {buf_idx:,} tokens...")
391
+
392
+ all_ids = buffer[:buf_idx]
393
+ else:
394
+ # Fallback: old list approach for unbounded collection
395
+ all_ids = []
396
+ target = max_samples * (seq_len + 1) if max_samples else float('inf')
397
+ for i, ex in enumerate(ds):
398
+ if cat_filters and not _matches_category_filter(ex, cat_filters):
399
+ skipped += 1
400
+ continue
401
+
402
+ text = _format_example(ex, tok, text_column, include_reasoning)
403
+ if not text or not text.strip():
404
+ skipped += 1
405
+ continue
406
+
407
+ all_ids.extend(tok.encode(text, add_special_tokens=False))
408
+ all_ids.append(tok.eos_token_id)
409
+ processed += 1
410
+
411
+ if len(all_ids) >= target:
412
+ break
413
+ if (processed + 1) % 10000 == 0:
414
+ print(f" {processed:,} examples, {len(all_ids):,} tokens...")
415
+ all_ids = torch.tensor(all_ids, dtype=torch.long)
416
 
417
  print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,} (category/text mismatch)")
418
 
 
422
  f"category_filter={category_filter}, text_column={text_column}"
423
  )
424
 
 
425
  n = len(all_ids) // (seq_len + 1)
426
  if max_samples:
427
  n = min(n, max_samples)
 
534
  print(f"IPEX: {HAS_IPEX}")
535
  print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)")
536
  print(f"Dataset: {args.dataset_name} / {args.dataset_split}")
537
+ if args.dataset_config:
538
+ print(f"Dataset config: {args.dataset_config}")
539
  if args.category_filter:
540
  print(f"Category filter: {args.category_filter}")
541
  if args.include_reasoning:
 
579
  dataset, tok = build_dataset(
580
  args.seq_len,
581
  max_samples=args.max_samples,
582
+ max_tokens=args.max_tokens,
583
  split=args.dataset_split,
584
  dataset_name=args.dataset_name,
585
  dataset_config=args.dataset_config,
 
760
  p.add_argument("--lr", type=float, default=1e-3)
761
  p.add_argument("--warmup", type=int, default=200)
762
  p.add_argument("--max_steps", type=int, default=5000)
763
+ p.add_argument("--max_samples", type=int, default=None,
764
+ help="Maximum number of chunks to generate")
765
+ p.add_argument("--max_tokens", type=int, default=None,
766
+ help="Maximum total tokens to collect (pre-allocated buffer, prevents OOM on huge datasets)")
767
 
768
  # CPU Optimizations
769
  p.add_argument("--bf16", action="store_true", default=True,