""" TinyBuddy-500K: Educational ~500K parameter Llama-style model MIT License """ from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast @dataclass class TinyBuddyConfig(PretrainedConfig): model_type = "tinybuddy" vocab_size: int = 2048 hidden_size: int = 96 num_hidden_layers: int = 2 num_attention_heads: int = 4 num_key_value_heads: int = 2 intermediate_size: int = 384 max_position_embeddings: int = 512 rms_norm_eps: float = 1e-6 tie_word_embeddings: bool = True bos_token_id: int = 2 eos_token_id: int = 2 def __init__(self, **kwargs): super().__init__(**kwargs) for k, v in kwargs.items(): setattr(self, k, v) class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return self.weight * x class GroupedQueryAttention(nn.Module): def __init__(self, config): super().__init__() self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) def forward(self, x): B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) attn = F.softmax(scores, dim=-1) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_dim) return self.o_proj(out) class MLP(nn.Module): def __init__(self, config): super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class DecoderLayer(nn.Module): def __init__(self, config): super().__init__() self.self_attn = GroupedQueryAttention(config) self.mlp = MLP(config) self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) def forward(self, x): residual = x x = self.input_layernorm(x) x = self.self_attn(x) x = residual + x residual = x x = self.post_attention_layernorm(x) x = self.mlp(x) x = residual + x return x class TinyBuddyForCausalLM(PreTrainedModel): config_class = TinyBuddyConfig base_model_prefix = "tinybuddy" def __init__(self, config): super().__init__(config) self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight self.post_init() def forward(self, input_ids, labels=None, **kwargs): x = self.embed_tokens(input_ids) for layer in self.layers: x = layer(x) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits) @torch.no_grad() def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=50, **kwargs): for _ in range(max_new_tokens): logits = self(input_ids).logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("Inf") probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) return input_ids TinyBuddyForCausalLM.register_for_auto_class("AutoModelForCausalLM")