| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
| import queue |
| import threading |
| import warnings |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from einops import rearrange |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| from transformers.activations import ACT2FN |
| from transformers.modeling_outputs import (BaseModelOutputWithPast, |
| CausalLMOutputWithPast) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import (add_start_docstrings, |
| add_start_docstrings_to_model_forward, logging, |
| replace_return_docstrings) |
| from fused_norm_gate import FusedRMSNormSwishGate |
|
|
| from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined |
| from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
|
|
|
|
| try: |
| from transformers.generation.streamers import BaseStreamer |
| except: |
| BaseStreamer = None |
|
|
| from .configuration_mmMamba import mmMambaConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = 'mmMambaConfig' |
|
|
| flash_attn_func, flash_attn_varlen_func = None, None |
| pad_input, index_first_axis, unpad_input = None, None, None |
| try: |
| from flash_attn import flash_attn_func as _flash_attn_func |
| from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func |
| from flash_attn.bert_padding import index_first_axis as _index_first_axis |
| from flash_attn.bert_padding import pad_input as _pad_input |
| from flash_attn.bert_padding import unpad_input as _unpad_input |
|
|
| flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func |
| pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input |
| has_flash_attn = True |
| except: |
| has_flash_attn = False |
|
|
| try: |
| from flash_attn import flash_attn_with_kvcache |
| except ImportError: |
| flash_attn_with_kvcache = None |
|
|
| try: |
| from flash_attn.layers.rotary import RotaryEmbedding |
| except ImportError: |
| RotaryEmbedding = None |
|
|
| import torch.nn.functional as F |
|
|
| def _update_kv_cache(kv, inference_params, layer_idx): |
| """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" |
| |
| num_heads, head_dim = kv.shape[-2:] |
| assert layer_idx in inference_params.key_value_memory_dict |
| kv_cache, _ = inference_params.key_value_memory_dict[layer_idx] |
| |
| batch_start = inference_params.batch_size_offset |
| batch_end = batch_start + kv.shape[0] |
| sequence_start = inference_params.seqlen_offset |
| sequence_end = sequence_start + kv.shape[1] |
| assert batch_end <= kv_cache.shape[0] |
| assert sequence_end <= kv_cache.shape[1] |
| assert kv_cache is not None |
| kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv |
| return kv_cache[batch_start:batch_end, :sequence_end, ...] |
|
|
| def _import_flash_attn(): |
| global flash_attn_func, flash_attn_varlen_func |
| global pad_input, index_first_axis, unpad_input |
| try: |
| from flash_attn import flash_attn_func as _flash_attn_func |
| from flash_attn import \ |
| flash_attn_varlen_func as _flash_attn_varlen_func |
| from flash_attn.bert_padding import \ |
| index_first_axis as _index_first_axis |
| from flash_attn.bert_padding import pad_input as _pad_input |
| from flash_attn.bert_padding import unpad_input as _unpad_input |
| flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func |
| pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input |
| except ImportError: |
| raise ImportError('flash_attn is not installed.') |
|
|
|
|
| |
| class mmMambaRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| mmMambaRMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
| |
| class mmMambaRotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
| super().__init__() |
|
|
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
| self.register_buffer('inv_freq', inv_freq, persistent=False) |
|
|
| |
| self._set_cos_sin_cache( |
| seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() |
| ) |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) |
|
|
| freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
| |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) |
| self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) |
|
|
| def forward(self, x, seq_len=None): |
| |
| if seq_len > self.max_seq_len_cached: |
| self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32) |
|
|
| return ( |
| self.cos_cached[:seq_len].to(dtype=x.dtype), |
| self.sin_cached[:seq_len].to(dtype=x.dtype), |
| ) |
|
|
|
|
| |
| class mmMambaLinearScalingRotaryEmbedding(mmMambaRotaryEmbedding): |
| """mmMambaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" |
|
|
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): |
| self.scaling_factor = scaling_factor |
| super().__init__(dim, max_position_embeddings, base, device) |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) |
| t = t / self.scaling_factor |
|
|
| freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
| |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) |
| self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) |
|
|
|
|
| |
| class mmMambaDynamicNTKScalingRotaryEmbedding(mmMambaRotaryEmbedding): |
| """mmMambaRotaryEmbedding extended with Dynamic NTK scaling. |
| Credits to the Reddit users /u/bloc97 and /u/emozilla. |
| """ |
|
|
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): |
| self.scaling_factor = scaling_factor |
| super().__init__(dim, max_position_embeddings, base, device) |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
|
|
| if seq_len > self.max_position_embeddings: |
| base = self.base * ( |
| (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) |
| ) ** (self.dim / (self.dim - 2)) |
| inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
| self.register_buffer('inv_freq', inv_freq, persistent=False) |
|
|
| t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) |
|
|
| freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
| |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) |
| self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) |
|
|
|
|
|
|
| class mmMambaMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) |
|
|
| return down_proj |
|
|
|
|
| |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
| def repeat_kv2(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :].expand(batch, num_key_value_heads, n_rep, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, head_dim) |
|
|
|
|
| class MHA_LM(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: mmMambaConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.max_position_embeddings = config.max_position_embeddings |
| self.is_causal = True |
| self.rotary_emb_dim = self.head_dim |
| self.softmax_scale = None |
| self.causal = True |
| |
| if (self.head_dim * self.num_heads) != self.hidden_size: |
| raise ValueError( |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
|
|
| self.wqkv = nn.Linear( |
| self.hidden_size, |
| (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, |
| bias=False, |
| ) |
|
|
| self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
| self.rotary_emb = RotaryEmbedding( |
| self.head_dim, |
| base=self.config.rope_theta, |
| interleaved=False, |
| device=self.wo.weight.device, |
| ) |
|
|
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
| def _update_kv_cache(self, kv, inference_params): |
| """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" |
| assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" |
| return _update_kv_cache(kv, inference_params, self.layer_idx) |
|
|
| def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): |
| """ |
| Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. |
| q: (batch_size, seqlen_q, nheads, head_dim) |
| kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) |
| """ |
| assert inference_params is not None and inference_params.seqlen_offset > 0 |
| if self.rotary_emb_dim > 0: |
| self.rotary_emb._update_cos_sin_cache( |
| inference_params.max_seqlen, device=q.device, dtype=q.dtype |
| ) |
| rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached |
| else: |
| rotary_cos, rotary_sin = None, None |
| batch = q.shape[0] |
| kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx] |
| kv_cache = kv_cache[:batch] |
| cache_seqlens = ( |
| inference_params.lengths_per_sample[:batch] |
| if inference_params.lengths_per_sample is not None |
| else inference_params.seqlen_offset |
| ) |
| assert flash_attn_with_kvcache is not None, "flash_attn must be installed" |
| context = flash_attn_with_kvcache( |
| q, |
| kv_cache[:, :, 0], |
| kv_cache[:, :, 1], |
| kv[:, :, 0], |
| kv[:, :, 1], |
| rotary_cos=rotary_cos, |
| rotary_sin=rotary_sin, |
| cache_seqlens=cache_seqlens, |
| softmax_scale=self.softmax_scale, |
| causal=self.causal, |
| rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, |
| ) |
| return context |
|
|
| def _update_kvcache_attention(self, q, kv, inference_params): |
| """Write kv to inference_params, then do attention""" |
| if ( |
| inference_params.seqlen_offset == 0 |
| or flash_attn_with_kvcache is None |
| ): |
| |
| kv = self._update_kv_cache(kv, inference_params) |
| k, v = kv.unbind(dim=-3) |
| |
| |
| attn_output = flash_attn_func( |
| q, k, v, 0.0, softmax_scale=None, causal=self.causal |
| ) |
| return attn_output |
| else: |
| batch = q.shape[0] |
| kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx] |
| kv_cache = kv_cache[:batch] |
| cache_seqlens = ( |
| inference_params.lengths_per_sample[:batch] |
| if inference_params.lengths_per_sample is not None |
| else inference_params.seqlen_offset |
| ) |
| return flash_attn_with_kvcache( |
| q, |
| kv_cache[:, :, 0], |
| kv_cache[:, :, 1], |
| kv[:, :, 0], |
| kv[:, :, 1], |
| cache_seqlens=cache_seqlens, |
| softmax_scale=self.softmax_scale, |
| causal=self.causal, |
| ) |
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| inference_params = None, |
| output_attentions: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| use_cache: bool = False, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict: |
| inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( |
| hidden_states.shape[0], inference_params.max_seqlen, dtype=hidden_states.dtype |
| ) |
| seqlen_offset = ( |
| 0 |
| if inference_params is None |
| else ( |
| inference_params.lengths_per_sample |
| if inference_params.lengths_per_sample is not None |
| else inference_params.seqlen_offset |
| ) |
| ) |
|
|
| bsz, q_len, _ = hidden_states.size() |
| rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None |
|
|
| qkv = self.wqkv(hidden_states) |
| qkv = rearrange( |
| qkv, |
| "b q (h gs d) -> b q h gs d", |
| gs=2 + self.num_key_value_groups, |
| d=self.head_dim, |
| ) |
|
|
| q = qkv[..., : self.num_key_value_groups, :] |
| q = rearrange(q, "b q h gs d -> b q (h gs) d") |
| kv = qkv[..., self.num_key_value_groups:, :].transpose(2,3) |
| |
| if ( |
| inference_params is None |
| or inference_params.seqlen_offset == 0 |
| or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) |
| ): |
| if self.rotary_emb_dim > 0: |
| q, kv = self.rotary_emb( |
| q, kv, seqlen_offset=seqlen_offset[:bsz,...], max_seqlen=rotary_max_seqlen |
| ) |
| if inference_params is None: |
| k, v = kv.unbind(dim=-3) |
| k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_key_value_heads) |
| v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_key_value_heads) |
| context = F.scaled_dot_product_attention( |
| q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, scale=None |
| ).transpose(1, 2) |
| else: |
| context = self._update_kvcache_attention(q, kv, inference_params) |
| else: |
| context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) |
| context = rearrange(context, "... h d -> ... (h d)") |
| out = self.wo(context) |
| return out |
| |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): |
| dtype = self.wo.weight.dtype if dtype is None else dtype |
| device = self.wo.weight.device |
| kv_cache = torch.empty( |
| batch_size, max_seqlen, 2, self.num_key_value_heads, self.head_dim, dtype=dtype, device=device, |
| ) |
| return kv_cache, None |
|
|
| |
| class Mamba2_LM(nn.Module): |
| """ |
| LoLCATs attention implementation initialized from a |
| `LlamaAttention` or `MistralAttention` object (base_attn) |
| |
| Most of the arguments are directly tied to argparse args |
| - For now we don't support padding. |
| """ |
| def __init__(self, config: mmMambaConfig, layer_idx: Optional[int] = None, |
| elementwise_affine: Optional[bool] = True, |
| norm_eps: float = 1e-5, |
| ): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.max_position_embeddings = config.max_position_embeddings |
| self.layer_idx = layer_idx |
| self.bias = False |
| self.chunk_size = 128 |
| conv_bias = True |
| self.conv_bias = conv_bias |
| self.d_conv = 2 |
| self.activation="silu" |
| self.max_position_embeddings = config.max_position_embeddings |
| self.rope_theta = config.rope_theta |
|
|
| self.wvkqgdt = nn.Linear( |
| self.hidden_size, |
| (self.num_heads + 2 * self.num_key_value_heads + self.num_heads) * self.head_dim + self.num_heads, |
| bias=self.bias |
| ) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
| self.device = self.wvkqgdt.weight.device |
| self.dtype = self.wvkqgdt.weight.dtype |
|
|
| conv_dim = self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim |
| |
| self.conv1d = nn.Conv1d( |
| in_channels=conv_dim, |
| out_channels=conv_dim, |
| bias=self.conv_bias, |
| kernel_size=self.d_conv, |
| groups=conv_dim, |
| padding=self.d_conv - 1, |
| device=self.device, |
| dtype=self.dtype |
| ) |
| with torch.no_grad(): |
| self.conv1d.weight.zero_() |
| self.conv1d.weight[:, 0, 1] = 1 |
| self.conv1d.bias.zero_() |
|
|
| |
| if self.activation == "identity": |
| self.act = nn.Identity() |
| elif self.activation in ["silu", "swish"]: |
| self.act = nn.SiLU() |
| else: |
| raise ValueError(f"Unknown activation {self.activation}") |
| |
| self.g_norm_swish_gate = FusedRMSNormSwishGate(hidden_size=self.head_dim, elementwise_affine=elementwise_affine, eps=norm_eps).to(self.dtype).to(self.device) |
|
|
| dt = torch.exp( |
| torch.rand(self.num_heads, dtype=self.dtype, device=self.device) * (math.log(0.1) - math.log(0.001)) |
| + math.log(0.001) |
| ) |
| dt = torch.clamp(dt, min=0.001) |
| |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| self.dt_bias = nn.Parameter(inv_dt) |
| self.dt_bias._no_weight_decay = True |
| |
| A_log_bias = torch.zeros(self.num_heads, dtype=self.dtype, device=self.device) |
| self.A_log_bias = nn.Parameter(A_log_bias) |
| self.A_log_bias._no_weight_decay = True |
|
|
| def forward(self, |
| hidden_states: torch.Tensor, |
| inference_params = None, |
| output_attentions: bool = False, |
| use_cache: bool = True, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| hidden_states = hidden_states.to(self.dtype) |
| vkqgdt = self.wvkqgdt(hidden_states) |
| vkq, g, dt = torch.split( |
| vkqgdt, |
| [ |
| (2*self.num_key_value_heads+self.num_heads) * self.head_dim, |
| self.num_heads * self.head_dim, |
| self.num_heads, |
| ], |
| dim=2, |
| ) |
| batch, seqlen, _ = hidden_states.shape |
| conv_state, ssm_state = None, None |
| if inference_params is not None: |
| conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) |
| conv_state = conv_state[:batch, ...] |
| ssm_state = ssm_state[:batch, ...] |
| |
| if use_cache and inference_params.seqlen_offset==0: |
| vkq, new_conv_states = causal_conv1d_fn( |
| vkq.transpose(1, 2), |
| rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| self.conv1d.bias, |
| initial_states=None, |
| return_final_states=True, |
| activation=None if self.activation == "identity" else self.activation, |
| ) |
|
|
| v, k, q = torch.split( |
| vkq, |
| [ |
| self.num_key_value_heads * self.head_dim, |
| self.num_key_value_heads * self.head_dim, |
| self.num_heads * self.head_dim, |
| ], |
| dim=1, |
| ) |
|
|
| v = rearrange(v, "b (h n) l -> b h l n", h=self.num_key_value_heads) |
| k = rearrange(k, "b (h n) l -> b h l n", h=self.num_key_value_heads) |
| q = rearrange(q, "b (h n) l -> b l h n", h=self.num_heads) |
| k = repeat_kv(k, self.num_key_value_groups).transpose(1, 2) |
| v = repeat_kv(v, self.num_key_value_groups).transpose(1, 2) |
| |
| A = -torch.exp(self.A_log_bias.float()) |
|
|
| y, new_ssm_states = mamba_chunk_scan_combined( |
| x = v, |
| |
| dt=dt, |
| dt_softplus=True, |
| A=A, |
| B=k, |
| C=q, |
| chunk_size=self.chunk_size, |
| dt_bias=self.dt_bias, |
| initial_states=None, |
| return_final_states=True, |
| ) |
|
|
| conv_state.copy_(new_conv_states) |
| ssm_state.copy_(new_ssm_states) |
|
|
| elif use_cache and inference_params.seqlen_offset>0: |
| |
| vkq = causal_conv1d_update( |
| vkq.transpose(1, 2).squeeze(-1), |
| conv_state, |
| self.conv1d.weight.squeeze(1), |
| self.conv1d.bias, |
| self.activation, |
| ) |
|
|
| v, k, q = torch.split( |
| vkq, |
| [ |
| self.num_key_value_heads * self.head_dim, |
| self.num_key_value_heads * self.head_dim, |
| self.num_heads * self.head_dim, |
| ], |
| dim=1, |
| ) |
|
|
| v = rearrange(v, "b (h n) -> b h n", h=self.num_key_value_heads) |
| k = rearrange(k, "b (h n) -> b h n", h=self.num_key_value_heads) |
| q = rearrange(q, "b (h n) -> b h n", h=self.num_heads) |
| k = repeat_kv2(k, self.num_key_value_groups) |
| v = repeat_kv2(v, self.num_key_value_groups) |
|
|
| dt = dt.transpose(1, 2).squeeze(-1) |
| dt = dt[:, :, None].expand(-1, -1, self.head_dim) |
| dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) |
| A = -torch.exp(self.A_log_bias.float()) |
| A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.head_dim).to(dtype=torch.float32) |
| D = torch.zeros((self.num_heads, self.head_dim), dtype=A.dtype, device=A.device) |
|
|
| y = selective_state_update( |
| ssm_state, |
| v, |
| dt, |
| A=A, |
| B=k, |
| C=q, |
| D=D, |
| dt_bias=dt_bias, |
| dt_softplus=True, |
| ) |
| |
| else: |
| vkq = causal_conv1d_fn( |
| vkq.transpose(1, 2), |
| rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| self.conv1d.bias, |
| initial_states=None, |
| return_final_states=False, |
| activation=None if self.activation == "identity" else self.activation, |
| ) |
|
|
| v, k, q = torch.split( |
| vkq, |
| [ |
| self.num_key_value_heads * self.head_dim, |
| self.num_key_value_heads * self.head_dim, |
| self.num_heads * self.head_dim, |
| ], |
| dim=1, |
| ) |
|
|
| v = rearrange(v, "b (h n) l -> b h l n", h=self.num_key_value_heads) |
| k = rearrange(k, "b (h n) l -> b h l n", h=self.num_key_value_heads) |
| q = rearrange(q, "b (h n) l -> b l h n", h=self.num_heads) |
| k = repeat_kv(k, self.num_key_value_groups).transpose(1, 2) |
| v = repeat_kv(v, self.num_key_value_groups).transpose(1, 2) |
| |
| A = -torch.exp(self.A_log_bias.float()) |
|
|
| y = mamba_chunk_scan_combined( |
| x = v, |
| dt=dt, |
| dt_softplus=True, |
| A=A, |
| B=k, |
| C=q, |
| chunk_size=self.chunk_size, |
| dt_bias=self.dt_bias, |
| initial_states=None, |
| return_final_states=False, |
| ) |
| |
| g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads) |
| y_true = self.g_norm_swish_gate(y, g) |
| y_true = y_true.view(batch, seqlen, self.hidden_size) |
| y_true = self.o_proj(y_true) |
|
|
| return y_true |
| |
| def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): |
| device = self.conv1d.weight.device |
| dtype = self.conv1d.weight.dtype |
| assert self.layer_idx is not None |
| if self.layer_idx not in inference_params.key_value_memory_dict: |
| batch_shape = (batch_size,) |
| conv_state = torch.zeros( |
| batch_size, 2*self.hidden_size, self.d_conv-1, device=device, dtype=dtype |
| ) |
| ssm_state = torch.zeros( |
| batch_size, self.num_heads, self.head_dim, self.head_dim, device=device, dtype=dtype |
| ) |
| inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) |
| else: |
| conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] |
| |
| if initialize_states: |
| conv_state.zero_() |
| ssm_state.zero_() |
| return conv_state, ssm_state |
|
|
| |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| device = self.conv1d.weight.device |
| dtype = self.conv1d.weight.dtype |
| conv_state = torch.zeros( |
| batch_size, 2*self.hidden_size, self.d_conv-1, device=device, dtype=dtype |
| ) |
|
|
| ssm_state = torch.zeros( |
| batch_size, self.num_heads, self.head_dim, self.head_dim, device=device, dtype=dtype |
| ) |
| return conv_state, ssm_state |
| |
|
|
| mmMamba_ATTENTION_CLASSES = { |
| 'mha': MHA_LM, |
| "mamba2":Mamba2_LM |
| } |
|
|
|
|
| |
| class mmMambaDecoderLayer(nn.Module): |
| def __init__(self, config: mmMambaConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.layer_idx = layer_idx |
| self.attention = mmMamba_ATTENTION_CLASSES[config.layers_block_type[layer_idx-8]](config=config, layer_idx=layer_idx) |
|
|
| self.feed_forward = mmMambaMLP(config) |
| self.attention_norm = mmMambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.ffn_norm = mmMambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| inference_params = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = True, |
| **kwargs, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| """ |
| |
| residual = hidden_states |
|
|
| hidden_states = self.attention_norm(hidden_states) |
|
|
| |
| hidden_states = self.attention( |
| hidden_states=hidden_states, |
| inference_params=inference_params, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.ffn_norm(hidden_states) |
| hidden_states = self.feed_forward(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += self_attn_weights |
|
|
| |
| |
| return outputs |
| |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| return self.attention.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
|
|
| mmMamba_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| Parameters: |
| config ([`mmMambaConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
|
|
| |
| @add_start_docstrings( |
| 'The bare mmMamba Model outputting raw hidden-states without any specific head on top.', |
| mmMamba_START_DOCSTRING, |
| ) |
| class mmMambaPreTrainedModel(PreTrainedModel): |
| config_class = mmMambaConfig |
| base_model_prefix = 'model' |
| supports_gradient_checkpointing = True |
| _no_split_modules = ['mmMambaDecoderLayer'] |
| _skip_keys_device_placement = 'past_key_values' |
| _supports_flash_attn_2 = True |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| mmMamba_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| it. |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| [What are input IDs?](../glossary#input-ids) |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| |
| @add_start_docstrings( |
| 'The bare mmMamba Model outputting raw hidden-states without any specific head on top.', |
| mmMamba_START_DOCSTRING, |
| ) |
| class mmMambaModel(mmMambaPreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`mmMambaDecoderLayer`] |
| Args: |
| config: mmMambaConfig |
| """ |
|
|
| _auto_class = 'AutoModel' |
|
|
| def __init__(self, config: mmMambaConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
| self.config = config |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
| self.layers = nn.ModuleList([mmMambaDecoderLayer(config, (layer_idx+8)) for layer_idx in range(config.num_hidden_layers)]) |
| self.norm = mmMambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.tok_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.tok_embeddings = value |
|
|
| @add_start_docstrings_to_model_forward(mmMamba_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| inference_params=None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = True, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if self.config.attn_implementation == 'flash_attention_2': |
| _import_flash_attn() |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape[:2] |
| elif inputs_embeds is not None: |
| batch_size, seq_length = inputs_embeds.shape[:2] |
| else: |
| raise ValueError('You have to specify either input_ids or inputs_embeds') |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.tok_embeddings(input_ids) |
|
|
| |
| hidden_states = inputs_embeds |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' |
| ) |
| use_cache = False |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = () if use_cache else None |
|
|
| for idx, decoder_layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| |
| return module(*inputs, output_attentions, None) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(decoder_layer), |
| hidden_states, |
| inference_params, |
| None, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| inference_params=inference_params, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
|
|
| if output_attentions: |
| all_self_attns += layer_outputs[1] |
| |
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = None |
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
| |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| return { |
| layer.layer_idx: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
| for layer in self.layers |
| } |
|
|
| |
| class mmMambaForCausalLM(mmMambaPreTrainedModel): |
| _auto_class = 'AutoModelForCausalLM' |
|
|
| _tied_weights_keys = ['output.weight'] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = mmMambaModel(config) |
| self.vocab_size = config.vocab_size |
| self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.tok_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.model.tok_embeddings = value |
|
|
| def get_output_embeddings(self): |
| return self.output |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.output = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| @add_start_docstrings_to_model_forward(mmMamba_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| inference_params=None, |
| num_last_tokens=0, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = True, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| Returns: |
| Example: |
| ```python |
| >>> from transformers import AutoTokenizer, mmMambaForCausalLM |
| >>> model = mmMambaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
| >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| inference_params=inference_params, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| |
| if num_last_tokens > 0: |
| hidden_states = hidden_states[:, -num_last_tokens:] |
| |
| logits = self.output(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
| output = CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
| output['logits'] = output['logits'].to(device) |
| return output |
| |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| return self.model.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
| ) |
| return reordered_past |
|
|
|
|
| @torch.no_grad() |
| def stream_chat( |
| self, |
| tokenizer, |
| query: str, |
| history: List[Tuple[str, str]] = [], |
| max_new_tokens: int = 1024, |
| do_sample: bool = True, |
| temperature: float = 0.8, |
| top_p: float = 0.8, |
| **kwargs, |
| ): |
| """ |
| Return a generator in format: (response, history) |
| Eg. |
| ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) |
| ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) |
| """ |
| if BaseStreamer is None: |
| raise ModuleNotFoundError( |
| 'The version of `transformers` is too low. Please make sure ' |
| 'that you have installed `transformers>=4.28.0`.' |
| ) |
|
|
| response_queue = queue.Queue(maxsize=20) |
|
|
| class ChatStreamer(BaseStreamer): |
| def __init__(self, tokenizer) -> None: |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.queue = response_queue |
| self.query = query |
| self.history = history |
| self.response = '' |
| self.cache = [] |
| self.received_inputs = False |
| self.queue.put((self.response, history + [(self.query, self.response)])) |
|
|
| def put(self, value): |
| if len(value.shape) > 1 and value.shape[0] > 1: |
| raise ValueError('ChatStreamer only supports batch size 1') |
| elif len(value.shape) > 1: |
| value = value[0] |
|
|
| if not self.received_inputs: |
| |
| self.received_inputs = True |
| return |
|
|
| self.cache.extend(value.tolist()) |
| token = self.tokenizer.decode(self.cache, skip_special_tokens=True) |
| if token.strip() != '<|im_end|>': |
| self.response = self.response + token |
| history = self.history + [(self.query, self.response)] |
| self.queue.put((self.response, history)) |
| self.cache = [] |
| else: |
| self.end() |
|
|
| def end(self): |
| self.queue.put(None) |
|
|
| def stream_producer(): |
| return self.chat( |
| tokenizer=tokenizer, |
| query=query, |
| streamer=ChatStreamer(tokenizer=tokenizer), |
| history=history, |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| temperature=temperature, |
| top_p=top_p, |
| **kwargs, |
| ) |
|
|
| def consumer(): |
| producer = threading.Thread(target=stream_producer) |
| producer.start() |
| while True: |
| res = response_queue.get() |
| if res is None: |
| return |
| yield res |
|
|
| return consumer() |
|
|
|
|