""" Training script for Resonance 200M. ClimbMix data, own BPE tokenizer (Rust backend), AdamW optimizer. Shows BOTH train loss AND val loss. Always. """ import os import sys import time import math import struct import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import autocast, GradScaler from model import Resonance, RESONANCE_200M from bpe_tokenizer import BPETokenizer # ───────────────────────────────────────────────────────────────────────────── # Data # ───────────────────────────────────────────────────────────────────────────── def download_climbmix_shards(data_dir, n_shards=100): """Download ClimbMix parquet shards from HuggingFace.""" os.makedirs(data_dir, exist_ok=True) try: import pyarrow.parquet as pq except ImportError: print("pip install pyarrow pandas") sys.exit(1) base_url = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" texts_path = os.path.join(data_dir, "texts.txt") if os.path.exists(texts_path): size = os.path.getsize(texts_path) print(f" [Data] texts.txt exists ({size/1e9:.2f} GB), skipping download") return texts_path import urllib.request import ssl ctx = ssl.create_default_context() ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE total_bytes = 0 with open(texts_path, 'w', encoding='utf-8') as out: for i in range(n_shards): shard_name = f"shard_{i:05d}.parquet" shard_path = os.path.join(data_dir, shard_name) url = f"{base_url}/{shard_name}" if not os.path.exists(shard_path): print(f" [Data] Downloading shard {i+1}/{n_shards}...", end=" ", flush=True) try: urllib.request.urlretrieve(url, shard_path) print("OK") except Exception as e: print(f"FAIL: {e}") continue # Extract text try: table = pq.read_table(shard_path, columns=['text']) texts = table.column('text').to_pylist() for text in texts: if text and len(text) > 100: out.write(text + '\n') total_bytes += len(text) # Remove parquet to save disk os.remove(shard_path) except Exception as e: print(f" [Data] Error reading shard {i}: {e}") continue if (i + 1) % 10 == 0: print(f" [Data] {i+1}/{n_shards} shards, {total_bytes/1e9:.2f} GB text") print(f" [Data] Total: {total_bytes/1e9:.2f} GB text from {n_shards} shards") return texts_path def tokenize_data(texts_path, tokenizer, data_dir, context_len): """Tokenize text into binary shards (uint16 for vocab < 65536). Streams to disk — no OOM on 16GB+ corpora.""" train_path = os.path.join(data_dir, "train.bin") val_path = os.path.join(data_dir, "val.bin") if os.path.exists(train_path) and os.path.exists(val_path): train_tokens = os.path.getsize(train_path) // 2 val_tokens = os.path.getsize(val_path) // 2 print(f" [Data] Tokenized data exists: train={train_tokens:,} val={val_tokens:,}") return train_tokens, val_tokens print(f" [Data] Tokenizing...") tmp_path = os.path.join(data_dir, "tokens_all.bin") total_tokens = 0 t0 = time.time() with open(texts_path, 'r', encoding='utf-8', errors='replace') as f_in, \ open(tmp_path, 'wb') as f_out: chunk_size = 10_000_000 # 10MB chunks total_chars = 0 while True: text = f_in.read(chunk_size) if not text: break ids = tokenizer.encode(text) arr = np.array(ids, dtype=np.uint16) f_out.write(arr.tobytes()) total_tokens += len(ids) total_chars += len(text) if total_chars % 100_000_000 < chunk_size: elapsed = time.time() - t0 rate = total_chars / elapsed / 1e6 print(f" [Data] {total_chars/1e9:.2f} GB text → {total_tokens:,} tokens " f"({rate:.1f} MB/s, {elapsed:.0f}s)") elapsed = time.time() - t0 print(f" [Data] Tokenized {total_chars/1e9:.2f} GB → {total_tokens:,} tokens in {elapsed:.0f}s") # Split 95/5 train/val — stream from memmap to avoid loading all into RAM split = int(total_tokens * 0.95) print(f" [Data] Splitting: train={split:,} val={total_tokens - split:,}") all_data = np.memmap(tmp_path, dtype=np.uint16, mode='r') # Write train split in chunks chunk = 50_000_000 # 50M tokens per chunk with open(train_path, 'wb') as f: for start in range(0, split, chunk): end = min(start + chunk, split) f.write(all_data[start:end].tobytes()) # Write val split with open(val_path, 'wb') as f: for start in range(split, total_tokens, chunk): end = min(start + chunk, total_tokens) f.write(all_data[start:end].tobytes()) del all_data os.remove(tmp_path) train_tokens = split val_tokens = total_tokens - split print(f" [Data] train: {train_tokens:,} tokens ({train_tokens*2/1e9:.2f} GB)") print(f" [Data] val: {val_tokens:,} tokens ({val_tokens*2/1e9:.2f} GB)") return train_tokens, val_tokens class DataLoader: """Simple random-chunk dataloader from mmap'd binary file.""" def __init__(self, path, context_len, batch_size, device): self.data = np.memmap(path, dtype=np.uint16, mode='r') self.context_len = context_len self.batch_size = batch_size self.device = device self.n_tokens = len(self.data) def get_batch(self): T = self.context_len B = self.batch_size ix = torch.randint(0, self.n_tokens - T - 1, (B,)) x = torch.stack([torch.from_numpy(self.data[i:i+T].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(self.data[i+1:i+T+1].astype(np.int64)) for i in ix]) return x.to(self.device), y.to(self.device) # ───────────────────────────────────────────────────────────────────────────── # Training # ───────────────────────────────────────────────────────────────────────────── def get_lr(step, warmup_steps, total_steps, max_lr, min_lr=0.0): """WSD schedule: warmup → stable → linear decay.""" if step < warmup_steps: return max_lr * (step + 1) / warmup_steps decay_start = total_steps // 2 if step < decay_start: return max_lr # Linear decay progress = (step - decay_start) / (total_steps - decay_start) return max_lr * (1.0 - progress) + min_lr * progress @torch.no_grad() def evaluate(model, val_loader, n_batches=50): """Evaluate val loss. Returns average loss.""" model.eval() losses = [] for _ in range(n_batches): x, y = val_loader.get_batch() with autocast('cuda', dtype=torch.bfloat16): _, loss = model(x, y) losses.append(loss.item()) model.train() return sum(losses) / len(losses) def save_checkpoint(model, optimizer, step, train_loss, val_loss, config, path): """Save PyTorch checkpoint.""" torch.save({ 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': step, 'train_loss': train_loss, 'val_loss': val_loss, 'config': config, }, path) def save_c_weights(model, tokenizer, config, path): """Save weights in C-compatible binary format for resonance-bpe.c.""" with open(path, 'wb') as f: # Header: magic + config f.write(struct.pack('= 2: decay_params.append(p) else: no_decay_params.append(p) optimizer = torch.optim.AdamW([ {'params': decay_params, 'weight_decay': args.weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0}, ], lr=args.lr, betas=(0.9, 0.95), eps=1e-8) scaler = GradScaler('cuda') # Step 6: Data loaders (micro-batch for gradient accumulation) T = config['context_len'] micro_B = args.micro_batch // T # sequences per micro-batch grad_accum = args.batch_size // args.micro_batch print(f"\n[6] DataLoader: effective_batch={args.batch_size} tokens " f"({grad_accum} x {args.micro_batch} micro), {micro_B} seq x {T} ctx") train_loader = DataLoader(os.path.join(data_dir, "train.bin"), T, micro_B, device) val_loader = DataLoader(os.path.join(data_dir, "val.bin"), T, micro_B, device) total_steps = n_train // args.batch_size print(f" Total steps: {total_steps:,}") # Step 7: Train loop print(f"\n[7] Training resonance-200m...") print(f" {'step':>8} | {'train_loss':>10} | {'val_loss':>10} | {'lr':>10} | {'tok/s':>10} | {'time':>8}") print(" " + "-" * 75) best_val_loss = float('inf') running_loss = 0.0 t0 = time.time() tokens_seen = 0 model.train() for step in range(total_steps): # LR schedule lr = get_lr(step, args.warmup_steps, total_steps, args.lr) for pg in optimizer.param_groups: pg['lr'] = lr # Gradient accumulation: grad_accum micro-batches per optimizer step optimizer.zero_grad(set_to_none=True) step_loss = 0.0 for micro_step in range(grad_accum): x, y = train_loader.get_batch() with autocast('cuda', dtype=torch.bfloat16): _, loss = model(x, y) loss = loss / grad_accum scaler.scale(loss).backward() step_loss += loss.item() * grad_accum scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() train_loss = step_loss / grad_accum running_loss += train_loss tokens_seen += args.batch_size # Log every N steps if (step + 1) % args.log_every == 0: avg_train = running_loss / args.log_every running_loss = 0.0 elapsed = time.time() - t0 tok_per_sec = tokens_seen / elapsed # Val loss val_loss = evaluate(model, val_loader, n_batches=args.val_batches) print(f" {step+1:>8} | {avg_train:>10.4f} | {val_loss:>10.4f} | " f"{lr:>10.2e} | {tok_per_sec/1000:>8.1f}k | {elapsed:>7.0f}s") # Save best if val_loss < best_val_loss: best_val_loss = val_loss save_checkpoint(model, optimizer, step, avg_train, val_loss, config, os.path.join(args.save_dir, "best.pt")) # Checkpoint every N steps if (step + 1) % args.save_every == 0: save_checkpoint(model, optimizer, step, train_loss, val_loss if 'val_loss' in dir() else 0, config, os.path.join(args.save_dir, f"step_{step+1}.pt")) save_c_weights(model, tokenizer, config, os.path.join(args.save_dir, f"resonance_200m_step{step+1}.bin")) # Gate monitoring every N steps if (step + 1) % (args.log_every * 5) == 0: gates = [] for block in model._orig_mod.blocks if hasattr(model, '_orig_mod') else model.blocks: g = torch.sigmoid(block.gate).detach().cpu().numpy() gates.append(g.mean()) gate_str = " ".join(f"{g:.2f}" for g in gates) print(f" [gates] {gate_str}") # Final save elapsed = time.time() - t0 print(f"\n Training complete. {elapsed/3600:.1f} hours, {tokens_seen:,} tokens") save_checkpoint(model, optimizer, total_steps, train_loss, best_val_loss, config, os.path.join(args.save_dir, "final.pt")) save_c_weights(model, tokenizer, config, os.path.join(args.save_dir, "resonance_200m_final.bin")) # Re-save tokenizer (paranoia) tokenizer.save_copies(os.path.join(args.save_dir, "tokenizer.bin"), n=3) print(f"\n Best val loss: {best_val_loss:.4f}") print(f" resonance is unbreakable.") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data-dir', type=str, default='data/') parser.add_argument('--save-dir', type=str, default='checkpoints/') parser.add_argument('--n-shards', type=int, default=65, help='Number of ClimbMix shards to download (~65 for ~4B tokens)') parser.add_argument('--vocab-size', type=int, default=None, help='Override vocab size (default: 16384)') parser.add_argument('--batch-size', type=int, default=131072, help='Effective batch size in tokens (default: 131072)') parser.add_argument('--micro-batch', type=int, default=65536, help='Micro-batch size in tokens for grad accum (default: 65536)') parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--warmup-steps', type=int, default=800) parser.add_argument('--weight-decay', type=float, default=0.1) parser.add_argument('--grad-clip', type=float, default=1.0) parser.add_argument('--log-every', type=int, default=100) parser.add_argument('--save-every', type=int, default=2000) parser.add_argument('--val-batches', type=int, default=50) args = parser.parse_args() train(args)