| import math
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.nn import functional as F
|
|
|
| from transformers import PreTrainedModel, PretrainedConfig
|
| from transformers.generation import GenerationMixin
|
| from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithCrossAttentions
|
|
|
| class BVVConfig(PretrainedConfig):
|
| model_type = "model_n_embed_16_binary_n_layer_32"
|
|
|
| def __init__(
|
| self,
|
| vocab_size=65536,
|
| n_embed=16,
|
| d_model=1024,
|
| n_head=32,
|
| n_layer=32,
|
| block_size=1024,
|
| dropout=0.00,
|
| layer_norm_eps=1e-5,
|
| initializer_range=0.02,
|
| pad_token_id=57344,
|
| pad_id=57344,
|
| bos_token_id=None,
|
| eos_token_id=None,
|
| tie_word_embeddings=False,
|
| use_cache=False,
|
| **kwargs,
|
| ):
|
| if pad_token_id is None:
|
| pad_token_id = 57344 if pad_id is None else pad_id
|
|
|
| super().__init__(
|
| pad_token_id=pad_token_id,
|
| bos_token_id=bos_token_id,
|
| eos_token_id=eos_token_id,
|
| tie_word_embeddings=tie_word_embeddings,
|
| use_cache=use_cache,
|
| **kwargs,
|
| )
|
|
|
| if d_model % n_embed != 0:
|
| raise ValueError(f"d_model ({d_model}) must be divisible by n_embed ({n_embed})")
|
| if d_model % n_head != 0:
|
| raise ValueError(f"d_model ({d_model}) must be divisible by n_head ({n_head})")
|
| if (d_model // n_head) % 2 != 0:
|
| raise ValueError("head_dim must be even for rotary embeddings")
|
|
|
| self.vocab_size = vocab_size
|
| self.block_size = block_size
|
| self.max_position_embeddings = block_size
|
|
|
| self.n_embed = n_embed
|
| self.d_model = d_model
|
| self.n_head = n_head
|
| self.n_layer = n_layer
|
|
|
| self.dropout = dropout
|
| self.layer_norm_eps = layer_norm_eps
|
| self.initializer_range = initializer_range
|
|
|
| self.scale = d_model // n_embed
|
|
|
|
|
| self.pad_id = pad_token_id
|
|
|
|
|
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| t = torch.arange(end, device=freqs.device)
|
| freqs = torch.outer(t, freqs).float()
|
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| return freqs_cis
|
|
|
|
|
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| ndim = x.ndim
|
| assert 0 <= 1 < ndim
|
| assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| return freqs_cis.view(*shape)
|
|
|
|
|
| def apply_rotary_emb(
|
| xq: torch.Tensor,
|
| xk: torch.Tensor,
|
| freqs_cis: torch.Tensor,
|
| ):
|
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
|
|
| class MultiHeadSelfAttention(nn.Module):
|
| def __init__(self, d_model, n_head, dropout=0.0):
|
| super().__init__()
|
| assert d_model % n_head == 0
|
|
|
| self.d_model = d_model
|
| self.n_head = n_head
|
| self.head_dim = d_model // n_head
|
|
|
| assert self.head_dim % 2 == 0, "head_dim must be even for rotary embeddings"
|
|
|
| self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(self, x, freqs_cis, mask=None):
|
| B, T, C = x.shape
|
|
|
| q = self.q_proj(x).view(B, T, self.n_head, self.head_dim)
|
| k = self.k_proj(x).view(B, T, self.n_head, self.head_dim)
|
| v = self.v_proj(x).view(B, T, self.n_head, self.head_dim)
|
|
|
| q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
|
|
|
| q = q.transpose(1, 2)
|
| k = k.transpose(1, 2)
|
| v = v.transpose(1, 2)
|
|
|
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
| if mask is not None:
|
| attn_scores = attn_scores + mask
|
|
|
| attn_probs = F.softmax(attn_scores.float(), dim=-1).type_as(q)
|
| attn_probs = self.dropout(attn_probs)
|
|
|
| out = torch.matmul(attn_probs, v)
|
| out = out.transpose(1, 2).contiguous().view(B, T, C)
|
|
|
| return self.o_proj(out)
|
|
|
|
|
| class TransformerMLP(nn.Module):
|
| def __init__(self, d_model, dropout=0.0):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Linear(d_model, 4 * d_model),
|
| nn.GELU(),
|
| nn.Linear(4 * d_model, d_model),
|
| nn.Dropout(dropout),
|
| )
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
| class TransformerBlock(nn.Module):
|
| def __init__(self, d_model, n_head, dropout=0.0, layer_norm_eps=1e-5):
|
| super().__init__()
|
| self.self_attn = MultiHeadSelfAttention(d_model, n_head, dropout=dropout)
|
| self.mlp = TransformerMLP(d_model, dropout=dropout)
|
| self.input_layernorm = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| self.post_attention_layernorm = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
|
| def forward(self, x, freqs_cis, mask=None):
|
| x = x + self.self_attn(self.input_layernorm(x), freqs_cis, mask)
|
| x = x + self.mlp(self.post_attention_layernorm(x))
|
| return x
|
|
|
|
|
| class BVVForCausalLM(PreTrainedModel, GenerationMixin):
|
| config_class = BVVConfig
|
| main_input_name = "input_ids"
|
|
|
| def __init__(self, config: BVVConfig):
|
| super().__init__(config)
|
|
|
| self.token_embeddings = nn.Embedding(
|
| config.vocab_size,
|
| config.n_embed,
|
| padding_idx=config.pad_token_id,
|
| )
|
| self.scale = config.scale
|
|
|
| self.transformer_layers = nn.ModuleList([
|
| TransformerBlock(
|
| config.d_model,
|
| n_head=config.n_head,
|
| dropout=config.dropout,
|
| layer_norm_eps=config.layer_norm_eps,
|
| )
|
| for _ in range(config.n_layer)
|
| ])
|
|
|
| self.final_layernorm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
| self.lm_head = nn.Linear(config.d_model, config.vocab_size)
|
|
|
| self.register_buffer(
|
| "freqs_cis",
|
| precompute_freqs_cis(
|
| config.d_model // config.n_head,
|
| config.block_size,
|
| ),
|
| persistent=False,
|
| )
|
|
|
| self.post_init()
|
|
|
| def _init_weights(self, module):
|
| if isinstance(module, nn.Linear):
|
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
|
|
| elif isinstance(module, nn.Embedding):
|
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| if module.padding_idx is not None:
|
| module.weight.data[module.padding_idx].zero_()
|
|
|
| def get_input_embeddings(self):
|
| return self.token_embeddings
|
|
|
| def set_input_embeddings(self, value):
|
| self.token_embeddings = value
|
|
|
| def get_output_embeddings(self):
|
| return self.lm_head
|
|
|
| def set_output_embeddings(self, new_embeddings):
|
| self.lm_head = new_embeddings
|
|
|
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
|
| if input_ids.shape[1] > self.config.block_size:
|
| input_ids = input_ids[:, -self.config.block_size:]
|
| if attention_mask is not None:
|
| attention_mask = attention_mask[:, -self.config.block_size:]
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| }
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| attention_mask=None,
|
| labels=None,
|
| targets=None,
|
| return_dict=None,
|
| output_logits=True,
|
| **kwargs,
|
| ):
|
| if input_ids is None:
|
| raise ValueError("input_ids must be provided")
|
|
|
| if labels is not None and targets is not None:
|
| raise ValueError("Use either labels or targets, not both.")
|
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
| B, T = input_ids.shape
|
| if T > self.config.block_size:
|
| raise ValueError(f"Sequence length {T} exceeds block_size {self.config.block_size}")
|
|
|
| token_emb = self.token_embeddings(input_ids)
|
| x = token_emb.repeat(1, 1, self.scale)
|
|
|
| freqs_cis = self.freqs_cis[:T]
|
| if not torch.is_complex(freqs_cis):
|
| freqs_cis = torch.view_as_complex(freqs_cis.contiguous())
|
| freqs_cis = freqs_cis.to(x.device)
|
|
|
| mask = None
|
| mask_value = torch.finfo(x.dtype).min
|
|
|
| if T > 1:
|
| mask = torch.full((1, 1, T, T), mask_value, device=x.device, dtype=x.dtype)
|
| mask = torch.triu(mask, diagonal=1)
|
|
|
| if attention_mask is not None:
|
| if attention_mask.shape != (B, T):
|
| raise ValueError(f"attention_mask must have shape {(B, T)}, got {tuple(attention_mask.shape)}")
|
| pad_mask = torch.zeros((B, 1, 1, T), device=x.device, dtype=x.dtype)
|
| pad_mask = pad_mask.masked_fill(attention_mask[:, None, None, :].eq(0), mask_value)
|
| mask = pad_mask if mask is None else mask + pad_mask
|
|
|
| for layer in self.transformer_layers:
|
| x = layer(x, freqs_cis, mask)
|
|
|
| x = self.final_layernorm(x)
|
| logits = self.lm_head(x)
|
|
|
| loss = None
|
|
|
| if labels is not None:
|
| shift_logits = logits[:, :-1, :].contiguous()
|
| shift_labels = labels[:, 1:].contiguous()
|
|
|
| if attention_mask is not None:
|
| shift_labels = shift_labels.masked_fill(attention_mask[:, 1:].eq(0), -100)
|
|
|
| if self.config.pad_token_id is not None:
|
| shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_token_id, -100)
|
|
|
| loss = F.cross_entropy(
|
| shift_logits.float().view(-1, shift_logits.size(-1)),
|
| shift_labels.view(-1),
|
| ignore_index=-100,
|
| )
|
|
|
| elif targets is not None:
|
| legacy_targets = targets.contiguous()
|
|
|
| if attention_mask is not None:
|
| legacy_targets = legacy_targets.masked_fill(attention_mask.eq(0), -100)
|
|
|
| if self.config.pad_token_id is not None:
|
| legacy_targets = legacy_targets.masked_fill(legacy_targets == self.config.pad_token_id, -100)
|
|
|
| loss = F.cross_entropy(
|
| logits.float().view(-1, logits.size(-1)),
|
| legacy_targets.view(-1),
|
| ignore_index=-100,
|
| )
|
|
|
| if not return_dict:
|
| if output_logits:
|
| output = (logits,)
|
| return ((loss,) + output) if loss is not None else output
|
| return (loss,) if loss is not None else tuple()
|
|
|
| if output_logits:
|
| return CausalLMOutput(loss=loss, logits=logits)
|
| return CausalLMOutput(loss=loss, logits=None)
|
|
|
| def generate(self, input_ids, max_new_tokens, attention_mask=None, do_sample=False):
|
| was_training = self.training
|
| self.eval()
|
|
|
| if attention_mask is None:
|
| attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
|
|
| with torch.no_grad():
|
| for _ in range(max_new_tokens):
|
| input_ids_cond = input_ids[:, -self.config.block_size:]
|
| attention_mask_cond = attention_mask[:, -self.config.block_size:]
|
|
|
| outputs = self(
|
| input_ids=input_ids_cond,
|
| attention_mask=attention_mask_cond,
|
| return_dict=True
|
| )
|
| logits = outputs.logits[:, -1, :]
|
|
|
| if do_sample:
|
| probs = F.softmax(logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
| else:
|
| next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
|
|
| input_ids = torch.cat([input_ids, next_token], dim=1)
|
| attention_mask = torch.cat(
|
| [attention_mask, torch.ones_like(next_token, dtype=attention_mask.dtype)],
|
| dim=1
|
| )
|
|
|
| if was_training:
|
| self.train()
|
|
|
| return input_ids |