Upload train.py
Browse files
train.py
CHANGED
|
@@ -309,7 +309,7 @@ def _format_example(ex: dict, tok, text_column: str = "auto", include_reasoning:
|
|
| 309 |
return str(ex)
|
| 310 |
|
| 311 |
|
| 312 |
-
def build_dataset(seq_len: int, max_samples=None, split: str = "train",
|
| 313 |
dataset_name: str = "roneneldan/TinyStories",
|
| 314 |
dataset_config: str = None,
|
| 315 |
text_column: str = "auto",
|
|
@@ -322,6 +322,7 @@ def build_dataset(seq_len: int, max_samples=None, split: str = "train",
|
|
| 322 |
- Messages/chat format (auto-detected, uses apply_chat_template)
|
| 323 |
- Category filtering (comma-separated substrings)
|
| 324 |
- Streaming for huge datasets
|
|
|
|
| 325 |
"""
|
| 326 |
from datasets import load_dataset
|
| 327 |
from chimera import ChimeraTokenizer
|
|
@@ -341,30 +342,77 @@ def build_dataset(seq_len: int, max_samples=None, split: str = "train",
|
|
| 341 |
cat_filters = [c.strip() for c in category_filter.split(",") if c.strip()]
|
| 342 |
print(f"[DATA] Filtering categories: {cat_filters}")
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
processed = 0
|
| 347 |
skipped = 0
|
| 348 |
|
| 349 |
-
|
| 350 |
-
#
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,} (category/text mismatch)")
|
| 370 |
|
|
@@ -374,7 +422,6 @@ def build_dataset(seq_len: int, max_samples=None, split: str = "train",
|
|
| 374 |
f"category_filter={category_filter}, text_column={text_column}"
|
| 375 |
)
|
| 376 |
|
| 377 |
-
all_ids = torch.tensor(all_ids, dtype=torch.long)
|
| 378 |
n = len(all_ids) // (seq_len + 1)
|
| 379 |
if max_samples:
|
| 380 |
n = min(n, max_samples)
|
|
@@ -487,6 +534,8 @@ def train(args):
|
|
| 487 |
print(f"IPEX: {HAS_IPEX}")
|
| 488 |
print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)")
|
| 489 |
print(f"Dataset: {args.dataset_name} / {args.dataset_split}")
|
|
|
|
|
|
|
| 490 |
if args.category_filter:
|
| 491 |
print(f"Category filter: {args.category_filter}")
|
| 492 |
if args.include_reasoning:
|
|
@@ -530,6 +579,7 @@ def train(args):
|
|
| 530 |
dataset, tok = build_dataset(
|
| 531 |
args.seq_len,
|
| 532 |
max_samples=args.max_samples,
|
|
|
|
| 533 |
split=args.dataset_split,
|
| 534 |
dataset_name=args.dataset_name,
|
| 535 |
dataset_config=args.dataset_config,
|
|
@@ -710,7 +760,10 @@ if __name__ == "__main__":
|
|
| 710 |
p.add_argument("--lr", type=float, default=1e-3)
|
| 711 |
p.add_argument("--warmup", type=int, default=200)
|
| 712 |
p.add_argument("--max_steps", type=int, default=5000)
|
| 713 |
-
p.add_argument("--max_samples", type=int, default=None
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
# CPU Optimizations
|
| 716 |
p.add_argument("--bf16", action="store_true", default=True,
|
|
|
|
| 309 |
return str(ex)
|
| 310 |
|
| 311 |
|
| 312 |
+
def build_dataset(seq_len: int, max_samples=None, max_tokens=None, split: str = "train",
|
| 313 |
dataset_name: str = "roneneldan/TinyStories",
|
| 314 |
dataset_config: str = None,
|
| 315 |
text_column: str = "auto",
|
|
|
|
| 322 |
- Messages/chat format (auto-detected, uses apply_chat_template)
|
| 323 |
- Category filtering (comma-separated substrings)
|
| 324 |
- Streaming for huge datasets
|
| 325 |
+
- Pre-allocated token buffer to avoid OOM on billion-token datasets
|
| 326 |
"""
|
| 327 |
from datasets import load_dataset
|
| 328 |
from chimera import ChimeraTokenizer
|
|
|
|
| 342 |
cat_filters = [c.strip() for c in category_filter.split(",") if c.strip()]
|
| 343 |
print(f"[DATA] Filtering categories: {cat_filters}")
|
| 344 |
|
| 345 |
+
# Determine token budget
|
| 346 |
+
if max_tokens is not None:
|
| 347 |
+
token_budget = max_tokens
|
| 348 |
+
elif max_samples is not None:
|
| 349 |
+
token_budget = max_samples * (seq_len + 1)
|
| 350 |
+
else:
|
| 351 |
+
token_budget = None
|
| 352 |
+
|
| 353 |
processed = 0
|
| 354 |
skipped = 0
|
| 355 |
|
| 356 |
+
if token_budget is not None and token_budget > 0:
|
| 357 |
+
# Pre-allocated flat buffer — avoids Python list overhead (~28 bytes/token)
|
| 358 |
+
buffer = torch.empty(token_budget, dtype=torch.long)
|
| 359 |
+
buf_idx = 0
|
| 360 |
+
|
| 361 |
+
for i, ex in enumerate(ds):
|
| 362 |
+
if cat_filters and not _matches_category_filter(ex, cat_filters):
|
| 363 |
+
skipped += 1
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
text = _format_example(ex, tok, text_column, include_reasoning)
|
| 367 |
+
if not text or not text.strip():
|
| 368 |
+
skipped += 1
|
| 369 |
+
continue
|
| 370 |
+
|
| 371 |
+
ids = tok.encode(text, add_special_tokens=False)
|
| 372 |
+
ids.append(tok.eos_token_id)
|
| 373 |
+
n_ids = len(ids)
|
| 374 |
+
|
| 375 |
+
# Truncate if we would exceed the buffer
|
| 376 |
+
if buf_idx + n_ids > token_budget:
|
| 377 |
+
n_ids = token_budget - buf_idx
|
| 378 |
+
if n_ids <= 0:
|
| 379 |
+
break
|
| 380 |
+
ids = ids[:n_ids]
|
| 381 |
+
|
| 382 |
+
if n_ids > 0:
|
| 383 |
+
buffer[buf_idx:buf_idx + n_ids] = torch.tensor(ids, dtype=torch.long)
|
| 384 |
+
buf_idx += n_ids
|
| 385 |
+
processed += 1
|
| 386 |
+
|
| 387 |
+
if buf_idx >= token_budget:
|
| 388 |
+
break
|
| 389 |
+
if (processed + 1) % 10000 == 0:
|
| 390 |
+
print(f" {processed:,} examples, {buf_idx:,} tokens...")
|
| 391 |
+
|
| 392 |
+
all_ids = buffer[:buf_idx]
|
| 393 |
+
else:
|
| 394 |
+
# Fallback: old list approach for unbounded collection
|
| 395 |
+
all_ids = []
|
| 396 |
+
target = max_samples * (seq_len + 1) if max_samples else float('inf')
|
| 397 |
+
for i, ex in enumerate(ds):
|
| 398 |
+
if cat_filters and not _matches_category_filter(ex, cat_filters):
|
| 399 |
+
skipped += 1
|
| 400 |
+
continue
|
| 401 |
+
|
| 402 |
+
text = _format_example(ex, tok, text_column, include_reasoning)
|
| 403 |
+
if not text or not text.strip():
|
| 404 |
+
skipped += 1
|
| 405 |
+
continue
|
| 406 |
+
|
| 407 |
+
all_ids.extend(tok.encode(text, add_special_tokens=False))
|
| 408 |
+
all_ids.append(tok.eos_token_id)
|
| 409 |
+
processed += 1
|
| 410 |
+
|
| 411 |
+
if len(all_ids) >= target:
|
| 412 |
+
break
|
| 413 |
+
if (processed + 1) % 10000 == 0:
|
| 414 |
+
print(f" {processed:,} examples, {len(all_ids):,} tokens...")
|
| 415 |
+
all_ids = torch.tensor(all_ids, dtype=torch.long)
|
| 416 |
|
| 417 |
print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,} (category/text mismatch)")
|
| 418 |
|
|
|
|
| 422 |
f"category_filter={category_filter}, text_column={text_column}"
|
| 423 |
)
|
| 424 |
|
|
|
|
| 425 |
n = len(all_ids) // (seq_len + 1)
|
| 426 |
if max_samples:
|
| 427 |
n = min(n, max_samples)
|
|
|
|
| 534 |
print(f"IPEX: {HAS_IPEX}")
|
| 535 |
print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)")
|
| 536 |
print(f"Dataset: {args.dataset_name} / {args.dataset_split}")
|
| 537 |
+
if args.dataset_config:
|
| 538 |
+
print(f"Dataset config: {args.dataset_config}")
|
| 539 |
if args.category_filter:
|
| 540 |
print(f"Category filter: {args.category_filter}")
|
| 541 |
if args.include_reasoning:
|
|
|
|
| 579 |
dataset, tok = build_dataset(
|
| 580 |
args.seq_len,
|
| 581 |
max_samples=args.max_samples,
|
| 582 |
+
max_tokens=args.max_tokens,
|
| 583 |
split=args.dataset_split,
|
| 584 |
dataset_name=args.dataset_name,
|
| 585 |
dataset_config=args.dataset_config,
|
|
|
|
| 760 |
p.add_argument("--lr", type=float, default=1e-3)
|
| 761 |
p.add_argument("--warmup", type=int, default=200)
|
| 762 |
p.add_argument("--max_steps", type=int, default=5000)
|
| 763 |
+
p.add_argument("--max_samples", type=int, default=None,
|
| 764 |
+
help="Maximum number of chunks to generate")
|
| 765 |
+
p.add_argument("--max_tokens", type=int, default=None,
|
| 766 |
+
help="Maximum total tokens to collect (pre-allocated buffer, prevents OOM on huge datasets)")
|
| 767 |
|
| 768 |
# CPU Optimizations
|
| 769 |
p.add_argument("--bf16", action="store_true", default=True,
|