|
|
| import torch |
| from models.gem_model import GEM |
| from utils.data_preprocessing import load_tokenizer |
| from configs.config import MODEL_CONFIG |
|
|
| def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7): |
| device = torch.device(MODEL_CONFIG['DEVICE']) |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) |
| generated = model.generate(input_ids, max_length=max_length, temperature=temperature) |
| return tokenizer.decode(generated[0], skip_special_tokens=True) |
|
|
| def main(): |
| device = torch.device(MODEL_CONFIG['DEVICE']) |
|
|
| tokenizer = load_tokenizer() |
|
|
| model = GEM( |
| vocab_size=MODEL_CONFIG['VOCAB_SIZE'], |
| d_model=MODEL_CONFIG['D_MODEL'], |
| n_heads=MODEL_CONFIG['N_HEADS'], |
| d_ff=MODEL_CONFIG['D_FF'], |
| n_layers=MODEL_CONFIG['N_LAYERS'], |
| max_seq_len=MODEL_CONFIG['MAX_SEQ_LEN'], |
| dropout=MODEL_CONFIG['DROPOUT'] |
| ).to(device) |
|
|
| checkpoint = torch.load('final_model/model.pt') |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.eval() |
|
|
| prompt = "Once upon a time" |
| generated_text = generate_text(model, tokenizer, prompt, max_length=100) |
| print(f"Generated text:\n{generated_text}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|