| #pragma once |
|
|
| #include "scaled_mm_c2x.cuh" |
| #include "cutlass/float8.h" |
|
|
| |
| |
| |
| |
|
|
| namespace vllm { |
|
|
| template <typename InType, typename OutType, |
| template <typename, typename> typename Epilogue> |
| struct sm89_fp8_fallback_gemm { |
| |
| static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>; |
| using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; |
| using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd; |
| using Cutlass2xGemm = |
| cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType, |
| Epilogue, TileShape, WarpShape, InstructionShape, 5, |
| FP8MathOperator>; |
| }; |
|
|
| struct sm89_fp8_config_default { |
| |
| using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; |
| using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; |
|
|
| template <typename InType, typename OutType, |
| template <typename, typename> typename Epilogue, |
| typename... EpilogueArgs> |
| static void dispatch(torch::Tensor& out, torch::Tensor const& a, |
| torch::Tensor const& b, EpilogueArgs&&... args) { |
| static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
|
| using FallbackGemm = |
| typename sm89_fp8_fallback_gemm<InType, OutType, |
| Epilogue>::Cutlass2xGemm; |
|
|
| uint32_t const n = out.size(1); |
| uint32_t const np2 = next_pow_2(n); |
|
|
| if (np2 <= 4096) { |
| using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 5, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else if (np2 <= 8192) { |
| using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 3, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
|
|
| } else { |
| using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 5, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } |
| } |
| }; |
|
|
| struct sm89_fp8_config_M256 { |
| |
| using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; |
| using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; |
|
|
| template <typename InType, typename OutType, |
| template <typename, typename> typename Epilogue, |
| typename... EpilogueArgs> |
| static void dispatch(torch::Tensor& out, torch::Tensor const& a, |
| torch::Tensor const& b, EpilogueArgs&&... args) { |
| static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
|
| using FallbackGemm = |
| typename sm89_fp8_fallback_gemm<InType, OutType, |
| Epilogue>::Cutlass2xGemm; |
|
|
| uint32_t const n = out.size(1); |
| uint32_t const np2 = next_pow_2(n); |
|
|
| if (np2 <= 4096) { |
| using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 3, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else { |
| using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 5, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } |
| } |
| }; |
|
|
| struct sm89_fp8_config_M128 { |
| |
| using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; |
| using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; |
|
|
| template <typename InType, typename OutType, |
| template <typename, typename> typename Epilogue, |
| typename... EpilogueArgs> |
| static void dispatch(torch::Tensor& out, torch::Tensor const& a, |
| torch::Tensor const& b, EpilogueArgs&&... args) { |
| static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
|
| using FallbackGemm = |
| typename sm89_fp8_fallback_gemm<InType, OutType, |
| Epilogue>::Cutlass2xGemm; |
|
|
| uint32_t const n = out.size(1); |
| uint32_t const np2 = next_pow_2(n); |
|
|
| if (np2 <= 8192) { |
| using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 3, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
|
|
| } else if (np2 <= 16384) { |
| using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 5, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else { |
| using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 3, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } |
| } |
| }; |
|
|
| struct sm89_fp8_config_M64 { |
| |
| using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
|
|
| template <typename InType, typename OutType, |
| template <typename, typename> typename Epilogue, |
| typename... EpilogueArgs> |
| static void dispatch(torch::Tensor& out, torch::Tensor const& a, |
| torch::Tensor const& b, EpilogueArgs&&... args) { |
| static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
|
| using FallbackGemm = |
| typename sm89_fp8_fallback_gemm<InType, OutType, |
| Epilogue>::Cutlass2xGemm; |
|
|
| uint32_t const n = out.size(1); |
| uint32_t const np2 = next_pow_2(n); |
|
|
| if (np2 <= 8196) { |
| using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>; |
| using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 5, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else if (np2 <= 16384) { |
| using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; |
| using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 3, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else { |
| using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>; |
| using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 5, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } |
| } |
| }; |
|
|
| struct sm89_fp8_config_M32 { |
| |
| using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; |
|
|
| template <typename InType, typename OutType, |
| template <typename, typename> typename Epilogue, |
| typename... EpilogueArgs> |
| static void dispatch(torch::Tensor& out, torch::Tensor const& a, |
| torch::Tensor const& b, EpilogueArgs&&... args) { |
| static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
|
| using FallbackGemm = |
| typename sm89_fp8_fallback_gemm<InType, OutType, |
| Epilogue>::Cutlass2xGemm; |
|
|
| uint32_t const n = out.size(1); |
| uint32_t const np2 = next_pow_2(n); |
|
|
| if (np2 <= 8192) { |
| using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; |
| using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 5, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else if (np2 <= 16384) { |
| using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>; |
| using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 4, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else { |
| using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; |
| using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, 5, FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } |
| } |
| }; |
|
|
| struct sm89_fp8_config_M16 { |
| |
| using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; |
| using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; |
| static const int32_t MainLoopStages = 5; |
|
|
| template <typename InType, typename OutType, |
| template <typename, typename> typename Epilogue, |
| typename... EpilogueArgs> |
| static void dispatch(torch::Tensor& out, torch::Tensor const& a, |
| torch::Tensor const& b, EpilogueArgs&&... args) { |
| static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
|
| using FallbackGemm = |
| typename sm89_fp8_fallback_gemm<InType, OutType, |
| Epilogue>::Cutlass2xGemm; |
|
|
| uint32_t const n = out.size(1); |
| uint32_t const np2 = next_pow_2(n); |
|
|
| if (np2 <= 8192) { |
| using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, MainLoopStages, |
| FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else if (np2 <= 24576) { |
| using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, MainLoopStages, |
| FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else { |
| using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; |
|
|
| return vllm::fallback_cutlass_gemm_caller< |
| vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, |
| InType, OutType, Epilogue, TileShape, WarpShape, |
| InstructionShape, MainLoopStages, |
| FP8MathOperator>, |
| FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); |
| } |
| } |
| }; |
|
|
| template <typename InType, typename OutType, |
| template <typename, typename> typename Epilogue, |
| typename... EpilogueArgs> |
| inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, |
| torch::Tensor const& a, |
| torch::Tensor const& b, |
| EpilogueArgs&&... args) { |
| static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
| TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); |
|
|
| uint32_t const m = a.size(0); |
| uint32_t const mp2 = |
| std::max(static_cast<uint32_t>(16), next_pow_2(m)); |
|
|
| if (mp2 <= 16) { |
| |
| return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>( |
| out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else if (mp2 <= 32) { |
| |
| return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>( |
| out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else if (mp2 <= 64) { |
| |
| return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>( |
| out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else if (mp2 <= 128) { |
| |
| return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>( |
| out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else if (mp2 <= 256) { |
| |
| return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>( |
| out, a, b, std::forward<EpilogueArgs>(args)...); |
| } else { |
| |
| return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>( |
| out, a, b, std::forward<EpilogueArgs>(args)...); |
| } |
| } |
|
|
| } |
|
|