| #define CUB_IGNORE_DEPRECATED_API |
|
|
| #undef CUB_WRAPPED_NAMESPACE |
| #define CUB_WRAPPED_NAMESPACE megablocks |
|
|
| #include <cstdint> |
|
|
| #include <cub/cub.cuh> |
| #include <c10/cuda/CUDAStream.h> |
| #include <torch/all.h> |
| |
|
|
| #define CUDA_CALL(code) \ |
| do { \ |
| cudaError_t status = code; \ |
| std::string err = cudaGetErrorString(status); \ |
| TORCH_CHECK(status == cudaSuccess, err); \ |
| } while (0) |
|
|
| namespace megablocks { |
|
|
| struct Inclusive {}; |
| struct Exclusive {}; |
|
|
| template <typename Type> struct Cumsum { |
|
|
| template< |
| typename InputIteratorT, |
| typename OutputIteratorT> |
| static void Run(void * d_temp_storage, |
| size_t & temp_storage_bytes, |
| InputIteratorT d_in, |
| OutputIteratorT d_out, |
| int num_items, |
| cudaStream_t stream = 0, |
| bool debug_synchronous = false) { |
| CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, |
| temp_storage_bytes, |
| d_in, |
| d_out, |
| num_items, |
| stream)); |
| |
| } |
| }; |
|
|
| template <> struct Cumsum<Inclusive> { |
| template< |
| typename InputIteratorT, |
| typename OutputIteratorT> |
| static void Run(void * d_temp_storage, |
| size_t & temp_storage_bytes, |
| InputIteratorT d_in, |
| OutputIteratorT d_out, |
| int num_items, |
| cudaStream_t stream = 0, |
| bool debug_synchronous = false) { |
| CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage, |
| temp_storage_bytes, |
| d_in, |
| d_out, |
| num_items, |
| stream)); |
| |
| } |
| }; |
|
|
| template <typename SumType, typename T> |
| void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) { |
| |
| size_t scratchpad_bytes = 0; |
| Cumsum<SumType>::Run(nullptr, |
| scratchpad_bytes, |
| x.data_ptr<T>(), |
| out.data_ptr<T>(), |
| x.size(1), |
| c10::cuda::getCurrentCUDAStream()); |
|
|
| |
| |
| |
| auto options = torch::TensorOptions() |
| .dtype(torch::kInt8) |
| .device(x.device()); |
| torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0), |
| options); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| for (int i = 0; i < x.size(0); ++i) { |
| void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i; |
| Cumsum<SumType>::Run(scratchpad_ptr, |
| scratchpad_bytes, |
| x.data_ptr<T>() + x.size(1) * i, |
| out.data_ptr<T>() + x.size(1) * i, |
| x.size(1), |
| c10::cuda::getCurrentCUDAStream()); |
| } |
| } |
|
|
| void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { |
| |
| TORCH_CHECK(x.is_cuda()); |
| TORCH_CHECK(x.ndimension() == 2); |
| TORCH_CHECK(x.scalar_type() == torch::kInt16 || |
| x.scalar_type() == torch::kInt32 || |
| x.scalar_type() == torch::kInt64); |
| TORCH_CHECK(out.is_cuda()); |
| TORCH_CHECK(out.ndimension() == 2); |
| TORCH_CHECK(out.scalar_type() == x.scalar_type()); |
|
|
| |
| |
| TORCH_CHECK(dim == 1); |
|
|
| switch (x.scalar_type()) { |
| case torch::kInt16: |
| cub_cumsum<Exclusive, short>(x, dim, out); |
| return; |
| case torch::kInt32: |
| cub_cumsum<Exclusive, int>(x, dim, out); |
| return; |
| } |
| TORCH_CHECK(x.scalar_type() == torch::kInt64); |
| cub_cumsum<Exclusive, long>(x, dim, out); |
| } |
|
|
| void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { |
| |
| TORCH_CHECK(x.is_cuda()); |
| TORCH_CHECK(x.ndimension() == 2); |
| TORCH_CHECK(x.scalar_type() == torch::kInt16 || |
| x.scalar_type() == torch::kInt32 || |
| x.scalar_type() == torch::kInt64); |
| TORCH_CHECK(out.is_cuda()); |
| TORCH_CHECK(out.ndimension() == 2); |
| TORCH_CHECK(out.scalar_type() == x.scalar_type()); |
|
|
| |
| |
| TORCH_CHECK(dim == 1); |
|
|
| switch (x.scalar_type()) { |
| case torch::kInt16: |
| cub_cumsum<Inclusive, short>(x, dim, out); |
| return; |
| case torch::kInt32: |
| cub_cumsum<Inclusive, int>(x, dim, out); |
| return; |
| } |
| TORCH_CHECK(x.scalar_type() == torch::kInt64); |
| cub_cumsum<Inclusive, long>(x, dim, out); |
| } |
|
|
| } |
|
|
| #undef CUB_WRAPPED_NAMESPACE |