| from huggingface_hub import hf_hub_download |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.utils.data import Dataset, DataLoader, random_split |
| import urllib.request |
| import os |
| from transformers import AutoTokenizer, logging |
| import pandas as pd |
| from tqdm import tqdm |
| from safetensors.torch import load_file |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, emb_dim, num_heads, context_length, dropout=0.1): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(emb_dim) |
| self.ln2 = nn.LayerNorm(emb_dim) |
| self.attn = nn.MultiheadAttention( |
| emb_dim, num_heads, dropout=dropout, batch_first=True |
| ) |
| self.mlp = nn.Sequential( |
| nn.Linear(emb_dim, 4 * emb_dim), |
| nn.GELU(), |
| nn.Linear(4 * emb_dim, emb_dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x): |
| attn_out, _ = self.attn( |
| self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False |
| ) |
| x = x + attn_out |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| class MiniTransformer(nn.Module): |
| def __init__( |
| self, |
| vocab_size, |
| emb_dim, |
| context_length, |
| num_heads, |
| num_layers, |
| dropout=0.1, |
| ): |
| super().__init__() |
| self.emb = nn.Embedding(vocab_size, emb_dim) |
| self.pos_emb = nn.Embedding(context_length, emb_dim) |
| self.blocks = nn.Sequential( |
| *[ |
| TransformerBlock(emb_dim, num_heads, context_length, dropout) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.ln_f = nn.LayerNorm(emb_dim) |
| self.head = nn.Linear(emb_dim, vocab_size, bias=False) |
| self.context_length = context_length |
|
|
| def forward(self, x): |
| B, T = x.shape |
| pos = torch.arange(T, device=x.device) |
| x = self.emb(x) + self.pos_emb(pos) |
| x = self.blocks(x) |
| x = self.ln_f(x) |
| logits = self.head(x) |
| return logits |
|
|
| @torch.no_grad() |
| def generate(self, x, max_new_tokens=20, temperature=1.0, top_k=None): |
|
|
| for _ in range(max_new_tokens): |
| |
| x_cond = x[:, -self.context_length :] |
|
|
| |
| logits = self(x_cond) |
| logits = logits[:, -1, :] / temperature |
|
|
| |
|
|
| probs = F.softmax(logits, dim=-1) |
|
|
| |
| next_token = torch.multinomial(probs, num_samples=1) |
| |
| |
| x = torch.cat([x, next_token], dim=1) |
|
|
| return x |
|
|
|
|
| CONTEXT_LENGTH = 256 |
| EMBEDDING_DIMENSION = 512 |
| HEAD_NUMBER = 8 |
| N_LAYER = 6 |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps") |
|
|
| |
| model_path = hf_hub_download( |
| repo_id="pierjoe/MiniTransformer", |
| filename="checkpoints/mini_transformer_v4/model_50.safetensors", |
| ) |
|
|
| |
| model = MiniTransformer( |
| vocab_size=tokenizer.vocab_size, |
| emb_dim=EMBEDDING_DIMENSION, |
| context_length=CONTEXT_LENGTH, |
| num_heads=HEAD_NUMBER, |
| num_layers=N_LAYER, |
| ).to(device) |
| state_dict = load_file(model_path) |
| state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
|
|
| model.load_state_dict(state_dict) |
|
|
| model.eval() |
| max_tokens = 100 |
| prompt = "You are a helpful assistant. Provide clear, concise, and accurate responses to the user " |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
| output_ids = model.generate( |
| input_ids, max_new_tokens=max_tokens, temperature=5, top_k=10 |
| ) |
| generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| generated_text |
|
|