| import gradio as gr |
| import torch |
| from gpt_dev import GPTLanguageModel, encode, decode, generate_text |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| block_size = 256 |
| n_embd = 384 |
| n_head = 6 |
| n_layer = 6 |
| vocab_size = 95 |
|
|
| |
| model = GPTLanguageModel() |
| model.to(device) |
|
|
| |
| checkpoint = torch.load("gpt_language_model.pth", map_location=device) |
| model.load_state_dict(checkpoint) |
| model.eval() |
|
|
| |
| def generate_response(prompt, max_length=100, temperature=1.0): |
| generated_text = generate_text(model, prompt, max_length=max_length, temperature=temperature) |
| return generated_text |
|
|
| |
| def gradio_interface(prompt, max_length=100, temperature=1.0): |
| return generate_response(prompt, max_length, temperature) |
|
|
| |
| interface = gr.Interface( |
| fn=gradio_interface, |
| inputs=[ |
| gr.Textbox(label="Prompt", value="Once upon a time"), |
| gr.Slider(50, 240, step=1, value=75, label="Max Length"), |
| ], |
| outputs="text", |
| title="Odeyssey Rhyme Generator", |
| description="Enter a prompt to generate text." |
| ) |
|
|
| |
| if __name__ == "__main__": |
| interface.launch(share=True) |
|
|