llama3-edge / modeling_llama_edge.py
crab27's picture
Upload folder using huggingface_hub
cd53fcc verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from transformers import PreTrainedModel
try:
from .configuration_llama_edge import LlamaEdgeConfig
except ImportError:
from configuration_llama_edge import LlamaEdgeConfig
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return output * self.weight
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# Precompute complex exponentials for Rotary Positional Embeddings (RoPE)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]):
super().__init__()
# If the config provides a specific hidden_dim (intermediate_size), use it directly.
# Otherwise, calculate it using the standard Llama formula.
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# In Llama 3 8B, this will now be 14336
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Attention(nn.Module):
def __init__(self, config: LlamaEdgeConfig):
super().__init__()
self.n_heads = config.n_heads
self.n_kv_heads = config.n_kv_heads
self.head_dim = config.dim // config.n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(config.dim, config.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(config.dim, config.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(config.dim, config.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
def forward(self, x, freqs_cis, mask=None):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# Reshape for multi-head attention
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
# Apply RoPE
# xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
# Repeat K and V heads for GQA (if n_kv_heads < n_heads)
if self.n_rep > 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, self.n_rep, 1).reshape(bsz, seqlen, self.n_heads, self.head_dim)
xv = xv.unsqueeze(3).repeat(1, 1, 1, self.n_rep, 1).reshape(bsz, seqlen, self.n_heads, self.head_dim)
# Transpose for attention calculation: (bsz, heads, seqlen, dim)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# Scaled Dot-Product Attention
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
# if mask is not None:
# scores = scores + mask # Apply causal mask
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, xv)
# Reshape back
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, config: LlamaEdgeConfig):
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(
dim=config.dim,
hidden_dim=config.intermediate_size,
multiple_of=config.multiple_of,
ffn_dim_multiplier=config.ffn_dim_multiplier,
)
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
def forward(self, x, freqs_cis, mask=None):
h = x + self.attention(self.attention_norm(x), freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class LlamaEdgeForCausalLM(PreTrainedModel):
config_class = LlamaEdgeConfig
def __init__(self, config: LlamaEdgeConfig):
super().__init__(config)
self.token_embedding = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList([TransformerBlock(i, config) for i in range(config.n_layers)])
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
# Precompute RoPE frequencies
self.freqs_cis = precompute_freqs_cis(
config.dim // config.n_heads, config.max_seq_len, config.rope_theta,
)
def forward(self, x):
bsz, seqlen = x.shape
freqs_cis = self.freqs_cis[:seqlen].to(x.device)
# Create causal mask
mask = torch.full((seqlen, seqlen), float("-inf"), device=x.device)
mask = torch.triu(mask, diagonal=1)
h = self.token_embedding(x)
for layer in self.layers:
h = layer(h, freqs_cis, mask)
h = self.norm(h)
logits = self.output(h)
return logits