| import gradio as gr
|
| import torch
|
| import torch.nn.functional as F
|
| import tiktoken
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import DataLoader, Dataset
|
| from transformers import AutoTokenizer
|
| import os
|
|
|
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| class SmolLM(nn.Module):
|
| def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len):
|
| super(SmolLM, self).__init__()
|
| self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
|
| self.layers = nn.ModuleList([
|
| nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
|
| for _ in range(num_layers)
|
| ])
|
| self.fc_out = nn.Linear(embed_dim, vocab_size)
|
|
|
| def forward(self, x):
|
| seq_len = x.size(1)
|
| x = self.embedding(x) + self.pos_embedding[:, :seq_len, :]
|
| for layer in self.layers:
|
| x = layer(x)
|
| return self.fc_out(x)
|
|
|
|
|
| def load_model():
|
| checkpoint_path = 'final_checkpoint.pth'
|
| embed_dim = 512
|
| num_heads = 8
|
| num_layers = 4
|
| max_seq_len = 128
|
| vocab_size = 50257
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| model = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| model = model.to(device)
|
| model.eval()
|
|
|
|
|
| for param in model.parameters():
|
| param.requires_grad = False
|
|
|
| return model
|
|
|
| model = load_model()
|
|
|
|
|
|
|
| model.train(False)
|
|
|
| def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
|
|
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| outputs = model(input_ids)
|
| predictions = torch.argmax(outputs, dim=-1)
|
| decoded = tokenizer.decode(predictions[0], skip_special_tokens=True)
|
| return decoded
|
|
|
|
|
|
|
| iface = gr.Interface(
|
| fn=generate_text,
|
| inputs=[
|
| gr.Textbox(label="Prompt", value="Good night, good night! Parting is such sweet sorrow"),
|
| gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
|
| gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
|
| ],
|
| outputs=gr.Textbox(label="Generated Text"),
|
| title="Shakesphere Text Generator",
|
| description="Enter text for Shakesphere way of text and continue the same",
|
| examples=[
|
| ["There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy.", 100, 1],
|
| ["Love all, trust a few, do wrong to none.", 60, 2],
|
| ["It's not enough to speak, but to speak true", 50, 3],
|
| ["To be, or not to be: that is the question.", 100, 1],
|
| ["If you can look into the seeds of time, and say which grain will grow and which will not, speak then to me", 100, 1],
|
| ["Love sought is good, but given unsought is better.", 100, 1],
|
| ]
|
| )
|
|
|
| if __name__ == "__main__":
|
| iface.launch() |