| import os |
| import time |
| import math |
| import pickle |
| from contextlib import nullcontext |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.distributed import init_process_group, destroy_process_group |
| from mamba_lm import MambaLM, MambaLMConfig |
| import pyarrow.parquet as pq |
| import random |
| from torch.utils.data import Dataset, DataLoader |
| import glob |
|
|
| |
| |
| |
| out_dir = 'out' |
| eval_interval = 2000 |
| log_interval = 1 |
| eval_iters = 5 |
| eval_only = False |
| always_save_checkpoint = True |
| init_from = 'resume' |
| |
| wandb_log = False |
| wandb_project = 'mamba' |
| wandb_run_name = 'mamba_run' |
| |
| dataset = 'chess' |
| gradient_accumulation_steps = 5 * 8 |
| batch_size = 12 |
| base_batch_size = batch_size |
| effective_batch_size = batch_size |
| max_seq_len = 1024 |
| train_file_update_interval = 7 |
|
|
| |
| n_layer = 12 |
| d_model = 768 |
| dt_rank = 'auto' |
| d_state = 16 |
| expand_factor = 2 |
| bias = False |
| conv_bias = True |
| pscan = True |
| vocab_size = 32000 |
| move_num_in_gamestate = True |
|
|
| |
| learning_rate = 6e-4 |
| max_iters = 600000 |
| weight_decay = 1e-1 |
| beta1 = 0.9 |
| beta2 = 0.95 |
| grad_clip = 1.0 |
| auto_clip = False |
| grad_clip_start_size = 100 |
| grad_clip_max_size = 500 |
| grad_clip_percentile = 10 |
| |
| decay_lr = True |
| warmup_iters = 2000 |
| lr_decay_iters = 600000 |
| min_lr = 6e-5 |
| |
| backend = 'nccl' |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32' |
| compile = False |
| |
|
|
| config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] |
| exec(open('configurator.py').read()) |
| config = {k: globals()[k] for k in config_keys} |
| |
|
|
| anneal_checkpoint = 'anneal/ckpt.pt' |
| anneal_dir = os.path.join(out_dir, 'anneal/') |
| anneal_start_iters = None |
| anneal_decay_iters = None |
|
|
| mamba_config = MambaLMConfig( |
| d_model=d_model, |
| n_layers=n_layer, |
| dt_rank=dt_rank, |
| d_state=d_state, |
| expand_factor=expand_factor, |
| bias=bias, |
| conv_bias=conv_bias, |
| pscan=pscan, |
| vocab_size=vocab_size |
| ) |
|
|
| |
| ddp = int(os.environ.get('RANK', -1)) != -1 |
| if ddp: |
| init_process_group(backend=backend) |
| ddp_rank = int(os.environ['RANK']) |
| ddp_local_rank = int(os.environ['LOCAL_RANK']) |
| ddp_world_size = int(os.environ['WORLD_SIZE']) |
| device = f'cuda:{ddp_local_rank}' |
| torch.cuda.set_device(device) |
| master_process = ddp_rank == 0 |
| seed_offset = ddp_rank |
| assert gradient_accumulation_steps % ddp_world_size == 0 |
| gradient_accumulation_steps //= ddp_world_size |
| else: |
| master_process = True |
| seed_offset = 0 |
| ddp_world_size = 1 |
| tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len |
|
|
| if master_process: |
| os.makedirs(out_dir, exist_ok=True) |
| os.makedirs(anneal_dir, exist_ok=True) |
| torch.manual_seed(1337 + seed_offset) |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| device_type = 'cuda' if 'cuda' in device else 'cpu' |
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype] |
| ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
| |
| data_dir = os.path.join('data', dataset) |
| current_train_file_index = 0 |
| train_files = glob.glob(os.path.join(data_dir, 'train*.parquet')) |
| train_datasets = [] |
| for f in train_files: |
| dataset = pq.read_table(f).to_pandas() |
| dataset = dataset[dataset['tokenized'].apply(len) >= 8] |
| train_datasets.append(dataset) |
| |
| |
| truncated_games_count = 0 |
| total_games_count = 0 |
| games_seen = 0 |
| def get_batch(split): |
| global truncated_games_count, total_games_count, current_train_file_index |
|
|
| |
| dataset = train_datasets[current_train_file_index] if split == 'train' else None |
| sample_df = dataset.sample(batch_size) |
| games = sample_df['tokenized'].tolist() |
|
|
| |
| max_length_in_batch = min(max(len(game) for game in games), max_seq_len) |
| sequences = torch.zeros((batch_size, max_length_in_batch), dtype=torch.int64) |
|
|
| for i, game in enumerate(games): |
| total_games_count += 1 |
|
|
| if len(game) > max_seq_len: |
| truncated_games_count += 1 |
| |
| truncation_choice = random.choice(['beginning', 'end', 'end2', 'random']) |
| if truncation_choice == 'beginning': |
| |
| truncated_game = game[-max_seq_len:] |
| elif truncation_choice.startswith('end'): |
| |
| truncated_game = game[:max_seq_len] |
| else: |
| |
| start_idx = random.randint(0, len(game) - max_seq_len) |
| truncated_game = game[start_idx:start_idx + max_seq_len] |
| sequences[i, :len(truncated_game)] = torch.tensor(truncated_game, dtype=torch.int64) |
| |
| if truncated_games_count > 0 and truncated_games_count % 50 == 0: |
| truncated_percentage = (truncated_games_count / total_games_count) * 100 |
| print(f"Percentage of truncated games: {truncated_percentage:.2f}%\t\t({truncated_games_count}/{total_games_count})") |
| else: |
| sequences[i, :len(game)] = torch.tensor(game, dtype=torch.int64) |
| |
| if (total_games_count // batch_size) % train_file_update_interval == 0: |
| current_train_file_index = random.randint(0, len(train_files) - 1) |
| |
| |
| if device_type == 'cuda': |
| sequences = sequences.pin_memory().to(device, non_blocking=True) |
| else: |
| sequences = sequences.to(device) |
|
|
| return sequences |
| |
| |
| iter_num = 0 |
| best_val_loss = 1e9 |
|
|
| |
| meta_path = os.path.join(data_dir, 'meta.pkl') |
| meta_vocab_size = None |
| if not move_num_in_gamestate: |
| meta_vocab_size = 28 |
| elif os.path.exists(meta_path): |
| with open(meta_path, 'rb') as f: |
| meta = pickle.load(f) |
| meta_vocab_size = meta['vocab_size'] |
| print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") |
| |
| |
| if init_from == 'scratch': |
| print("Initializing a new Mamba model from scratch") |
| if meta_vocab_size is None: |
| print(f"defaulting to vocab_size of {vocab_size}") |
| else: |
| mamba_config.vocab_size = meta_vocab_size |
| model = MambaLM(mamba_config) |
| if auto_clip: |
| grad_clip = 0 |
| config['grad_clip'] = 0 |
| grad_norm_history = [] |
| elif init_from == 'resume' or init_from == 'anneal': |
| print(f"Resuming training from {out_dir}") |
| if init_from == 'anneal': |
| ckpt_path = os.path.join(out_dir, anneal_checkpoint) |
| else: |
| ckpt_path = os.path.join(out_dir, 'ckpt.pt') |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| mamba_config = checkpoint['model_args'] |
| model = MambaLM(mamba_config) |
| state_dict = checkpoint['model'] |
| |
| |
| unwanted_prefix = '_orig_mod.' |
| for k,v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| model.load_state_dict(state_dict) |
| if 'effective_batch_size' not in checkpoint['config']: |
| print("Checkpoint was saved without `effective_batch_size`, assuming current value (will save with next checkpoint). This is used for correcting `iter_num` when the effetive batch size is changed.") |
| checkpoint['config']['effective_batch_size'] = effective_batch_size |
| iter_num = int(round(checkpoint['iter_num'] * (checkpoint['config']['effective_batch_size'] / effective_batch_size))) |
| if 'games_seen' in checkpoint: |
| games_seen = checkpoint['games_seen'] |
| else: |
| games_seen = checkpoint['config']['effective_batch_size'] * checkpoint['iter_num'] |
| checkpoint['games_seen'] = games_seen |
| print(f"Checkpoint was saved without `games_seen`, assuming checkpoint's effective batch size * iters (will save with next checkpoint). {games_seen}") |
| best_val_loss = checkpoint['best_val_loss'] |
| print(f"Best val loss: {best_val_loss}") |
| if auto_clip: |
| grad_clip = checkpoint['config']['grad_clip'] |
| config['grad_clip'] = grad_clip |
| |
| grad_norm_history = checkpoint.get('grad_norm_history', []) |
| if init_from == 'anneal': |
| print(f"\n\nANNEAL STARTING/RESUMING FROM ITERNUM: {iter_num} ({games_seen} games)\n\n") |
| anneal_start_iters = iter_num if 'anneal_start_iters' not in checkpoint else checkpoint['anneal_start_iters'] |
| anneal_decay_iters = iter_num / 7.0 if 'anneal_decay_iters' not in checkpoint else checkpoint['anneal_decay_iters'] |
| print(anneal_start_iters) |
| print(anneal_decay_iters) |
| if 'anneal_start_iters' not in checkpoint: |
| grad_clip = 0 |
| config['grad_clip'] = 0 |
| grad_norm_history = [] |
| print(f"Starting anneal. Resumed from anneal_me.pt, will now decay learning rate for {anneal_decay_iters} / until iter_num {anneal_start_iters + anneal_decay_iters}.") |
| out_dir = anneal_dir |
| weight_decay = weight_decay / 10.0 |
| beta2 = np.sqrt(beta2) * beta2 |
| auto_clip = True |
| grad_clip_percentile = 6.3333 |
| elif init_from.startswith('state-spaces'): |
| print(f"Initializing from Mamba pre-trained weights: {init_from}") |
| model = from_pretrained(init_from) |
| mamba_config = model.config |
| else: |
| raise ValueError("Invalid init_from value") |
|
|
| model.to(device) |
|
|
| print(f'Model with {sum([p.numel() for p in model.parameters()])} parameters loaded.') |
|
|
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) |
| scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16') |
| if init_from == 'resume': |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| checkpoint = None |
|
|
| |
| if compile: |
| print("compiling the model... (takes a ~minute)") |
| model = torch.compile(model) |
|
|
| |
| if ddp: |
| model = DDP(model, device_ids=[ddp_local_rank]) |
| |
|
|
| @torch.no_grad() |
| def estimate_loss(): |
| out = {} |
| model.eval() |
| for split in ['train']: |
| losses = torch.zeros(eval_iters) |
| for k in range(eval_iters): |
| tokens = get_batch(split) |
| logits = model(tokens[:, :-1]) |
|
|
| |
| targets = tokens[:, 1:].reshape(-1) |
|
|
| |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets) |
| losses[k] = loss.item() |
|
|
| split = 'val' |
| out[split] = losses.mean() |
| model.train() |
| return out |
|
|
|
|
| |
| def get_lr(it): |
| if init_from == 'anneal': |
| |
| decay_ratio = min(it - anneal_start_iters, anneal_decay_iters) / anneal_decay_iters |
| return learning_rate - decay_ratio * (learning_rate - min_lr) |
|
|
| if it < warmup_iters: |
| |
| return learning_rate * it / warmup_iters |
| |
| |
| return learning_rate |
|
|
| |
| if wandb_log and master_process: |
| import wandb |
| wandb.init(project=wandb_project, name=wandb_run_name, config=config) |
|
|
| |
| local_iter_num = 0 |
| last_crossed_multiple = 0 |
| save_every_n_games = 150000 |
| raw_model = model.module if ddp else model |
|
|
| t0 = time.time() |
| while True: |
| |
| lr = get_lr(iter_num) if decay_lr else learning_rate |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
|
|
| |
| if iter_num % eval_interval == 0 and master_process: |
| losses = estimate_loss() |
| print(f"\ngame {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): 'val' loss {losses['val']:.4f}") |
| |
| if auto_clip and len(grad_norm_history) >= grad_clip_start_size: |
| grad_clip = np.percentile(grad_norm_history, grad_clip_percentile) |
| config['grad_clip'] = grad_clip |
| print(f"Auto adjusted grad_clip to {grad_clip}") |
| if wandb_log: |
| wandb.log({ |
| "iter": iter_num, |
| "games": games_seen, |
| |
| "grad_clip": grad_clip, |
| "val/loss": losses['val'], |
| "lr": lr, |
| }) |
| if losses['val'] < best_val_loss or always_save_checkpoint: |
| if iter_num > 0: |
| checkpoint = { |
| 'model': raw_model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'model_args': mamba_config, |
| 'iter_num': iter_num, |
| "games_seen": games_seen, |
| 'best_val_loss': min(best_val_loss, losses['val']), |
| 'config': config, |
| } |
| checkpoint['grad_norm_history'] = grad_norm_history |
| if init_from == 'anneal': |
| checkpoint['anneal_start_iters'] = anneal_start_iters |
| checkpoint['anneal_decay_iters'] = anneal_decay_iters |
| print(f"saving checkpoint to {out_dir}\n") |
| torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) |
| current_nearest_multiple = (games_seen // save_every_n_games) * save_every_n_games |
| if losses['val'] < best_val_loss: |
| best_val_loss = losses['val'] |
| torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt')) |
| elif current_nearest_multiple != last_crossed_multiple: |
| last_crossed_multiple = current_nearest_multiple |
| torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}.pt')) |
| |
| if iter_num == 0 and eval_only: |
| break |
|
|
| |
| for micro_step in range(gradient_accumulation_steps): |
| if ddp: |
| model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) |
|
|
| sequences = get_batch('train') |
| with ctx: |
| logits = model(sequences[:, :-1]) |
| |
| targets = sequences[:, 1:].reshape(-1) |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets) |
| loss = loss / gradient_accumulation_steps |
|
|
| scaler.scale(loss).backward() |
| |
|
|
| |
| if grad_clip != 0.0 or auto_clip: |
| scaler.unscale_(optimizer) |
| total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 999.9) |
| grad_norm_history.append(total_norm.item()) |
| grad_norm_history = grad_norm_history[-grad_clip_max_size:] |
| |
| |
| scaler.step(optimizer) |
| scaler.update() |
| |
| optimizer.zero_grad(set_to_none=True) |
| |
| |
| t1 = time.time() |
| dt = t1 - t0 |
| t0 = t1 |
| if iter_num % log_interval == 0 and master_process: |
| |
| |
| lossf = loss.item() * gradient_accumulation_steps |
| print(f"game {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): loss {lossf:.4f}, time {dt*1000:.2f}ms") |
| if wandb_log: |
| wandb.log({ |
| "iter": iter_num, |
| "games": games_seen, |
| "grad_norm": grad_norm_history[-1] if grad_norm_history else 0, |
| "train/loss": lossf, |
| "lr": lr, |
| }) |
| iter_num += 1 |
| local_iter_num += 1 |
| games_seen += effective_batch_size |
|
|
| |
| if iter_num > max_iters: |
| checkpoint = { |
| 'model': raw_model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'model_args': mamba_config, |
| 'iter_num': iter_num, |
| "games_seen": games_seen, |
| 'best_val_loss': best_val_loss, |
| 'config': config, |
| } |
| checkpoint['grad_norm_history'] = grad_norm_history |
| if init_from == 'anneal': |
| checkpoint['anneal_start_iters'] = anneal_start_iters |
| checkpoint['anneal_decay_iters'] = anneal_decay_iters |
| print(f"Max_iters reached. Saving checkpoint to {out_dir}") |
| torch.save(checkpoint, os.path.join(out_dir, 'ckpt_final.pt')) |
| break |
| |
| if init_from == 'anneal' and iter_num >= anneal_start_iters + anneal_decay_iters: |
| checkpoint = { |
| 'model': raw_model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'model_args': mamba_config, |
| 'iter_num': iter_num, |
| "games_seen": games_seen, |
| 'best_val_loss': best_val_loss, |
| 'config': config, |
| } |
| checkpoint['grad_norm_history'] = grad_norm_history |
| if init_from == 'anneal': |
| checkpoint['anneal_start_iters'] = anneal_start_iters |
| checkpoint['anneal_decay_iters'] = anneal_decay_iters |
| print(f"Anneal complete. Saving checkpoint to {out_dir}") |
| torch.save(checkpoint, os.path.join(out_dir, 'anneal_complete.pt')) |
| break |
|
|
| |
| |
| if ddp: |
| destroy_process_group() |
|
|
|
|