| |
| import os |
|
|
| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts |
|
|
| import wandb |
| from tqdm import tqdm |
| from transformers import BloomForCausalLM, BloomTokenizerFast |
| from gated_state_spaces_pytorch import GatedStateSpacesLM |
| from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper |
|
|
| |
| from pile_hf import ThePile, ThePileTokenized |
| from accelerate import Accelerator |
|
|
|
|
| def main(): |
| accelerator = Accelerator( |
| log_with="wandb", |
| gradient_accumulation_steps=8192, |
| ) |
| accelerator.init_trackers("gated-state-space") |
|
|
| emb_fn = "emb.pt" |
| model_name = "bigscience/bloomz-1b7" |
| if not os.path.isfile(emb_fn): |
| bloom = BloomForCausalLM.from_pretrained(model_name) |
| wte = bloom.transformer.word_embeddings.state_dict() |
| torch.save(wte, emb_fn) |
| else: |
| wte = torch.load(emb_fn) |
|
|
| f_emb = 2048 |
| n_vocab = 250880 |
| model = AutoregressiveWrapper( |
| GatedStateSpacesLM( |
| num_tokens=n_vocab, |
| dim=f_emb, |
| depth=24, |
| ), |
| ) |
|
|
| model.net.token_emb.requires_grad_(False) |
| model.net.token_emb.load_state_dict(wte) |
|
|
| to_logits = nn.Linear(f_emb, n_vocab, bias=False) |
| to_logits.requires_grad_(False) |
| to_logits.load_state_dict(wte) |
|
|
| model.net.to_logits = nn.Sequential( |
| nn.LayerNorm(f_emb), |
| to_logits, |
| ) |
| model.load_state_dict(torch.load("model3.pt")) |
| model = model.to(accelerator.device) |
|
|
| if accelerator.is_main_process: |
| wandb.watch(model) |
|
|
| optim = AdamW(model.parameters(), 1e-4) |
| sch = CosineAnnealingWarmRestarts( |
| optim, |
| T_0=1000, |
| T_mult=2, |
| eta_min=1e-7, |
| ) |
|
|
| bs = 1 |
| kk = 2048 |
| tok: BloomTokenizerFast = BloomTokenizerFast.from_pretrained(model_name) |
| dsx = ThePileTokenized( |
| ThePile("train"), |
| tokenizer=tok, |
| max_length=kk, |
| repeat_factor=4 / 3, |
| ) |
| dlx = DataLoader( |
| dsx, |
| batch_size=bs, |
| num_workers=12, |
| ) |
|
|
| prog = tqdm(dlx, disable=not accelerator.is_main_process) |
|
|
| model = accelerator.prepare(model) |
| optim, dlx, sch = accelerator.prepare(optim, dlx, sch) |
|
|
| optim.zero_grad() |
| for i, batch in enumerate(prog): |
| batch = batch.to(accelerator.device) |
| with accelerator.accumulate(model): |
| with accelerator.autocast(): |
| los = model(batch) |
| accelerator.backward(los) |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) |
| optim.step() |
| optim.zero_grad() |
| if not accelerator.optimizer_step_was_skipped: |
| sch.step() |
|
|
| if i % 1000 == 0: |
| unwrapped_model = accelerator.unwrap_model(model) |
| b, n = 1, 512 |
| init = torch.tensor([[2]] * b).to(accelerator.device) |
| prd = unwrapped_model.generate(init, n) |
| prd = [tok.decode(p) for p in prd] |
| try: |
| accelerator.log( |
| dict( |
| text=wandb.Html( |
| "<hr>".join(p.replace("\n", "<br>") for p in prd) |
| ) |
| ), |
| step=i, |
| ) |
| except Exception as ex: |
| accelerator.print("Failed to log to W&B...", ex) |
| sd = unwrapped_model.state_dict() |
| |
| accelerator.save(sd, "model4.pt") |
|
|
| if i % 10 == 0: |
| accelerator.log( |
| dict( |
| loss=los.item(), |
| lr=optim.param_groups[0]["lr"], |
| ), |
| step=i, |
| ) |
| prog.set_postfix(loss=los.item()) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|