| import json |
| import torch |
| import torch.nn as nn |
| from transformers import GPT2Tokenizer |
|
|
|
|
| def load_tinylm(model_dir, device="cpu"): |
| |
| with open(f"{model_dir}/config.json") as f: |
| config = json.load(f) |
|
|
| VOCAB_SIZE = config["vocab_size"] |
| EMBED_RANK = config["embed_rank"] |
| D_MODEL = config["d_model"] |
| N_HEADS = config["n_heads"] |
| FFN_DIM = config["ffn_dim"] |
| N_LAYERS = config["n_layers"] |
| MAX_SEQ_LEN = config["max_seq_len"] |
| DROPOUT = config["dropout"] |
|
|
| class FactoredEmbedding(nn.Module): |
| def __init__(self, vocab_size, rank, d_model): |
| super().__init__() |
| self.in_proj = nn.Embedding(vocab_size, rank) |
| self.out_proj = nn.Linear(rank, d_model, bias=False) |
|
|
| def forward(self, x): |
| return self.out_proj(self.in_proj(x)) |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(D_MODEL) |
| self.attn = nn.MultiheadAttention(D_MODEL, N_HEADS, dropout=DROPOUT, batch_first=True) |
| self.ln2 = nn.LayerNorm(D_MODEL) |
| self.ffn = nn.Sequential( |
| nn.Linear(D_MODEL, FFN_DIM), |
| nn.GELU(), |
| nn.Linear(FFN_DIM, D_MODEL), |
| nn.Dropout(DROPOUT), |
| ) |
|
|
| def forward(self, x, attn_mask=None, key_padding_mask=None): |
| x_norm = self.ln1(x) |
| attn_out, _ = self.attn(x_norm, x_norm, x_norm, |
| attn_mask=attn_mask, |
| key_padding_mask=key_padding_mask, |
| is_causal=True) |
| x = x + attn_out |
| x = x + self.ffn(self.ln2(x)) |
| return x |
|
|
| class TinyLM(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tok_emb = FactoredEmbedding(VOCAB_SIZE, EMBED_RANK, D_MODEL) |
| self.pos_emb = nn.Embedding(MAX_SEQ_LEN, D_MODEL) |
| self.drop = nn.Dropout(DROPOUT) |
| self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)]) |
| self.ln_final = nn.LayerNorm(D_MODEL) |
| self.head_down = nn.Linear(D_MODEL, EMBED_RANK, bias=False) |
| self.head_vocab = nn.Linear(EMBED_RANK, VOCAB_SIZE, bias=False) |
| self.head_vocab.weight = nn.Parameter(self.tok_emb.in_proj.weight) |
|
|
| def forward(self, idx): |
| B, T = idx.shape |
| if T > MAX_SEQ_LEN: |
| idx = idx[:, :MAX_SEQ_LEN] |
| T = idx.shape[1] |
| positions = torch.arange(T, device=idx.device).unsqueeze(0) |
| x = self.drop(self.tok_emb(idx) + self.pos_emb(positions)) |
| mask = nn.Transformer.generate_square_subsequent_mask(T, device=idx.device) |
| for block in self.blocks: |
| x = block(x, attn_mask=mask) |
| x = self.ln_final(x) |
| x = self.head_down(x) |
| return self.head_vocab(x) |
|
|
| |
| model = TinyLM().to(device) |
| state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location=device) |
| model.load_state_dict(state_dict) |
| model.eval() |
|
|
| |
| tokenizer = GPT2Tokenizer.from_pretrained(model_dir) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| return model, tokenizer, config |
|
|
|
|
| def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.1, top_k=25, device="cpu"): |
| MAX_SEQ_LEN = model.pos_emb.num_embeddings |
| model.eval() |
| ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| idx_cond = ids[:, -MAX_SEQ_LEN:] |
| logits = model(idx_cond) |
| logits = logits[:, -1, :] / temperature |
| if top_k is not None: |
| values, _ = torch.topk(logits, top_k) |
| logits[logits < values[:, -1:]] = -float("inf") |
| probs = torch.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| if next_id.item() == tokenizer.eos_token_id: |
| break |
| ids = torch.cat([ids, next_id], dim=1) |
|
|
| return tokenizer.decode(ids[0], skip_special_tokens=True) |
|
|
|
|
| if __name__ == "__main__": |
| model, tokenizer, config = load_tinylm("./tinylm") |
| print("Model loaded!") |
| print("Use 'module.generate(model, tokenizer, \"Once upon a time\")' to generate.") |
|
|