| from typing import List, Optional, Tuple |
| import torch |
|
|
| from mamba_ssm.ops.triton.ssd_combined import _mamba_chunk_scan_combined_fwd, _mamba_chunk_scan_combined_bwd |
|
|
|
|
| @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) |
| def _compiled_mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=None): |
| return _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
|
|
| @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) |
| def _compiled_mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, dt_limit=None): |
| return _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) |
|
|
|
|
| @torch.library.custom_op( |
| "mamba_ssm::ssm_chunk_scan_combined_fwd", |
| mutates_args=(), |
| device_types="cuda", |
| ) |
| def ssm_chunk_scan_combined_fwd( |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| chunk_size: int, |
| D: Optional[torch.Tensor] = None, |
| z: Optional[torch.Tensor] = None, |
| dt_bias: Optional[torch.Tensor] = None, |
| initial_states: Optional[torch.Tensor] = None, |
| seq_idx: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| dt_softplus: bool = False, |
| dt_limit: Optional[List[float]] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
|
|
| return out, out_x if out_x is not None else out.new_empty(0), rest[0] if cu_seqlens is not None else out.new_empty(0) |
|
|
| @ssm_chunk_scan_combined_fwd.register_fake |
| def _ssm_chunk_scan_combined_fwd_fake( |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| chunk_size: int, |
| D: Optional[torch.Tensor] = None, |
| z: Optional[torch.Tensor] = None, |
| dt_bias: Optional[torch.Tensor] = None, |
| initial_states: Optional[torch.Tensor] = None, |
| seq_idx: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| dt_softplus: bool = False, |
| dt_limit: Optional[List[float]] = None |
| ): |
| _, _, n_heads, head_dim = x.shape |
| return ( |
| torch.empty_like(x), |
| torch.empty_like(x) if z is not None else None, |
| x.new_empty((cu_seqlens.size(0)-1, n_heads, head_dim, B.size(0))) if cu_seqlens is not None else None, |
| ) |
|
|
| @torch.library.custom_op( |
| "mamba_ssm::ssm_chunk_scan_combined_bwd", |
| mutates_args=(), |
| device_types="cuda", |
| ) |
| def ssm_chunk_scan_combined_bwd( |
| dout: torch.Tensor, |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| out: torch.Tensor, |
| chunk_size: int, |
| D: Optional[torch.Tensor] = None, |
| z: Optional[torch.Tensor] = None, |
| dt_bias: Optional[torch.Tensor] = None, |
| initial_states: Optional[torch.Tensor] = None, |
| seq_idx: Optional[torch.Tensor] = None, |
| dt_softplus: bool = False, |
| dt_limit: Optional[List[float]] = None |
| )-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=None, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) |
| return ( |
| dx, |
| ddt, |
| dA, |
| dB, |
| dC, |
| dD if dD is not None else dx.new_empty(0), |
| dz if dz is not None else dx.new_empty(0), |
| ddt_bias if ddt_bias is not None else dx.new_empty(0), |
| dinitial_states if dinitial_states is not None else dx.new_empty(0) |
| ) |
|
|
| @ssm_chunk_scan_combined_bwd.register_fake |
| def _ssm_chunk_scan_combined_bwd_fake( |
| dout: torch.Tensor, |
| x: torch.Tensor, |
| dt: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| out: torch.Tensor, |
| chunk_size: int, |
| D: Optional[torch.Tensor] = None, |
| z: Optional[torch.Tensor] = None, |
| dt_bias: Optional[torch.Tensor] = None, |
| initial_states: Optional[torch.Tensor] = None, |
| seq_idx: Optional[torch.Tensor] = None, |
| dt_softplus: bool = False, |
| dt_limit: Optional[List[float]] = None |
| ): |
| return ( |
| torch.empty_like(x), |
| torch.empty_like(dt), |
| torch.empty_like(A), |
| torch.empty_like(B), |
| torch.empty_like(C), |
| torch.empty_like(D) if D is not None else None, |
| torch.empty_like(z) if z is not None else None, |
| torch.empty_like(dt_bias) if dt_bias is not None else None, |
| torch.empty_like(initial_states) if initial_states is not None else None, |
| ) |
|
|
|
|
| def ssm_chunk_scan_combined_setup_context(ctx, inputs, output): |
| x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit = inputs |
| out, out_x, state_varlen = output |
|
|
| ctx.save_for_backward(out if z is None else out_x, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx) |
| ctx.dt_softplus = dt_softplus |
| ctx.chunk_size = chunk_size |
| ctx.dt_limit = dt_limit |
|
|
| def ssm_chunk_scan_combined_bridge(ctx, dout, dout_x, dout_state_varlen): |
| out, x, dt, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors |
|
|
| dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = ssm_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) |
|
|
| return ( |
| dx, |
| ddt, |
| dA, |
| dB, |
| dC, |
| None, |
| dD if D is not None else None, |
| dz if z is not None else None, |
| ddt_bias if dt_bias is not None else None, |
| dinitial_states if initial_states is not None else None, |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
| |
| torch.library.register_autograd( |
| "mamba_ssm::ssm_chunk_scan_combined_fwd", |
| ssm_chunk_scan_combined_bridge, |
| setup_context=ssm_chunk_scan_combined_setup_context, |
| ) |
|
|
| def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): |
| """ |
| Argument: |
| x: (batch, seqlen, nheads, headdim) |
| dt: (batch, seqlen, nheads) |
| A: (nheads) |
| B: (batch, seqlen, ngroups, dstate) |
| C: (batch, seqlen, ngroups, dstate) |
| chunk_size: int |
| D: (nheads, headdim) or (nheads,) |
| z: (batch, seqlen, nheads, headdim) |
| dt_bias: (nheads,) |
| initial_states: (batch, nheads, headdim, dstate) |
| seq_idx: (batch, seqlen) |
| cu_seqlens: (num_sequences + 1) or None |
| dt_softplus: Whether to apply softplus to dt |
| Return: |
| out: (batch, seqlen, nheads, headdim) |
| """ |
| |
| out, _, varlen_states = ssm_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) |
| if cu_seqlens is not None: |
| return out, varlen_states |
| return out |
|
|
| if __name__ == "__main__": |
| from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as mamba_chunk_scan_combined_ref |
|
|
| torch.manual_seed(0) |
| torch.cuda.manual_seed(0) |
|
|
| x = torch.randn(2, 3, 4, 5).cuda() |
| dt = torch.randn(2, 3, 4).cuda() |
| A = torch.randn(4).cuda() |
| B = torch.randn(2, 3, 4, 5).cuda() |
| C = torch.randn(2, 3, 4, 5).cuda() |
| chunk_size = 2 |
| D = torch.randn(4, 5).cuda() |
| z = torch.randn(2, 3, 4, 5).cuda() |
| dt_bias = torch.randn(4).cuda() |
|
|
| out = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
|
|
| print(out.min(), out.max(), out.mean(), out.std()) |
|
|
| compiled_mamba_chunk_scan_combined = torch.compile(mamba_chunk_scan_combined) |
| out = compiled_mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
|
|
| print(out.min(), out.max(), out.mean(), out.std()) |
|
|
| out_ref = mamba_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias) |
|
|
| print(out_ref.min(), out_ref.max(), out_ref.mean(), out_ref.std()) |
|
|
|
|