| import torch |
| import torch.nn as nn |
|
|
| from torch.utils.data import Dataset, DataLoader , random_split |
| from datasets import load_dataset , concatenate_datasets |
| from tokenizers import Tokenizer |
| from tokenizers.models import BPE,WordLevel |
| from tokenizers.trainers import BpeTrainer,WordLevelTrainer |
| from tokenizers.pre_tokenizers import ByteLevel,Whitespace |
| from tokenizers.processors import TemplateProcessing |
| from tokenizers import decoders |
| from torch.cuda.amp import autocast, GradScaler |
|
|
| import time |
|
|
| from torch.utils.tensorboard import SummaryWriter |
| from itertools import islice |
|
|
| from config import get_weights_file_path, get_config |
|
|
| from tqdm import tqdm |
|
|
| from pathlib import Path |
|
|
| import warnings |
| from dataset import BilingualDataset |
| from model import build_gpt |
|
|
| g = torch.Generator() |
| g.manual_seed(23) |
|
|
| def greedy_decode(model, text,mask, tokenizer, max_len, device): |
| sos_idx = tokenizer.token_to_id('<s>') |
| eos_idx = tokenizer.token_to_id('</s>') |
| |
| decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(text).to(device) |
| while True: |
| if decoder_input.size(1) == max_len: |
| break |
| |
| decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device) |
| |
| out = model.decode(decoder_input, decoder_mask) |
| |
| prob = model.project(out[:,-1]) |
| _, next_word = torch.max(prob, dim=1) |
| |
| |
| decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(text).fill_(next_word.item()).to(device)],dim=1) |
|
|
| if next_word == eos_idx: |
| break |
| |
| return decoder_input.squeeze(0) |
| def generate_text( |
| model, text, mask, tokenizer, max_len, device, |
| temperature=0.7, top_k=50 |
| ): |
| eos_idx = tokenizer.token_to_id('</s>') |
|
|
| |
| decoder_input = text.to(device) |
| if decoder_input.dim() == 1: |
| decoder_input = decoder_input.unsqueeze(0) |
|
|
|
|
| |
| prompt_text = tokenizer.decode(text.squeeze(0).tolist()) |
| print(prompt_text, end="", flush=True) |
|
|
| while len(decoder_input[0]) < max_len - 3: |
| |
| decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device) |
|
|
| |
| out = model.decode(decoder_input, decoder_mask) |
| logits = model.project(out[:, -1]) |
|
|
| |
| logits = logits / temperature |
| top_k_logits, top_k_indices = torch.topk(logits, top_k) |
| probs = torch.softmax(top_k_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| next_token = top_k_indices.gather(-1, next_token) |
|
|
| |
| word = tokenizer.decode([next_token.item()]) |
| print(word, end="", flush=True) |
|
|
| |
| decoder_input = torch.cat([decoder_input, next_token], dim=1) |
|
|
| if next_token.item() == eos_idx: |
| break |
|
|
| print() |
| return decoder_input.squeeze(0) |
|
|
|
|
| def generate_text_(model, text,m, tokenizer, max_len, device, temperature=0.7, top_k=50): |
| sos_idx = tokenizer.token_to_id('<s>') |
| eos_idx = tokenizer.token_to_id('</s>') |
| pad_idx = tokenizer.token_to_id('<pad>') |
| |
| |
| input_tokens = [sos_idx] + tokenizer.encode(text).ids |
| |
| |
| input_tokens = input_tokens[:max_len-1] |
| |
| |
| decoder_input = torch.tensor(input_tokens, device=device).unsqueeze(0) |
| |
| |
| for _ in range(max_len - len(input_tokens)): |
| |
| decoder_mask = causal_mask(decoder_input.size(1)).to(device) |
| |
| |
| out = model.decode(decoder_input, decoder_mask) |
| logits = model.project(out[:, -1]) |
| |
| |
| logits = logits / temperature |
| top_k_logits, top_k_indices = torch.topk(logits, top_k) |
| probs = torch.softmax(top_k_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| next_token = top_k_indices.gather(-1, next_token) |
| |
| |
| word = tokenizer.decode([next_token.item()]) |
| print(word, end="", flush=True) |
| |
| |
| decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1) |
| |
| if next_token.item() == eos_idx: |
| break |
| |
| return decoder_input.squeeze(0) |
|
|
| def run_validation(model,validation_ds, tokenizer, max_len, device, print_msg, global_state, writer, num_examples=2): |
| model.eval() |
| |
| count = 0 |
| pad_token = torch.tensor([tokenizer.token_to_id('<pad>')],dtype=torch.int64) |
| sos_token = torch.tensor([tokenizer.token_to_id('<s>')],dtype=torch.int64) |
| with torch.no_grad(): |
| for batch in validation_ds: |
| count += 1 |
| input_tokens = batch['input'] |
| |
| |
| print("TOKENIZED INPUT : ",input_tokens) |
| input_tokens = input_tokens |
| |
| |
| |
| |
| |
| |
| |
| input_tokens = torch.tensor(input_tokens) |
| |
| mask = causal_mask(input_tokens.size(0)) |
| |
| |
| model_output = generate_text(model, input_tokens, mask, tokenizer, max_len, device) |
| |
| print_msg("Model Output Embedding : ") |
| print_msg(str(model_output.tolist())) |
|
|
| model_out_text = tokenizer.decode(model_output.detach().cpu().numpy()) |
| |
| |
| |
| print_msg(f'SOURCE : {input_tokens}') |
| print_msg(f'PREDICTED : {model_out_text}') |
| |
| if count == num_examples: |
| break |
| |
| def get_all_sentences(ds): |
| for item in ds: |
| yield item['text'] |
|
|
| def get_or_build_tokenizer_(config,ds): |
| tokenizer_path = Path(config['tokenizer_file']) |
| if not Path.exists(tokenizer_path): |
| tokenizer = Tokenizer(WordLevel(unk_token="<unk>")) |
| tokenizer.pre_tokenizer = Whitespace() |
| trainer = WordLevelTrainer(special_tokens=["<s>", "</s>", "<pad>", "<unk>", "<mask>","<user>","<ai>","<search_start>","<search_end>","<think>","</think>"],min_frequency=2) |
| tokenizer.train_from_iterator(get_all_sentences(ds),trainer=trainer) |
| tokenizer.save(str(tokenizer_path)) |
| else: |
| tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
| return tokenizer |
|
|
| def get_or_build_tokenizer(config, ds): |
| tokenizer_path = Path(config['tokenizer_file']) |
|
|
| if not tokenizer_path.exists(): |
| |
| tokenizer = Tokenizer(BPE(unk_token="<unk>")) |
|
|
| |
| tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True) |
| tokenizer.decoder = decoders.ByteLevel() |
|
|
| |
| tokenizer.post_processor = TemplateProcessing( |
| single="<s> $A </s>", |
| pair="<s> $A </s> <s> $B </s>", |
| special_tokens=[ |
| ("<s>", 0), |
| ("</s>", 1), |
| ], |
| ) |
|
|
| |
| trainer = BpeTrainer( |
| vocab_size = 30000, |
| min_frequency=2, |
| special_tokens=["<s>", "</s>", "<pad>", "<unk>", "<mask>","<user>","<ai>","<search_start>","<search_end>","<think>","</think>"] |
| ) |
|
|
| |
| tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer) |
|
|
| |
| tokenizer.save(str(tokenizer_path)) |
|
|
| else: |
| tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
|
|
| return tokenizer |
|
|
| def get_ds(config): |
| |
| ds_raw = load_dataset("json",data_files='./dataset/openwebtext_500k_docs.jsonl',split="train",streaming=True) |
| ds_test = load_dataset("json",data_files='./dataset/openwebtext_test.jsonl',split="train",streaming=True) |
| |
| |
| tokenizer = get_or_build_tokenizer(config,ds_raw) |
| |
| train_ds = BilingualDataset(ds_raw, tokenizer, config['seq_len']) |
| val_ds = BilingualDataset(ds_test, tokenizer, config['seq_len']) |
| train_dataloader = DataLoader(train_ds, num_workers=6,prefetch_factor=2,pin_memory=True,batch_size=config['batch_size']) |
| val_dataloader = DataLoader(val_ds, batch_size=1) |
| |
| return train_dataloader, val_dataloader, tokenizer |
|
|
| def get_model(config, vocab_size): |
| |
| model = build_gpt( vocab_size, config['seq_len'], config['d_model'], config['N'] , config['h'], config['d_ff'],config['dropout']) |
| return model |
|
|
| def validate_model(val_dataloader, model,device,loss_fn,vocab_size): |
| total_loss = 0 |
| model.eval() |
| i = 0 |
| with torch.no_grad(): |
| for batch in val_dataloader: |
| input_tokens = batch['input'].to(device,non_blocking=True) |
| label = batch['label'].to(device,non_blocking=True) |
| decoder_output = model.decode(input_tokens) |
| project_output = model.project(decoder_output) |
| total_loss += loss_fn( |
| project_output.view(-1,vocab_size), |
| label.view(-1) |
| ) |
| i+=1 |
| print(f"Validation loss : {total_loss/i:4f}") |
| |
| |
|
|
|
|
| def train_model(config): |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device : {device}") |
|
|
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| Path(config['model_folder']).mkdir(parents=True, exist_ok=True) |
|
|
| train_dataloader , val_dataloader, tokenizer = get_ds(config) |
| print(tokenizer.get_vocab_size()) |
| model = get_model(config, tokenizer.get_vocab_size()).to(device) |
| |
| writer = SummaryWriter(config['experiment_name']) |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9) |
| scaler = GradScaler() |
|
|
| initial_epoch = 0 |
| global_step = 0 |
| tqdm_state = {'n':0} |
| |
| model_filename = None |
| if config['preload']: |
| model_filename = get_weights_file_path(config, config['preload']) |
| print(f"Preloading Model {model_filename}") |
| state = torch.load(model_filename) |
| model.load_state_dict(state['model_state_dict']) |
| optimizer.load_state_dict(state['optimizer_state_dict']) |
| initial_epoch = state['epoch'] if 'mid-' in model_filename else state['epoch'] + 1 |
| global_step = state['global_step'] |
| tqdm_state = state['tqdm_state'] if 'mid-' in model_filename else {'n':0} |
| else: |
| print("No Model to preload. Setting from scratch.") |
|
|
| loss_fn = nn.CrossEntropyLoss( |
| ignore_index=tokenizer.token_to_id('<pad>'), |
| label_smoothing=0.05 |
| ).to(device) |
| e = 0 |
| |
| try: |
|
|
| for epoch in range(initial_epoch, config['num_epochs']): |
| model.train() |
| batch_iterator = tqdm(islice(train_dataloader,tqdm_state['n'],None), desc=f'Processing epoch {epoch:02d}',initial=tqdm_state['n'] ,total=140000) |
| e = epoch |
| if 'elapsed' in tqdm_state and "mid-" in model_filename : |
| batch_iterator.start_t = time.time() - tqdm_state['elapsed'] |
| |
| for batch in batch_iterator: |
| |
| |
| |
| input_tokens = batch['input'].to(device,non_blocking=True) |
| label = batch['label'].to(device,non_blocking=True) |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| |
| with autocast(dtype=torch.float16): |
| decoder_output = model.decode(input_tokens) |
| project_output = model.project(decoder_output) |
|
|
| loss = loss_fn( |
| project_output.view(-1, tokenizer.get_vocab_size()), |
| label.view(-1) |
| ) |
| if global_step%10 ==0: |
| batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"}) |
| writer.add_scalar("train loss", loss.item(), global_step) |
| writer.flush() |
| if global_step % 10000 == 0 and global_step != 0: |
| validate_model(val_dataloader,model,device,loss_fn,tokenizer.get_vocab_size()) |
|
|
| |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| global_step += 1 |
| tqdm_state = {'n': batch_iterator.n,'elapsed':batch_iterator.format_dict["elapsed"]} |
| |
| tqdm_state['n'] = 0 |
| del tqdm_state['elapsed'] |
|
|
| model_filename = get_weights_file_path(config, f'{epoch:02d}') |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'global_step': global_step, |
| 'tqdm_state':tqdm_state |
| }, model_filename) |
| validate_model(validate_model,model,device,loss_fn,tokenizer.get_vocab_size()) |
| except KeyboardInterrupt: |
| print("You are stoping training : ... ") |
| model_filename = get_weights_file_path(config, f'mid-{e:02d}{input("Checkpoint Name: ")}') |
| torch.save({ |
| 'epoch': e, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'global_step': global_step, |
| 'tqdm_state':tqdm_state |
| }, model_filename) |
| |
| if __name__ == "__main__": |
| warnings.filterwarnings('ignore') |
| config = get_config("./openweb.config.json") |
| train_model(config) |
|
|