| |
|
|
| #pragma once |
|
|
| #include <assert.h> |
| #include <stdint.h> |
| #include <stdlib.h> |
|
|
| #include <cuda_bf16.h> |
|
|
| #include <cute/tensor.hpp> |
|
|
| #include <cutlass/array.h> |
| #include <cutlass/cutlass.h> |
| #include <cutlass/numeric_conversion.h> |
| #include <cutlass/numeric_types.h> |
|
|
| |
|
|
| namespace flash { |
|
|
| |
|
|
| template<typename T> |
| struct MaxOp { |
| __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } |
| }; |
|
|
| template <> |
| struct MaxOp<float> { |
| |
| __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } |
| }; |
|
|
| |
|
|
| template<typename T> |
| struct SumOp { |
| __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } |
| }; |
|
|
| |
|
|
| template<int THREADS> |
| struct Allreduce { |
| static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); |
| template<typename T, typename Operator> |
| static __device__ __forceinline__ T run(T x, Operator &op) { |
| constexpr int OFFSET = THREADS / 2; |
| x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); |
| return Allreduce<OFFSET>::run(x, op); |
| } |
| }; |
|
|
| |
|
|
| template<> |
| struct Allreduce<2> { |
| template<typename T, typename Operator> |
| static __device__ __forceinline__ T run(T x, Operator &op) { |
| x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); |
| return x; |
| } |
| }; |
|
|
| |
|
|
| template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma> |
| __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { |
| constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value; |
| |
| if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); } |
| warpgroup_fence_operand(tCrC); |
| if constexpr (arrive) { |
| warpgroup_arrive(); |
| } |
| if constexpr (zero_init) { |
| tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; |
| |
| CUTLASS_PRAGMA_UNROLL |
| for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { |
| cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); |
| tiled_mma.accumulate_ = GMMA::ScaleOut::One; |
| } |
| } else { |
| |
| |
| CUTLASS_PRAGMA_UNROLL |
| for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { |
| cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); |
| tiled_mma.accumulate_ = GMMA::ScaleOut::One; |
| } |
| } |
| if constexpr (commit) { |
| warpgroup_commit_batch(); |
| } |
| if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); } |
| warpgroup_fence_operand(tCrC); |
| if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); } |
| } |
|
|
| |
|
|
| |
| |
| template<bool Transposed=false, typename Layout0> |
| __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { |
| if constexpr (decltype(rank<0>(acc_layout))::value == 3) { |
| static_assert(decltype(size<0, 0>(acc_layout))::value == 2); |
| static_assert(decltype(size<0, 1>(acc_layout))::value == 2); |
| static_assert(decltype(rank(acc_layout))::value == 3); |
| auto l = acc_layout; |
| if constexpr (!Transposed) { |
| return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); |
| } else { |
| return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); |
| } |
|
|
| } else { |
| static_assert(decltype(size<0>(acc_layout))::value == 4); |
| static_assert(decltype(rank(acc_layout))::value == 3); |
| auto l = logical_divide(acc_layout, Shape<_2>{}); |
| if constexpr (!Transposed) { |
| return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); |
| } else { |
| return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); |
| } |
| } |
| }; |
|
|
| |
|
|
| |
| |
| |
| |
| template<typename MMA_Traits, typename Layout0> |
| __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { |
| using X = Underscore; |
| if constexpr (decltype(rank<0>(acc_layout))::value == 3) { |
| static_assert(decltype(size<0, 0>(acc_layout))::value == 2); |
| static_assert(decltype(size<0, 1>(acc_layout))::value == 2); |
| static_assert(decltype(rank(acc_layout))::value == 3); |
| static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); |
| if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { |
| auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); |
| return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); |
| } else { |
| static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); |
| static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); |
| static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); |
| auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); |
| |
| |
| return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), |
| get<1>(acc_layout), |
| coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); |
| |
| |
| |
| |
| } |
| } else { |
| static_assert(decltype(size<0>(acc_layout))::value == 4); |
| static_assert(decltype(rank(acc_layout))::value == 3); |
| constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); |
| static_assert(mma_shape_K == 8 || mma_shape_K == 16); |
| if constexpr (mma_shape_K == 8) { |
| return acc_layout; |
| } else { |
| auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); |
| return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); |
| } |
| } |
| }; |
|
|
| |
|
|
| template <typename To_type, typename Engine, typename Layout> |
| __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) { |
| using From_type = typename Engine::value_type; |
| constexpr int numel = decltype(size(tensor))::value; |
| cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; |
| |
| auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); |
| return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout()); |
| } |
|
|
| |
|
|
| |
| |
| |
| |
| |
| template <int N> |
| CUTE_HOST_DEVICE |
| void cp_async_wait() { |
| #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) |
| asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); |
| #endif |
| } |
|
|
| |
|
|
| template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, |
| typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, |
| typename Engine2, typename Layout2, typename Engine3, typename Layout3> |
| __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S, |
| Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, |
| Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) { |
| CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); |
| CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); |
| CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); |
| CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); |
| CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); |
| |
| static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); |
| #pragma unroll |
| for (int m = 0; m < size<1>(S); ++m) { |
| if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { |
| #pragma unroll |
| for (int k = 0; k < size<2>(S); ++k) { |
| if (Is_even_K || predicate_K(k)) { |
| cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); |
| } else if (Clear_OOB_K) { |
| cute::clear(D(_, m, k)); |
| } |
| } |
| } else if (Clear_OOB_MN) { |
| cute::clear(D(_, m, _)); |
| } |
| } |
| } |
|
|
| |
|
|
| } |
|
|