| import torch |
|
|
| from tokenizers import Tokenizer |
|
|
|
|
| from pathlib import Path |
| from config import get_config, get_weights_file_path |
| from train import get_model |
|
|
| def generate_text( |
| model, text, tokenizer, max_len, device, |
| temperature=0.7, top_k=50 |
| ): |
| eos_idx = tokenizer.token_to_id('</s>') |
| pad_idx = tokenizer.token_to_id('<pad>') |
|
|
| |
| decoder_input = text.to(device) |
| if decoder_input.dim() == 1: |
| decoder_input = decoder_input.unsqueeze(0) |
|
|
|
|
| |
|
|
| while decoder_input.shape[1] < 2000 : |
| |
| |
|
|
| |
| out = model.decode(decoder_input) |
| logits = model.project(out[:, -1]) |
|
|
| |
| logits = logits / temperature |
| top_k_logits, top_k_indices = torch.topk(logits, top_k) |
| probs = torch.softmax(top_k_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| next_token = top_k_indices.gather(-1, next_token) |
|
|
| |
| word = tokenizer.decode([next_token.item()]) |
| print(word, end="", flush=True) |
|
|
| |
|
|
| decoder_input = torch.cat([decoder_input, next_token], dim=1) |
| if decoder_input.shape[1] > max_len: |
| decoder_input = decoder_input[:,-max_len:] |
| |
| |
| if next_token.item() == eos_idx: |
| break |
|
|
| print() |
| return decoder_input.squeeze(0) |
|
|
|
|
|
|
| def get_tokenizer(config)->Tokenizer: |
| tokenizers_path = Path(config['tokenizer_file']) |
| if Path.exists(tokenizers_path): |
| print("Loading tokenizer from ", tokenizers_path) |
| tokenizer = Tokenizer.from_file(str(tokenizers_path)) |
| return tokenizer |
| else: |
| raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path) |
| |
| def run_model(config): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device : {device}") |
| tokenizer = get_tokenizer(config) |
| model = get_model(config, tokenizer.get_vocab_size()).to(device) |
| model_path = get_weights_file_path(config,config['preload']) |
| model.eval() |
| |
| if Path.exists(Path(model_path)): |
| print("Loading Model from : ", model_path) |
| state = torch.load(model_path) |
| model.load_state_dict(state['model_state_dict']) |
| print("You : ",end="") |
| input_text = input() |
| pad_token_id = tokenizer.token_to_id("<pad>") |
| while input_text != "exit": |
| input_tokens = tokenizer.encode(input_text).ids[:-1] |
| if len(input_tokens) > config['seq_len']: |
| print(f"exceeding max length of input : {config['seq_len']}") |
| continue |
| |
| |
| input_tokens = torch.tensor(input_tokens) |
| output_tokens = generate_text(model, input_tokens, tokenizer, config['seq_len'], device ) |
| print("MODEL : ",output_tokens) |
| output_text = tokenizer.decode(output_tokens.detach().cpu().numpy()) |
| |
| print("You : ",end="") |
| input_text = input() |
| |
| else: |
| raise FileNotFoundError("Model File not found : "+ model_path) |
| |
| def generate_response(prompt:str): |
| config = get_config("./openweb.config.json") |
| print(config) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer = get_tokenizer(config) |
| pad_token_id = tokenizer.token_to_id("<pad>") |
| eos_token_id = tokenizer.token_to_id("</s>") |
|
|
|
|
| model = get_model(config, tokenizer.get_vocab_size()).to(device) |
| model_path = get_weights_file_path(config,config['preload']) |
| model.eval() |
| state = torch.load(model_path) |
| model.load_state_dict(state['model_state_dict']) |
| input_tokens = tokenizer.encode(prompt).ids[:-1] |
| if len(input_tokens) > config['seq_len']: |
| print(f"exceeding max length of input : {config['seq_len']}") |
| exit() |
| input_tokens = torch.tensor(input_tokens) |
| input_mask = (input_tokens != pad_token_id).unsqueeze(0).int() & causal_mask(input_tokens.size(0)) |
| decoder_input = input_tokens.to(device) |
| if decoder_input.dim() == 1: |
| decoder_input = decoder_input.unsqueeze(0) |
| temperature = 0.7 |
| top_k = 50 |
|
|
| while decoder_input.shape[1] < 2000 : |
| |
| |
| print(decoder_input) |
| |
| out = model.decode(decoder_input) |
| logits = model.project(out[:, -1]) |
| logits = logits / temperature |
| top_k_logits, top_k_indices = torch.topk(logits, top_k) |
| probs = torch.softmax(top_k_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| next_token = top_k_indices.gather(-1, next_token) |
| word = tokenizer.decode([next_token.item()]) |
| yield word |
| decoder_input = torch.cat([decoder_input, next_token], dim=1) |
| if decoder_input.shape[1] > config['seq_len']: |
| decoder_input = decoder_input[:,-config['seq_len']:] |
| if next_token.item() == eos_token_id: |
| break |
| |
| |
|
|
| if __name__ == "__main__": |
| config = get_config("openweb.config.json") |
| run_model(config) |