| #pragma once |
|
|
| #include <torch/all.h> |
|
|
| #include <ATen/cuda/CUDAContext.h> |
| #include <c10/cuda/CUDAGuard.h> |
| #include <cuda.h> |
| #include <cuda_fp16.h> |
| #include <cuda_runtime.h> |
| #include <iostream> |
|
|
| #ifndef MARLIN_NAMESPACE_NAME |
| #define MARLIN_NAMESPACE_NAME marlin |
| #endif |
|
|
| namespace MARLIN_NAMESPACE_NAME { |
|
|
| |
|
|
| |
| |
| |
| static constexpr int default_threads = 256; |
|
|
| static constexpr int pipe_stages = |
| 4; |
|
|
| static constexpr int min_thread_n = 64; |
| static constexpr int min_thread_k = 64; |
| static constexpr int max_thread_n = 256; |
|
|
| static constexpr int tile_size = 16; |
| static constexpr int max_par = 16; |
|
|
| |
| static constexpr int repack_stages = 8; |
|
|
| static constexpr int repack_threads = 256; |
|
|
| static constexpr int tile_k_size = tile_size; |
| static constexpr int tile_n_size = tile_k_size * 4; |
|
|
| |
| template <typename T, int n> |
| struct Vec { |
| T elems[n]; |
| __device__ T& operator[](int i) { return elems[i]; } |
| }; |
|
|
| using I4 = Vec<int, 4>; |
|
|
| constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } |
|
|
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
| |
| #else |
|
|
| __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, |
| bool pred = true) { |
| const int BYTES = 16; |
| uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); |
| asm volatile( |
| "{\n" |
| " .reg .pred p;\n" |
| " setp.ne.b32 p, %0, 0;\n" |
| " @p cp.async.cg.shared.global [%1], [%2], %3;\n" |
| "}\n" ::"r"((int)pred), |
| "r"(smem), "l"(glob_ptr), "n"(BYTES)); |
| } |
|
|
| __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { |
| const int BYTES = 16; |
| uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); |
| asm volatile( |
| "{\n" |
| " cp.async.cg.shared.global [%0], [%1], %2;\n" |
| "}\n" ::"r"(smem), |
| "l"(glob_ptr), "n"(BYTES)); |
| } |
|
|
| __device__ inline void cp_async_fence() { |
| asm volatile("cp.async.commit_group;\n" ::); |
| } |
|
|
| template <int n> |
| __device__ inline void cp_async_wait() { |
| asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); |
| } |
|
|
| #endif |
|
|
| } |
|
|