| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from enum import Enum |
| from dataclasses import dataclass, field |
| from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states |
| from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined |
|
|
| from .causal_conv1d_compilable import causal_conv1d_fn, causal_conv1d_update |
| from .ssm_compilable import mamba_chunk_scan_combined |
| from .norms import build_norm |
|
|
|
|
| class InitStdFactor(Enum): |
| DISABLED = "disabled" |
| GLOBAL_DEPTH = "global_depth" |
| CURRENT_DEPTH = "current_depth" |
| DIM_RATIO = "dim_ratio" |
|
|
|
|
| @dataclass |
| class InitConfig: |
| dt_max: float = 0.1 |
| dt_min: float = 0.001 |
|
|
| dt_init_floor: float = 1e-4 |
|
|
| A_init_min: float = 1 |
| A_init_max: float = 16 |
|
|
|
|
| DEFAULT_INIT_CONFIG = InitConfig() |
|
|
|
|
| @dataclass |
| class BaseMambaConfig: |
| """ |
| Configuration for the Mamba family of models. |
| """ |
| dim: int = 512 |
| num_layers: int = 8 |
| num_heads: int = 8 |
|
|
| state_dim: int = 128 |
| num_groups: int = 1 |
| conv_size: int | None = 4 |
|
|
| bias: bool = False |
| conv_bias: bool = True |
| dt_bias: bool = False |
| D_has_head_dim: bool = False |
| learnable_init_states: bool = False |
|
|
| ffn_dim_multiplier: float = 2.0 |
| multiple_of: int = 256 |
|
|
| norm_eps: float = 1e-6 |
| norm_type: str = "rmsnorm" |
| |
| |
| ssm_chunk_size: int = 256 |
| use_mem_eff_path: bool = False |
|
|
| |
| init_use_depth: bool = False |
| init_base_std: float | None = None |
| init_std_factor: str = "disabled" |
| init_config: InitConfig = field(default_factory=InitConfig) |
|
|
|
|
| class SSM(nn.Module): |
| """ |
| State Space Model (SSM) implementation with selective state updates and convolution. |
| |
| Implements the core SSM computation with support for both training and inference modes. |
| During inference, uses cached states for efficient token-by-token generation. |
| """ |
| def __init__(self, config: BaseMambaConfig) -> None: |
| """Initialize SSM parameters and layers. |
| Args: |
| config: Configuration containing model hyperparameters |
| """ |
| super().__init__() |
| self.config = config |
| vars(self).update(vars(config)) |
|
|
| assert self.dim > 0, "Model dimension (config.dim) must be positive" |
| assert self.num_heads > 0, "Number of heads (config.num_heads) must be positive" |
| assert self.state_dim > 0, "State dimension (config.state_dim) must be positive" |
|
|
| if self.ffn_dim_multiplier is None: |
| raise ValueError( |
| "ffn_dim_multiplier must be set to a valid float (e.g. 2.0) " |
| "to determine hidden_dim in SSM." |
| ) |
| assert self.ffn_dim_multiplier > 0, "ffn_dim_multiplier must be > 0" |
|
|
| self.hidden_dim = int(self.ffn_dim_multiplier * self.dim) |
| self.hidden_dim = config.multiple_of * ( |
| (self.hidden_dim + self.multiple_of - 1) // self.multiple_of |
| ) |
| |
| assert self.hidden_dim % self.num_heads == 0, ( |
| f"Hidden dim {self.hidden_dim} not divisible by num_heads={self.num_heads}." |
| ) |
|
|
| self.head_dim = self.hidden_dim // self.num_heads |
|
|
| self.dt_limit_kwargs = {} |
| dt_limit = (self.init_config.dt_min, self.init_config.dt_max) |
| if dt_limit != (0.0, float("inf")): |
| self.dt_limit_kwargs = dict(dt_limit=dt_limit) |
|
|
| |
| d_input = ( |
| 2 * self.hidden_dim |
| + 2 * self.num_groups * self.state_dim |
| + self.num_heads |
| ) |
|
|
| self.input = nn.Linear(self.dim, d_input, bias=self.bias) |
|
|
| |
| if self.conv_size is not None: |
| conv_dim = self.hidden_dim + 2 * self.num_groups * self.state_dim |
|
|
| |
| |
| self.conv1d = nn.Conv1d( |
| in_channels=conv_dim, |
| out_channels=conv_dim, |
| kernel_size=self.conv_size, |
| groups=conv_dim, |
| bias=self.conv_bias, |
| padding=self.conv_size - 1 |
| ) |
|
|
| if config.dt_bias: |
| self.dt_bias = nn.Parameter(torch.empty(self.num_heads)) |
| else: |
| self.dt_bias = nn.Parameter(torch.zeros(self.num_heads), requires_grad=False) |
|
|
| self.A_log = nn.Parameter(torch.empty(self.num_heads)) |
|
|
| if config.D_has_head_dim: |
| self.D = nn.Parameter(torch.ones(self.num_heads, self.head_dim)) |
| else: |
| self.D = nn.Parameter(torch.ones(self.num_heads)) |
| |
| if self.learnable_init_states: |
| self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim)) |
|
|
| |
| self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps) |
| |
| self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias) |
|
|
| def _causal_conv( |
| self, |
| zxbcdt: torch.Tensor, |
| tok_idx: torch.Tensor | None = None, |
| cu_seqlens: torch.Tensor | None = None, |
| ssm_impl: str = "ssm" |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| """Processes input through causal convolution path, handling both full sequence and incremental cases. |
| |
| This function implements two processing modes: |
| 1. Full sequence ("ssm"): Used during training and initial prompt processing. |
| 2. Incremental ("ssm_update"): Used during token-by-token generation. |
| |
| Args: |
| zxbcdt: Input tensor containing concatenated [z, x, B, C, dt] components |
| tok_idx: Token indices for sequence processing. Required for "ssm" mode. |
| Defaults to None. |
| cu_seqlens: Cumulative sequence lengths for variable length processing. |
| Used only in "ssm" mode with caching. Defaults to None. |
| ssm_impl: Implementation mode, either "ssm" for full sequence processing |
| or "ssm_update" for incremental generation. Defaults to "ssm". |
| |
| Returns: |
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| Tuple containing separated components (z, x, B, C, dt), where: |
| - z: Gating branch |
| - x: Main branch |
| - B, C: SSM state matrices (analogous to K, Q in attention) |
| - dt: Time delta values |
| |
| Notes: |
| - When using "ssm" mode during inference, a cache should be pre-initialized |
| externally. This design allows for flexible caching strategies without |
| modifying model code. |
| - The "ssm_update" mode requires a cache to exist and will use it for |
| incremental state updates during generation. |
| - B, C components correspond to Key, Query in the SSM/attention duality. |
| """ |
| |
| z, xBC, dt = torch.split( |
| zxbcdt, |
| [ |
| self.hidden_dim, |
| self.hidden_dim + 2 * self.num_groups * self.state_dim, |
| self.num_heads, |
| ], |
| dim=-1, |
| ) |
|
|
| if ssm_impl == "ssm": |
| if hasattr(self, "cache"): |
| conv_varlen_states = causal_conv1d_varlen_states( |
| xBC.squeeze(0), |
| cu_seqlens, |
| state_len=self.cache.conv_cache.shape[-1], |
| ) |
| self.cache.conv_cache.copy_(conv_varlen_states) |
|
|
| xBC = causal_conv1d_fn( |
| x=xBC.transpose(1, 2), |
| weight=self.conv1d.weight.squeeze(1), |
| bias=self.conv1d.bias, |
| activation="silu", |
| seq_idx=tok_idx, |
| ).transpose(1, 2) |
| elif ssm_impl == "ssm_update": |
| xBC = causal_conv1d_update( |
| x=xBC.squeeze(0), |
| conv_state=self.cache.conv_cache, |
| weight=self.conv1d.weight.squeeze(1), |
| bias=self.conv1d.bias, |
| activation="silu", |
| ).unsqueeze(0) |
| else: |
| raise NotImplementedError(f"SSM implementation {ssm_impl} not supported") |
|
|
| |
| x, B, C = torch.split( |
| xBC, |
| [ |
| self.hidden_dim, |
| self.num_groups * self.state_dim, |
| self.num_groups * self.state_dim, |
| ], |
| dim=-1, |
| ) |
|
|
| return z, x, B, C, dt |
|
|
| def _non_causal_conv(self, zxbcdt: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| z, x, B, C, dt = torch.split( |
| zxbcdt, |
| [ |
| self.hidden_dim, |
| self.hidden_dim, |
| self.num_groups * self.state_dim, |
| self.num_groups * self.state_dim, |
| self.num_heads, |
| ], |
| dim=-1, |
| ) |
| return z, x, B, C, dt |
|
|
| def _fwd(self, x, dt, A, B, C, tok_idx, cu_seqlens, initial_states): |
| """ |
| For training |
| |
| Returns: |
| (bsz, seq_len, num_heads, head_dim) |
| """ |
| y = mamba_chunk_scan_combined( |
| x, |
| dt, |
| A, |
| B, |
| C, |
| dt_bias=self.dt_bias, |
| dt_softplus=True, |
| chunk_size=self.ssm_chunk_size, |
| D=self.D, |
| z=None, |
| seq_idx=tok_idx, |
| cu_seqlens=cu_seqlens, |
| initial_states=initial_states, |
| **self.dt_limit_kwargs, |
| ) |
|
|
| if hasattr(self, "cache"): |
| y, varlen_states = y |
| self.cache.state_cache.copy_(varlen_states) |
|
|
| return y |
| |
| def _step(self, x, seq_len, dt, A, B, C): |
| """ |
| For inference / generation. |
| """ |
| x = x.squeeze(0) |
| A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_dim) |
| dt = dt.permute(1, 2, 0).expand(seq_len, self.num_heads, self.head_dim) |
| D = self.D |
| if D is not None and D.dim() == 1: |
| D = D.unsqueeze(1).expand(self.num_heads, self.head_dim) |
| B, C = B.squeeze(0), C.squeeze(0) |
| y = selective_state_update( |
| self.cache.state_cache, |
| x, |
| dt, |
| A, |
| B, |
| C, |
| D, |
| z=None, |
| dt_bias=( |
| torch.zeros(self.num_heads, self.head_dim).to(x) |
| if self.dt_bias is None |
| else self.dt_bias.unsqueeze(1).expand(self.num_heads, self.head_dim) |
| ), |
| dt_softplus=True, |
| ).unsqueeze(0) |
| |
| return y |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| tok_idx: torch.Tensor | None = None, |
| cu_seqlens: torch.Tensor | None = None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| bsz, seq_len, _ = x.shape |
|
|
| zxbcdt = self.input(x) |
|
|
| A = -torch.exp(self.A_log.float()) |
| initial_states = ( |
| self.init_states.expand(bsz, -1, -1, -1) |
| if self.learnable_init_states else None |
| ) |
|
|
| |
| if self.conv_size is not None: |
| |
| |
| if self.use_mem_eff_path: |
| out = mamba_split_conv1d_scan_combined( |
| zxbcdt, |
| self.conv1d.weight.squeeze(1), |
| self.conv1d.bias, |
| self.dt_bias, |
| A, |
| D=self.D, |
| chunk_size=self.ssm_chunk_size, |
| seq_idx=tok_idx, |
| activation="silu", |
| rmsnorm_weight=self.norm.weight, |
| rmsnorm_eps=self.norm.eps, |
| outproj_weight=self.output.weight, |
| outproj_bias=self.output.bias, |
| headdim=self.head_dim, |
| ngroups=self.num_groups, |
| norm_before_gate=False, |
| initial_states=initial_states, |
| **self.dt_limit_kwargs, |
| ) |
| return out |
| else: |
| |
| z, x, B, C, dt = self._causal_conv(zxbcdt) |
| else: |
| |
| z, x, B, C, dt = self._non_causal_conv(zxbcdt) |
|
|
| x = x.view(bsz, seq_len, self.num_heads, self.head_dim) |
| B = B.view(bsz, seq_len, self.num_groups, self.state_dim) |
| C = C.view(bsz, seq_len, self.num_groups, self.state_dim) |
|
|
| |
| if ssm_impl == "ssm": |
| |
| y = self._fwd(x, dt, A, B, C, tok_idx, cu_seqlens, initial_states) |
| elif ssm_impl == "ssm_update": |
| y = self._step(x, seq_len, dt, A, B, C) |
| else: |
| raise NotImplementedError(f"SSM implementation {ssm_impl} not supported") |
|
|
| y = y.view(bsz, seq_len, self.hidden_dim) |
|
|
| |
| |
| |
| y = self.norm(y * F.silu(z)) |
| out = self.output(y) |
|
|
| return out |
|
|
| @torch.inference_mode() |
| def reset_parameters(self, init_std, factor) -> None: |
| config = self.config |
| init_config = config.init_config |
| if init_config is None: |
| init_config = DEFAULT_INIT_CONFIG |
|
|
| |
| in_init_std = init_std or (self.dim ** (-0.5)) |
| out_init_std = init_std or (self.hidden_dim ** (-0.5)) |
| out_init_std = out_init_std / factor |
|
|
| nn.init.trunc_normal_( |
| self.input.weight, |
| mean=0.0, |
| std=in_init_std, |
| a=-3 * in_init_std, |
| b=3 * in_init_std, |
| ) |
|
|
| nn.init.trunc_normal_( |
| self.output.weight, |
| mean=0.0, |
| std=out_init_std, |
| a=-3 * out_init_std, |
| b=3 * out_init_std, |
| ) |
|
|
| |
| if self.dt_bias is not None and self.dt_bias.requires_grad: |
| log_dt_min = math.log(init_config.dt_min) |
| log_dt_max = math.log(init_config.dt_max) |
| |
| |
| log_dt = torch.rand(self.num_heads, device=self.dt_bias.device) * (log_dt_max - log_dt_min) + log_dt_min |
| dt = torch.exp(log_dt) |
| dt = torch.clamp(dt, min=init_config.dt_init_floor) |
|
|
| |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| self.dt_bias.copy_(inv_dt) |
|
|
| elif self.dt_bias is not None: |
| |
| self.dt_bias.fill_(0.0) |
|
|
| |
| if self.conv_size is not None: |
| conv_std = init_std or (self.conv_size ** (-0.5)) |
| nn.init.trunc_normal_( |
| self.conv1d.weight, |
| mean=0.0, |
| std=conv_std, |
| a=-3 * conv_std, |
| b=3 * conv_std, |
| ) |
| if self.conv1d.bias is not None: |
| nn.init.zeros_(self.conv1d.bias) |
|
|
| |
| if self.learnable_init_states: |
| self.init_states.zero_() |
|
|
| |
| self.A_log.uniform_(init_config.A_init_min, init_config.A_init_max) |
| self.A_log.log_() |
|
|
| if self.D is not None: |
| self.D.data.fill_(1.0) |
|
|
| |
| self.norm.reset_parameters() |
|
|
|
|
| class MambaBlock(nn.Module): |
| def __init__(self, config: BaseMambaConfig): |
| super().__init__() |
| self.norm = build_norm(config.norm_type, dim=config.dim, eps=config.norm_eps) |
| self.ssm = SSM(config) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| tok_idx: torch.Tensor | None, |
| cu_seqlens: torch.Tensor | None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| x = x + self.ssm(self.norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl) |
| return x |
|
|
| @torch.inference_mode() |
| def init_weights(self, init_std=None, factor=1.0): |
| self.norm.reset_parameters() |
| self.ssm.reset_parameters(init_std, factor) |
|
|
|
|
| class BaseMamba(nn.Module): |
| def __init__(self, config: BaseMambaConfig): |
| super().__init__() |
| self.model_dim = config.dim |
| self.init_base_std = config.init_base_std |
|
|
| self.init_config = config.init_config |
| self.init_std_factor = InitStdFactor(config.init_std_factor) |
|
|
| self.layers = nn.ModuleList() |
| for _ in range(config.num_layers): |
| self.layers.append(MambaBlock(config)) |
|
|
| def forward( |
| self, |
| h: torch.Tensor, |
| tok_idx: torch.Tensor | None, |
| cu_seqlens: torch.Tensor | None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| for layer in self.layers: |
| h = layer(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl) |
| return h |
|
|
| @torch.inference_mode() |
| def reset_parameters(self): |
| pass |
|
|
| @torch.inference_mode() |
| def init_weights(self): |
| self.reset_parameters() |
| for depth, layer in enumerate(self.layers): |
| factor = { |
| InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, |
| InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, |
| InitStdFactor.DIM_RATIO: self.model_dim / 4096, |
| InitStdFactor.DISABLED: 1.0, |
| }[self.init_std_factor] |
|
|
| layer.init_weights(self.init_base_std, factor) |
|
|
|
|
| @dataclass |
| class Mamba2Config(BaseMambaConfig): |
| seed: int = 1337 |
|
|
| vocab_size: int = -1 |
| weight_tying: bool = False |
| torch_dtype: torch.dtype = torch.bfloat16 |
|
|
| loss_reduction: str = "mean" |
|
|
| use_attn: bool = False |
| softcap: float = 50.0 |
|
|
|
|
| class Mamba2(BaseMamba): |
| def __init__(self, config: Mamba2Config) -> None: |
| super().__init__(config) |
| self.weight_tying = config.weight_tying |
| self.loss_reduction = config.loss_reduction |
|
|
| assert config.vocab_size > 0, "vocab_size must be set and > 0" |
|
|
| self.tok_emb = torch.nn.Embedding(config.vocab_size, config.dim) |
|
|
| self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps) |
|
|
| self.output = nn.Linear( |
| config.dim, |
| config.vocab_size, |
| bias=False, |
| ) |
|
|
| if config.weight_tying: |
| self.output.weight = self.tok_emb.weight |
|
|
| print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,)) |
|
|
| def _get_num_params(self): |
| n_params = sum(p.numel() for p in self.parameters()) |
| if hasattr(self, "pos_emb") and self.pos_emb is not None: |
| n_params -= self.pos_emb.weight.numel() |
| if self.tok_emb.weight is not self.output.weight: |
| n_params -= self.tok_emb.weight.numel() |
| return n_params |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| target: torch.Tensor | None = None, |
| tok_idx: torch.Tensor | None = None, |
| cu_seqlens: torch.Tensor | None = None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| h = self.tok_emb(x) |
| h = super().forward(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl) |
| logits = self.output(self.norm(h)) |
| return logits |
|
|
| @torch.inference_mode() |
| def reset_parameters(self, init_std=None): |
| |
| super().reset_parameters() |
| init_std = init_std or (self.model_dim ** (-0.5)) |
| self.norm.reset_parameters() |
| nn.init.trunc_normal_( |
| self.tok_emb.weight, |
| mean=0.0, |
| std=init_std, |
| a=-3 * init_std, |
| b=3 * init_std, |
| ) |
| if not self.weight_tying: |
| nn.init.trunc_normal_( |
| self.output.weight, |
| mean=0.0, |
| std=init_std, |
| a=-3 * init_std, |
| b=3 * init_std, |
| ) |
|
|
| @torch.inference_mode() |
| def init_weights(self, buffer_device: torch.device = None): |
| """ |
| Initialize model parameters and optionally compute buffers on a specific device. |
| |
| Args: |
| buffer_device (torch.device, optional): If provided, any large or precomputed |
| buffers (like RoPE frequency tensors) will be allocated or re-created on |
| this device during initialization. This can avoid overhead from transferring |
| buffers between CPU and GPU after creation. If None, buffers default to the |
| device of the first parameter or CPU. |
| |
| Usage: |
| - Pass a GPU device (e.g., ``torch.device('cuda')``) when you want to ensure |
| buffers are created directly on GPU, preventing extra transfers. |
| - Pass a CPU device (e.g., ``torch.device('cpu')``) if you want to keep |
| large buffers in CPU memory (common in CPU-offload or pipeline-parallel setups). |
| - Leave it as ``None`` to rely on the model’s existing parameter device or |
| the default PyTorch device context. |
| |
| When / Why: |
| - Useful in distributed or pipeline-parallel training where parameters may |
| initially live on CPU, but you still need certain buffers on GPU to avoid |
| overhead during forward passes. |
| - Prevents large re-allocations or re-copies when big buffers (like RoPE |
| frequency tables) are needed per rank. |
| """ |
| super().init_weights() |
|
|
| @classmethod |
| def from_model_args(cls, config: Mamba2Config) -> "Mamba2": |
| """ |
| Initialize a Mamba model from a MambaConfig object. |
| |
| Args: |
| config (MambaConfig): Mamba configuration arguments. |
| |
| Returns: |
| Mamba: Mamba-2 model. |
| """ |
| return cls(config) |
|
|
|
|
| def get_mamba2_flops( |
| seq_len: int, |
| dim: int, |
| num_layers: int, |
| vocab_size: int, |
| ffn_multiplier: float = 2.0, |
| state_dim: int = 128, |
| conv_size: int = 4, |
| num_heads: int = 8, |
| num_groups: int = 1, |
| multiple_of: int = 256, |
| include_input_embedding: bool = True, |
| include_output_logits: bool = True, |
| forward_backward_multiplier: float = 1.0, |
| ) -> int: |
| """ |
| Estimate the FLOPs for a Mamba-2 style model using a "Chinchilla-like" shape-based approach. |
| |
| By default, this returns the forward-pass cost. If you want a rough |
| forward+backward estimate, set `forward_backward_multiplier=3.0` (common |
| rule-of-thumb for these models). |
| |
| What gets counted: |
| • Hidden dimension is rounded up to 'multiple_of' = 256 (as in Mamba). |
| • Per-layer: |
| 1) Input Linear: [dim → 2*hidden_dim + 2*(groups*state_dim) + num_heads] |
| 2) Depthwise Conv1D: 2*(conv_dim * conv_size), where conv_dim=hidden_dim + 2*groups*state_dim |
| 3) SSM selective scan: ~9*(dim*state_dim) (from Mamba dev discussion) |
| 4) Output Linear: [hidden_dim → dim] |
| • Each layer’s cost is multiplied by (seq_len * num_layers). |
| • Optionally adds: |
| - The cost of the input embedding (treating it as a matmul: seq_len×vocab_size × vocab_size×dim). |
| - The cost of the final projection [dim → vocab_size]. |
| • Finally scaled by `forward_backward_multiplier` if desired. |
| |
| Args: |
| seq_len (int): Sequence length (number of tokens). |
| dim (int): Model (embedding) dimension. |
| num_layers (int): Number of Mamba layers. |
| vocab_size (int): Vocabulary size for final logits projection. |
| ffn_multiplier (float): FFN expansion ratio, e.g. 2.0 => hidden_dim=2×dim (rounded up). |
| state_dim (int): SSM state dimension (commonly 128). |
| conv_size (int): Kernel size for the depthwise conv1d (default=4). |
| num_heads (int): Number of heads (slightly affects input-lin out_dim). |
| num_groups (int): For "grouped" states in some Mamba variants (usually 1). |
| multiple_of (int): Round hidden_dim up to this multiple (commonly 256). |
| include_input_embedding (bool): If True, count the cost of an “embedding matmul” |
| for the input tokens => shape-based approach. |
| include_output_logits (bool): If True, count the cost of final [dim → vocab_size]. |
| forward_backward_multiplier (float): E.g. 1.0 for forward only, 2.0 or 3.0 for forward+backward. |
| |
| Returns: |
| int: Approximate total FLOPs (multiply-adds) for the selected pass(es), |
| as an integer. |
| """ |
| |
| flops_embedding = 0 |
| if include_input_embedding: |
| flops_embedding = 2 * (seq_len * vocab_size * dim) |
|
|
| |
| raw_hidden_dim = int(ffn_multiplier * dim) |
| hidden_dim = multiple_of * ((raw_hidden_dim + multiple_of - 1) // multiple_of) |
|
|
| |
| out_dim_input = 2*hidden_dim + 2*(num_groups*state_dim) + num_heads |
| flops_input_linear = 2 * (dim * out_dim_input) |
| conv_dim = hidden_dim + 2*(num_groups*state_dim) |
| flops_conv = 2 * (conv_dim * conv_size) |
| flops_ssm = 9 * state_dim * dim |
| flops_output_linear = 2 * (hidden_dim * dim) |
| flops_layer = (flops_input_linear + flops_conv + flops_ssm + flops_output_linear) |
|
|
| |
| flops_layers = flops_layer * num_layers * seq_len |
|
|
| |
| flops_vocab = 0 |
| if include_output_logits: |
| flops_vocab = 2 * (seq_len * dim * vocab_size) |
|
|
| |
| flops_forward = flops_embedding + flops_layers + flops_vocab |
|
|
| |
| return int(flops_forward * forward_backward_multiplier) |
|
|
| def get_mamba2_flops_per_token( |
| **kwargs |
| ) -> float: |
| """ |
| Estimate FLOPs per token for a Mamba-2 style model. |
| |
| This function extracts necessary parameters from kwargs and calculates the FLOPs per token. |
| |
| Args: |
| **kwargs: Dictionary containing model configuration parameters. |
| |
| Returns: |
| float: Approximate FLOPs per token. |
| """ |
| defaults = { |
| 'ffn_dim_multiplier': 2.0, |
| 'state_dim': 128, |
| 'conv_size': 4, |
| 'num_heads': 8, |
| 'num_groups': 1, |
| 'multiple_of': 256, |
| 'include_input_embedding': True, |
| 'include_output_logits': True, |
| 'forward_backward_multiplier': 1.0, |
| } |
| |
| for k, v in defaults.items(): |
| kwargs.setdefault(k, v) |
| |
| for required in ['seq_len', 'dim', 'num_layers', 'vocab_size']: |
| if required not in kwargs: |
| raise ValueError(f"Missing required parameter: {required}") |
|
|
| total_flops = get_mamba2_flops( |
| seq_len=kwargs['seq_len'], |
| dim=kwargs['dim'], |
| num_layers=kwargs['num_layers'], |
| vocab_size=kwargs['vocab_size'], |
| ffn_multiplier=kwargs['ffn_dim_multiplier'], |
| state_dim=kwargs['state_dim'], |
| conv_size=kwargs['conv_size'], |
| num_heads=kwargs['num_heads'], |
| num_groups=kwargs['num_groups'], |
| multiple_of=kwargs['multiple_of'], |
| include_input_embedding=kwargs['include_input_embedding'], |
| include_output_logits=kwargs['include_output_logits'], |
| forward_backward_multiplier=kwargs['forward_backward_multiplier'], |
| ) |
| flops_per_token = total_flops / kwargs['seq_len'] |
|
|
| return flops_per_token |
|
|
|
|
| |
| def get_no_recompute_ops(): |
| return { |
| torch.ops.aten.mm.default, |
| torch.ops.aten._scaled_mm.default, |
| torch.ops.c10d_functional.reduce_scatter_tensor.default, |
| torch.ops.mamba_ssm.ssm_chunk_scan_combined_fwd.default, |
|
|
| |
| torch.ops.aten.abs.default, |
| torch.ops.aten.max.default, |
| } |
|
|
|
|
| def main(): |
| from mamba_ssm import Mamba2 as MambaRef |
|
|
| x = torch.randn(2, 64, 192).cuda() |
|
|
| |
| model = MambaRef( |
| d_model=192, |
| expand=2, |
| d_conv=4, |
| d_state=64, |
| headdim=48, |
| ).cuda() |
| y = model(x) |
| print("Mamba reference output: ", y) |
| print("Mean of MambaRef output: ", y.mean().item()) |
| print("Stddev of MambaRef output: ", y.std().item()) |
|
|
| |
| config = Mamba2Config(vocab_size=200064, use_mem_eff_path=True) |
| model2 = Mamba2( |
| config=config, |
| ).cuda() |
|
|
| |
| x_indices = torch.randint(0, config.vocab_size, (2, 64), dtype=torch.long).cuda() |
|
|
| y2 = model2(x_indices) |
| print("Mamba output: ", y2) |
| print("Mean of Mamba output: ", y2.mean().item()) |
| print("Stddev of Mamba output: ", y2.std().item()) |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|