diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..0cd58331b2a989b68be4ec5676383437fca8687b
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,36 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.so filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8e2bf2d587bd92c011ce44896ead4fc9bbc68c7f
--- /dev/null
+++ b/README.md
@@ -0,0 +1,22 @@
+---
+license: bsd-3-clause
+tags:
+ - kernels
+---
+
+## causal-conv1d
+
+Causal [depthwise conv1d kernel](https://github.com/Dao-AILab/causal-conv1d/) by Tri Dao.
+
+Kernel source: https://github.com/huggingface/kernels-community/tree/main/causal-conv1d
+
+### Performance
+
+
+
+
+
+
+
+
+
diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1e59c3b6fb91250556ddf55a423b33aee0af2c5
--- /dev/null
+++ b/benchmarks/benchmark.py
@@ -0,0 +1,92 @@
+import torch
+import torch.nn.functional as F
+
+from kernels.benchmark import Benchmark
+
+
+class CausalConv1dBenchmark(Benchmark):
+ seed: int = 42
+
+ def setup(self):
+ batch_size, dim, seqlen, width = 2, 64, 128, 4
+ self.x = torch.randn(
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
+ )
+ self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
+ self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
+ self.out = torch.empty(
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
+ )
+ self.dim = dim
+ self.width = width
+ self.seqlen = seqlen
+
+ def benchmark_base(self):
+ self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
+
+ def verify_base(self) -> torch.Tensor:
+ x_fp32 = self.x.to(self.weight.dtype)
+ out = F.conv1d(
+ x_fp32,
+ self.weight.unsqueeze(1),
+ self.bias,
+ padding=self.width - 1,
+ groups=self.dim,
+ )
+ return out[..., : self.seqlen].to(self.x.dtype)
+
+ def setup_large(self):
+ batch_size, dim, seqlen, width = 8, 256, 512, 4
+ self.x = torch.randn(
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
+ )
+ self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
+ self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
+ self.out = torch.empty(
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
+ )
+ self.dim = dim
+ self.width = width
+ self.seqlen = seqlen
+
+ def benchmark_large(self):
+ self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
+
+ def verify_large(self) -> torch.Tensor:
+ x_fp32 = self.x.to(self.weight.dtype)
+ out = F.conv1d(
+ x_fp32,
+ self.weight.unsqueeze(1),
+ self.bias,
+ padding=self.width - 1,
+ groups=self.dim,
+ )
+ return out[..., : self.seqlen].to(self.x.dtype)
+
+ def setup_xlarge(self):
+ batch_size, dim, seqlen, width = 16, 512, 1024, 4
+ self.x = torch.randn(
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
+ )
+ self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
+ self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
+ self.out = torch.empty(
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
+ )
+ self.dim = dim
+ self.width = width
+ self.seqlen = seqlen
+
+ def benchmark_xlarge(self):
+ self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
+
+ def verify_xlarge(self) -> torch.Tensor:
+ x_fp32 = self.x.to(self.weight.dtype)
+ out = F.conv1d(
+ x_fp32,
+ self.weight.unsqueeze(1),
+ self.bias,
+ padding=self.width - 1,
+ groups=self.dim,
+ )
+ return out[..., : self.seqlen].to(self.x.dtype)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..bfba996f987a4a91f20a5908b13fd9d83ac93b7b
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83b8ab4db3d387552329f75f775db33b59380b18ba2af057504ad810fab09295
+size 80857232
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/cpp_functions.py b/build/torch210-cxx11-cu126-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/metadata.json b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0+PTX"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..a1a1b4d30e2106991416a65da38c8dad70c8a7bb
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06c71255dcc14bbe4e00c85170d0f1dce0d6510e091a4237bc1c2e61368d47f2
+size 80694472
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0+PTX"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..9adbbb9e769872f8226f3bb3eae46e537e353869
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c9bcd794221ea9d6cb4f2c3ec409e75f83cd955823a6406f693c50d77d1c4b28
+size 107312656
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/cpp_functions.py b/build/torch210-cxx11-cu128-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/metadata.json b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0+PTX",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..58e4cc8a7164d312d42fb7ab5a4cc5fec69e3a4c
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed02f049828da6a24af2b061c4f2a9f440b66ae258411071fe4575c0f577d5d4
+size 107169840
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0+PTX",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..89f0bf651fb60fe51bf08fded9a291386aae7015
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8b8c42b85e433d28b00dd22919b0abe1152b68b2701cf5de31a46d3d1fabd128
+size 64755512
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/cpp_functions.py b/build/torch210-cxx11-cu130-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,19 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "11.0",
+ "12.0+PTX",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..3cd522d7fff5ef2f0a3d3661c17f440e1e4031ed
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e5ec24f997ea256acf15d45d41acd47e841d26f50dab276b6f8f3600247501e
+size 64618472
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,19 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "11.0",
+ "12.0+PTX",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..7d6a25eecfefa85762ba0a837eb7c9e3cc82fc31
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54d057ef1f5e12e7715f8dfd2879190e4ae82c170ae0d412bc1696be8031d1ef
+size 80857352
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/cpp_functions.py b/build/torch211-cxx11-cu126-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/metadata.json b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0+PTX"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..f5466087716c10109e093c4f6dc98cb1400ea837
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71ecf103b1b26f969ecd7734e196f24210bb5e2937f63b9f65e2e127fd5e8e5f
+size 80694560
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch211-cxx11-cu126-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/metadata.json b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0+PTX"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..df1664c45faf3611842b70b761f4ca490042413b
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:69e30c7a01b7d693affe82ae3556b7d8428c095d86fc5dca1cf4014d3ceff43b
+size 107312776
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/cpp_functions.py b/build/torch211-cxx11-cu128-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/metadata.json b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0+PTX",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..d2b16fb61e2d09e8316f35076d9ed257c716c246
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67a94922da263147a2328b548a12d12bfe77654431dc234659eb45e3337cb948
+size 107169936
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch211-cxx11-cu128-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0+PTX",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..83e375375beaba6eece156ddda39eb9f96151365
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5948aec7b42fd865dd38266aad64a34dd3ac7ce5c58c5f04babe9ee28bba1d5f
+size 64755624
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/cpp_functions.py b/build/torch211-cxx11-cu130-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/metadata.json b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,19 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "11.0",
+ "12.0+PTX",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..e96f3ef9130d3cc40a9be87e6f625d6bd408f801
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:075f1884806147ccb391ee3949e30fc6e8009a4bb02894f2c97b5819d0979c8d
+size 64618568
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch211-cxx11-cu130-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,19 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "11.0",
+ "12.0+PTX",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cebc7817a292011587f4941dfff502f5e5c98cbd
Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f17eb8838cb893f61b2f42a345df9f69eb7e8bf
Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdd84db8007b53b4dd4368dc01ad9296d4b95830
Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67e3eba29657209ffccf401d6e8c340a28057265
Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f78fec2e6a76348b50183b154618fd095d197d13
Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..0a843146f753e99ec745a7e2deb9c3db543a3482
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c7b5ae8af9477be3049ba1ae6af3f9e2d8bf82979fa9e9632c485a8d49f532a
+size 64503960
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_90f5a60
+ops = torch.ops._causal_conv1d_90f5a60
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_90f5a60::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..786bd24d264db90632165d514db4f5521d8133e0
Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9d95dcbdfd288de6d7077fd18259eb3bddb4958
Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..297eb3e2197477c1df29aaeaff32a0c4b4ba8611
Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40dd2dec4c9da46e89fe46f6974c794ad6432a52
Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6d47e6accd7dc9a8d27df0e15f3363d6579b752
Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..770b6a40b54b2b7dc3ff89d20de6dae3cfd5a06a
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:299bc47bf7fdea21eb71f9b0d0cd329a32e792106029d4fd5d6c637c76b9c6f7
+size 64213568
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_90f5a60
+ops = torch.ops._causal_conv1d_90f5a60
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_90f5a60::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f5e5e7f7807ad1ef65b13252d2c019610e63ea1
Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc420da27d8204045a0413fb32b05b30f65d7ed4
Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06007695980ef777fd26f660043ecdb3d633b439
Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ade83fbdcf1c14aa347c29d3b12afa3833385854
Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7c1d92c27a92656d2b0fe8c5fd52fbe61f15d6f
Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..40a950b0a1cf66eedb98e1fe4bd018c280a29005
--- /dev/null
+++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7ec3c41413afbb69d499eae6a432fa9d41a580e7b1c6ee83d09e8dab51f91803
+size 90795560
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..2257797ae235d25abede0851de00a59f5220a87d
--- /dev/null
+++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_306ae84
+ops = torch.ops._causal_conv1d_306ae84
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_306ae84::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18d3139112b33e61419b6a2728a795c6a358861e
Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2857339f18ec0febb5d24266e3d91fbbe2fa820c
Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bef01edf5863f420831c4fd7b62444df502bf29e
Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb24c381faa7b75b6519cd48d85261bec5d03f1d
Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7283c7bf629a1f669d5802a0aecc629cfcda5eb0
Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..bc680b2d853196f8878bb9cb5f6d73bce3cecb80
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:859247ab0b3e7852c4e1a6ac76f3c62b3aebea729241058075eb2e6f29139a50
+size 90656256
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_90f5a60
+ops = torch.ops._causal_conv1d_90f5a60
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_90f5a60::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_e7e5852.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_e7e5852.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..06feae7dcffaefead4119f47cf223bda523ef0b8
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_e7e5852.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e8a4c3c1eb4c667ed0ef6affd83922fc4b76d96c491b26afb012bbb4e84ac245
+size 80684768
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c9592575922d5a8d400f767f6d5f31fa8dbcb3
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_e7e5852
+ops = torch.ops._causal_conv1d_e7e5852
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_e7e5852::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_e7e5852.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_e7e5852.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..75b81a3487cf14abd1be1a8f1db4cfe11db65fc8
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_e7e5852.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:488a09e6d74f4f4f8b6c0dfd26df892dab4e8bc2283a2e95be24c65ed043ec70
+size 107168432
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c9592575922d5a8d400f767f6d5f31fa8dbcb3
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_e7e5852
+ops = torch.ops._causal_conv1d_e7e5852
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_e7e5852::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd27c508caa49bc189bc90b3dadc231801ff39a7
Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7052f3156badccc6818d6e61442177ee29be2e61
Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4d41ce73daaeb59359e6dc777f3b5a596aa8ddd
Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7df49a782e7c05722c1b4f05c72b7a40ae91e057
Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8599cda250865c0f80bce799b2593833c99384c
Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..e26a04f6ebfd8fa02bf23fc08727ce37bea5a617
--- /dev/null
+++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c3535b795cbc5baf363b0cd8636649153b017599c9992ea6c08aa4ab23ceae0
+size 97678768
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_ops.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..2257797ae235d25abede0851de00a59f5220a87d
--- /dev/null
+++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_306ae84
+ops = torch.ops._causal_conv1d_306ae84
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_306ae84::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/cpp_functions.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_e7e5852.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_e7e5852.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..d50e82bfa5604d3b8a591901aed86e54a7e07afc
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_e7e5852.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dbee5d58b825b18e0751347cc6ed27982623257ee037e3a9c3da47bee3dd8f53
+size 115140584
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c9592575922d5a8d400f767f6d5f31fa8dbcb3
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_e7e5852
+ops = torch.ops._causal_conv1d_e7e5852
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_e7e5852::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..83ab043e6f16f9edd3796ebdb1e3a4106a2b78e4
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:984a88b1b33598f95d2b2c6f19ace193be19f5933f11642bbf9cf8d8cecc9050
+size 80789912
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_7579ac2
+ops = torch.ops._causal_conv1d_cuda_7579ac2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_7579ac2::{op_name}"
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/cpp_functions.py b/build/torch29-cxx11-cu126-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/metadata.json b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0+PTX"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..32d5cba202a2eab17d06c5e19bd6c429cc6c323f
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b648cc9f51b076e4cf8b6fd739012ae066a797f4948730273f5ece8adf950976
+size 80684872
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_7579ac2
+ops = torch.ops._causal_conv1d_cuda_7579ac2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_7579ac2::{op_name}"
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0+PTX"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..83e2fea72221ac78ca56ec46f7a2a982f1e13c02
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:21bfbb1a2c4685ef84881f72d5a63acb9f38840688fd8587f0be78ad31bb24df
+size 107310800
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_7579ac2
+ops = torch.ops._causal_conv1d_cuda_7579ac2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_7579ac2::{op_name}"
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/cpp_functions.py b/build/torch29-cxx11-cu128-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/metadata.json b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0+PTX",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..8fedf1104c4d17fabcd905d1208d390fa3244f10
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e26af82de4f1fc452d1a1cba90cb154ce269f1f5d20bccb94ccc665e2c970e5d
+size 107172632
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_7579ac2
+ops = torch.ops._causal_conv1d_cuda_7579ac2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_7579ac2::{op_name}"
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0+PTX",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch29-cxx11-cu129-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..cf187ce7f99f214ff94fb55fa237bc0d0df0ee77
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc4012b62d077e5479ffa580abe50036943b3a2dcccec65055b1e13752be1d62
+size 115308008
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/cpp_functions.py b/build/torch29-cxx11-cu129-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0+PTX",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch29-cxx11-cu129-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..107538ba364830b8a0b1cb7c4c1e628c8abc775c
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:800c4a26ae3ece97637afa81f9ae4e294e643d2259204b0be027e4d9cb82e147
+size 115140688
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_6b83b83
+ops = torch.ops._causal_conv1d_cuda_6b83b83
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_6b83b83::{op_name}"
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu129-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0+PTX",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..822c94dee6c0812e147a0ababd144414e364d9b4
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:081a03764e03c62943760b2eb9baf991dacafaf985bfc0374f58eeb86b89ffc7
+size 64753648
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_7579ac2
+ops = torch.ops._causal_conv1d_cuda_7579ac2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_7579ac2::{op_name}"
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/cpp_functions.py b/build/torch29-cxx11-cu130-aarch64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/metadata.json b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,19 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "11.0",
+ "12.0+PTX",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,4 @@
+from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
+from .causal_conv1d_varlen import causal_conv1d_varlen_states
+
+__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..f0629040266fe8fb76e0bb910d7fe4ae47e95248
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed0967161460ebedec3b7216fb284ceef4a290cfedf5b71368e7a71f2e289e65
+size 64613072
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _causal_conv1d_cuda_7579ac2
+ops = torch.ops._causal_conv1d_cuda_7579ac2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_causal_conv1d_cuda_7579ac2::{op_name}"
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+ ):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ if seq_idx is not None:
+ assert (
+ initial_states is None
+ ), "initial_states must be None if seq_idx is not None"
+ assert (
+ not return_final_states
+ ), "If seq_idx is not None, we don't return final_states_out"
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
+ if initial_states is not None and (
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
+ ):
+ initial_states = initial_states.contiguous()
+ if return_final_states:
+ assert (
+ x.stride(1) == 1
+ ), "Only channel-last layout support returning final_states_out"
+ if final_states_out is not None:
+ assert (
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
+ )
+ else:
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ final_states_out = torch.empty(
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
+ ).transpose(1, 2)
+ else:
+ final_states_out = None
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_fwd_function(
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
+ )
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
+ ctx.return_final_states = return_final_states
+ ctx.return_dinitial_states = (
+ initial_states is not None and initial_states.requires_grad
+ )
+ return out if not return_final_states else (out, final_states_out)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
+ dfinal_states = args[0] if ctx.return_final_states else None
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
+ x,
+ weight,
+ bias,
+ dout,
+ seq_idx,
+ initial_states,
+ dfinal_states,
+ None,
+ ctx.return_dinitial_states,
+ ctx.activation,
+ )
+ return (
+ dx,
+ dweight,
+ dbias if bias is not None else None,
+ None,
+ dinitial_states if initial_states is not None else None,
+ None,
+ None,
+ None,
+ )
+
+
+def causal_conv1d_fn(
+ x,
+ weight,
+ bias=None,
+ seq_idx=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ seq_idx: (batch, seqlen)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1), to be written to
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(
+ x,
+ weight,
+ bias,
+ seq_idx,
+ initial_states,
+ return_final_states,
+ final_states_out,
+ activation,
+ )
+
+
+def causal_conv1d_ref(
+ x,
+ weight,
+ bias=None,
+ initial_states=None,
+ return_final_states=False,
+ final_states_out=None,
+ activation=None,
+):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ initial_states: (batch, dim, width - 1)
+ final_states_out: (batch, dim, width - 1)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ if initial_states is None:
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ else:
+ x = torch.cat([initial_states, x], dim=-1)
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
+ out = out[..., :seqlen]
+ if return_final_states:
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
+ dtype_in
+ ) # (batch, dim, width - 1)
+ if final_states_out is not None:
+ final_states_out.copy_(final_states)
+ else:
+ final_states_out = final_states
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+ return out if not return_final_states else (out, final_states_out)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ out = causal_conv1d_update_function(
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
+ )
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return out
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state starting at the index
+ @cache_seqlens % state_len before performing the convolution.
+
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ batch, dim, seqlen = x.shape
+ width = weight.shape[1]
+ state_len = conv_state.shape[-1]
+ assert conv_state.shape == (batch, dim, state_len)
+ assert weight.shape == (dim, width)
+ if cache_seqlens is None:
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
+ conv_state.copy_(x_new[:, :, -state_len:])
+ else:
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
+ conv_state.scatter_(2, copy_idx, x)
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
+ if unsqueeze:
+ out = out.squeeze(-1)
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py
@@ -0,0 +1,86 @@
+import torch
+from torch import Tensor
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _causal_conv1d_varlen_states(
+ X,
+ CU_SEQLENS,
+ STATES,
+ state_len,
+ dim,
+ stride_x_seqlen, stride_x_dim,
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr
+):
+ batch_idx = tl.program_id(2)
+ STATES += batch_idx * stride_states_batch
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
+ other=0)
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
+ x,
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
+
+
+def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
+ with torch.cuda.device(x.device.index):
+ _causal_conv1d_varlen_states[grid](
+ x,
+ cu_seqlens,
+ states,
+ state_len,
+ dim,
+ x.stride(0), x.stride(1),
+ states.stride(0), states.stride(2), states.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
+ )
+ return states
+
+
+def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
+ """
+ Forward pass only, does not support backward pass.
+ Parameters:
+ x: (total_tokens, dim)
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
+ If some of those elements belong to a different sequence, the value of the states will be zero.
+ Return:
+ states: (batch, dim, state_len)
+ """
+ _, dim = x.shape
+ batch = cu_seqlens.shape[0] - 1
+ cu_seqlens = cu_seqlens.contiguous()
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
+ for i in range(batch):
+ end_idx = cu_seqlens[i + 1]
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
+ return states
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+
+from ._ops import ops
+
+def causal_conv1d_fwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ final_states_out: torch.Tensor | None,
+ silu_activation: bool,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_fwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ out=out,
+ final_states_out=final_states_out,
+ silu_activation=silu_activation,
+ )
+ return out
+
+
+def causal_conv1d_bwd_function(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ dout: torch.Tensor,
+ seq_idx: torch.Tensor | None,
+ initial_states: torch.Tensor | None,
+ dfinal_states: torch.Tensor | None,
+ dx: torch.Tensor | None,
+ return_dinitial_states: torch.Tensor,
+ silu_activation: bool,
+) -> tuple[torch.Tensor | None]:
+ batch_size, dim = x.size()[:2]
+ width = weight.size(-1)
+
+ if dx is None:
+ dx = torch.empty_like(x)
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
+ dbias = None
+ if bias is not None:
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
+ dinitial_states = None
+ if return_dinitial_states:
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
+
+ ops.causal_conv1d_bwd(
+ x=x,
+ weight=weight,
+ bias=bias,
+ dout=dout,
+ seq_idx=seq_idx,
+ initial_states=initial_states,
+ dfinal_states=dfinal_states,
+ dx=dx,
+ dweight=dweight,
+ dbias=dbias,
+ dinitial_states=dinitial_states,
+ silu_activation=silu_activation,
+ )
+
+ dweight = dweight.type_as(weight)
+ if dbias is not None:
+ dbias = dbias.type_as(bias)
+ return dx, dweight, dbias, dinitial_states
+
+
+def causal_conv1d_update_function(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None,
+ silu_activation: bool,
+ cache_seqlens: torch.Tensor | None,
+ conv_state_indices: torch.Tensor | None,
+) -> torch.Tensor:
+ out = torch.empty_like(x)
+ ops.causal_conv1d_update(
+ x=x,
+ conv_state=conv_state,
+ weight=weight,
+ bias=bias,
+ out=out,
+ silu_activation=silu_activation,
+ cache_seqlens=cache_seqlens,
+ conv_state_indices=conv_state_indices,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,19 @@
+{
+ "version": 1,
+ "license": "BSD-3-Clause",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "11.0",
+ "12.0+PTX",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/media/benches_dark_animation.svg b/media/benches_dark_animation.svg
new file mode 100644
index 0000000000000000000000000000000000000000..776391a87df7a62d9cdd3610975d70c842b0c86a
--- /dev/null
+++ b/media/benches_dark_animation.svg
@@ -0,0 +1,42 @@
+
\ No newline at end of file
diff --git a/media/benches_dark_latency.svg b/media/benches_dark_latency.svg
new file mode 100644
index 0000000000000000000000000000000000000000..8d4353f58c0bc50685e813ae21c9fbbfda0d38e4
--- /dev/null
+++ b/media/benches_dark_latency.svg
@@ -0,0 +1,2104 @@
+
+
+
diff --git a/media/benches_dark_throughput.svg b/media/benches_dark_throughput.svg
new file mode 100644
index 0000000000000000000000000000000000000000..8b57257a5035a7df5bfe9e2eb864a665e54d7a90
--- /dev/null
+++ b/media/benches_dark_throughput.svg
@@ -0,0 +1,2228 @@
+
+
+
diff --git a/media/benches_light_animation.svg b/media/benches_light_animation.svg
new file mode 100644
index 0000000000000000000000000000000000000000..1d344696328ce4024da15f12f714ec488b02cb61
--- /dev/null
+++ b/media/benches_light_animation.svg
@@ -0,0 +1,42 @@
+
\ No newline at end of file
diff --git a/media/benches_light_latency.svg b/media/benches_light_latency.svg
new file mode 100644
index 0000000000000000000000000000000000000000..a72eab3d163ce8a87261c589cff9529b53247d31
--- /dev/null
+++ b/media/benches_light_latency.svg
@@ -0,0 +1,2104 @@
+
+
+
diff --git a/media/benches_light_throughput.svg b/media/benches_light_throughput.svg
new file mode 100644
index 0000000000000000000000000000000000000000..3d92b810d5b39ec7438bde650d00ab1479e6e31c
--- /dev/null
+++ b/media/benches_light_throughput.svg
@@ -0,0 +1,2228 @@
+
+
+