| import torch.nn as nn |
| import torch |
| from .configuration_mamba import MambaConfig |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
| import math |
| import json |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from einops import rearrange, repeat, einsum |
| from typing import Optional , Union ,Tuple |
|
|
| |
|
|
|
|
| class MambaRMSNorm(nn.Module): |
| def __init__(self, |
| d_model: int, |
| eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(d_model)) |
| def forward(self, x): |
| output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
| return output |
| |
|
|
| class MambaBlock(nn.Module): |
| def __init__(self, config: MambaConfig): |
| """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" |
| super().__init__() |
| self.config = config |
|
|
| self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) |
|
|
| self.conv1d = nn.Conv1d( |
| in_channels=config.d_inner, |
| out_channels=config.d_inner, |
| bias=config.conv_bias, |
| kernel_size=config.d_conv, |
| groups=config.d_inner, |
| padding=config.d_conv - 1, |
| ) |
|
|
| |
| self.x_proj = nn.Linear(config.d_inner, config.dt_rank + config.d_state * 2, bias=False) |
| |
| |
| self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) |
|
|
| A = repeat(torch.arange(1, config.d_state + 1), 'n -> d n', d=config.d_inner) |
| self.A_log = nn.Parameter(torch.log(A)) |
| self.D = nn.Parameter(torch.ones(config.d_inner)) |
| self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) |
| self.norm = MambaRMSNorm(config.d_model) |
|
|
| def forward(self, x): |
| """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. |
| |
| Args: |
| x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) |
| |
| Returns: |
| output: shape (b, l, d) |
| |
| Official Implementation: |
| class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 |
| mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 |
| |
| """ |
|
|
| (b, l, d) = x.shape |
| x_copy = x |
| x = self.norm(x) |
| x_and_res = self.in_proj(x) |
| (x, res) = x_and_res.split(split_size=[self.config.d_inner, self.config.d_inner], dim=-1) |
|
|
| x = rearrange(x, 'b l d_in -> b d_in l') |
| x = self.conv1d(x)[:, :, :l] |
| x = rearrange(x, 'b d_in l -> b l d_in') |
| |
| x = F.silu(x) |
|
|
| y = self.ssm(x) |
| |
| y = y * F.silu(res) |
| |
| output = self.out_proj(y) + x_copy |
|
|
| return output |
|
|
| |
| def ssm(self, x): |
| """Runs the SSM. See: |
| - Algorithm 2 in Section 3.2 in the Mamba paper [1] |
| - run_SSM(A, B, C, u) in The Annotated S4 [2] |
| |
| Args: |
| x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) |
| |
| Returns: |
| output: shape (b, l, d_in) |
| |
| Official Implementation: |
| mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 |
| |
| """ |
| (d_in, n) = self.A_log.shape |
|
|
| |
| |
| |
| |
| |
| A = -torch.exp(self.A_log.float()) |
| D = self.D.float() |
|
|
| x_dbl = self.x_proj(x) |
| |
| (delta, B, C) = x_dbl.split(split_size=[self.config.dt_rank, n, n], dim=-1) |
| delta = F.softplus(self.dt_proj(delta)) |
| |
| y = self.selective_scan(x, delta, A, B, C, D) |
| |
| return y |
|
|
| |
| def selective_scan(self, u, delta, A, B, C, D): |
| """Does selective scan algorithm. See: |
| - Section 2 State Space Models in the Mamba paper [1] |
| - Algorithm 2 in Section 3.2 in the Mamba paper [1] |
| - run_SSM(A, B, C, u) in The Annotated S4 [2] |
| |
| This is the classic discrete state space formula: |
| x(t + 1) = Ax(t) + Bu(t) |
| y(t) = Cx(t) + Du(t) |
| except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). |
| |
| Args: |
| u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) |
| delta: shape (b, l, d_in) |
| A: shape (d_in, n) |
| B: shape (b, l, n) |
| C: shape (b, l, n) |
| D: shape (d_in,) |
| |
| Returns: |
| output: shape (b, l, d_in) |
| |
| Official Implementation: |
| selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 |
| Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. |
| |
| """ |
| (b, l, d_in) = u.shape |
| n = A.shape[1] |
| |
| |
| |
| |
| |
| deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) |
| deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') |
| |
| |
| x = torch.zeros((b, d_in, n), device=deltaA.device) |
| ys = [] |
| for i in range(l): |
| x = deltaA[:, :, i] * x + deltaB_u[:, :, i] |
| y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') |
| ys.append(y) |
| y = torch.stack(ys, dim=1) |
| |
| y = y + u * D |
| |
| return y |
| |
| class MambaPreTrainedModel(PreTrainedModel): |
| config_class = MambaConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["MambaBlock"] |
|
|
| def _init_weights(self, module): |
| std = 0.02 |
| if isinstance(module, (nn.Linear, nn.Conv1d)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| class MambaModel(MambaPreTrainedModel): |
| def __init__(self, config: MambaConfig): |
| """Full Mamba model. |
| Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`] |
| |
| Args: |
| config: MambaConfig |
| """ |
| super().__init__(config) |
| self.config = config |
| |
| self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
| self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)]) |
| self.norm_f = MambaRMSNorm(config.d_model) |
|
|
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embedding |
|
|
| def set_input_embeddings(self, value): |
| self.embedding = value |
|
|
| def forward(self, |
| input_ids: torch.LongTensor = None, |
| return_dict: Optional[bool] = None, |
| )-> Union[Tuple, BaseModelOutputWithPast]: |
| x = self.embedding(input_ids) |
| all_hidden_states = list() |
| for layer in self.layers: |
| x = layer(x) |
| all_hidden_states.append(x) |
| |
| hidden_states = self.norm_f(x) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| ) |
| class MambaForCausalLM(MambaPreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = MambaModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.lm_head.weight = self.model.embedding.weight |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embedding |
|
|
| def set_input_embeddings(self, value): |
| self.model.embedding = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
| |
| def forward(self, |
| input_ids: torch.LongTensor = None, |
| labels: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| )-> Union[Tuple, CausalLMOutputWithPast]: |
| outputs = self.model( |
| input_ids=input_ids, |
| return_dict=return_dict, |
| ) |
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| ) |
| |
| def prepare_inputs_for_generation( |
| self, input_ids, **kwargs |
| ): |
| model_inputs = {"input_ids": input_ids} |
| return model_inputs |
|
|
|
|
|
|