diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..0cd58331b2a989b68be4ec5676383437fca8687b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.so filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8e2bf2d587bd92c011ce44896ead4fc9bbc68c7f --- /dev/null +++ b/README.md @@ -0,0 +1,22 @@ +--- +license: bsd-3-clause +tags: + - kernels +--- + +## causal-conv1d + +Causal [depthwise conv1d kernel](https://github.com/Dao-AILab/causal-conv1d/) by Tri Dao. + +Kernel source: https://github.com/huggingface/kernels-community/tree/main/causal-conv1d + +### Performance + + + + + + + + + diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e59c3b6fb91250556ddf55a423b33aee0af2c5 --- /dev/null +++ b/benchmarks/benchmark.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F + +from kernels.benchmark import Benchmark + + +class CausalConv1dBenchmark(Benchmark): + seed: int = 42 + + def setup(self): + batch_size, dim, seqlen, width = 2, 64, 128, 4 + self.x = torch.randn( + batch_size, dim, seqlen, device=self.device, dtype=torch.float16 + ) + self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32) + self.bias = torch.randn(dim, device=self.device, dtype=torch.float32) + self.out = torch.empty( + batch_size, dim, seqlen, device=self.device, dtype=torch.float16 + ) + self.dim = dim + self.width = width + self.seqlen = seqlen + + def benchmark_base(self): + self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias) + + def verify_base(self) -> torch.Tensor: + x_fp32 = self.x.to(self.weight.dtype) + out = F.conv1d( + x_fp32, + self.weight.unsqueeze(1), + self.bias, + padding=self.width - 1, + groups=self.dim, + ) + return out[..., : self.seqlen].to(self.x.dtype) + + def setup_large(self): + batch_size, dim, seqlen, width = 8, 256, 512, 4 + self.x = torch.randn( + batch_size, dim, seqlen, device=self.device, dtype=torch.float16 + ) + self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32) + self.bias = torch.randn(dim, device=self.device, dtype=torch.float32) + self.out = torch.empty( + batch_size, dim, seqlen, device=self.device, dtype=torch.float16 + ) + self.dim = dim + self.width = width + self.seqlen = seqlen + + def benchmark_large(self): + self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias) + + def verify_large(self) -> torch.Tensor: + x_fp32 = self.x.to(self.weight.dtype) + out = F.conv1d( + x_fp32, + self.weight.unsqueeze(1), + self.bias, + padding=self.width - 1, + groups=self.dim, + ) + return out[..., : self.seqlen].to(self.x.dtype) + + def setup_xlarge(self): + batch_size, dim, seqlen, width = 16, 512, 1024, 4 + self.x = torch.randn( + batch_size, dim, seqlen, device=self.device, dtype=torch.float16 + ) + self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32) + self.bias = torch.randn(dim, device=self.device, dtype=torch.float32) + self.out = torch.empty( + batch_size, dim, seqlen, device=self.device, dtype=torch.float16 + ) + self.dim = dim + self.width = width + self.seqlen = seqlen + + def benchmark_xlarge(self): + self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias) + + def verify_xlarge(self) -> torch.Tensor: + x_fp32 = self.x.to(self.weight.dtype) + out = F.conv1d( + x_fp32, + self.weight.unsqueeze(1), + self.bias, + padding=self.width - 1, + groups=self.dim, + ) + return out[..., : self.seqlen].to(self.x.dtype) diff --git a/build/torch210-cxx11-cu126-aarch64-linux/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..bfba996f987a4a91f20a5908b13fd9d83ac93b7b --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83b8ab4db3d387552329f75f775db33b59380b18ba2af057504ad810fab09295 +size 80857232 diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu126-aarch64-linux/cpp_functions.py b/build/torch210-cxx11-cu126-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu126-aarch64-linux/metadata.json b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..a1a1b4d30e2106991416a65da38c8dad70c8a7bb --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06c71255dcc14bbe4e00c85170d0f1dce0d6510e091a4237bc1c2e61368d47f2 +size 80694472 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch210-cxx11-cu128-aarch64-linux/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..9adbbb9e769872f8226f3bb3eae46e537e353869 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9bcd794221ea9d6cb4f2c3ec409e75f83cd955823a6406f693c50d77d1c4b28 +size 107312656 diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu128-aarch64-linux/cpp_functions.py b/build/torch210-cxx11-cu128-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu128-aarch64-linux/metadata.json b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..58e4cc8a7164d312d42fb7ab5a4cc5fec69e3a4c --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed02f049828da6a24af2b061c4f2a9f440b66ae258411071fe4575c0f577d5d4 +size 107169840 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..89f0bf651fb60fe51bf08fded9a291386aae7015 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b8c42b85e433d28b00dd22919b0abe1152b68b2701cf5de31a46d3d1fabd128 +size 64755512 diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu130-aarch64-linux/cpp_functions.py b/build/torch210-cxx11-cu130-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..3cd522d7fff5ef2f0a3d3661c17f440e1e4031ed --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e5ec24f997ea256acf15d45d41acd47e841d26f50dab276b6f8f3600247501e +size 64618472 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu126-aarch64-linux/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..7d6a25eecfefa85762ba0a837eb7c9e3cc82fc31 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54d057ef1f5e12e7715f8dfd2879190e4ae82c170ae0d412bc1696be8031d1ef +size 80857352 diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch211-cxx11-cu126-aarch64-linux/cpp_functions.py b/build/torch211-cxx11-cu126-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch211-cxx11-cu126-aarch64-linux/metadata.json b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch211-cxx11-cu126-x86_64-linux/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f5466087716c10109e093c4f6dc98cb1400ea837 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71ecf103b1b26f969ecd7734e196f24210bb5e2937f63b9f65e2e127fd5e8e5f +size 80694560 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch211-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch211-cxx11-cu126-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch211-cxx11-cu126-x86_64-linux/metadata.json b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch211-cxx11-cu128-aarch64-linux/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..df1664c45faf3611842b70b761f4ca490042413b --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69e30c7a01b7d693affe82ae3556b7d8428c095d86fc5dca1cf4014d3ceff43b +size 107312776 diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch211-cxx11-cu128-aarch64-linux/cpp_functions.py b/build/torch211-cxx11-cu128-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch211-cxx11-cu128-aarch64-linux/metadata.json b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d2b16fb61e2d09e8316f35076d9ed257c716c246 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67a94922da263147a2328b548a12d12bfe77654431dc234659eb45e3337cb948 +size 107169936 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch211-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch211-cxx11-cu128-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu130-aarch64-linux/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..83e375375beaba6eece156ddda39eb9f96151365 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5948aec7b42fd865dd38266aad64a34dd3ac7ce5c58c5f04babe9ee28bba1d5f +size 64755624 diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch211-cxx11-cu130-aarch64-linux/cpp_functions.py b/build/torch211-cxx11-cu130-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch211-cxx11-cu130-aarch64-linux/metadata.json b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..e96f3ef9130d3cc40a9be87e6f625d6bd408f801 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:075f1884806147ccb391ee3949e30fc6e8009a4bb02894f2c97b5819d0979c8d +size 64618568 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch211-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch211-cxx11-cu130-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cebc7817a292011587f4941dfff502f5e5c98cbd Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f17eb8838cb893f61b2f42a345df9f69eb7e8bf Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdd84db8007b53b4dd4368dc01ad9296d4b95830 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67e3eba29657209ffccf401d6e8c340a28057265 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f78fec2e6a76348b50183b154618fd095d197d13 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0a843146f753e99ec745a7e2deb9c3db543a3482 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c7b5ae8af9477be3049ba1ae6af3f9e2d8bf82979fa9e9632c485a8d49f532a +size 64503960 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..786bd24d264db90632165d514db4f5521d8133e0 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9d95dcbdfd288de6d7077fd18259eb3bddb4958 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..297eb3e2197477c1df29aaeaff32a0c4b4ba8611 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40dd2dec4c9da46e89fe46f6974c794ad6432a52 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6d47e6accd7dc9a8d27df0e15f3363d6579b752 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..770b6a40b54b2b7dc3ff89d20de6dae3cfd5a06a --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:299bc47bf7fdea21eb71f9b0d0cd329a32e792106029d4fd5d6c637c76b9c6f7 +size 64213568 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f5e5e7f7807ad1ef65b13252d2c019610e63ea1 Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc420da27d8204045a0413fb32b05b30f65d7ed4 Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06007695980ef777fd26f660043ecdb3d633b439 Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ade83fbdcf1c14aa347c29d3b12afa3833385854 Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7c1d92c27a92656d2b0fe8c5fd52fbe61f15d6f Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..40a950b0a1cf66eedb98e1fe4bd018c280a29005 --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ec3c41413afbb69d499eae6a432fa9d41a580e7b1c6ee83d09e8dab51f91803 +size 90795560 diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..2257797ae235d25abede0851de00a59f5220a87d --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_306ae84 +ops = torch.ops._causal_conv1d_306ae84 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_306ae84::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18d3139112b33e61419b6a2728a795c6a358861e Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2857339f18ec0febb5d24266e3d91fbbe2fa820c Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bef01edf5863f420831c4fd7b62444df502bf29e Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb24c381faa7b75b6519cd48d85261bec5d03f1d Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7283c7bf629a1f669d5802a0aecc629cfcda5eb0 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..bc680b2d853196f8878bb9cb5f6d73bce3cecb80 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:859247ab0b3e7852c4e1a6ac76f3c62b3aebea729241058075eb2e6f29139a50 +size 90656256 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_e7e5852.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_e7e5852.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..06feae7dcffaefead4119f47cf223bda523ef0b8 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_causal_conv1d_e7e5852.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8a4c3c1eb4c667ed0ef6affd83922fc4b76d96c491b26afb012bbb4e84ac245 +size 80684768 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..68c9592575922d5a8d400f767f6d5f31fa8dbcb3 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_e7e5852 +ops = torch.ops._causal_conv1d_e7e5852 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_e7e5852::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_e7e5852.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_e7e5852.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..75b81a3487cf14abd1be1a8f1db4cfe11db65fc8 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_causal_conv1d_e7e5852.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:488a09e6d74f4f4f8b6c0dfd26df892dab4e8bc2283a2e95be24c65ed043ec70 +size 107168432 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..68c9592575922d5a8d400f767f6d5f31fa8dbcb3 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_e7e5852 +ops = torch.ops._causal_conv1d_e7e5852 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_e7e5852::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd27c508caa49bc189bc90b3dadc231801ff39a7 Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7052f3156badccc6818d6e61442177ee29be2e61 Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4d41ce73daaeb59359e6dc777f3b5a596aa8ddd Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df49a782e7c05722c1b4f05c72b7a40ae91e057 Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8599cda250865c0f80bce799b2593833c99384c Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..e26a04f6ebfd8fa02bf23fc08727ce37bea5a617 --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c3535b795cbc5baf363b0cd8636649153b017599c9992ea6c08aa4ab23ceae0 +size 97678768 diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_ops.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..2257797ae235d25abede0851de00a59f5220a87d --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_306ae84 +ops = torch.ops._causal_conv1d_306ae84 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_306ae84::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/cpp_functions.py b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_e7e5852.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_e7e5852.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d50e82bfa5604d3b8a591901aed86e54a7e07afc --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_causal_conv1d_e7e5852.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbee5d58b825b18e0751347cc6ed27982623257ee037e3a9c3da47bee3dd8f53 +size 115140584 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..68c9592575922d5a8d400f767f6d5f31fa8dbcb3 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_e7e5852 +ops = torch.ops._causal_conv1d_e7e5852 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_e7e5852::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py b/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-aarch64-linux/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..83ab043e6f16f9edd3796ebdb1e3a4106a2b78e4 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:984a88b1b33598f95d2b2c6f19ace193be19f5933f11642bbf9cf8d8cecc9050 +size 80789912 diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_7579ac2 +ops = torch.ops._causal_conv1d_cuda_7579ac2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_7579ac2::{op_name}" diff --git a/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu126-aarch64-linux/cpp_functions.py b/build/torch29-cxx11-cu126-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu126-aarch64-linux/metadata.json b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..32d5cba202a2eab17d06c5e19bd6c429cc6c323f --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b648cc9f51b076e4cf8b6fd739012ae066a797f4948730273f5ece8adf950976 +size 80684872 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_7579ac2 +ops = torch.ops._causal_conv1d_cuda_7579ac2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_7579ac2::{op_name}" diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..0dacb99125f1112a811819ca1ffdde15c8c0faff --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch29-cxx11-cu128-aarch64-linux/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..83e2fea72221ac78ca56ec46f7a2a982f1e13c02 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21bfbb1a2c4685ef84881f72d5a63acb9f38840688fd8587f0be78ad31bb24df +size 107310800 diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_7579ac2 +ops = torch.ops._causal_conv1d_cuda_7579ac2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_7579ac2::{op_name}" diff --git a/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu128-aarch64-linux/cpp_functions.py b/build/torch29-cxx11-cu128-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu128-aarch64-linux/metadata.json b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8fedf1104c4d17fabcd905d1208d390fa3244f10 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e26af82de4f1fc452d1a1cba90cb154ce269f1f5d20bccb94ccc665e2c970e5d +size 107172632 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_7579ac2 +ops = torch.ops._causal_conv1d_cuda_7579ac2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_7579ac2::{op_name}" diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch29-cxx11-cu129-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..cf187ce7f99f214ff94fb55fa237bc0d0df0ee77 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc4012b62d077e5479ffa580abe50036943b3a2dcccec65055b1e13752be1d62 +size 115308008 diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu129-aarch64-linux/cpp_functions.py b/build/torch29-cxx11-cu129-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so b/build/torch29-cxx11-cu129-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..107538ba364830b8a0b1cb7c4c1e628c8abc775c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_causal_conv1d_cuda_6b83b83.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:800c4a26ae3ece97637afa81f9ae4e294e643d2259204b0be027e4d9cb82e147 +size 115140688 diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6796899661ef6f73609047ca344503d13ca050bd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_6b83b83 +ops = torch.ops._causal_conv1d_cuda_6b83b83 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_6b83b83::{op_name}" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu129-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu129-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a794c92436c3827ae79b48d55f7ea964afd50f52 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..822c94dee6c0812e147a0ababd144414e364d9b4 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_causal_conv1d_cuda_7579ac2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:081a03764e03c62943760b2eb9baf991dacafaf985bfc0374f58eeb86b89ffc7 +size 64753648 diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_7579ac2 +ops = torch.ops._causal_conv1d_cuda_7579ac2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_7579ac2::{op_name}" diff --git a/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu130-aarch64-linux/cpp_functions.py b/build/torch29-cxx11-cu130-aarch64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu130-aarch64-linux/metadata.json b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f0629040266fe8fb76e0bb910d7fe4ae47e95248 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_causal_conv1d_cuda_7579ac2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed0967161460ebedec3b7216fb284ceef4a290cfedf5b71368e7a71f2e289e65 +size 64613072 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..265c44512f222c8028a7141ca7bb227d24107b1a --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_cuda_7579ac2 +ops = torch.ops._causal_conv1d_cuda_7579ac2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_cuda_7579ac2::{op_name}" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py b/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..eff725542128e103dfb5df382d74940efff77214 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "BSD-3-Clause", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/media/benches_dark_animation.svg b/media/benches_dark_animation.svg new file mode 100644 index 0000000000000000000000000000000000000000..776391a87df7a62d9cdd3610975d70c842b0c86a --- /dev/null +++ b/media/benches_dark_animation.svg @@ -0,0 +1,42 @@ + +kernels-community/causal-conv1d vs Torch - Relative Speed +PyTorch 2.11.0+cu130 · CPU + +CausalConv1dBenchmark.base +1.60x + + + + + + + +CausalConv1dBenchmark.large +1.64x + + + + + + + +CausalConv1dBenchmark.xlarge +8.57x + + + + + + + +Kernel + +Torch (ref) + + + + + + + + \ No newline at end of file diff --git a/media/benches_dark_latency.svg b/media/benches_dark_latency.svg new file mode 100644 index 0000000000000000000000000000000000000000..8d4353f58c0bc50685e813ae21c9fbbfda0d38e4 --- /dev/null +++ b/media/benches_dark_latency.svg @@ -0,0 +1,2104 @@ + + + + + + + + 2026-03-25T23:51:47.559961 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/media/benches_dark_throughput.svg b/media/benches_dark_throughput.svg new file mode 100644 index 0000000000000000000000000000000000000000..8b57257a5035a7df5bfe9e2eb864a665e54d7a90 --- /dev/null +++ b/media/benches_dark_throughput.svg @@ -0,0 +1,2228 @@ + + + + + + + + 2026-03-25T23:51:47.724478 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/media/benches_light_animation.svg b/media/benches_light_animation.svg new file mode 100644 index 0000000000000000000000000000000000000000..1d344696328ce4024da15f12f714ec488b02cb61 --- /dev/null +++ b/media/benches_light_animation.svg @@ -0,0 +1,42 @@ + +kernels-community/causal-conv1d vs Torch - Relative Speed +PyTorch 2.11.0+cu130 · CPU + +CausalConv1dBenchmark.base +1.60x + + + + + + + +CausalConv1dBenchmark.large +1.64x + + + + + + + +CausalConv1dBenchmark.xlarge +8.57x + + + + + + + +Kernel + +Torch (ref) + + + + + + + + \ No newline at end of file diff --git a/media/benches_light_latency.svg b/media/benches_light_latency.svg new file mode 100644 index 0000000000000000000000000000000000000000..a72eab3d163ce8a87261c589cff9529b53247d31 --- /dev/null +++ b/media/benches_light_latency.svg @@ -0,0 +1,2104 @@ + + + + + + + + 2026-03-25T23:51:46.815109 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/media/benches_light_throughput.svg b/media/benches_light_throughput.svg new file mode 100644 index 0000000000000000000000000000000000000000..3d92b810d5b39ec7438bde650d00ab1479e6e31c --- /dev/null +++ b/media/benches_light_throughput.svg @@ -0,0 +1,2228 @@ + + + + + + + + 2026-03-25T23:51:47.289202 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +