Build uploaded using `kernels` (batch 10/10).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +985 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h +610 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h +684 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h +250 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h +452 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h +70 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h +265 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_generic_with_scaling.h +325 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h +69 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h +231 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h +75 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h +236 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h +572 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h +543 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h +301 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h +70 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h +69 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp +253 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h +234 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/reduction_op.h +97 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/scale_type.h +66 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h +255 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h +264 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h +74 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h +241 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +443 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +904 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h +175 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h +337 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_absmax.h +126 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h +376 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h +177 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h +165 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h +127 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h +208 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h +228 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h +113 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h +142 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue.h +548 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h +234 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h +197 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h +335 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h +347 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h +206 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h +401 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h +224 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h +443 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h +513 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h +922 -0
- build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +1717 -0
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h
ADDED
|
@@ -0,0 +1,985 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Functor performing linear combination operations used by epilogues.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/platform/platform.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 46 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace epilogue {
|
| 52 |
+
namespace thread {
|
| 53 |
+
|
| 54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
namespace detail {
|
| 57 |
+
|
| 58 |
+
struct EmptyArguments {};
|
| 59 |
+
|
| 60 |
+
template<class T, class = void>
|
| 61 |
+
struct ElementwiseOpDispatcher {
|
| 62 |
+
using Arguments = EmptyArguments;
|
| 63 |
+
|
| 64 |
+
T op;
|
| 65 |
+
|
| 66 |
+
CUTLASS_HOST_DEVICE
|
| 67 |
+
ElementwiseOpDispatcher(Arguments) {}
|
| 68 |
+
|
| 69 |
+
template <typename ValueType>
|
| 70 |
+
CUTLASS_HOST_DEVICE
|
| 71 |
+
ValueType operator()(ValueType value) {
|
| 72 |
+
return op(value);
|
| 73 |
+
}
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
template<class T>
|
| 77 |
+
struct ElementwiseOpDispatcher<T, std::void_t<typename T::Arguments>> {
|
| 78 |
+
using Arguments = typename T::Arguments;
|
| 79 |
+
|
| 80 |
+
Arguments args;
|
| 81 |
+
T op;
|
| 82 |
+
|
| 83 |
+
CUTLASS_HOST_DEVICE
|
| 84 |
+
ElementwiseOpDispatcher(Arguments args_):args(args_) {}
|
| 85 |
+
|
| 86 |
+
template <typename ValueType>
|
| 87 |
+
CUTLASS_HOST_DEVICE
|
| 88 |
+
ValueType operator()(ValueType value) {
|
| 89 |
+
return op(value, args);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 96 |
+
|
| 97 |
+
/// This base class is meant to define the concept required of the
|
| 98 |
+
/// EpilogueWithBroadcast::OutputOp
|
| 99 |
+
template <
|
| 100 |
+
typename ElementC_,
|
| 101 |
+
typename ElementAccumulator_,
|
| 102 |
+
typename ElementCompute_,
|
| 103 |
+
typename ElementZ_,
|
| 104 |
+
typename ElementT_,
|
| 105 |
+
int ElementsPerAccess,
|
| 106 |
+
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
| 107 |
+
typename BinaryOp_ = plus<ElementCompute_>,
|
| 108 |
+
bool StoreT_ = true,
|
| 109 |
+
typename ElementVector_ = ElementC_
|
| 110 |
+
>
|
| 111 |
+
class LinearCombinationBiasElementwise {
|
| 112 |
+
public:
|
| 113 |
+
|
| 114 |
+
using ElementOutput = ElementC_;
|
| 115 |
+
using ElementD = ElementOutput;
|
| 116 |
+
using ElementC = ElementC_;
|
| 117 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 118 |
+
using ElementCompute = ElementCompute_;
|
| 119 |
+
using ElementScalar = ElementCompute;
|
| 120 |
+
using ElementZ = ElementZ_;
|
| 121 |
+
using ElementT = ElementT_;
|
| 122 |
+
using ElementVector = ElementVector_;
|
| 123 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 124 |
+
static int const kCount = kElementsPerAccess;
|
| 125 |
+
|
| 126 |
+
/// Follow cutlass3x EVT aliases
|
| 127 |
+
static bool const IsEltActSupported = true;
|
| 128 |
+
|
| 129 |
+
using ElementwiseOp = ElementwiseOp_;
|
| 130 |
+
using BinaryOp = BinaryOp_;
|
| 131 |
+
|
| 132 |
+
using ElementwiseOpDispatcher = detail::ElementwiseOpDispatcher<ElementwiseOp>;
|
| 133 |
+
using ElementwiseArguments = typename ElementwiseOpDispatcher::Arguments;
|
| 134 |
+
|
| 135 |
+
// Indicates that this epilogue applies only one binary operation
|
| 136 |
+
static bool const kIsSingleSource = true;
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
| 140 |
+
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
| 141 |
+
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
| 142 |
+
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
| 143 |
+
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
| 144 |
+
|
| 145 |
+
// Definitions needed for collective epilogue
|
| 146 |
+
using FragmentSource = FragmentC;
|
| 147 |
+
using FragmentOutput = FragmentZ;
|
| 148 |
+
using ElementBias = ElementVector;
|
| 149 |
+
using FragmentBias = Array<ElementBias, kElementsPerAccess>;
|
| 150 |
+
using ActivationFn = ElementwiseOp;
|
| 151 |
+
static const ScaleType::Kind kScale = ScaleType::Default;
|
| 152 |
+
|
| 153 |
+
static bool const kIsHeavy = kIsHeavy_member_or_false<ElementwiseOp>::value;
|
| 154 |
+
|
| 155 |
+
/// If true, the 'Z' tensor is stored
|
| 156 |
+
static bool const kStoreZ = true;
|
| 157 |
+
|
| 158 |
+
/// If true, the 'T' tensor is stored
|
| 159 |
+
static bool const kStoreT = StoreT_;
|
| 160 |
+
|
| 161 |
+
/// Host-constructable parameters structure
|
| 162 |
+
struct Params {
|
| 163 |
+
|
| 164 |
+
ElementCompute alpha; ///< scales accumulators
|
| 165 |
+
ElementCompute beta; ///< scales source tensor
|
| 166 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 167 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 168 |
+
ElementwiseArguments elementwise; ///< Arguments for elementwise operation
|
| 169 |
+
|
| 170 |
+
//
|
| 171 |
+
// Methods
|
| 172 |
+
//
|
| 173 |
+
|
| 174 |
+
CUTLASS_HOST_DEVICE
|
| 175 |
+
Params():
|
| 176 |
+
alpha(ElementCompute(1)),
|
| 177 |
+
beta(ElementCompute(0)),
|
| 178 |
+
alpha_ptr(nullptr),
|
| 179 |
+
beta_ptr(nullptr) { }
|
| 180 |
+
|
| 181 |
+
CUTLASS_HOST_DEVICE
|
| 182 |
+
Params(
|
| 183 |
+
ElementCompute alpha,
|
| 184 |
+
ElementCompute beta,
|
| 185 |
+
ElementwiseArguments elementwise_ = ElementwiseArguments{}
|
| 186 |
+
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr), elementwise(elementwise_) {
|
| 187 |
+
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
CUTLASS_HOST_DEVICE
|
| 191 |
+
Params(
|
| 192 |
+
ElementCompute alpha
|
| 193 |
+
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 194 |
+
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
CUTLASS_HOST_DEVICE
|
| 198 |
+
Params(
|
| 199 |
+
ElementCompute const *alpha_ptr,
|
| 200 |
+
ElementCompute const *beta_ptr,
|
| 201 |
+
ElementwiseArguments elementwise_ = ElementwiseArguments{}
|
| 202 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), elementwise(elementwise_) {
|
| 203 |
+
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
CUTLASS_HOST_DEVICE
|
| 207 |
+
Params(
|
| 208 |
+
ElementCompute const *alpha_ptr
|
| 209 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
|
| 210 |
+
|
| 211 |
+
}
|
| 212 |
+
};
|
| 213 |
+
|
| 214 |
+
private:
|
| 215 |
+
|
| 216 |
+
//
|
| 217 |
+
// Data members
|
| 218 |
+
//
|
| 219 |
+
|
| 220 |
+
ElementCompute alpha_;
|
| 221 |
+
ElementCompute beta_;
|
| 222 |
+
ElementwiseArguments const &elementwise_;
|
| 223 |
+
bool skip_elementwise_;
|
| 224 |
+
|
| 225 |
+
public:
|
| 226 |
+
|
| 227 |
+
//
|
| 228 |
+
// Methods
|
| 229 |
+
//
|
| 230 |
+
|
| 231 |
+
/// Constructor from Params
|
| 232 |
+
CUTLASS_HOST_DEVICE
|
| 233 |
+
LinearCombinationBiasElementwise(Params const ¶ms): elementwise_(params.elementwise) {
|
| 234 |
+
|
| 235 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 236 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 237 |
+
skip_elementwise_ = false;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/// Returns true if source is needed
|
| 241 |
+
CUTLASS_HOST_DEVICE
|
| 242 |
+
bool is_source_needed() const {
|
| 243 |
+
return beta_ != ElementCompute(0);
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
/// Functionally required for serial reduction in the epilogue
|
| 247 |
+
CUTLASS_HOST_DEVICE
|
| 248 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 249 |
+
if (k_partition) {
|
| 250 |
+
beta_ = ElementCompute(1);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
if (k_partition != k_partition_count - 1) {
|
| 254 |
+
skip_elementwise_ = true;
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is true
|
| 259 |
+
template <typename ElementwiseArgs>
|
| 260 |
+
CUTLASS_HOST_DEVICE
|
| 261 |
+
void operator()(
|
| 262 |
+
FragmentZ &frag_Z,
|
| 263 |
+
FragmentT &frag_T,
|
| 264 |
+
FragmentAccumulator const &AB,
|
| 265 |
+
FragmentC const &frag_C,
|
| 266 |
+
FragmentCompute const &V,
|
| 267 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 268 |
+
|
| 269 |
+
ElementwiseOp elementwise_op;
|
| 270 |
+
BinaryOp binary_op;
|
| 271 |
+
|
| 272 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 273 |
+
FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
|
| 274 |
+
FragmentCompute result_Z;
|
| 275 |
+
FragmentCompute result_T;
|
| 276 |
+
|
| 277 |
+
CUTLASS_PRAGMA_UNROLL
|
| 278 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 279 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
|
| 280 |
+
result_T[i] = z;
|
| 281 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 285 |
+
frag_Z = convert_z(result_Z);
|
| 286 |
+
|
| 287 |
+
if constexpr (kStoreT) {
|
| 288 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 289 |
+
frag_T = convert_t(result_T);
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is false
|
| 294 |
+
template <typename ElementwiseArgs>
|
| 295 |
+
CUTLASS_HOST_DEVICE
|
| 296 |
+
void operator()(
|
| 297 |
+
FragmentZ &frag_Z,
|
| 298 |
+
FragmentT &frag_T,
|
| 299 |
+
FragmentAccumulator const &AB,
|
| 300 |
+
FragmentCompute const &V,
|
| 301 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 302 |
+
|
| 303 |
+
ElementwiseOp elementwise_op;
|
| 304 |
+
BinaryOp binary_op;
|
| 305 |
+
|
| 306 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 307 |
+
FragmentCompute result_Z;
|
| 308 |
+
FragmentCompute result_T;
|
| 309 |
+
|
| 310 |
+
CUTLASS_PRAGMA_UNROLL
|
| 311 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 312 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
|
| 313 |
+
result_T[i] = z;
|
| 314 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 318 |
+
frag_Z = convert_z(result_Z);
|
| 319 |
+
|
| 320 |
+
if constexpr (kStoreT) {
|
| 321 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 322 |
+
frag_T = convert_t(result_T);
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
/// Applies the operation when is_source_needed() is true
|
| 327 |
+
CUTLASS_HOST_DEVICE
|
| 328 |
+
void operator()(
|
| 329 |
+
FragmentZ &frag_Z,
|
| 330 |
+
FragmentT &frag_T,
|
| 331 |
+
FragmentAccumulator const &AB,
|
| 332 |
+
FragmentC const &frag_C,
|
| 333 |
+
FragmentCompute const &V) const {
|
| 334 |
+
|
| 335 |
+
ElementwiseOpDispatcher elementwise_op(elementwise_);
|
| 336 |
+
BinaryOp binary_op;
|
| 337 |
+
|
| 338 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 339 |
+
FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
|
| 340 |
+
FragmentCompute result_Z;
|
| 341 |
+
FragmentCompute result_T;
|
| 342 |
+
|
| 343 |
+
CUTLASS_PRAGMA_UNROLL
|
| 344 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 345 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
|
| 346 |
+
result_T[i] = z;
|
| 347 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 351 |
+
frag_Z = convert_z(result_Z);
|
| 352 |
+
|
| 353 |
+
if constexpr (kStoreT) {
|
| 354 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 355 |
+
frag_T = convert_t(result_T);
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
/// Applies the operation when is_source_needed() is false
|
| 360 |
+
CUTLASS_HOST_DEVICE
|
| 361 |
+
void operator()(
|
| 362 |
+
FragmentZ &frag_Z,
|
| 363 |
+
FragmentT &frag_T,
|
| 364 |
+
FragmentAccumulator const &AB,
|
| 365 |
+
FragmentCompute const &V) const {
|
| 366 |
+
|
| 367 |
+
ElementwiseOpDispatcher elementwise_op(elementwise_);
|
| 368 |
+
BinaryOp binary_op;
|
| 369 |
+
|
| 370 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 371 |
+
FragmentCompute result_Z;
|
| 372 |
+
FragmentCompute result_T;
|
| 373 |
+
|
| 374 |
+
CUTLASS_PRAGMA_UNROLL
|
| 375 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 376 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
|
| 377 |
+
result_T[i] = z;
|
| 378 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 382 |
+
frag_Z = convert_z(result_Z);
|
| 383 |
+
|
| 384 |
+
if constexpr (kStoreT) {
|
| 385 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 386 |
+
frag_T = convert_t(result_T);
|
| 387 |
+
}
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is true
|
| 391 |
+
template <typename ElementwiseArgs>
|
| 392 |
+
CUTLASS_HOST_DEVICE
|
| 393 |
+
void operator()(
|
| 394 |
+
ElementZ &Z,
|
| 395 |
+
ElementT &T,
|
| 396 |
+
ElementAccumulator const &AB,
|
| 397 |
+
ElementC const &C,
|
| 398 |
+
ElementCompute const &V,
|
| 399 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 400 |
+
|
| 401 |
+
ElementwiseOp elementwise_op;
|
| 402 |
+
BinaryOp binary_op;
|
| 403 |
+
|
| 404 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 405 |
+
ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
|
| 406 |
+
|
| 407 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum + beta_ * tmp_C, V);
|
| 408 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 409 |
+
|
| 410 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 411 |
+
Z = convert_z(result_Z);
|
| 412 |
+
|
| 413 |
+
if constexpr (kStoreT) {
|
| 414 |
+
ElementCompute result_T = z;
|
| 415 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 416 |
+
T = convert_t(result_T);
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is false
|
| 421 |
+
template <typename ElementwiseArgs>
|
| 422 |
+
CUTLASS_HOST_DEVICE
|
| 423 |
+
void operator()(
|
| 424 |
+
ElementZ &Z,
|
| 425 |
+
ElementT &T,
|
| 426 |
+
ElementAccumulator const &AB,
|
| 427 |
+
ElementCompute const &V,
|
| 428 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 429 |
+
|
| 430 |
+
ElementwiseOp elementwise_op;
|
| 431 |
+
BinaryOp binary_op;
|
| 432 |
+
|
| 433 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 434 |
+
|
| 435 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum, V);
|
| 436 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 437 |
+
|
| 438 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 439 |
+
Z = convert_z(result_Z);
|
| 440 |
+
|
| 441 |
+
if constexpr (kStoreT) {
|
| 442 |
+
ElementCompute result_T = z;
|
| 443 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 444 |
+
T = convert_t(result_T);
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
/// Applies the operation when is_source_needed() is true
|
| 449 |
+
CUTLASS_HOST_DEVICE
|
| 450 |
+
void operator()(
|
| 451 |
+
ElementZ &Z,
|
| 452 |
+
ElementT &T,
|
| 453 |
+
ElementAccumulator const &AB,
|
| 454 |
+
ElementC const &C,
|
| 455 |
+
ElementCompute const &V) const {
|
| 456 |
+
|
| 457 |
+
ElementwiseOpDispatcher elementwise_op(elementwise_);
|
| 458 |
+
BinaryOp binary_op;
|
| 459 |
+
|
| 460 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 461 |
+
ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
|
| 462 |
+
|
| 463 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum + beta_ * tmp_C, V);
|
| 464 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z);
|
| 465 |
+
|
| 466 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 467 |
+
Z = convert_z(result_Z);
|
| 468 |
+
|
| 469 |
+
if constexpr (kStoreT) {
|
| 470 |
+
ElementCompute result_T = z;
|
| 471 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 472 |
+
T = convert_t(result_T);
|
| 473 |
+
}
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
/// Applies the operation when is_source_needed() is false
|
| 477 |
+
CUTLASS_HOST_DEVICE
|
| 478 |
+
void operator()(
|
| 479 |
+
ElementZ &Z,
|
| 480 |
+
ElementT &T,
|
| 481 |
+
ElementAccumulator const &AB,
|
| 482 |
+
ElementCompute const &V) const {
|
| 483 |
+
|
| 484 |
+
ElementwiseOpDispatcher elementwise_op(elementwise_);
|
| 485 |
+
BinaryOp binary_op;
|
| 486 |
+
|
| 487 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 488 |
+
|
| 489 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum, V);
|
| 490 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z);
|
| 491 |
+
|
| 492 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 493 |
+
Z = convert_z(result_Z);
|
| 494 |
+
|
| 495 |
+
if constexpr (kStoreT) {
|
| 496 |
+
ElementCompute result_T = z;
|
| 497 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 498 |
+
T = convert_t(result_T);
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
};
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
/// This base class is meant to define the concept required of the
|
| 505 |
+
/// EpilogueWithBroadcast::OutputOp
|
| 506 |
+
template <
|
| 507 |
+
typename ElementC_,
|
| 508 |
+
typename ElementAccumulator_,
|
| 509 |
+
typename ElementCompute_,
|
| 510 |
+
typename ElementZ_,
|
| 511 |
+
typename ElementT_,
|
| 512 |
+
int ElementsPerAccess,
|
| 513 |
+
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
| 514 |
+
typename BinaryOp_ = plus<ElementCompute_>,
|
| 515 |
+
bool StoreT_ = true,
|
| 516 |
+
typename ElementVector_ = ElementC_
|
| 517 |
+
>
|
| 518 |
+
class LinearCombinationPerChannelScalingBiasElementwise {
|
| 519 |
+
public:
|
| 520 |
+
|
| 521 |
+
using ElementOutput = ElementC_;
|
| 522 |
+
using ElementD = ElementOutput;
|
| 523 |
+
using ElementC = ElementC_;
|
| 524 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 525 |
+
using ElementCompute = ElementCompute_;
|
| 526 |
+
using ElementScalar = ElementCompute;
|
| 527 |
+
using ElementZ = ElementZ_;
|
| 528 |
+
using ElementT = ElementT_;
|
| 529 |
+
using ElementVector = ElementVector_;
|
| 530 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 531 |
+
static int const kCount = kElementsPerAccess;
|
| 532 |
+
|
| 533 |
+
/// Follow cutlass3x EVT aliases
|
| 534 |
+
static bool const IsEltActSupported = true;
|
| 535 |
+
static bool const IsPerChannelScalingSupported = true;
|
| 536 |
+
|
| 537 |
+
using ElementwiseOp = ElementwiseOp_;
|
| 538 |
+
using BinaryOp = BinaryOp_;
|
| 539 |
+
|
| 540 |
+
using ElementwiseOpDispatcher = detail::ElementwiseOpDispatcher<ElementwiseOp>;
|
| 541 |
+
using ElementwiseArguments = typename ElementwiseOpDispatcher::Arguments;
|
| 542 |
+
|
| 543 |
+
// Indicates that this epilogue applies only one binary operation
|
| 544 |
+
static bool const kIsSingleSource = true;
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
| 548 |
+
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
| 549 |
+
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
| 550 |
+
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
| 551 |
+
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
| 552 |
+
|
| 553 |
+
// Definitions needed for collective epilogue
|
| 554 |
+
using FragmentSource = FragmentC;
|
| 555 |
+
using FragmentOutput = FragmentZ;
|
| 556 |
+
using ElementBias = ElementVector;
|
| 557 |
+
using FragmentBias = Array<ElementBias, kElementsPerAccess>;
|
| 558 |
+
using ActivationFn = ElementwiseOp;
|
| 559 |
+
static const ScaleType::Kind kScale = ScaleType::PerChannelScaling;
|
| 560 |
+
|
| 561 |
+
static bool const kIsHeavy = kIsHeavy_member_or_false<ElementwiseOp>::value;
|
| 562 |
+
|
| 563 |
+
/// If true, the 'Z' tensor is stored
|
| 564 |
+
static bool const kStoreZ = true;
|
| 565 |
+
|
| 566 |
+
/// If true, the 'T' tensor is stored
|
| 567 |
+
static bool const kStoreT = StoreT_;
|
| 568 |
+
|
| 569 |
+
/// Host-constructable parameters structure
|
| 570 |
+
struct Params {
|
| 571 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 572 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 573 |
+
ElementCompute beta; ///< scales source tensor
|
| 574 |
+
ElementwiseArguments elementwise; ///< Arguments for elementwise operation
|
| 575 |
+
|
| 576 |
+
//
|
| 577 |
+
// Methods
|
| 578 |
+
//
|
| 579 |
+
|
| 580 |
+
CUTLASS_HOST_DEVICE
|
| 581 |
+
Params():
|
| 582 |
+
alpha_ptr(nullptr),
|
| 583 |
+
beta_ptr(nullptr),
|
| 584 |
+
beta(ElementCompute(0)) { }
|
| 585 |
+
|
| 586 |
+
CUTLASS_HOST_DEVICE
|
| 587 |
+
Params(
|
| 588 |
+
ElementCompute const *alpha_ptr,
|
| 589 |
+
ElementCompute const *beta_ptr,
|
| 590 |
+
ElementwiseArguments elementwise_ = ElementwiseArguments{}
|
| 591 |
+
): beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), elementwise(elementwise_) {
|
| 592 |
+
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
CUTLASS_HOST_DEVICE
|
| 596 |
+
Params(
|
| 597 |
+
ElementCompute const *alpha_ptr
|
| 598 |
+
): beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
|
| 599 |
+
|
| 600 |
+
}
|
| 601 |
+
};
|
| 602 |
+
|
| 603 |
+
private:
|
| 604 |
+
|
| 605 |
+
//
|
| 606 |
+
// Data members
|
| 607 |
+
//
|
| 608 |
+
|
| 609 |
+
ElementCompute const* beta_ptr_ = nullptr;
|
| 610 |
+
ElementCompute beta_ = 0;
|
| 611 |
+
ElementwiseArguments const &elementwise_;
|
| 612 |
+
bool skip_elementwise_;
|
| 613 |
+
|
| 614 |
+
public:
|
| 615 |
+
|
| 616 |
+
//
|
| 617 |
+
// Methods
|
| 618 |
+
//
|
| 619 |
+
|
| 620 |
+
/// Constructor from Params
|
| 621 |
+
CUTLASS_HOST_DEVICE
|
| 622 |
+
LinearCombinationPerChannelScalingBiasElementwise(Params const ¶ms): elementwise_(params.elementwise) {
|
| 623 |
+
if (params.beta_ptr) {
|
| 624 |
+
beta_ptr_ = params.beta_ptr;
|
| 625 |
+
}
|
| 626 |
+
else {
|
| 627 |
+
beta_ = params.beta;
|
| 628 |
+
}
|
| 629 |
+
skip_elementwise_ = false;
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
/// Returns true if source is needed
|
| 633 |
+
CUTLASS_HOST_DEVICE
|
| 634 |
+
bool is_source_needed() const {
|
| 635 |
+
return beta_ptr_ != nullptr || beta_ != ElementCompute(0);
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
CUTLASS_HOST_DEVICE
|
| 639 |
+
bool is_beta_vector() const {
|
| 640 |
+
return beta_ptr_ != nullptr;
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
/// Functionally required for serial reduction in the epilogue
|
| 644 |
+
CUTLASS_HOST_DEVICE
|
| 645 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 646 |
+
if (k_partition) {
|
| 647 |
+
beta_ = ElementCompute(1);
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
if (k_partition != k_partition_count - 1) {
|
| 651 |
+
skip_elementwise_ = true;
|
| 652 |
+
}
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is true
|
| 656 |
+
template <typename ElementwiseArgs>
|
| 657 |
+
CUTLASS_HOST_DEVICE
|
| 658 |
+
void operator()(
|
| 659 |
+
FragmentZ &frag_Z,
|
| 660 |
+
FragmentT &frag_T,
|
| 661 |
+
FragmentAccumulator const &AB,
|
| 662 |
+
FragmentC const &frag_C,
|
| 663 |
+
FragmentCompute const & valpha,
|
| 664 |
+
FragmentCompute const & vbias,
|
| 665 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 666 |
+
|
| 667 |
+
ElementwiseOp elementwise_op;
|
| 668 |
+
BinaryOp binary_op;
|
| 669 |
+
|
| 670 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 671 |
+
FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
|
| 672 |
+
FragmentCompute result_Z;
|
| 673 |
+
FragmentCompute result_T;
|
| 674 |
+
|
| 675 |
+
CUTLASS_PRAGMA_UNROLL
|
| 676 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 677 |
+
ElementCompute z = binary_op(valpha[i] * tmp_Accum[i] + beta_ * tmp_C[i], vbias[i]);
|
| 678 |
+
result_T[i] = z;
|
| 679 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 683 |
+
frag_Z = convert_z(result_Z);
|
| 684 |
+
|
| 685 |
+
if constexpr (kStoreT) {
|
| 686 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 687 |
+
frag_T = convert_t(result_T);
|
| 688 |
+
}
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is true
|
| 692 |
+
/// D = elementwise_op(vector_alpha * accumulator + vector_beta * source + bias)
|
| 693 |
+
template <typename ElementwiseArgs>
|
| 694 |
+
CUTLASS_HOST_DEVICE
|
| 695 |
+
void operator()(
|
| 696 |
+
FragmentZ &frag_Z,
|
| 697 |
+
FragmentT &frag_T,
|
| 698 |
+
FragmentAccumulator const &AB,
|
| 699 |
+
FragmentC const &frag_C,
|
| 700 |
+
FragmentCompute const & valpha,
|
| 701 |
+
FragmentCompute const & vbeta,
|
| 702 |
+
FragmentCompute const & vbias,
|
| 703 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 704 |
+
|
| 705 |
+
ElementwiseOp elementwise_op;
|
| 706 |
+
BinaryOp binary_op;
|
| 707 |
+
|
| 708 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 709 |
+
FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
|
| 710 |
+
FragmentCompute result_Z;
|
| 711 |
+
FragmentCompute result_T;
|
| 712 |
+
|
| 713 |
+
CUTLASS_PRAGMA_UNROLL
|
| 714 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 715 |
+
ElementCompute z = binary_op(valpha[i] * tmp_Accum[i] + vbeta[i] * tmp_C[i], vbias[i]);
|
| 716 |
+
result_T[i] = z;
|
| 717 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 721 |
+
frag_Z = convert_z(result_Z);
|
| 722 |
+
|
| 723 |
+
if constexpr (kStoreT) {
|
| 724 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 725 |
+
frag_T = convert_t(result_T);
|
| 726 |
+
}
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is false
|
| 730 |
+
template <typename ElementwiseArgs>
|
| 731 |
+
CUTLASS_HOST_DEVICE
|
| 732 |
+
void operator()(
|
| 733 |
+
FragmentZ &frag_Z,
|
| 734 |
+
FragmentT &frag_T,
|
| 735 |
+
FragmentAccumulator const &AB,
|
| 736 |
+
FragmentCompute const & valpha,
|
| 737 |
+
FragmentCompute const & vbias,
|
| 738 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 739 |
+
|
| 740 |
+
ElementwiseOp elementwise_op;
|
| 741 |
+
BinaryOp binary_op;
|
| 742 |
+
|
| 743 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 744 |
+
FragmentCompute result_Z;
|
| 745 |
+
FragmentCompute result_T;
|
| 746 |
+
|
| 747 |
+
CUTLASS_PRAGMA_UNROLL
|
| 748 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 749 |
+
ElementCompute z = binary_op(valpha[i] * tmp_Accum[i], vbias[i]);
|
| 750 |
+
result_T[i] = z;
|
| 751 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 755 |
+
frag_Z = convert_z(result_Z);
|
| 756 |
+
|
| 757 |
+
if constexpr (kStoreT) {
|
| 758 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 759 |
+
frag_T = convert_t(result_T);
|
| 760 |
+
}
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
/// Applies the operation when is_source_needed() is true
|
| 764 |
+
CUTLASS_HOST_DEVICE
|
| 765 |
+
void operator()(
|
| 766 |
+
FragmentZ &frag_Z,
|
| 767 |
+
FragmentT &frag_T,
|
| 768 |
+
FragmentAccumulator const &AB,
|
| 769 |
+
FragmentC const &frag_C,
|
| 770 |
+
FragmentCompute const & valpha,
|
| 771 |
+
FragmentCompute const & vbias) const {
|
| 772 |
+
|
| 773 |
+
ElementwiseOpDispatcher elementwise_op(elementwise_);
|
| 774 |
+
BinaryOp binary_op;
|
| 775 |
+
|
| 776 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 777 |
+
FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
|
| 778 |
+
FragmentCompute result_Z;
|
| 779 |
+
FragmentCompute result_T;
|
| 780 |
+
|
| 781 |
+
CUTLASS_PRAGMA_UNROLL
|
| 782 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 783 |
+
ElementCompute z = binary_op(valpha[i] * tmp_Accum[i] + beta_ * tmp_C[i], vbias[i]);
|
| 784 |
+
result_T[i] = z;
|
| 785 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
|
| 786 |
+
}
|
| 787 |
+
|
| 788 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 789 |
+
frag_Z = convert_z(result_Z);
|
| 790 |
+
|
| 791 |
+
if constexpr (kStoreT) {
|
| 792 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 793 |
+
frag_T = convert_t(result_T);
|
| 794 |
+
}
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
/// Applies the operation when is_source_needed() is false
|
| 798 |
+
CUTLASS_HOST_DEVICE
|
| 799 |
+
void operator()(
|
| 800 |
+
FragmentZ &frag_Z,
|
| 801 |
+
FragmentT &frag_T,
|
| 802 |
+
FragmentAccumulator const &AB,
|
| 803 |
+
FragmentCompute const & valpha,
|
| 804 |
+
FragmentCompute const & vbias) const {
|
| 805 |
+
|
| 806 |
+
ElementwiseOpDispatcher elementwise_op(elementwise_);
|
| 807 |
+
BinaryOp binary_op;
|
| 808 |
+
|
| 809 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 810 |
+
FragmentCompute result_Z;
|
| 811 |
+
FragmentCompute result_T;
|
| 812 |
+
|
| 813 |
+
CUTLASS_PRAGMA_UNROLL
|
| 814 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 815 |
+
ElementCompute z = binary_op(valpha[i] * tmp_Accum[i], vbias[i]);
|
| 816 |
+
result_T[i] = z;
|
| 817 |
+
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 821 |
+
frag_Z = convert_z(result_Z);
|
| 822 |
+
|
| 823 |
+
if constexpr (kStoreT) {
|
| 824 |
+
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
| 825 |
+
frag_T = convert_t(result_T);
|
| 826 |
+
}
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is true
|
| 830 |
+
template <typename ElementwiseArgs>
|
| 831 |
+
CUTLASS_HOST_DEVICE
|
| 832 |
+
void operator()(
|
| 833 |
+
ElementZ &Z,
|
| 834 |
+
ElementT &T,
|
| 835 |
+
ElementAccumulator const &AB,
|
| 836 |
+
ElementC const &C,
|
| 837 |
+
ElementCompute const & valpha,
|
| 838 |
+
ElementCompute const & vbias,
|
| 839 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 840 |
+
|
| 841 |
+
ElementwiseOp elementwise_op;
|
| 842 |
+
BinaryOp binary_op;
|
| 843 |
+
|
| 844 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 845 |
+
ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
|
| 846 |
+
|
| 847 |
+
ElementCompute z = binary_op(valpha * tmp_Accum + beta_ * tmp_C, vbias);
|
| 848 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 849 |
+
|
| 850 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 851 |
+
Z = convert_z(result_Z);
|
| 852 |
+
|
| 853 |
+
if constexpr (kStoreT) {
|
| 854 |
+
ElementCompute result_T = z;
|
| 855 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 856 |
+
T = convert_t(result_T);
|
| 857 |
+
}
|
| 858 |
+
}
|
| 859 |
+
|
| 860 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is true
|
| 861 |
+
/// D = elementwise_op(vector_alpha * accumulator + vector_beta * source + bias)
|
| 862 |
+
template <typename ElementwiseArgs>
|
| 863 |
+
CUTLASS_HOST_DEVICE
|
| 864 |
+
void operator()(
|
| 865 |
+
ElementZ &Z,
|
| 866 |
+
ElementT &T,
|
| 867 |
+
ElementAccumulator const &AB,
|
| 868 |
+
ElementC const &C,
|
| 869 |
+
ElementCompute const & valpha,
|
| 870 |
+
ElementCompute const & vbeta,
|
| 871 |
+
ElementCompute const & vbias,
|
| 872 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 873 |
+
|
| 874 |
+
ElementwiseOp elementwise_op;
|
| 875 |
+
BinaryOp binary_op;
|
| 876 |
+
|
| 877 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 878 |
+
ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
|
| 879 |
+
|
| 880 |
+
ElementCompute z = binary_op(valpha * tmp_Accum + vbeta * tmp_C, vbias);
|
| 881 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 882 |
+
|
| 883 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 884 |
+
Z = convert_z(result_Z);
|
| 885 |
+
|
| 886 |
+
if constexpr (kStoreT) {
|
| 887 |
+
ElementCompute result_T = z;
|
| 888 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 889 |
+
T = convert_t(result_T);
|
| 890 |
+
}
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
/// Applies the operation when elementwise_op require arguments and is_source_needed() is false
|
| 894 |
+
template <typename ElementwiseArgs>
|
| 895 |
+
CUTLASS_HOST_DEVICE
|
| 896 |
+
void operator()(
|
| 897 |
+
ElementZ &Z,
|
| 898 |
+
ElementT &T,
|
| 899 |
+
ElementAccumulator const &AB,
|
| 900 |
+
ElementCompute const & valpha,
|
| 901 |
+
ElementCompute const & vbias,
|
| 902 |
+
ElementwiseArgs const &elementwise_args) const {
|
| 903 |
+
|
| 904 |
+
ElementwiseOp elementwise_op;
|
| 905 |
+
BinaryOp binary_op;
|
| 906 |
+
|
| 907 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 908 |
+
|
| 909 |
+
ElementCompute z = binary_op(valpha * tmp_Accum, vbias);
|
| 910 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
|
| 911 |
+
|
| 912 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 913 |
+
Z = convert_z(result_Z);
|
| 914 |
+
|
| 915 |
+
if constexpr (kStoreT) {
|
| 916 |
+
ElementCompute result_T = z;
|
| 917 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 918 |
+
T = convert_t(result_T);
|
| 919 |
+
}
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
/// Applies the operation when is_source_needed() is true
|
| 923 |
+
CUTLASS_HOST_DEVICE
|
| 924 |
+
void operator()(
|
| 925 |
+
ElementZ &Z,
|
| 926 |
+
ElementT &T,
|
| 927 |
+
ElementAccumulator const &AB,
|
| 928 |
+
ElementC const &C,
|
| 929 |
+
ElementCompute const & valpha,
|
| 930 |
+
ElementCompute const & vbias) const {
|
| 931 |
+
|
| 932 |
+
ElementwiseOpDispatcher elementwise_op(elementwise_);
|
| 933 |
+
BinaryOp binary_op;
|
| 934 |
+
|
| 935 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 936 |
+
ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
|
| 937 |
+
|
| 938 |
+
ElementCompute z = binary_op(valpha * tmp_Accum + beta_ * tmp_C, vbias);
|
| 939 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z);
|
| 940 |
+
|
| 941 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 942 |
+
Z = convert_z(result_Z);
|
| 943 |
+
|
| 944 |
+
if constexpr (kStoreT) {
|
| 945 |
+
ElementCompute result_T = z;
|
| 946 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 947 |
+
T = convert_t(result_T);
|
| 948 |
+
}
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
/// Applies the operation when is_source_needed() is false
|
| 952 |
+
CUTLASS_HOST_DEVICE
|
| 953 |
+
void operator()(
|
| 954 |
+
ElementZ &Z,
|
| 955 |
+
ElementT &T,
|
| 956 |
+
ElementAccumulator const &AB,
|
| 957 |
+
ElementCompute const & valpha,
|
| 958 |
+
ElementCompute const & vbias) const {
|
| 959 |
+
|
| 960 |
+
ElementwiseOpDispatcher elementwise_op(elementwise_);
|
| 961 |
+
BinaryOp binary_op;
|
| 962 |
+
|
| 963 |
+
ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
|
| 964 |
+
|
| 965 |
+
ElementCompute z = binary_op(valpha * tmp_Accum, vbias);
|
| 966 |
+
ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z);
|
| 967 |
+
|
| 968 |
+
NumericConverter<ElementZ, ElementCompute> convert_z;
|
| 969 |
+
Z = convert_z(result_Z);
|
| 970 |
+
|
| 971 |
+
if constexpr (kStoreT) {
|
| 972 |
+
ElementCompute result_T = z;
|
| 973 |
+
NumericConverter<ElementT, ElementCompute> convert_t;
|
| 974 |
+
T = convert_t(result_T);
|
| 975 |
+
}
|
| 976 |
+
}
|
| 977 |
+
};
|
| 978 |
+
|
| 979 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 980 |
+
|
| 981 |
+
} // namespace thread
|
| 982 |
+
} // namespace epilogue
|
| 983 |
+
} // namespace cutlass
|
| 984 |
+
|
| 985 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination operations used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include <cuda_fp16.h>
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/array.h"
|
| 42 |
+
#include "cutlass/functional.h"
|
| 43 |
+
#include "cutlass/numeric_conversion.h"
|
| 44 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace epilogue {
|
| 50 |
+
namespace thread {
|
| 51 |
+
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
namespace detail {
|
| 55 |
+
|
| 56 |
+
template <typename Element, int ElementsPerAccess>
|
| 57 |
+
struct ArrayMaximum {
|
| 58 |
+
|
| 59 |
+
CUTLASS_HOST_DEVICE
|
| 60 |
+
Array<Element, ElementsPerAccess> operator()(
|
| 61 |
+
Array<Element, ElementsPerAccess> const &lhs,
|
| 62 |
+
Array<Element, ElementsPerAccess> const &rhs) const {
|
| 63 |
+
|
| 64 |
+
Array<Element, ElementsPerAccess> result;
|
| 65 |
+
|
| 66 |
+
CUTLASS_PRAGMA_UNROLL
|
| 67 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 68 |
+
result[i] = platform::max(lhs[i].get(), rhs[i]);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
return result;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
CUTLASS_HOST_DEVICE
|
| 75 |
+
Array<Element, ElementsPerAccess> operator()(
|
| 76 |
+
Array<Element, ElementsPerAccess> const &lhs,
|
| 77 |
+
Element rhs) const {
|
| 78 |
+
|
| 79 |
+
Array<Element, ElementsPerAccess> result;
|
| 80 |
+
|
| 81 |
+
CUTLASS_PRAGMA_UNROLL
|
| 82 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 83 |
+
result[i] = platform::max(lhs[i].get(), rhs);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
return result;
|
| 87 |
+
}
|
| 88 |
+
};
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
/// Partial specialization: Element=float
|
| 92 |
+
template <int ElementsPerAccess>
|
| 93 |
+
struct ArrayMaximum<float, ElementsPerAccess> {
|
| 94 |
+
|
| 95 |
+
CUTLASS_HOST_DEVICE
|
| 96 |
+
Array<float, ElementsPerAccess> operator()(
|
| 97 |
+
Array<float, ElementsPerAccess> const &lhs,
|
| 98 |
+
Array<float, ElementsPerAccess> const &rhs) const {
|
| 99 |
+
|
| 100 |
+
Array<float, ElementsPerAccess> result;
|
| 101 |
+
|
| 102 |
+
CUTLASS_PRAGMA_UNROLL
|
| 103 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 104 |
+
result[i] = fmax(lhs[i], rhs[i]);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
return result;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
CUTLASS_HOST_DEVICE
|
| 111 |
+
Array<float, ElementsPerAccess> operator()(
|
| 112 |
+
Array<float, ElementsPerAccess> const &lhs,
|
| 113 |
+
float rhs) const {
|
| 114 |
+
|
| 115 |
+
Array<float, ElementsPerAccess> result;
|
| 116 |
+
|
| 117 |
+
CUTLASS_PRAGMA_UNROLL
|
| 118 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 119 |
+
result[i] = fmax(lhs[i], rhs);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
return result;
|
| 123 |
+
}
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
/// Partial specialization: Element=half
|
| 127 |
+
template <int ElementsPerAccess>
|
| 128 |
+
struct ArrayMaximum<half_t, ElementsPerAccess> {
|
| 129 |
+
|
| 130 |
+
CUTLASS_DEVICE
|
| 131 |
+
Array<half_t, ElementsPerAccess> operator()(
|
| 132 |
+
Array<half_t, ElementsPerAccess> const &lhs,
|
| 133 |
+
Array<half_t, ElementsPerAccess> const &rhs) const {
|
| 134 |
+
|
| 135 |
+
Array<half_t, ElementsPerAccess> result;
|
| 136 |
+
|
| 137 |
+
#if __CUDA_ARCH__ >= 800
|
| 138 |
+
int const kVectorCount = ElementsPerAccess / 2;
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data());
|
| 142 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(rhs.raw_data());
|
| 143 |
+
__half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data());
|
| 144 |
+
|
| 145 |
+
CUTLASS_PRAGMA_UNROLL
|
| 146 |
+
for (int i = 0; i < kVectorCount; ++i) {
|
| 147 |
+
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
|
| 151 |
+
|
| 152 |
+
#else
|
| 153 |
+
__half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
|
| 154 |
+
__half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data());
|
| 155 |
+
__half *res_ptr = reinterpret_cast<__half *>(result.raw_data());
|
| 156 |
+
|
| 157 |
+
CUTLASS_PRAGMA_UNROLL
|
| 158 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 159 |
+
res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
#endif
|
| 163 |
+
|
| 164 |
+
return result;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
CUTLASS_DEVICE
|
| 168 |
+
Array<half_t, ElementsPerAccess> operator()(
|
| 169 |
+
Array<half_t, ElementsPerAccess> const &lhs,
|
| 170 |
+
half_t const &rhs) const {
|
| 171 |
+
|
| 172 |
+
Array<half_t, ElementsPerAccess> result;
|
| 173 |
+
|
| 174 |
+
#if __CUDA_ARCH__ >= 800
|
| 175 |
+
int const kVectorCount = ElementsPerAccess / 2;
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
__half rhs_raw = reinterpret_cast<__half const &>(rhs);
|
| 179 |
+
__half2 rhs_pair = __half2half2(rhs_raw);
|
| 180 |
+
|
| 181 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data());
|
| 182 |
+
__half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data());
|
| 183 |
+
|
| 184 |
+
CUTLASS_PRAGMA_UNROLL
|
| 185 |
+
for (int i = 0; i < kVectorCount; ++i) {
|
| 186 |
+
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
|
| 190 |
+
|
| 191 |
+
#else
|
| 192 |
+
|
| 193 |
+
__half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
|
| 194 |
+
__half const rhs_raw = reinterpret_cast<__half const &>(rhs);
|
| 195 |
+
__half *res_ptr = reinterpret_cast<__half *>(result.raw_data());
|
| 196 |
+
|
| 197 |
+
CUTLASS_PRAGMA_UNROLL
|
| 198 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 199 |
+
res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]);
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
#endif
|
| 203 |
+
|
| 204 |
+
return result;
|
| 205 |
+
}
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
/// Partial specialization: Element=bfloat16_t
|
| 209 |
+
template <int ElementsPerAccess>
|
| 210 |
+
struct ArrayMaximum<bfloat16_t, ElementsPerAccess> {
|
| 211 |
+
|
| 212 |
+
using NvType = __nv_bfloat16;
|
| 213 |
+
using NvTypeV2 = __nv_bfloat162;
|
| 214 |
+
|
| 215 |
+
CUTLASS_DEVICE
|
| 216 |
+
Array<bfloat16_t, ElementsPerAccess> operator()(
|
| 217 |
+
Array<bfloat16_t, ElementsPerAccess> const &lhs,
|
| 218 |
+
Array<bfloat16_t, ElementsPerAccess> const &rhs) const {
|
| 219 |
+
|
| 220 |
+
Array<bfloat16_t, ElementsPerAccess> result;
|
| 221 |
+
|
| 222 |
+
#if __CUDA_ARCH__ >= 800
|
| 223 |
+
int const kVectorCount = ElementsPerAccess / 2;
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
NvTypeV2 const *lhs_ptr = reinterpret_cast<NvTypeV2 const *>(lhs.raw_data());
|
| 227 |
+
NvTypeV2 const *rhs_ptr = reinterpret_cast<NvTypeV2 const *>(rhs.raw_data());
|
| 228 |
+
NvTypeV2 *res_ptr = reinterpret_cast<NvTypeV2 *>(result.raw_data());
|
| 229 |
+
|
| 230 |
+
CUTLASS_PRAGMA_UNROLL
|
| 231 |
+
for (int i = 0; i < kVectorCount; ++i) {
|
| 232 |
+
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
#else
|
| 236 |
+
NvType const *lhs_ptr = reinterpret_cast<NvType const *>(lhs.raw_data());
|
| 237 |
+
NvType const *rhs_ptr = reinterpret_cast<NvType const *>(rhs.raw_data());
|
| 238 |
+
NvType *res_ptr = reinterpret_cast<NvType *>(result.raw_data());
|
| 239 |
+
|
| 240 |
+
CUTLASS_PRAGMA_UNROLL
|
| 241 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 242 |
+
res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]);
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
#endif
|
| 246 |
+
|
| 247 |
+
return result;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
CUTLASS_DEVICE
|
| 251 |
+
Array<bfloat16_t, ElementsPerAccess> operator()(
|
| 252 |
+
Array<bfloat16_t, ElementsPerAccess> const &lhs,
|
| 253 |
+
bfloat16_t rhs) const {
|
| 254 |
+
|
| 255 |
+
Array<bfloat16_t, ElementsPerAccess> result;
|
| 256 |
+
|
| 257 |
+
#if __CUDA_ARCH__ >= 800
|
| 258 |
+
int const kVectorCount = ElementsPerAccess / 2;
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
NvType rhs_raw = reinterpret_cast<NvType const &>(rhs);
|
| 262 |
+
NvTypeV2 rhs_pair = __bfloat162bfloat162(rhs_raw);
|
| 263 |
+
|
| 264 |
+
NvTypeV2 const *lhs_ptr = reinterpret_cast<NvTypeV2 const *>(lhs.raw_data());
|
| 265 |
+
NvTypeV2 *res_ptr = reinterpret_cast<NvTypeV2 *>(result.raw_data());
|
| 266 |
+
|
| 267 |
+
CUTLASS_PRAGMA_UNROLL
|
| 268 |
+
for (int i = 0; i < kVectorCount; ++i) {
|
| 269 |
+
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
|
| 273 |
+
|
| 274 |
+
#else
|
| 275 |
+
|
| 276 |
+
NvType const *lhs_ptr = reinterpret_cast<NvType const *>(lhs.raw_data());
|
| 277 |
+
NvType const rhs_raw = reinterpret_cast<NvType const &>(rhs);
|
| 278 |
+
NvType *res_ptr = reinterpret_cast<NvType *>(result.raw_data());
|
| 279 |
+
|
| 280 |
+
CUTLASS_PRAGMA_UNROLL
|
| 281 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 282 |
+
res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
#endif
|
| 286 |
+
|
| 287 |
+
return result;
|
| 288 |
+
}
|
| 289 |
+
};
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 293 |
+
|
| 294 |
+
template <typename Element, int ElementsPerAccess>
|
| 295 |
+
struct ReluConditional {
|
| 296 |
+
|
| 297 |
+
CUTLASS_HOST_DEVICE
|
| 298 |
+
void operator()(
|
| 299 |
+
bool conditional[],
|
| 300 |
+
Array<Element, ElementsPerAccess> const &fragment,
|
| 301 |
+
Element threshold) const {
|
| 302 |
+
|
| 303 |
+
CUTLASS_PRAGMA_UNROLL
|
| 304 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 305 |
+
conditional[i] = !(fragment[i] < threshold);
|
| 306 |
+
}
|
| 307 |
+
}
|
| 308 |
+
};
|
| 309 |
+
|
| 310 |
+
template <int ElementsPerAccess>
|
| 311 |
+
struct ReluConditional<half_t, ElementsPerAccess> {
|
| 312 |
+
|
| 313 |
+
CUTLASS_DEVICE
|
| 314 |
+
void operator()(
|
| 315 |
+
bool conditional[],
|
| 316 |
+
Array<half_t, ElementsPerAccess> const &fragment,
|
| 317 |
+
half_t threshold) const {
|
| 318 |
+
|
| 319 |
+
__half y = reinterpret_cast<__half const &>(threshold);
|
| 320 |
+
__half const *x = reinterpret_cast<__half const *>(fragment.raw_data());
|
| 321 |
+
|
| 322 |
+
CUTLASS_PRAGMA_UNROLL
|
| 323 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 324 |
+
conditional[i] = !__hlt(x[i], y);
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
};
|
| 328 |
+
|
| 329 |
+
template <int ElementsPerAccess>
|
| 330 |
+
struct ReluConditional<bfloat16_t, ElementsPerAccess> {
|
| 331 |
+
|
| 332 |
+
CUTLASS_DEVICE
|
| 333 |
+
void operator()(
|
| 334 |
+
bool conditional[],
|
| 335 |
+
Array<bfloat16_t, ElementsPerAccess> const &fragment,
|
| 336 |
+
bfloat16_t threshold) const {
|
| 337 |
+
|
| 338 |
+
__nv_bfloat16 y = reinterpret_cast<__nv_bfloat16 const &>(threshold);
|
| 339 |
+
__nv_bfloat16 const *x = reinterpret_cast<__nv_bfloat16 const *>(fragment.raw_data());
|
| 340 |
+
|
| 341 |
+
CUTLASS_PRAGMA_UNROLL
|
| 342 |
+
for (int i = 0; i < ElementsPerAccess; ++i) {
|
| 343 |
+
conditional[i] = !__hlt(x[i], y);
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
};
|
| 347 |
+
|
| 348 |
+
} // namespace detail
|
| 349 |
+
|
| 350 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 351 |
+
|
| 352 |
+
/// This is a partial specialization for fused Bias and ReLU. It supports the option of packing
|
| 353 |
+
/// ReLU conditionals in a bit vector that may be used by backwards passes as an optimization.
|
| 354 |
+
///
|
| 355 |
+
/// This class can only be used with cutlass::epilogue::threadblock::EpilogueWithBroadcast<>.
|
| 356 |
+
///
|
| 357 |
+
/// This base class is meant to define the concept required of the
|
| 358 |
+
/// EpilogueWithBroadcast::OutputOp
|
| 359 |
+
template <
|
| 360 |
+
typename ElementC_,
|
| 361 |
+
typename ElementAccumulator_,
|
| 362 |
+
typename ElementCompute_,
|
| 363 |
+
typename ElementZ_,
|
| 364 |
+
int ElementsPerAccess,
|
| 365 |
+
bool StoreT_ = true,
|
| 366 |
+
typename ElementVector_ = ElementC_
|
| 367 |
+
>
|
| 368 |
+
class LinearCombinationBiasRelu {
|
| 369 |
+
public:
|
| 370 |
+
|
| 371 |
+
using ElementOutput = ElementC_;
|
| 372 |
+
using ElementC = ElementC_;
|
| 373 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 374 |
+
using ElementCompute = ElementCompute_;
|
| 375 |
+
using ElementZ = ElementZ_;
|
| 376 |
+
using ElementVector = ElementVector_;
|
| 377 |
+
|
| 378 |
+
using ElementT = uint1b_t;
|
| 379 |
+
|
| 380 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 381 |
+
static int const kCount = kElementsPerAccess;
|
| 382 |
+
|
| 383 |
+
using ElementwiseOp = ReLu<ElementCompute>;
|
| 384 |
+
using BinaryOp = plus<ElementCompute>;
|
| 385 |
+
|
| 386 |
+
// Indicates that this epilogue applies only one binary operation
|
| 387 |
+
static bool const kIsSingleSource = true;
|
| 388 |
+
|
| 389 |
+
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
| 390 |
+
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
| 391 |
+
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
| 392 |
+
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
| 393 |
+
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
| 394 |
+
|
| 395 |
+
/// If true, the 'Z' tensor is stored
|
| 396 |
+
static bool const kStoreZ = true;
|
| 397 |
+
|
| 398 |
+
/// If true, the 'T' tensor is stored
|
| 399 |
+
static bool const kStoreT = StoreT_;
|
| 400 |
+
|
| 401 |
+
/// Host-constructable parameters structure
|
| 402 |
+
struct Params {
|
| 403 |
+
|
| 404 |
+
ElementCompute alpha; ///< scales accumulators
|
| 405 |
+
ElementCompute beta; ///< scales source tensor
|
| 406 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 407 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 408 |
+
ElementZ threshold; ///< ReLu threshold
|
| 409 |
+
|
| 410 |
+
//
|
| 411 |
+
// Methods
|
| 412 |
+
//
|
| 413 |
+
//
|
| 414 |
+
// Methods
|
| 415 |
+
//
|
| 416 |
+
|
| 417 |
+
CUTLASS_HOST_DEVICE
|
| 418 |
+
Params():
|
| 419 |
+
alpha(ElementCompute(1)),
|
| 420 |
+
beta(ElementCompute()),
|
| 421 |
+
alpha_ptr(nullptr),
|
| 422 |
+
beta_ptr(nullptr),
|
| 423 |
+
threshold(ElementCompute()) { }
|
| 424 |
+
|
| 425 |
+
CUTLASS_HOST_DEVICE
|
| 426 |
+
Params(
|
| 427 |
+
ElementCompute alpha,
|
| 428 |
+
ElementCompute beta,
|
| 429 |
+
ElementCompute threshold_ = ElementCompute()
|
| 430 |
+
):
|
| 431 |
+
alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 432 |
+
|
| 433 |
+
NumericConverter<ElementZ, ElementCompute> convert_threshold;
|
| 434 |
+
|
| 435 |
+
threshold = convert_threshold(threshold_);
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
CUTLASS_HOST_DEVICE
|
| 439 |
+
Params(
|
| 440 |
+
ElementCompute alpha
|
| 441 |
+
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(ElementZ()) {
|
| 442 |
+
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
CUTLASS_HOST_DEVICE
|
| 446 |
+
Params(
|
| 447 |
+
ElementCompute const *alpha_ptr,
|
| 448 |
+
ElementCompute const *beta_ptr,
|
| 449 |
+
ElementCompute threshold_ = ElementCompute()
|
| 450 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 451 |
+
|
| 452 |
+
NumericConverter<ElementZ, ElementCompute> convert_threshold;
|
| 453 |
+
|
| 454 |
+
threshold = convert_threshold(threshold_);
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
CUTLASS_HOST_DEVICE
|
| 458 |
+
Params(
|
| 459 |
+
ElementCompute const *alpha_ptr
|
| 460 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr), threshold(ElementZ()) {
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
};
|
| 464 |
+
|
| 465 |
+
private:
|
| 466 |
+
|
| 467 |
+
//
|
| 468 |
+
// Data members
|
| 469 |
+
//
|
| 470 |
+
|
| 471 |
+
ElementCompute alpha_;
|
| 472 |
+
ElementCompute beta_;
|
| 473 |
+
ElementZ threshold_;
|
| 474 |
+
|
| 475 |
+
public:
|
| 476 |
+
|
| 477 |
+
//
|
| 478 |
+
// Methods
|
| 479 |
+
//
|
| 480 |
+
|
| 481 |
+
/// Constructor from Params
|
| 482 |
+
CUTLASS_HOST_DEVICE
|
| 483 |
+
LinearCombinationBiasRelu(Params const ¶ms) {
|
| 484 |
+
|
| 485 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 486 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 487 |
+
threshold_ = params.threshold;
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
/// Returns true if source is needed
|
| 491 |
+
CUTLASS_HOST_DEVICE
|
| 492 |
+
bool is_source_needed() const {
|
| 493 |
+
return beta_ != ElementCompute(0);
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
/// Functionally required for serial reduction in the epilogue
|
| 497 |
+
CUTLASS_HOST_DEVICE
|
| 498 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 499 |
+
if (k_partition) {
|
| 500 |
+
beta_ = ElementCompute(1);
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
if (k_partition != k_partition_count - 1) {
|
| 504 |
+
// set to NaN to make ReLU no-op for all except last k partitions
|
| 505 |
+
int64_t allones = -1;
|
| 506 |
+
threshold_ = reinterpret_cast<ElementZ const &>(allones);
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
/// Applies the operation when is_source_needed() is true
|
| 511 |
+
CUTLASS_HOST_DEVICE
|
| 512 |
+
void operator()(
|
| 513 |
+
FragmentZ &frag_Z,
|
| 514 |
+
FragmentT &frag_T,
|
| 515 |
+
FragmentAccumulator const &AB,
|
| 516 |
+
FragmentC const &frag_C,
|
| 517 |
+
FragmentCompute const &V) const {
|
| 518 |
+
|
| 519 |
+
BinaryOp binary_op;
|
| 520 |
+
|
| 521 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 522 |
+
FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
|
| 523 |
+
FragmentCompute result_Z;
|
| 524 |
+
|
| 525 |
+
bool conditions[kElementsPerAccess];
|
| 526 |
+
|
| 527 |
+
CUTLASS_PRAGMA_UNROLL
|
| 528 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 529 |
+
|
| 530 |
+
ElementCompute z = alpha_ * tmp_Accum[i];
|
| 531 |
+
z += beta_ * tmp_C[i];
|
| 532 |
+
|
| 533 |
+
z = binary_op(z, V[i]);
|
| 534 |
+
result_Z[i] = z;
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 538 |
+
frag_Z = convert_z(result_Z);
|
| 539 |
+
|
| 540 |
+
//
|
| 541 |
+
// Compute condition
|
| 542 |
+
//
|
| 543 |
+
|
| 544 |
+
detail::ReluConditional<ElementZ, kElementsPerAccess> relu_conditional;
|
| 545 |
+
relu_conditional(conditions, frag_Z, threshold_);
|
| 546 |
+
|
| 547 |
+
detail::ArrayMaximum<ElementZ, kElementsPerAccess> maximum_op;
|
| 548 |
+
frag_Z = maximum_op(frag_Z, threshold_);
|
| 549 |
+
|
| 550 |
+
if (kStoreT) {
|
| 551 |
+
PackPredicates<kElementsPerAccess> pack_predicates;
|
| 552 |
+
frag_T = pack_predicates(conditions);
|
| 553 |
+
}
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
/// Applies the operation when is_source_needed() is false
|
| 557 |
+
CUTLASS_HOST_DEVICE
|
| 558 |
+
void operator()(
|
| 559 |
+
FragmentZ &frag_Z,
|
| 560 |
+
FragmentT &frag_T,
|
| 561 |
+
FragmentAccumulator const &AB,
|
| 562 |
+
FragmentCompute const &V) const {
|
| 563 |
+
|
| 564 |
+
BinaryOp binary_op;
|
| 565 |
+
|
| 566 |
+
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 567 |
+
FragmentCompute result_Z;
|
| 568 |
+
|
| 569 |
+
bool conditions[kElementsPerAccess];
|
| 570 |
+
|
| 571 |
+
CUTLASS_PRAGMA_UNROLL
|
| 572 |
+
for (int i = 0; i < kElementsPerAccess; ++i) {
|
| 573 |
+
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
|
| 574 |
+
result_Z[i] = z;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
| 578 |
+
frag_Z = convert_z(result_Z);
|
| 579 |
+
|
| 580 |
+
//
|
| 581 |
+
// Compute condition
|
| 582 |
+
//
|
| 583 |
+
|
| 584 |
+
detail::ReluConditional<ElementZ, kElementsPerAccess> relu_conditional;
|
| 585 |
+
relu_conditional(conditions, frag_Z, threshold_);
|
| 586 |
+
|
| 587 |
+
detail::ArrayMaximum<ElementZ, kElementsPerAccess> maximum_op;
|
| 588 |
+
frag_Z = maximum_op(frag_Z, threshold_);
|
| 589 |
+
|
| 590 |
+
//
|
| 591 |
+
// Compute conditions
|
| 592 |
+
//
|
| 593 |
+
|
| 594 |
+
//
|
| 595 |
+
// Store
|
| 596 |
+
//
|
| 597 |
+
if (kStoreT) {
|
| 598 |
+
PackPredicates<kElementsPerAccess> pack_predicates;
|
| 599 |
+
frag_T = pack_predicates(conditions);
|
| 600 |
+
}
|
| 601 |
+
}
|
| 602 |
+
};
|
| 603 |
+
|
| 604 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 605 |
+
|
| 606 |
+
} // namespace thread
|
| 607 |
+
} // namespace epilogue
|
| 608 |
+
} // namespace cutlass
|
| 609 |
+
|
| 610 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h
ADDED
|
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear scaling operations used by epilogues. Values are clamped before
|
| 33 |
+
converting to the output element type.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace epilogue {
|
| 49 |
+
namespace thread {
|
| 50 |
+
|
| 51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
namespace detail {
|
| 54 |
+
|
| 55 |
+
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
|
| 56 |
+
constexpr bool LinearCombinationClampIsHeavy() {
|
| 57 |
+
return false;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 63 |
+
|
| 64 |
+
/// Applies a linear combination operator to an array of elements then clamps the output before
|
| 65 |
+
/// converting to the output element type.
|
| 66 |
+
///
|
| 67 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 68 |
+
///
|
| 69 |
+
template <
|
| 70 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 71 |
+
int Count, ///< Number of elements computed per operation
|
| 72 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 73 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 74 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 75 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 76 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 77 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 78 |
+
>
|
| 79 |
+
class LinearCombinationClamp {
|
| 80 |
+
public:
|
| 81 |
+
|
| 82 |
+
using ElementOutput = ElementOutput_;
|
| 83 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 84 |
+
using ElementCompute = ElementCompute_;
|
| 85 |
+
|
| 86 |
+
static int const kCount = Count;
|
| 87 |
+
|
| 88 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 89 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 90 |
+
using ComputeFragment = Array<ElementCompute, kCount>;
|
| 91 |
+
using FragmentSource = Array<ElementOutput, kCount>;
|
| 92 |
+
|
| 93 |
+
static FloatRoundStyle const kRound = Round;
|
| 94 |
+
|
| 95 |
+
static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy();
|
| 96 |
+
|
| 97 |
+
/// Host-constructable parameters structure
|
| 98 |
+
struct Params {
|
| 99 |
+
|
| 100 |
+
ElementCompute alpha; ///< scales accumulators
|
| 101 |
+
ElementCompute beta; ///< scales source tensor
|
| 102 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 103 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 104 |
+
|
| 105 |
+
//
|
| 106 |
+
// Methods
|
| 107 |
+
//
|
| 108 |
+
|
| 109 |
+
CUTLASS_HOST_DEVICE
|
| 110 |
+
Params():
|
| 111 |
+
alpha(ElementCompute(1)),
|
| 112 |
+
beta(ElementCompute(0)),
|
| 113 |
+
alpha_ptr(nullptr),
|
| 114 |
+
beta_ptr(nullptr) { }
|
| 115 |
+
|
| 116 |
+
CUTLASS_HOST_DEVICE
|
| 117 |
+
Params(
|
| 118 |
+
ElementCompute alpha,
|
| 119 |
+
ElementCompute beta
|
| 120 |
+
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 121 |
+
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
CUTLASS_HOST_DEVICE
|
| 125 |
+
Params(
|
| 126 |
+
ElementCompute alpha
|
| 127 |
+
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 128 |
+
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
CUTLASS_HOST_DEVICE
|
| 132 |
+
Params(
|
| 133 |
+
ElementCompute const *alpha_ptr,
|
| 134 |
+
ElementCompute const *beta_ptr
|
| 135 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 136 |
+
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
CUTLASS_HOST_DEVICE
|
| 140 |
+
Params(
|
| 141 |
+
ElementCompute const *alpha_ptr
|
| 142 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
|
| 143 |
+
|
| 144 |
+
}
|
| 145 |
+
};
|
| 146 |
+
|
| 147 |
+
private:
|
| 148 |
+
|
| 149 |
+
//
|
| 150 |
+
// Data members
|
| 151 |
+
//
|
| 152 |
+
|
| 153 |
+
ElementCompute alpha_;
|
| 154 |
+
ElementCompute beta_;
|
| 155 |
+
|
| 156 |
+
public:
|
| 157 |
+
|
| 158 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 159 |
+
CUTLASS_HOST_DEVICE
|
| 160 |
+
LinearCombinationClamp(Params const ¶ms) {
|
| 161 |
+
|
| 162 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 163 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Returns true if source is needed
|
| 167 |
+
CUTLASS_HOST_DEVICE
|
| 168 |
+
bool is_source_needed() const {
|
| 169 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 170 |
+
|
| 171 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 172 |
+
|
| 173 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 174 |
+
|
| 175 |
+
return beta_ != ElementCompute(0);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
/// Functionally required for serial reduction in the epilogue
|
| 179 |
+
CUTLASS_HOST_DEVICE
|
| 180 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 181 |
+
if (k_partition) {
|
| 182 |
+
beta_ = ElementCompute(1);
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 187 |
+
CUTLASS_HOST_DEVICE
|
| 188 |
+
FragmentOutput operator()(
|
| 189 |
+
FragmentAccumulator const &accumulator,
|
| 190 |
+
FragmentOutput const &source,
|
| 191 |
+
ElementCompute uniform = ElementCompute(0)) const {
|
| 192 |
+
|
| 193 |
+
// Convert source to interal compute numeric type
|
| 194 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 195 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 196 |
+
|
| 197 |
+
ComputeFragment converted_source = source_converter(source);
|
| 198 |
+
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
| 199 |
+
|
| 200 |
+
// Perform binary operations
|
| 201 |
+
|
| 202 |
+
ComputeFragment intermediate;
|
| 203 |
+
|
| 204 |
+
multiplies<ComputeFragment> mul_add_source;
|
| 205 |
+
multiply_add<ComputeFragment> mul_add_accumulator;
|
| 206 |
+
|
| 207 |
+
minimum<ComputeFragment> min_accumulator;
|
| 208 |
+
maximum<ComputeFragment> max_accumulator;
|
| 209 |
+
|
| 210 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 211 |
+
intermediate = converted_source;
|
| 212 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 213 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 214 |
+
intermediate = converted_accumulator;
|
| 215 |
+
} else {
|
| 216 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 217 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
/// Clamping constant value
|
| 221 |
+
ElementCompute const kClampMax =
|
| 222 |
+
ElementCompute(cutlass::platform::numeric_limits<ElementOutput>::max());
|
| 223 |
+
|
| 224 |
+
ElementCompute const kClampMin =
|
| 225 |
+
ElementCompute(cutlass::platform::numeric_limits<ElementOutput>::lowest());
|
| 226 |
+
|
| 227 |
+
intermediate = max_accumulator(intermediate, kClampMin);
|
| 228 |
+
intermediate = min_accumulator(intermediate, kClampMax);
|
| 229 |
+
|
| 230 |
+
// Convert to destination numeric type
|
| 231 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 232 |
+
|
| 233 |
+
return destination_converter(intermediate);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 237 |
+
CUTLASS_HOST_DEVICE
|
| 238 |
+
FragmentOutput operator()(
|
| 239 |
+
FragmentAccumulator const &accumulator) const {
|
| 240 |
+
|
| 241 |
+
// Convert source to interal compute numeric type
|
| 242 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 243 |
+
|
| 244 |
+
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
| 245 |
+
|
| 246 |
+
// Perform binary operations
|
| 247 |
+
|
| 248 |
+
ComputeFragment intermediate;
|
| 249 |
+
|
| 250 |
+
multiplies<ComputeFragment> mul_accumulator;
|
| 251 |
+
|
| 252 |
+
minimum<ComputeFragment> min_accumulator;
|
| 253 |
+
maximum<ComputeFragment> max_accumulator;
|
| 254 |
+
|
| 255 |
+
if (Scale == ScaleType::Nothing) {
|
| 256 |
+
intermediate = converted_accumulator;
|
| 257 |
+
} else {
|
| 258 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
/// Clamping constant value
|
| 262 |
+
ElementCompute const kClampMax =
|
| 263 |
+
ElementCompute(cutlass::platform::numeric_limits<ElementOutput>::max());
|
| 264 |
+
|
| 265 |
+
ElementCompute const kClampMin =
|
| 266 |
+
ElementCompute(cutlass::platform::numeric_limits<ElementOutput>::lowest());
|
| 267 |
+
|
| 268 |
+
intermediate = max_accumulator(intermediate, kClampMin);
|
| 269 |
+
intermediate = min_accumulator(intermediate, kClampMax);
|
| 270 |
+
|
| 271 |
+
// Convert to destination numeric type
|
| 272 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 273 |
+
|
| 274 |
+
return destination_converter(intermediate);
|
| 275 |
+
}
|
| 276 |
+
};
|
| 277 |
+
|
| 278 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 279 |
+
|
| 280 |
+
// Conditional guards to enable partial specialization for packed integers
|
| 281 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
| 282 |
+
|
| 283 |
+
/// Applies a linear combination operator to an array of elements then clamps the output before
|
| 284 |
+
/// converting to the output element type.
|
| 285 |
+
///
|
| 286 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 287 |
+
///
|
| 288 |
+
template <
|
| 289 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 290 |
+
int Count, ///< Number of elements computed per operation
|
| 291 |
+
ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
|
| 292 |
+
FloatRoundStyle Round
|
| 293 |
+
>
|
| 294 |
+
class LinearCombinationClamp<ElementOutput_, Count, int, float, Scale, Round> {
|
| 295 |
+
public:
|
| 296 |
+
|
| 297 |
+
using ElementOutput = ElementOutput_;
|
| 298 |
+
using ElementAccumulator = int;
|
| 299 |
+
using ElementCompute = float;
|
| 300 |
+
|
| 301 |
+
static_assert(
|
| 302 |
+
cutlass::platform::numeric_limits<ElementOutput>::is_integer,
|
| 303 |
+
"This elementwise op expects the output to be int.");
|
| 304 |
+
|
| 305 |
+
static int const kCount = Count;
|
| 306 |
+
|
| 307 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 308 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 309 |
+
using ComputeFragment = Array<ElementCompute, kCount>;
|
| 310 |
+
|
| 311 |
+
static FloatRoundStyle const kRound = Round;
|
| 312 |
+
|
| 313 |
+
static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy();
|
| 314 |
+
|
| 315 |
+
/// Host-constructable parameters structure
|
| 316 |
+
struct Params {
|
| 317 |
+
|
| 318 |
+
ElementCompute alpha; ///< scales accumulators
|
| 319 |
+
ElementCompute beta; ///< scales source tensor
|
| 320 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 321 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 322 |
+
|
| 323 |
+
//
|
| 324 |
+
// Methods
|
| 325 |
+
//
|
| 326 |
+
|
| 327 |
+
CUTLASS_HOST_DEVICE
|
| 328 |
+
Params():
|
| 329 |
+
alpha(ElementCompute(1)),
|
| 330 |
+
beta(ElementCompute(0)),
|
| 331 |
+
alpha_ptr(nullptr),
|
| 332 |
+
beta_ptr(nullptr) { }
|
| 333 |
+
|
| 334 |
+
CUTLASS_HOST_DEVICE
|
| 335 |
+
Params(
|
| 336 |
+
ElementCompute alpha,
|
| 337 |
+
ElementCompute beta
|
| 338 |
+
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 339 |
+
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
CUTLASS_HOST_DEVICE
|
| 343 |
+
Params(
|
| 344 |
+
ElementCompute alpha
|
| 345 |
+
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 346 |
+
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
CUTLASS_HOST_DEVICE
|
| 350 |
+
Params(
|
| 351 |
+
ElementCompute const *alpha_ptr,
|
| 352 |
+
ElementCompute const *beta_ptr
|
| 353 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 354 |
+
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
CUTLASS_HOST_DEVICE
|
| 358 |
+
Params(
|
| 359 |
+
ElementCompute const *alpha_ptr
|
| 360 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
|
| 361 |
+
|
| 362 |
+
}
|
| 363 |
+
};
|
| 364 |
+
|
| 365 |
+
private:
|
| 366 |
+
|
| 367 |
+
//
|
| 368 |
+
// Data members
|
| 369 |
+
//
|
| 370 |
+
|
| 371 |
+
ElementCompute alpha_;
|
| 372 |
+
ElementCompute beta_;
|
| 373 |
+
|
| 374 |
+
public:
|
| 375 |
+
|
| 376 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 377 |
+
CUTLASS_HOST_DEVICE
|
| 378 |
+
LinearCombinationClamp(Params const ¶ms) {
|
| 379 |
+
|
| 380 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 381 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
/// Returns true if source is needed
|
| 385 |
+
CUTLASS_HOST_DEVICE
|
| 386 |
+
bool is_source_needed() const {
|
| 387 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 388 |
+
|
| 389 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 390 |
+
|
| 391 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 392 |
+
|
| 393 |
+
return beta_ != ElementCompute(0);
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
/// Functionally required for serial reduction in the epilogue
|
| 397 |
+
CUTLASS_HOST_DEVICE
|
| 398 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 399 |
+
if (k_partition) {
|
| 400 |
+
beta_ = ElementCompute(1);
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 405 |
+
CUTLASS_HOST_DEVICE
|
| 406 |
+
FragmentOutput operator()(
|
| 407 |
+
FragmentAccumulator const &accumulator,
|
| 408 |
+
FragmentOutput const &source,
|
| 409 |
+
ElementCompute uniform = ElementCompute(0)) const {
|
| 410 |
+
|
| 411 |
+
// Convert source to interal compute numeric type
|
| 412 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 413 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 414 |
+
|
| 415 |
+
ComputeFragment converted_source = source_converter(source);
|
| 416 |
+
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
| 417 |
+
|
| 418 |
+
// Compute linear scaling in floating point
|
| 419 |
+
ComputeFragment intermediate;
|
| 420 |
+
|
| 421 |
+
multiplies<ComputeFragment> mul_add_source;
|
| 422 |
+
multiply_add<ComputeFragment> mul_add_accumulator;
|
| 423 |
+
|
| 424 |
+
// Float min-max
|
| 425 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 426 |
+
intermediate = converted_source;
|
| 427 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 428 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 429 |
+
intermediate = converted_accumulator;
|
| 430 |
+
} else {
|
| 431 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 432 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
//
|
| 436 |
+
// Convert float => ElementOutput_ with clamping
|
| 437 |
+
//
|
| 438 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 439 |
+
|
| 440 |
+
return destination_converter(intermediate);
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 444 |
+
CUTLASS_HOST_DEVICE
|
| 445 |
+
FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
|
| 446 |
+
|
| 447 |
+
// Convert source to interal compute numeric type
|
| 448 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 449 |
+
|
| 450 |
+
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
| 451 |
+
|
| 452 |
+
// Compute linear scaling in floating point
|
| 453 |
+
ComputeFragment intermediate;
|
| 454 |
+
|
| 455 |
+
multiplies<ComputeFragment> mul_add_accumulator;
|
| 456 |
+
|
| 457 |
+
// Float min-max
|
| 458 |
+
if (Scale == ScaleType::Nothing) {
|
| 459 |
+
intermediate = converted_accumulator;
|
| 460 |
+
} else {
|
| 461 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
//
|
| 465 |
+
// Convert float => ElementOutput_ with clamping
|
| 466 |
+
//
|
| 467 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 468 |
+
|
| 469 |
+
return destination_converter(intermediate);
|
| 470 |
+
}
|
| 471 |
+
};
|
| 472 |
+
|
| 473 |
+
#endif // Conditional guards to enable partial specialization for packed integers
|
| 474 |
+
|
| 475 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 476 |
+
|
| 477 |
+
/// Applies a linear combination operator to an array of elements then clamps
|
| 478 |
+
/// the output before converting to the output element type.
|
| 479 |
+
///
|
| 480 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 481 |
+
///
|
| 482 |
+
/// Note: The below method only when problem_size_K <= 256 for signed int8 gemm
|
| 483 |
+
/// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
|
| 484 |
+
/// above.
|
| 485 |
+
template <
|
| 486 |
+
/// Data type used to load and store< tensors
|
| 487 |
+
typename ElementOutput_,
|
| 488 |
+
/// Number of elements computed per operation
|
| 489 |
+
int Count,
|
| 490 |
+
///< Control Alpha and Beta scaling
|
| 491 |
+
ScaleType::Kind Scale = ScaleType::Default,
|
| 492 |
+
/// Rounding mode
|
| 493 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
|
| 494 |
+
class FastLinearCombinationClamp {
|
| 495 |
+
public:
|
| 496 |
+
using ElementOutput = ElementOutput_;
|
| 497 |
+
using ElementAccumulator = int;
|
| 498 |
+
using ElementCompute = float;
|
| 499 |
+
|
| 500 |
+
static_assert(
|
| 501 |
+
cutlass::platform::numeric_limits<ElementOutput>::is_integer,
|
| 502 |
+
"This elementwise op expects the output to be int.");
|
| 503 |
+
|
| 504 |
+
static int const kCount = Count;
|
| 505 |
+
|
| 506 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 507 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 508 |
+
using ComputeFragment = Array<ElementCompute, kCount>;
|
| 509 |
+
|
| 510 |
+
static FloatRoundStyle const kRound = Round;
|
| 511 |
+
|
| 512 |
+
static bool const kIsHeavy = false;
|
| 513 |
+
|
| 514 |
+
/// Host-constructable parameters structure
|
| 515 |
+
struct Params {
|
| 516 |
+
/// scales accumulators
|
| 517 |
+
ElementCompute alpha;
|
| 518 |
+
/// scales source tensor
|
| 519 |
+
ElementCompute beta;
|
| 520 |
+
/// pointer to accumulator scalar - if not null, loads it from memory
|
| 521 |
+
ElementCompute const *alpha_ptr;
|
| 522 |
+
/// pointer to source scalar - if not null, loads it from memory
|
| 523 |
+
ElementCompute const *beta_ptr;
|
| 524 |
+
|
| 525 |
+
//
|
| 526 |
+
// Methods
|
| 527 |
+
//
|
| 528 |
+
|
| 529 |
+
CUTLASS_HOST_DEVICE
|
| 530 |
+
Params()
|
| 531 |
+
: alpha(ElementCompute(1)),
|
| 532 |
+
beta(ElementCompute(0)),
|
| 533 |
+
alpha_ptr(nullptr),
|
| 534 |
+
beta_ptr(nullptr) {}
|
| 535 |
+
|
| 536 |
+
CUTLASS_HOST_DEVICE
|
| 537 |
+
Params(ElementCompute alpha, ElementCompute beta)
|
| 538 |
+
: alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {}
|
| 539 |
+
|
| 540 |
+
CUTLASS_HOST_DEVICE
|
| 541 |
+
Params(ElementCompute alpha)
|
| 542 |
+
: alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {}
|
| 543 |
+
|
| 544 |
+
CUTLASS_HOST_DEVICE
|
| 545 |
+
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
| 546 |
+
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
| 547 |
+
|
| 548 |
+
CUTLASS_HOST_DEVICE
|
| 549 |
+
Params(ElementCompute const *alpha_ptr)
|
| 550 |
+
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {}
|
| 551 |
+
};
|
| 552 |
+
|
| 553 |
+
private:
|
| 554 |
+
//
|
| 555 |
+
// Data members
|
| 556 |
+
//
|
| 557 |
+
|
| 558 |
+
ElementCompute alpha_;
|
| 559 |
+
ElementCompute beta_;
|
| 560 |
+
|
| 561 |
+
public:
|
| 562 |
+
/// Constructs the function object, possibly loading from pointers in host
|
| 563 |
+
/// memory
|
| 564 |
+
CUTLASS_HOST_DEVICE
|
| 565 |
+
FastLinearCombinationClamp(Params const ¶ms) {
|
| 566 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 567 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
/// Returns true if source is needed
|
| 571 |
+
CUTLASS_HOST_DEVICE
|
| 572 |
+
bool is_source_needed() const {
|
| 573 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 574 |
+
|
| 575 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 576 |
+
|
| 577 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 578 |
+
|
| 579 |
+
return beta_ != ElementCompute(0);
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
/// Functionally required for serial reduction in the epilogue
|
| 583 |
+
CUTLASS_HOST_DEVICE
|
| 584 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 585 |
+
if (k_partition) {
|
| 586 |
+
beta_ = ElementCompute(1);
|
| 587 |
+
}
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 591 |
+
CUTLASS_HOST_DEVICE
|
| 592 |
+
FragmentOutput operator()(FragmentAccumulator const &accumulator,
|
| 593 |
+
FragmentOutput const &source,
|
| 594 |
+
ElementCompute uniform = ElementCompute(0)) const {
|
| 595 |
+
// Convert source to interal compute numeric type
|
| 596 |
+
FastNumericArrayConverter<ElementCompute, ElementOutput, kCount, Round>
|
| 597 |
+
source_converter;
|
| 598 |
+
FastNumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
| 599 |
+
accumulator_converter;
|
| 600 |
+
|
| 601 |
+
ComputeFragment converted_source = source_converter(source);
|
| 602 |
+
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
| 603 |
+
|
| 604 |
+
// Compute linear scaling in floating point
|
| 605 |
+
ComputeFragment intermediate;
|
| 606 |
+
|
| 607 |
+
multiplies<ComputeFragment> mul_add_source;
|
| 608 |
+
multiply_add<ComputeFragment> mul_add_accumulator;
|
| 609 |
+
|
| 610 |
+
minimum<ComputeFragment> min_accumulator;
|
| 611 |
+
maximum<ComputeFragment> max_accumulator;
|
| 612 |
+
|
| 613 |
+
// Float min-max
|
| 614 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 615 |
+
intermediate = converted_source;
|
| 616 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 617 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 618 |
+
intermediate = converted_accumulator;
|
| 619 |
+
} else {
|
| 620 |
+
intermediate =
|
| 621 |
+
mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 622 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator,
|
| 623 |
+
intermediate); // D = alpha * Accum + X
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
/// Clamping constant value
|
| 627 |
+
ElementCompute const kClamp =
|
| 628 |
+
ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
|
| 629 |
+
|
| 630 |
+
intermediate = max_accumulator(intermediate, -kClamp);
|
| 631 |
+
intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
|
| 632 |
+
|
| 633 |
+
// Convert to destination numeric type
|
| 634 |
+
FastNumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
| 635 |
+
destination_converter;
|
| 636 |
+
|
| 637 |
+
return destination_converter(intermediate);
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 641 |
+
CUTLASS_HOST_DEVICE
|
| 642 |
+
FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
|
| 643 |
+
|
| 644 |
+
// Convert source to interal compute numeric type
|
| 645 |
+
FastNumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
| 646 |
+
accumulator_converter;
|
| 647 |
+
|
| 648 |
+
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
| 649 |
+
|
| 650 |
+
// Compute linear scaling in floating point
|
| 651 |
+
ComputeFragment intermediate;
|
| 652 |
+
|
| 653 |
+
multiplies<ComputeFragment> mul_accumulator;
|
| 654 |
+
|
| 655 |
+
minimum<ComputeFragment> min_accumulator;
|
| 656 |
+
maximum<ComputeFragment> max_accumulator;
|
| 657 |
+
|
| 658 |
+
// Float min-max
|
| 659 |
+
if (Scale == ScaleType::Nothing) {
|
| 660 |
+
intermediate = converted_accumulator;
|
| 661 |
+
} else {
|
| 662 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator);
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
/// Clamping constant value
|
| 666 |
+
ElementCompute const kClamp =
|
| 667 |
+
ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
|
| 668 |
+
|
| 669 |
+
intermediate = max_accumulator(intermediate, -kClamp);
|
| 670 |
+
intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
|
| 671 |
+
|
| 672 |
+
// Convert to destination numeric type
|
| 673 |
+
FastNumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
| 674 |
+
destination_converter;
|
| 675 |
+
|
| 676 |
+
return destination_converter(intermediate);
|
| 677 |
+
}
|
| 678 |
+
};
|
| 679 |
+
|
| 680 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 681 |
+
|
| 682 |
+
} // namespace thread
|
| 683 |
+
} // namespace epilogue
|
| 684 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
|
| 33 |
+
\brief Functor performing linear combination followed by dGelu operation
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/half.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/array.h"
|
| 42 |
+
#include "cutlass/constants.h"
|
| 43 |
+
#include "cutlass/fast_math.h"
|
| 44 |
+
#include "cutlass/functional.h"
|
| 45 |
+
#include "cutlass/numeric_conversion.h"
|
| 46 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace epilogue {
|
| 52 |
+
namespace thread {
|
| 53 |
+
|
| 54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
/// Applies a linear combination operator to an array of elements.
|
| 57 |
+
///
|
| 58 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 59 |
+
///
|
| 60 |
+
template <
|
| 61 |
+
typename ElementCompute_, ///< Data type returned by this functor
|
| 62 |
+
typename ElementAccumulator_, ///< Data type of accumulators
|
| 63 |
+
typename ElementSource_, ///< Data type of source tensor
|
| 64 |
+
typename ElementTensor_, ///< Data type of additional tensor
|
| 65 |
+
int Count, ///< Number of elements computed per operation
|
| 66 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 67 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 68 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 69 |
+
>
|
| 70 |
+
class LinearCombinationDGelu {
|
| 71 |
+
public:
|
| 72 |
+
|
| 73 |
+
using ElementOutput = ElementSource_;
|
| 74 |
+
using ElementCompute = ElementCompute_;
|
| 75 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 76 |
+
using ElementSource = ElementSource_;
|
| 77 |
+
using ElementTensor = ElementTensor_;
|
| 78 |
+
|
| 79 |
+
static bool const kIsHeavy = true;
|
| 80 |
+
|
| 81 |
+
static int const kCount = Count;
|
| 82 |
+
|
| 83 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 84 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 85 |
+
using FragmentSource = Array<ElementSource, kCount>;
|
| 86 |
+
using FragmentTensor = Array<ElementTensor, kCount>;
|
| 87 |
+
|
| 88 |
+
static FloatRoundStyle const kRound = Round;
|
| 89 |
+
|
| 90 |
+
/// Host-constructable parameters structure
|
| 91 |
+
struct Params {
|
| 92 |
+
|
| 93 |
+
ElementCompute alpha; ///< scales accumulators
|
| 94 |
+
ElementCompute beta; ///< scales source tensor
|
| 95 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 96 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 97 |
+
ElementCompute threshold; ///< minimum value that is output
|
| 98 |
+
//
|
| 99 |
+
// Methods
|
| 100 |
+
//
|
| 101 |
+
|
| 102 |
+
CUTLASS_HOST_DEVICE
|
| 103 |
+
Params():
|
| 104 |
+
alpha(ElementCompute(1)),
|
| 105 |
+
beta(ElementCompute(0)),
|
| 106 |
+
threshold(ElementCompute(0)),
|
| 107 |
+
alpha_ptr(nullptr),
|
| 108 |
+
beta_ptr(nullptr) { }
|
| 109 |
+
|
| 110 |
+
CUTLASS_HOST_DEVICE
|
| 111 |
+
Params(
|
| 112 |
+
ElementCompute alpha,
|
| 113 |
+
ElementCompute beta,
|
| 114 |
+
ElementCompute threshold = ElementCompute(0)
|
| 115 |
+
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 116 |
+
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
CUTLASS_HOST_DEVICE
|
| 120 |
+
Params(
|
| 121 |
+
ElementCompute const *alpha_ptr,
|
| 122 |
+
ElementCompute const *beta_ptr,
|
| 123 |
+
ElementCompute threshold = ElementCompute(0)
|
| 124 |
+
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 125 |
+
|
| 126 |
+
}
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
private:
|
| 130 |
+
|
| 131 |
+
//
|
| 132 |
+
// Data members
|
| 133 |
+
//
|
| 134 |
+
|
| 135 |
+
ElementCompute alpha_;
|
| 136 |
+
ElementCompute beta_;
|
| 137 |
+
ElementCompute threshold_;
|
| 138 |
+
bool participates_in_reduction_;
|
| 139 |
+
|
| 140 |
+
public:
|
| 141 |
+
|
| 142 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 143 |
+
CUTLASS_HOST_DEVICE
|
| 144 |
+
LinearCombinationDGelu(Params const ¶ms) {
|
| 145 |
+
|
| 146 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 147 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 148 |
+
threshold_ = params.threshold;
|
| 149 |
+
participates_in_reduction_ = true;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/// Returns true if source is needed
|
| 153 |
+
CUTLASS_HOST_DEVICE
|
| 154 |
+
bool is_source_needed() const {
|
| 155 |
+
return beta_ != ElementCompute(0);
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
/// Returns true if the threadblock computes the reduction
|
| 159 |
+
CUTLASS_HOST_DEVICE
|
| 160 |
+
bool participates_in_reduction() const {
|
| 161 |
+
return participates_in_reduction_;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/// Functionally required for serial reduction in the epilogue
|
| 165 |
+
CUTLASS_HOST_DEVICE
|
| 166 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 167 |
+
if (k_partition) {
|
| 168 |
+
beta_ = ElementCompute(1);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
if (k_partition != k_partition_count - 1) {
|
| 172 |
+
// set to NaN to make ReLU no-op for all except last k partitions
|
| 173 |
+
int64_t allones = -1;
|
| 174 |
+
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
|
| 175 |
+
// Avoid computing the reduction if this isn't the final Split-K slice
|
| 176 |
+
participates_in_reduction_ = false;
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 181 |
+
CUTLASS_HOST_DEVICE
|
| 182 |
+
FragmentCompute operator()(
|
| 183 |
+
FragmentAccumulator const &accumulator,
|
| 184 |
+
FragmentSource const &source,
|
| 185 |
+
FragmentTensor const &tensor) const {
|
| 186 |
+
|
| 187 |
+
// Convert source to interal compute numeric type
|
| 188 |
+
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
|
| 189 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 190 |
+
|
| 191 |
+
FragmentCompute converted_source = source_converter(source);
|
| 192 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 193 |
+
|
| 194 |
+
// Perform binary operations
|
| 195 |
+
FragmentCompute intermediate;
|
| 196 |
+
|
| 197 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 198 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 199 |
+
|
| 200 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 201 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 202 |
+
|
| 203 |
+
dGELU<ElementCompute> gelu_op;
|
| 204 |
+
|
| 205 |
+
// dGelu
|
| 206 |
+
CUTLASS_PRAGMA_UNROLL
|
| 207 |
+
for (int i = 0; i < kCount; ++i) {
|
| 208 |
+
intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i]));
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
return intermediate;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 215 |
+
CUTLASS_HOST_DEVICE
|
| 216 |
+
FragmentCompute operator()(
|
| 217 |
+
FragmentAccumulator const &accumulator,
|
| 218 |
+
FragmentTensor const &tensor) const {
|
| 219 |
+
|
| 220 |
+
// Convert source to interal compute numeric type
|
| 221 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 222 |
+
|
| 223 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 224 |
+
|
| 225 |
+
// Perform binary operations
|
| 226 |
+
FragmentCompute intermediate;
|
| 227 |
+
|
| 228 |
+
multiplies<FragmentCompute> mul_accumulator;
|
| 229 |
+
|
| 230 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 231 |
+
|
| 232 |
+
dGELU<ElementCompute> gelu_op;
|
| 233 |
+
|
| 234 |
+
// dGelu with conversion
|
| 235 |
+
CUTLASS_PRAGMA_UNROLL
|
| 236 |
+
for (int i = 0; i < kCount; ++i) {
|
| 237 |
+
intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i]));
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
return intermediate;
|
| 241 |
+
}
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 245 |
+
|
| 246 |
+
} // namespace thread
|
| 247 |
+
} // namespace epilogue
|
| 248 |
+
} // namespace cutlass
|
| 249 |
+
|
| 250 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/half.h"
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace epilogue {
|
| 49 |
+
namespace thread {
|
| 50 |
+
|
| 51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
/// Applies a linear combination operator to an array of elements.
|
| 54 |
+
///
|
| 55 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 56 |
+
///
|
| 57 |
+
template <
|
| 58 |
+
typename ElementCompute_, ///< Data type returned by this functor
|
| 59 |
+
typename ElementAccumulator_, ///< Data type of accumulators
|
| 60 |
+
typename ElementSource_, ///< Data type of source tensor
|
| 61 |
+
typename ElementTensor_, ///< Data type of additional tensor
|
| 62 |
+
int Count, ///< Number of elements computed per operation
|
| 63 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 64 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 65 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 66 |
+
>
|
| 67 |
+
class LinearCombinationDRelu {
|
| 68 |
+
public:
|
| 69 |
+
|
| 70 |
+
using ElementOutput = ElementSource_;
|
| 71 |
+
using ElementCompute = ElementCompute_;
|
| 72 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 73 |
+
using ElementSource = ElementSource_;
|
| 74 |
+
using ElementTensor = ElementTensor_;
|
| 75 |
+
|
| 76 |
+
static int const kCount = Count;
|
| 77 |
+
|
| 78 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 79 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 80 |
+
using FragmentSource = Array<ElementSource, kCount>;
|
| 81 |
+
using FragmentTensor = Array<ElementTensor, kCount>;
|
| 82 |
+
|
| 83 |
+
static FloatRoundStyle const kRound = Round;
|
| 84 |
+
|
| 85 |
+
/// Host-constructable parameters structure
|
| 86 |
+
struct Params {
|
| 87 |
+
|
| 88 |
+
ElementCompute alpha; ///< scales accumulators
|
| 89 |
+
ElementCompute beta; ///< scales source tensor
|
| 90 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 91 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 92 |
+
ElementCompute threshold; ///< minimum value that is output
|
| 93 |
+
//
|
| 94 |
+
// Methods
|
| 95 |
+
//
|
| 96 |
+
|
| 97 |
+
CUTLASS_HOST_DEVICE
|
| 98 |
+
Params():
|
| 99 |
+
alpha(ElementCompute(1)),
|
| 100 |
+
beta(ElementCompute(0)),
|
| 101 |
+
threshold(ElementCompute(0)),
|
| 102 |
+
alpha_ptr(nullptr),
|
| 103 |
+
beta_ptr(nullptr) { }
|
| 104 |
+
|
| 105 |
+
CUTLASS_HOST_DEVICE
|
| 106 |
+
Params(
|
| 107 |
+
ElementCompute alpha,
|
| 108 |
+
ElementCompute beta,
|
| 109 |
+
ElementCompute threshold = ElementCompute(0)
|
| 110 |
+
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 111 |
+
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
CUTLASS_HOST_DEVICE
|
| 115 |
+
Params(
|
| 116 |
+
ElementCompute const *alpha_ptr,
|
| 117 |
+
ElementCompute const *beta_ptr,
|
| 118 |
+
ElementCompute threshold = ElementCompute(0)
|
| 119 |
+
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 120 |
+
|
| 121 |
+
}
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
private:
|
| 125 |
+
|
| 126 |
+
//
|
| 127 |
+
// Data members
|
| 128 |
+
//
|
| 129 |
+
|
| 130 |
+
ElementCompute alpha_;
|
| 131 |
+
ElementCompute beta_;
|
| 132 |
+
ElementTensor threshold_;
|
| 133 |
+
bool participates_in_reduction_;
|
| 134 |
+
|
| 135 |
+
public:
|
| 136 |
+
|
| 137 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 138 |
+
CUTLASS_HOST_DEVICE
|
| 139 |
+
LinearCombinationDRelu(Params const ¶ms) {
|
| 140 |
+
|
| 141 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 142 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 143 |
+
threshold_ = ElementTensor(params.threshold);
|
| 144 |
+
participates_in_reduction_ = true;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/// Returns true if source is needed
|
| 148 |
+
CUTLASS_HOST_DEVICE
|
| 149 |
+
bool is_source_needed() const {
|
| 150 |
+
return beta_ != ElementCompute(0);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
/// Returns true if the threadblock computes the reduction
|
| 154 |
+
CUTLASS_HOST_DEVICE
|
| 155 |
+
bool participates_in_reduction() const {
|
| 156 |
+
return participates_in_reduction_;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/// Functionally required for serial reduction in the epilogue
|
| 160 |
+
CUTLASS_DEVICE
|
| 161 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 162 |
+
if (k_partition) {
|
| 163 |
+
beta_ = ElementCompute(1);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
if (k_partition != k_partition_count - 1) {
|
| 167 |
+
// set to NaN to make ReLU no-op for all except last k partitions
|
| 168 |
+
int64_t allones = -1;
|
| 169 |
+
threshold_ = reinterpret_cast<ElementTensor const &>(allones);
|
| 170 |
+
participates_in_reduction_ = false;
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 175 |
+
CUTLASS_HOST_DEVICE
|
| 176 |
+
FragmentCompute operator()(
|
| 177 |
+
FragmentAccumulator const &accumulator,
|
| 178 |
+
FragmentSource const &source,
|
| 179 |
+
FragmentTensor const &tensor) const {
|
| 180 |
+
|
| 181 |
+
// Convert source to interal compute numeric type
|
| 182 |
+
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
|
| 183 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 184 |
+
|
| 185 |
+
FragmentCompute converted_source = source_converter(source);
|
| 186 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 187 |
+
|
| 188 |
+
// Perform binary operations
|
| 189 |
+
FragmentCompute intermediate;
|
| 190 |
+
|
| 191 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 192 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 193 |
+
|
| 194 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C
|
| 195 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 196 |
+
|
| 197 |
+
// dReLU = (cond ? dy : 0)
|
| 198 |
+
CUTLASS_PRAGMA_UNROLL
|
| 199 |
+
for (int i = 0; i < kCount; ++i) {
|
| 200 |
+
ElementTensor cond = tensor[i];
|
| 201 |
+
if (cond <= threshold_) {
|
| 202 |
+
intermediate[i] = ElementCompute();
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
return intermediate;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 210 |
+
CUTLASS_HOST_DEVICE
|
| 211 |
+
FragmentCompute operator()(
|
| 212 |
+
FragmentAccumulator const &accumulator,
|
| 213 |
+
FragmentTensor const &tensor) const {
|
| 214 |
+
|
| 215 |
+
// Convert source to interal compute numeric type
|
| 216 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 217 |
+
|
| 218 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 219 |
+
|
| 220 |
+
// Perform binary operations
|
| 221 |
+
FragmentCompute intermediate;
|
| 222 |
+
|
| 223 |
+
multiplies<FragmentCompute> mul_accumulator;
|
| 224 |
+
|
| 225 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 226 |
+
|
| 227 |
+
// dReLU = (cond ? dy : 0)
|
| 228 |
+
CUTLASS_PRAGMA_UNROLL
|
| 229 |
+
for (int i = 0; i < kCount; ++i) {
|
| 230 |
+
ElementTensor cond = tensor[i];
|
| 231 |
+
if (cond <= threshold_) {
|
| 232 |
+
intermediate[i] = ElementCompute();
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
return intermediate;
|
| 237 |
+
}
|
| 238 |
+
};
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 242 |
+
|
| 243 |
+
/// Applies a linear combination operator to an array of elements.
|
| 244 |
+
///
|
| 245 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 246 |
+
///
|
| 247 |
+
template <
|
| 248 |
+
typename ElementCompute_, ///< Data type returned by this functor
|
| 249 |
+
typename ElementAccumulator_, ///< Data type of accumulators
|
| 250 |
+
typename ElementSource_, ///< Data type of source tensor
|
| 251 |
+
int Count, ///< Number of elements computed per operation
|
| 252 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 253 |
+
>
|
| 254 |
+
class LinearCombinationDReluConditionalBits {
|
| 255 |
+
public:
|
| 256 |
+
|
| 257 |
+
using ElementOutput = ElementSource_;
|
| 258 |
+
using ElementCompute = ElementCompute_;
|
| 259 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 260 |
+
using ElementSource = ElementSource_;
|
| 261 |
+
using ElementTensor = uint1b_t;
|
| 262 |
+
|
| 263 |
+
static bool const kIsHeavy = false;
|
| 264 |
+
|
| 265 |
+
static int const kCount = Count;
|
| 266 |
+
|
| 267 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 268 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 269 |
+
using FragmentSource = Array<ElementSource, kCount>;
|
| 270 |
+
using FragmentTensor = Array<ElementTensor, kCount>;
|
| 271 |
+
|
| 272 |
+
static FloatRoundStyle const kRound = Round;
|
| 273 |
+
|
| 274 |
+
/// Host-constructable parameters structure
|
| 275 |
+
struct Params {
|
| 276 |
+
|
| 277 |
+
ElementCompute alpha; ///< scales accumulators
|
| 278 |
+
ElementCompute beta; ///< scales source tensor
|
| 279 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 280 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 281 |
+
//
|
| 282 |
+
// Methods
|
| 283 |
+
//
|
| 284 |
+
|
| 285 |
+
CUTLASS_HOST_DEVICE
|
| 286 |
+
Params():
|
| 287 |
+
alpha(ElementCompute(1)),
|
| 288 |
+
beta(ElementCompute(0)),
|
| 289 |
+
alpha_ptr(nullptr),
|
| 290 |
+
beta_ptr(nullptr) { }
|
| 291 |
+
|
| 292 |
+
CUTLASS_HOST_DEVICE
|
| 293 |
+
Params(
|
| 294 |
+
ElementCompute alpha,
|
| 295 |
+
ElementCompute beta
|
| 296 |
+
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 297 |
+
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
CUTLASS_HOST_DEVICE
|
| 301 |
+
Params(
|
| 302 |
+
ElementCompute const *alpha_ptr,
|
| 303 |
+
ElementCompute const *beta_ptr
|
| 304 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 305 |
+
|
| 306 |
+
}
|
| 307 |
+
};
|
| 308 |
+
|
| 309 |
+
private:
|
| 310 |
+
|
| 311 |
+
//
|
| 312 |
+
// Data members
|
| 313 |
+
//
|
| 314 |
+
|
| 315 |
+
ElementCompute alpha_;
|
| 316 |
+
ElementCompute beta_;
|
| 317 |
+
FragmentTensor predicate_mask_;
|
| 318 |
+
bool participates_in_reduction_;
|
| 319 |
+
|
| 320 |
+
public:
|
| 321 |
+
|
| 322 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 323 |
+
CUTLASS_HOST_DEVICE
|
| 324 |
+
LinearCombinationDReluConditionalBits(Params const ¶ms) {
|
| 325 |
+
|
| 326 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 327 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 328 |
+
participates_in_reduction_ = true;
|
| 329 |
+
predicate_mask_.clear();
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
/// Returns true if source is needed
|
| 333 |
+
CUTLASS_HOST_DEVICE
|
| 334 |
+
bool is_source_needed() const {
|
| 335 |
+
return beta_ != ElementCompute(0);
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
/// Returns true if the threadblock computes the reduction
|
| 339 |
+
CUTLASS_HOST_DEVICE
|
| 340 |
+
bool participates_in_reduction() const {
|
| 341 |
+
return participates_in_reduction_;
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
/// Functionally required for serial reduction in the epilogue
|
| 345 |
+
CUTLASS_HOST_DEVICE
|
| 346 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 347 |
+
predicate_mask_.clear();
|
| 348 |
+
|
| 349 |
+
if (k_partition) {
|
| 350 |
+
beta_ = ElementCompute(1);
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
if (k_partition != k_partition_count - 1) {
|
| 354 |
+
// Avoid computing the reduction if this isn't the final Split-K slice
|
| 355 |
+
participates_in_reduction_ = false;
|
| 356 |
+
|
| 357 |
+
bit_not<FragmentTensor> not_op;
|
| 358 |
+
predicate_mask_ = not_op(predicate_mask_);
|
| 359 |
+
}
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 363 |
+
CUTLASS_DEVICE
|
| 364 |
+
FragmentCompute operator()(
|
| 365 |
+
FragmentAccumulator const &accumulator,
|
| 366 |
+
FragmentSource const &source,
|
| 367 |
+
FragmentTensor const &tensor) const {
|
| 368 |
+
|
| 369 |
+
// Convert source to interal compute numeric type
|
| 370 |
+
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
|
| 371 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 372 |
+
|
| 373 |
+
FragmentCompute converted_source = source_converter(source);
|
| 374 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 375 |
+
|
| 376 |
+
// Perform binary operations
|
| 377 |
+
FragmentCompute intermediate;
|
| 378 |
+
|
| 379 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 380 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 381 |
+
|
| 382 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 383 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 384 |
+
|
| 385 |
+
bit_or<FragmentTensor> or_op;
|
| 386 |
+
|
| 387 |
+
FragmentTensor predicates = or_op(tensor, predicate_mask_);
|
| 388 |
+
|
| 389 |
+
// Obtain from packed bits
|
| 390 |
+
bool conditions[kCount];
|
| 391 |
+
UnpackPredicates<kCount> unpack_predicates;
|
| 392 |
+
|
| 393 |
+
unpack_predicates(conditions, predicates);
|
| 394 |
+
|
| 395 |
+
// dReLU = (cond ? dy : 0)
|
| 396 |
+
CUTLASS_PRAGMA_UNROLL
|
| 397 |
+
for (int i = 0; i < kCount; ++i) {
|
| 398 |
+
if (!conditions[i]) {
|
| 399 |
+
intermediate[i] = ElementCompute();
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
return intermediate;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 407 |
+
CUTLASS_HOST_DEVICE
|
| 408 |
+
FragmentCompute operator()(
|
| 409 |
+
FragmentAccumulator const &accumulator,
|
| 410 |
+
FragmentTensor const &tensor) const {
|
| 411 |
+
|
| 412 |
+
// Convert source to interal compute numeric type
|
| 413 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 414 |
+
|
| 415 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 416 |
+
|
| 417 |
+
// Perform binary operations
|
| 418 |
+
FragmentCompute intermediate;
|
| 419 |
+
|
| 420 |
+
multiplies<FragmentCompute> mul_accumulator;
|
| 421 |
+
|
| 422 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 423 |
+
|
| 424 |
+
bit_or<FragmentTensor> or_op;
|
| 425 |
+
|
| 426 |
+
FragmentTensor predicates = or_op(tensor, predicate_mask_);
|
| 427 |
+
|
| 428 |
+
// Obtain from packed bits
|
| 429 |
+
bool conditions[kCount];
|
| 430 |
+
UnpackPredicates<kCount> unpack_predicates;
|
| 431 |
+
|
| 432 |
+
unpack_predicates(conditions, predicates);
|
| 433 |
+
|
| 434 |
+
// dReLU = (cond ? dy : 0)
|
| 435 |
+
CUTLASS_PRAGMA_UNROLL
|
| 436 |
+
for (int i = 0; i < kCount; ++i) {
|
| 437 |
+
if (!conditions[i]) {
|
| 438 |
+
intermediate[i] = ElementCompute();
|
| 439 |
+
}
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
return intermediate;
|
| 443 |
+
}
|
| 444 |
+
};
|
| 445 |
+
|
| 446 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 447 |
+
|
| 448 |
+
} // namespace thread
|
| 449 |
+
} // namespace epilogue
|
| 450 |
+
} // namespace cutlass
|
| 451 |
+
|
| 452 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination with GELU operations used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 39 |
+
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
| 40 |
+
|
| 41 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace epilogue {
|
| 45 |
+
namespace thread {
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
/// Applies a linear combination operator followed by the GELU activation to an array of elements.
|
| 50 |
+
///
|
| 51 |
+
/// D = gelu(alpha * accumulator + beta * source + uniform)
|
| 52 |
+
///
|
| 53 |
+
template <
|
| 54 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 55 |
+
int Count, ///< Number of elements computed per operation
|
| 56 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 57 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 58 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 59 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 60 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 61 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 62 |
+
>
|
| 63 |
+
using LinearCombinationGELU = LinearCombinationGeneric<GELU, ElementOutput_, Count, ElementAccumulator_,
|
| 64 |
+
ElementCompute_, Scale, Round, true>;
|
| 65 |
+
|
| 66 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 67 |
+
|
| 68 |
+
} // namespace thread
|
| 69 |
+
} // namespace epilogue
|
| 70 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination operations used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/numeric_types.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/functional.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace epilogue {
|
| 48 |
+
namespace thread {
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
template <class Activation, class = void>
|
| 53 |
+
struct GenericActivationTraits {
|
| 54 |
+
static constexpr bool IsArgumentsNeeded = false;
|
| 55 |
+
struct Arguments {};
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
template <class Activation>
|
| 59 |
+
struct GenericActivationTraits<Activation, decltype(typename Activation::Arguments(), void())> {
|
| 60 |
+
static constexpr bool IsArgumentsNeeded = true;
|
| 61 |
+
using Arguments = typename Activation::Arguments;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
template <typename T>
|
| 65 |
+
struct LinearCombinationGenericParams {
|
| 66 |
+
T alpha; ///< scales accumulators
|
| 67 |
+
T beta; ///< scales source tensor
|
| 68 |
+
T const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 69 |
+
T const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 70 |
+
|
| 71 |
+
//
|
| 72 |
+
// Methods
|
| 73 |
+
//
|
| 74 |
+
|
| 75 |
+
CUTLASS_HOST_DEVICE
|
| 76 |
+
LinearCombinationGenericParams():
|
| 77 |
+
alpha(T(1)),
|
| 78 |
+
beta(T(0)),
|
| 79 |
+
alpha_ptr(nullptr),
|
| 80 |
+
beta_ptr(nullptr) { }
|
| 81 |
+
|
| 82 |
+
CUTLASS_HOST_DEVICE
|
| 83 |
+
LinearCombinationGenericParams(
|
| 84 |
+
T alpha,
|
| 85 |
+
T beta = T(0)
|
| 86 |
+
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { }
|
| 87 |
+
|
| 88 |
+
CUTLASS_HOST_DEVICE
|
| 89 |
+
LinearCombinationGenericParams(
|
| 90 |
+
T const *alpha_ptr,
|
| 91 |
+
T const *beta_ptr = nullptr
|
| 92 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { }
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 96 |
+
|
| 97 |
+
/// Applies a linear combination operator followed by an activation function to an array of elements.
|
| 98 |
+
///
|
| 99 |
+
/// D = activation(alpha * accumulator + beta * source + uniform)
|
| 100 |
+
///
|
| 101 |
+
template <
|
| 102 |
+
template<typename T> class ActivationFunctor,
|
| 103 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 104 |
+
int Count, ///< Number of elements computed per operation
|
| 105 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 106 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 107 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 108 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 109 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 110 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
| 111 |
+
bool IsHeavy = false
|
| 112 |
+
>
|
| 113 |
+
class LinearCombinationGeneric {
|
| 114 |
+
public:
|
| 115 |
+
|
| 116 |
+
using ElementOutput = ElementOutput_;
|
| 117 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 118 |
+
using ElementCompute = ElementCompute_;
|
| 119 |
+
|
| 120 |
+
static bool const kIsHeavy = IsHeavy;
|
| 121 |
+
static int const kCount = Count;
|
| 122 |
+
static const ScaleType::Kind kScale = Scale;
|
| 123 |
+
|
| 124 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 125 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 126 |
+
using FragmentSource = Array<ElementOutput, kCount>;
|
| 127 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 128 |
+
|
| 129 |
+
static FloatRoundStyle const kRound = Round;
|
| 130 |
+
|
| 131 |
+
/// Host-constructable parameters structure
|
| 132 |
+
struct Params
|
| 133 |
+
: LinearCombinationGenericParams<ElementCompute>,
|
| 134 |
+
GenericActivationTraits<ActivationFunctor<ElementCompute>>::Arguments {
|
| 135 |
+
using LinearCombinationGenericParams<ElementCompute>::LinearCombinationGenericParams;
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
private:
|
| 139 |
+
|
| 140 |
+
//
|
| 141 |
+
// Data members
|
| 142 |
+
//
|
| 143 |
+
|
| 144 |
+
Params params_;
|
| 145 |
+
bool skip_elementwise_;
|
| 146 |
+
|
| 147 |
+
public:
|
| 148 |
+
|
| 149 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 150 |
+
CUTLASS_HOST_DEVICE
|
| 151 |
+
LinearCombinationGeneric(Params const ¶ms) {
|
| 152 |
+
params_ = params;
|
| 153 |
+
params_.alpha = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 154 |
+
params_.beta = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 155 |
+
skip_elementwise_ = false;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
/// Returns true if source is needed
|
| 159 |
+
CUTLASS_HOST_DEVICE
|
| 160 |
+
bool is_source_needed() const {
|
| 161 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 162 |
+
|
| 163 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 164 |
+
|
| 165 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 166 |
+
|
| 167 |
+
return params_.beta != ElementCompute(0);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/// Functionally required for serial reduction in the epilogue
|
| 171 |
+
CUTLASS_HOST_DEVICE
|
| 172 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 173 |
+
if (k_partition) {
|
| 174 |
+
params_.beta = ElementCompute(1);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
if (k_partition != k_partition_count - 1) {
|
| 178 |
+
skip_elementwise_ = true;
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 183 |
+
CUTLASS_HOST_DEVICE
|
| 184 |
+
FragmentOutput operator()(
|
| 185 |
+
FragmentAccumulator const &accumulator,
|
| 186 |
+
FragmentOutput const &source) const {
|
| 187 |
+
|
| 188 |
+
// Convert source to interal compute numeric type
|
| 189 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 190 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 191 |
+
|
| 192 |
+
FragmentCompute converted_source = source_converter(source);
|
| 193 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 194 |
+
|
| 195 |
+
// Perform binary operations
|
| 196 |
+
|
| 197 |
+
FragmentCompute intermediate;
|
| 198 |
+
|
| 199 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 200 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 201 |
+
ActivationFunctor<FragmentCompute> activation;
|
| 202 |
+
|
| 203 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 204 |
+
intermediate = converted_source;
|
| 205 |
+
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 206 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 207 |
+
intermediate = converted_accumulator;
|
| 208 |
+
} else {
|
| 209 |
+
intermediate = mul_add_source(params_.beta, converted_source); // X = beta * C + uniform
|
| 210 |
+
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
if constexpr (GenericActivationTraits<ActivationFunctor<ElementCompute>>::IsArgumentsNeeded) {
|
| 214 |
+
intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_);
|
| 215 |
+
} else {
|
| 216 |
+
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
// Convert to destination numeric type
|
| 220 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 221 |
+
|
| 222 |
+
return destination_converter(intermediate);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 226 |
+
CUTLASS_HOST_DEVICE
|
| 227 |
+
FragmentOutput operator()(
|
| 228 |
+
FragmentAccumulator const &accumulator) const {
|
| 229 |
+
|
| 230 |
+
// Convert source to interal compute numeric type
|
| 231 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 232 |
+
|
| 233 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 234 |
+
|
| 235 |
+
// Perform binary operations
|
| 236 |
+
|
| 237 |
+
FragmentCompute intermediate;
|
| 238 |
+
|
| 239 |
+
multiplies<FragmentCompute> mul_add_accumulator;
|
| 240 |
+
ActivationFunctor<FragmentCompute> activation;
|
| 241 |
+
|
| 242 |
+
if (Scale == ScaleType::Nothing) {
|
| 243 |
+
intermediate = converted_accumulator;
|
| 244 |
+
} else {
|
| 245 |
+
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator); // D = alpha * Accum
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
if constexpr (GenericActivationTraits<ActivationFunctor<FragmentCompute>>::IsArgumentsNeeded) {
|
| 249 |
+
intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_);
|
| 250 |
+
} else {
|
| 251 |
+
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// Convert to destination numeric type
|
| 255 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 256 |
+
|
| 257 |
+
return destination_converter(intermediate);
|
| 258 |
+
}
|
| 259 |
+
};
|
| 260 |
+
|
| 261 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 262 |
+
|
| 263 |
+
} // namespace thread
|
| 264 |
+
} // namespace epilogue
|
| 265 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_generic_with_scaling.h
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Functor performing linear combination operations with a generic element-wise activation
|
| 34 |
+
function. Scaling factors are applied to operands A, B, and C. The pre-activation auxiliary
|
| 35 |
+
output is also returned.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/numeric_types.h"
|
| 42 |
+
#include "cutlass/array.h"
|
| 43 |
+
#include "cutlass/functional.h"
|
| 44 |
+
#include "cutlass/numeric_conversion.h"
|
| 45 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 46 |
+
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace epilogue {
|
| 52 |
+
namespace thread {
|
| 53 |
+
|
| 54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
/// Applies a linear combination operator to an array of elements.
|
| 57 |
+
///
|
| 58 |
+
/// Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias
|
| 59 |
+
/// D = activation(Aux)
|
| 60 |
+
///
|
| 61 |
+
template <
|
| 62 |
+
template<typename T> class ActivationFunctor,
|
| 63 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 64 |
+
typename ElementAuxOutput_, ///< Data type used to store auxiliary output
|
| 65 |
+
int Count, ///< Number of elements computed per operation
|
| 66 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 67 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 68 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 69 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 70 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 71 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
| 72 |
+
bool IsHeavy = false
|
| 73 |
+
>
|
| 74 |
+
class LinearCombinationGenericWithScalingAndAbsMax {
|
| 75 |
+
public:
|
| 76 |
+
|
| 77 |
+
using ElementOutput = ElementOutput_;
|
| 78 |
+
using ElementAuxOutput = ElementAuxOutput_;
|
| 79 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 80 |
+
using ElementCompute = ElementCompute_;
|
| 81 |
+
using ElementScalingFactor = ElementAccumulator_;
|
| 82 |
+
|
| 83 |
+
/// Data type used for absolute maximum value
|
| 84 |
+
using ElementAbsmax = float;
|
| 85 |
+
|
| 86 |
+
static bool const kIsScalingAndAmaxAuxOutputNeeded = (platform::is_same<ElementAuxOutput, cutlass::float_e4m3_t>::value ||
|
| 87 |
+
platform::is_same<ElementAuxOutput, cutlass::float_e5m2_t>::value);
|
| 88 |
+
static bool const kIsScalingAndAmaxOutputNeeded = (platform::is_same<ElementOutput, cutlass::float_e4m3_t>::value ||
|
| 89 |
+
platform::is_same<ElementOutput, cutlass::float_e5m2_t>::value);
|
| 90 |
+
|
| 91 |
+
static bool const kIsHeavy = IsHeavy;
|
| 92 |
+
static int const kCount = Count;
|
| 93 |
+
static const ScaleType::Kind kScale = Scale;
|
| 94 |
+
|
| 95 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 96 |
+
using FragmentAuxOutput = Array<ElementAuxOutput, kCount>;
|
| 97 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 98 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 99 |
+
|
| 100 |
+
static FloatRoundStyle const kRound = Round;
|
| 101 |
+
|
| 102 |
+
/// Host-constructable parameters structure
|
| 103 |
+
struct Params {
|
| 104 |
+
struct ActivationParams
|
| 105 |
+
: LinearCombinationGenericParams<ElementCompute>,
|
| 106 |
+
GenericActivationTraits<ActivationFunctor<ElementCompute>>::Arguments {
|
| 107 |
+
using LinearCombinationGenericParams<ElementCompute>::LinearCombinationGenericParams;
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
ActivationParams activation;
|
| 111 |
+
ElementScalingFactor const* scale_a_ptr = nullptr; ///< pointer to a scalar - if not null, loads it from memory
|
| 112 |
+
ElementScalingFactor const* scale_b_ptr = nullptr; ///< pointer to b scalar - if not null, loads it from memory
|
| 113 |
+
ElementScalingFactor const* scale_c_ptr = nullptr; ///< pointer to c scalar - if not null, loads it from memory
|
| 114 |
+
ElementScalingFactor const* scale_d_ptr = nullptr; ///< pointer to d scalar - if not null, loads it from memory
|
| 115 |
+
ElementScalingFactor const* scale_aux_ptr = nullptr; ///< pointer to aux scalar - if not null, loads it from memory
|
| 116 |
+
|
| 117 |
+
ElementAbsmax * abs_max_aux_ptr = nullptr; ///< pointer to location to store amax of Aux
|
| 118 |
+
ElementAbsmax * abs_max_D_ptr = nullptr; ///< pointer to location to store amax of D
|
| 119 |
+
|
| 120 |
+
CUTLASS_HOST_DEVICE
|
| 121 |
+
Params() :
|
| 122 |
+
scale_a_ptr(nullptr),
|
| 123 |
+
scale_b_ptr(nullptr),
|
| 124 |
+
scale_c_ptr(nullptr),
|
| 125 |
+
scale_d_ptr(nullptr),
|
| 126 |
+
scale_aux_ptr(nullptr),
|
| 127 |
+
abs_max_aux_ptr(nullptr),
|
| 128 |
+
abs_max_D_ptr(nullptr) {}
|
| 129 |
+
|
| 130 |
+
CUTLASS_HOST_DEVICE
|
| 131 |
+
Params(ActivationParams activation_params,
|
| 132 |
+
ElementScalingFactor const* scale_a_ptr,
|
| 133 |
+
ElementScalingFactor const* scale_b_ptr,
|
| 134 |
+
ElementScalingFactor const* scale_c_ptr,
|
| 135 |
+
ElementScalingFactor const* scale_d_ptr,
|
| 136 |
+
ElementScalingFactor const* scale_aux_ptr,
|
| 137 |
+
ElementAbsmax * abs_max_aux_ptr,
|
| 138 |
+
ElementAbsmax * abs_max_D_ptr) :
|
| 139 |
+
activation(activation_params),
|
| 140 |
+
scale_a_ptr(scale_a_ptr),
|
| 141 |
+
scale_b_ptr(scale_b_ptr),
|
| 142 |
+
scale_c_ptr(scale_c_ptr),
|
| 143 |
+
scale_d_ptr(scale_d_ptr),
|
| 144 |
+
scale_aux_ptr(scale_aux_ptr),
|
| 145 |
+
abs_max_aux_ptr(abs_max_aux_ptr),
|
| 146 |
+
abs_max_D_ptr(abs_max_D_ptr) {}
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
private:
|
| 150 |
+
|
| 151 |
+
//
|
| 152 |
+
// Data members
|
| 153 |
+
//
|
| 154 |
+
|
| 155 |
+
Params params_;
|
| 156 |
+
bool skip_elementwise_;
|
| 157 |
+
|
| 158 |
+
// Scaling factors for output and auxiliary output
|
| 159 |
+
ElementCompute scale_d_;
|
| 160 |
+
ElementCompute scale_aux_;
|
| 161 |
+
|
| 162 |
+
public:
|
| 163 |
+
|
| 164 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 165 |
+
CUTLASS_HOST_DEVICE
|
| 166 |
+
LinearCombinationGenericWithScalingAndAbsMax(Params const ¶ms) :
|
| 167 |
+
params_(params),
|
| 168 |
+
skip_elementwise_(false),
|
| 169 |
+
scale_d_(ElementCompute(params.scale_d_ptr ? *(params.scale_d_ptr) : ElementScalingFactor(1))),
|
| 170 |
+
scale_aux_(ElementCompute(params.scale_aux_ptr ? *(params.scale_aux_ptr) : ElementScalingFactor(1)))
|
| 171 |
+
{
|
| 172 |
+
params_.activation.alpha = (params.activation.alpha_ptr ? *params.activation.alpha_ptr : params.activation.alpha);
|
| 173 |
+
params_.activation.beta = (params.activation.beta_ptr ? *params.activation.beta_ptr : params.activation.beta);
|
| 174 |
+
auto scale_a =
|
| 175 |
+
ElementCompute(params.scale_a_ptr ? *(params.scale_a_ptr) : ElementScalingFactor(1));
|
| 176 |
+
auto scale_b =
|
| 177 |
+
ElementCompute(params.scale_b_ptr ? *(params.scale_b_ptr) : ElementScalingFactor(1));
|
| 178 |
+
auto scale_c =
|
| 179 |
+
ElementCompute(params.scale_c_ptr ? *(params.scale_c_ptr) : ElementScalingFactor(1));
|
| 180 |
+
|
| 181 |
+
multiplies<ElementCompute> multiply;
|
| 182 |
+
params_.activation.alpha = multiply(params.activation.alpha, multiply(scale_a, scale_b));
|
| 183 |
+
params_.activation.beta = multiply(params.activation.beta, scale_c);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
/// Returns true if source is needed
|
| 187 |
+
CUTLASS_HOST_DEVICE
|
| 188 |
+
bool is_source_needed() const {
|
| 189 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 190 |
+
|
| 191 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 192 |
+
|
| 193 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 194 |
+
|
| 195 |
+
return params_.activation.beta != ElementCompute(0);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
/// Functionally required for serial reduction in the epilogue
|
| 199 |
+
CUTLASS_HOST_DEVICE
|
| 200 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 201 |
+
if (k_partition) {
|
| 202 |
+
params_.activation.beta = ElementCompute(1);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// Only the final partition should perform the activation function
|
| 206 |
+
// and scale the output and auxiliary output values.
|
| 207 |
+
if (k_partition != k_partition_count - 1) {
|
| 208 |
+
skip_elementwise_ = true;
|
| 209 |
+
scale_d_ = ElementCompute(1.);
|
| 210 |
+
scale_aux_ = ElementCompute(1.);
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
/// Computes linear scaling:
|
| 215 |
+
/// Aux = (alpha * scale_a * scale_b * accumulator) + (beta * scale_c * source) + bias
|
| 216 |
+
/// D = activation(Aux)
|
| 217 |
+
CUTLASS_HOST_DEVICE
|
| 218 |
+
void operator()(
|
| 219 |
+
FragmentCompute& output,
|
| 220 |
+
FragmentCompute& aux_output,
|
| 221 |
+
FragmentAccumulator const &accumulator,
|
| 222 |
+
FragmentCompute const& bias,
|
| 223 |
+
FragmentOutput const &source) {
|
| 224 |
+
|
| 225 |
+
// Convert source to interal compute numeric type
|
| 226 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 227 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 228 |
+
|
| 229 |
+
FragmentCompute converted_source = source_converter(source);
|
| 230 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 231 |
+
|
| 232 |
+
// Perform binary operations
|
| 233 |
+
|
| 234 |
+
FragmentCompute intermediate;
|
| 235 |
+
|
| 236 |
+
multiplies<FragmentCompute> multiply;
|
| 237 |
+
plus<FragmentCompute> add;
|
| 238 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 239 |
+
ActivationFunctor<FragmentCompute> activation;
|
| 240 |
+
|
| 241 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 242 |
+
intermediate = converted_source;
|
| 243 |
+
intermediate = mul_add_accumulator(params_.activation.alpha, converted_accumulator, intermediate);
|
| 244 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 245 |
+
intermediate = converted_accumulator;
|
| 246 |
+
} else {
|
| 247 |
+
intermediate = multiply(params_.activation.beta, converted_source);
|
| 248 |
+
intermediate = mul_add_accumulator(params_.activation.alpha, converted_accumulator, intermediate);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
intermediate = add(intermediate, bias);
|
| 252 |
+
|
| 253 |
+
aux_output = intermediate;
|
| 254 |
+
if constexpr (GenericActivationTraits<ActivationFunctor<ElementCompute>>::IsArgumentsNeeded) {
|
| 255 |
+
output = skip_elementwise_ ? intermediate : activation(intermediate, params_.activation);
|
| 256 |
+
} else {
|
| 257 |
+
output = skip_elementwise_ ? intermediate : activation(intermediate);
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
/// Computes linear scaling:
|
| 262 |
+
/// Aux = (alpha * scale_a * scale_b * accumulator) + bias
|
| 263 |
+
/// D = activation(Aux)
|
| 264 |
+
CUTLASS_DEVICE
|
| 265 |
+
void operator()(
|
| 266 |
+
FragmentCompute& output,
|
| 267 |
+
FragmentCompute& aux_output,
|
| 268 |
+
FragmentAccumulator const &accumulator,
|
| 269 |
+
FragmentCompute const& bias) {
|
| 270 |
+
|
| 271 |
+
// Convert source to interal compute numeric type
|
| 272 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 273 |
+
|
| 274 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 275 |
+
|
| 276 |
+
// Perform binary operations
|
| 277 |
+
|
| 278 |
+
FragmentCompute intermediate;
|
| 279 |
+
|
| 280 |
+
multiplies<FragmentCompute> multiply;
|
| 281 |
+
plus<FragmentCompute> add;
|
| 282 |
+
ActivationFunctor<FragmentCompute> activation;
|
| 283 |
+
|
| 284 |
+
if (Scale == ScaleType::Nothing) {
|
| 285 |
+
intermediate = converted_accumulator;
|
| 286 |
+
} else {
|
| 287 |
+
intermediate = multiply(params_.activation.alpha, converted_accumulator);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
intermediate = add(intermediate, bias);
|
| 291 |
+
|
| 292 |
+
aux_output = intermediate;
|
| 293 |
+
if constexpr (GenericActivationTraits<ActivationFunctor<FragmentCompute>>::IsArgumentsNeeded) {
|
| 294 |
+
output = skip_elementwise_ ? intermediate : activation(intermediate, params_.activation);
|
| 295 |
+
} else {
|
| 296 |
+
output = skip_elementwise_ ? intermediate : activation(intermediate);
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
CUTLASS_HOST_DEVICE
|
| 301 |
+
ElementAbsmax* get_ptr_output_abs_max() const {
|
| 302 |
+
return params_.abs_max_D_ptr;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
CUTLASS_HOST_DEVICE
|
| 306 |
+
ElementAbsmax* get_ptr_aux_output_abs_max() const {
|
| 307 |
+
return params_.abs_max_aux_ptr;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
CUTLASS_HOST_DEVICE
|
| 311 |
+
ElementCompute get_scale_d() const {
|
| 312 |
+
return scale_d_;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
CUTLASS_HOST_DEVICE
|
| 316 |
+
ElementCompute get_scale_aux() const {
|
| 317 |
+
return scale_aux_;
|
| 318 |
+
}
|
| 319 |
+
};
|
| 320 |
+
|
| 321 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 322 |
+
|
| 323 |
+
} // namespace thread
|
| 324 |
+
} // namespace epilogue
|
| 325 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination with HardSwish operations used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 39 |
+
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
| 40 |
+
|
| 41 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace epilogue {
|
| 45 |
+
namespace thread {
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
/// Applies a linear combination operator followed by the HardSwish activation to an array of elements.
|
| 50 |
+
///
|
| 51 |
+
/// D = hardswish(alpha * accumulator + beta * source + uniform)
|
| 52 |
+
///
|
| 53 |
+
template <
|
| 54 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 55 |
+
int Count, ///< Number of elements computed per operation
|
| 56 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 57 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 58 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 59 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 60 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 61 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 62 |
+
>
|
| 63 |
+
using LinearCombinationHardSwish = LinearCombinationGeneric<HardSwish, ElementOutput_, Count, ElementAccumulator_,
|
| 64 |
+
ElementCompute_, Scale, Round>;
|
| 65 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 66 |
+
|
| 67 |
+
} // namespace thread
|
| 68 |
+
} // namespace epilogue
|
| 69 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include "cutlass/cutlass.h"
|
| 35 |
+
#include "cutlass/numeric_types.h"
|
| 36 |
+
#include "cutlass/array.h"
|
| 37 |
+
#include "cutlass/functional.h"
|
| 38 |
+
#include "cutlass/numeric_conversion.h"
|
| 39 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 40 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 41 |
+
|
| 42 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
namespace epilogue {
|
| 46 |
+
namespace thread {
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
/// Applies a linear combination operator to an array of elements.
|
| 51 |
+
///
|
| 52 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 53 |
+
///
|
| 54 |
+
template <
|
| 55 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 56 |
+
int Count, ///< Number of elements computed per operation
|
| 57 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 58 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 59 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 60 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 61 |
+
>
|
| 62 |
+
class LinearCombinationLeakyRelu {
|
| 63 |
+
public:
|
| 64 |
+
|
| 65 |
+
using ElementOutput = ElementOutput_;
|
| 66 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 67 |
+
using ElementCompute = ElementCompute_;
|
| 68 |
+
|
| 69 |
+
static int const kCount = Count;
|
| 70 |
+
static const ScaleType::Kind kScale = Scale;
|
| 71 |
+
|
| 72 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 73 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 74 |
+
using ComputeFragment = Array<ElementCompute, kCount>;
|
| 75 |
+
using FragmentSource = Array<ElementOutput, kCount>;
|
| 76 |
+
|
| 77 |
+
static FloatRoundStyle const kRound = Round;
|
| 78 |
+
|
| 79 |
+
/// Host-constructable parameters structure
|
| 80 |
+
struct Params {
|
| 81 |
+
|
| 82 |
+
ElementCompute alpha; ///< scales accumulators
|
| 83 |
+
ElementCompute beta_bias; ///< scales bias tensor
|
| 84 |
+
ElementCompute leaky_alpha; ///< leaky_alpha
|
| 85 |
+
//
|
| 86 |
+
// Methods
|
| 87 |
+
//
|
| 88 |
+
|
| 89 |
+
CUTLASS_HOST_DEVICE
|
| 90 |
+
Params():
|
| 91 |
+
alpha(ElementCompute(1)),
|
| 92 |
+
beta_bias(ElementCompute(0)),
|
| 93 |
+
leaky_alpha(ElementCompute(1))
|
| 94 |
+
{ }
|
| 95 |
+
|
| 96 |
+
CUTLASS_HOST_DEVICE
|
| 97 |
+
Params(
|
| 98 |
+
ElementCompute alpha,
|
| 99 |
+
ElementCompute beta_bias,
|
| 100 |
+
ElementCompute leaky_alpha = ElementCompute(1)
|
| 101 |
+
): alpha(alpha), beta_bias(beta_bias), leaky_alpha(leaky_alpha) {
|
| 102 |
+
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
private:
|
| 108 |
+
|
| 109 |
+
//
|
| 110 |
+
// Data members
|
| 111 |
+
//
|
| 112 |
+
|
| 113 |
+
ElementCompute alpha_;
|
| 114 |
+
ElementCompute beta_bias_;
|
| 115 |
+
ElementCompute leaky_alpha_recip_;
|
| 116 |
+
|
| 117 |
+
public:
|
| 118 |
+
|
| 119 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 120 |
+
CUTLASS_HOST_DEVICE
|
| 121 |
+
LinearCombinationLeakyRelu(Params const ¶ms) {
|
| 122 |
+
alpha_ = (params.alpha);
|
| 123 |
+
beta_bias_ = (params.beta_bias);
|
| 124 |
+
leaky_alpha_recip_ = (ElementCompute(params.leaky_alpha));
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
/// Returns true if source is needed
|
| 128 |
+
CUTLASS_HOST_DEVICE
|
| 129 |
+
bool is_source_needed() const {
|
| 130 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 131 |
+
|
| 132 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 133 |
+
|
| 134 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 135 |
+
|
| 136 |
+
return beta_bias_ != ElementCompute(0);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
/// Functionally required for serial reduction in the epilogue
|
| 140 |
+
CUTLASS_HOST_DEVICE
|
| 141 |
+
void set_k_partition(int k_partition) {
|
| 142 |
+
if (k_partition) {
|
| 143 |
+
beta_bias_ = ElementCompute(1);
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
CUTLASS_HOST_DEVICE
|
| 147 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 148 |
+
if (k_partition) {
|
| 149 |
+
beta_bias_ = ElementCompute(1);
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 154 |
+
CUTLASS_HOST_DEVICE
|
| 155 |
+
FragmentOutput operator()(
|
| 156 |
+
FragmentAccumulator const &accumulator,
|
| 157 |
+
FragmentOutput const &source) const {
|
| 158 |
+
|
| 159 |
+
// Convert source to interal compute numeric type
|
| 160 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 161 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 162 |
+
|
| 163 |
+
ComputeFragment converted_source = source_converter(source);
|
| 164 |
+
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
| 165 |
+
|
| 166 |
+
// Perform binary operations
|
| 167 |
+
ComputeFragment intermediate;
|
| 168 |
+
|
| 169 |
+
multiplies<ComputeFragment> mul_add_source;
|
| 170 |
+
multiply_add<ComputeFragment> mul_add_accumulator;
|
| 171 |
+
|
| 172 |
+
LeakyReLU<ComputeFragment> leakyrelu;
|
| 173 |
+
|
| 174 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 175 |
+
intermediate = converted_source;
|
| 176 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 177 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 178 |
+
intermediate = converted_accumulator;
|
| 179 |
+
} else {
|
| 180 |
+
intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform
|
| 181 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 182 |
+
}
|
| 183 |
+
// Compute threshold optionally
|
| 184 |
+
intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
|
| 185 |
+
|
| 186 |
+
// Convert to destination numeric type
|
| 187 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 188 |
+
|
| 189 |
+
return destination_converter(intermediate);
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 193 |
+
CUTLASS_HOST_DEVICE
|
| 194 |
+
FragmentOutput operator()(
|
| 195 |
+
FragmentAccumulator const &accumulator) const {
|
| 196 |
+
|
| 197 |
+
// Convert source to interal compute numeric type
|
| 198 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 199 |
+
|
| 200 |
+
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
| 201 |
+
|
| 202 |
+
// Perform binary operations
|
| 203 |
+
ComputeFragment intermediate;
|
| 204 |
+
|
| 205 |
+
multiplies<ComputeFragment> mul_accumulator;
|
| 206 |
+
LeakyReLU<ComputeFragment> leakyrelu;
|
| 207 |
+
//printf("in doing with bias");
|
| 208 |
+
if (Scale == ScaleType::Nothing) {
|
| 209 |
+
intermediate = converted_accumulator;
|
| 210 |
+
} else {
|
| 211 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// Compute threshold optionally
|
| 215 |
+
intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
// Convert to destination numeric type
|
| 219 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 220 |
+
|
| 221 |
+
return destination_converter(intermediate);
|
| 222 |
+
}
|
| 223 |
+
};
|
| 224 |
+
|
| 225 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 226 |
+
|
| 227 |
+
} // namespace thread
|
| 228 |
+
} // namespace epilogue
|
| 229 |
+
} // namespace cutlass
|
| 230 |
+
|
| 231 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 38 |
+
|
| 39 |
+
namespace cutlass {
|
| 40 |
+
namespace epilogue {
|
| 41 |
+
namespace thread {
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
struct LinearCombinationParams {
|
| 46 |
+
uint64_t alpha_data[2];
|
| 47 |
+
uint64_t beta_data[2];
|
| 48 |
+
|
| 49 |
+
CUTLASS_HOST_DEVICE
|
| 50 |
+
LinearCombinationParams()
|
| 51 |
+
: alpha_data {0lu, 0lu}, beta_data {0lu, 0lu}
|
| 52 |
+
{ }
|
| 53 |
+
|
| 54 |
+
template <typename ElementCompute>
|
| 55 |
+
CUTLASS_HOST_DEVICE
|
| 56 |
+
LinearCombinationParams(ElementCompute alpha, ElementCompute beta)
|
| 57 |
+
: alpha_data {0lu, 0lu}, beta_data {0lu, 0lu}
|
| 58 |
+
{
|
| 59 |
+
#if defined(__CUDA_ARCH__)
|
| 60 |
+
reinterpret_cast<ElementCompute&>(alpha_data) = alpha;
|
| 61 |
+
reinterpret_cast<ElementCompute&>(beta_data) = beta;
|
| 62 |
+
#else
|
| 63 |
+
memcpy( alpha_data, &alpha, sizeof(ElementCompute) );
|
| 64 |
+
memcpy( beta_data, &beta, sizeof(ElementCompute) );
|
| 65 |
+
#endif
|
| 66 |
+
}
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 70 |
+
|
| 71 |
+
} // namespace thread
|
| 72 |
+
} // namespace epilogue
|
| 73 |
+
} // namespace cutlass
|
| 74 |
+
|
| 75 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination operations on planar-complex arrays
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/numeric_types.h"
|
| 39 |
+
#include "cutlass/complex.h"
|
| 40 |
+
#include "cutlass/array_planar_complex.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace epilogue {
|
| 49 |
+
namespace thread {
|
| 50 |
+
|
| 51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
/// Applies a linear combination operator to arrays of planar-complex elements.
|
| 54 |
+
///
|
| 55 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 56 |
+
///
|
| 57 |
+
/// Note, as with most CUTLASS components for planar complex, the template arguments describe
|
| 58 |
+
/// the underlying real data type.
|
| 59 |
+
template <
|
| 60 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 61 |
+
int Count, ///< Number of elements computed per operation
|
| 62 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 63 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 64 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 65 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 66 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
| 67 |
+
ScaleType::Kind Scale = ScaleType::Default ///< Control Alpha and Beta scaling
|
| 68 |
+
>
|
| 69 |
+
class LinearCombinationPlanarComplex {
|
| 70 |
+
public:
|
| 71 |
+
|
| 72 |
+
using ElementOutput = ElementOutput_;
|
| 73 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 74 |
+
using ElementCompute = ElementCompute_;
|
| 75 |
+
using ElementScalar = complex<ElementCompute>;
|
| 76 |
+
|
| 77 |
+
static int const kCount = Count;
|
| 78 |
+
static const ScaleType::Kind kScale = Scale;
|
| 79 |
+
|
| 80 |
+
using FragmentOutput = ArrayPlanarComplex<ElementOutput, kCount>;
|
| 81 |
+
using FragmentAccumulator = ArrayPlanarComplex<ElementAccumulator, kCount>;
|
| 82 |
+
using ComputeFragment = ArrayPlanarComplex<ElementCompute, kCount>;
|
| 83 |
+
|
| 84 |
+
static FloatRoundStyle const kRound = Round;
|
| 85 |
+
|
| 86 |
+
/// Host-constructable parameters structure
|
| 87 |
+
struct Params {
|
| 88 |
+
|
| 89 |
+
ElementScalar alpha{ElementCompute(1)}; ///< scales accumulators
|
| 90 |
+
ElementScalar beta{ElementCompute(0)}; ///< scales source tensor
|
| 91 |
+
ElementScalar const* alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 92 |
+
ElementScalar const* beta_ptr{nullptr}; ///< pointer to source scalar - if not null, loads it from memory
|
| 93 |
+
|
| 94 |
+
//
|
| 95 |
+
// Methods
|
| 96 |
+
//
|
| 97 |
+
|
| 98 |
+
Params() = default;
|
| 99 |
+
|
| 100 |
+
CUTLASS_HOST_DEVICE
|
| 101 |
+
Params(
|
| 102 |
+
ElementScalar alpha,
|
| 103 |
+
ElementScalar beta
|
| 104 |
+
): alpha(alpha), beta(beta)
|
| 105 |
+
{}
|
| 106 |
+
|
| 107 |
+
CUTLASS_HOST_DEVICE
|
| 108 |
+
Params(
|
| 109 |
+
ElementScalar const *alpha_ptr,
|
| 110 |
+
ElementScalar const *beta_ptr
|
| 111 |
+
): alpha_ptr(alpha_ptr), beta_ptr(beta_ptr)
|
| 112 |
+
{}
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
private:
|
| 116 |
+
|
| 117 |
+
//
|
| 118 |
+
// Data members
|
| 119 |
+
//
|
| 120 |
+
|
| 121 |
+
ElementScalar alpha_;
|
| 122 |
+
ElementScalar beta_;
|
| 123 |
+
|
| 124 |
+
public:
|
| 125 |
+
|
| 126 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 127 |
+
CUTLASS_HOST_DEVICE
|
| 128 |
+
LinearCombinationPlanarComplex(Params const ¶ms) {
|
| 129 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 130 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
/// Returns true if source is needed
|
| 134 |
+
CUTLASS_HOST_DEVICE
|
| 135 |
+
bool is_source_needed() const {
|
| 136 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 137 |
+
|
| 138 |
+
return beta_.real() != ElementCompute(0) || beta_.imag() != ElementCompute(0);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/// Functionally required for serial reduction in the epilogue
|
| 142 |
+
CUTLASS_HOST_DEVICE
|
| 143 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 144 |
+
if (k_partition) {
|
| 145 |
+
beta_ = ElementCompute(1);
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 150 |
+
CUTLASS_HOST_DEVICE
|
| 151 |
+
FragmentOutput operator()(
|
| 152 |
+
FragmentAccumulator const &accumulator,
|
| 153 |
+
FragmentOutput const &source) const {
|
| 154 |
+
|
| 155 |
+
// Convert source to interal compute numeric type
|
| 156 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 157 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 158 |
+
|
| 159 |
+
ComputeFragment converted_source{
|
| 160 |
+
source_converter(source.real),
|
| 161 |
+
source_converter(source.imag)};
|
| 162 |
+
|
| 163 |
+
ComputeFragment converted_accumulator{
|
| 164 |
+
accumulator_converter(accumulator.real),
|
| 165 |
+
accumulator_converter(accumulator.imag)};
|
| 166 |
+
|
| 167 |
+
multiplies<Array<ElementCompute, kCount> > mul_op;
|
| 168 |
+
multiply_add<Array<ElementCompute, kCount> > mul_add_op;
|
| 169 |
+
|
| 170 |
+
// Perform binary operations
|
| 171 |
+
|
| 172 |
+
// complex multiply: I = beta * C
|
| 173 |
+
ComputeFragment intermediate {
|
| 174 |
+
mul_op(beta_.real(), converted_source.real),
|
| 175 |
+
mul_op(beta_.real(), converted_source.imag)
|
| 176 |
+
};
|
| 177 |
+
|
| 178 |
+
intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real);
|
| 179 |
+
intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag);
|
| 180 |
+
|
| 181 |
+
// complex multiply-add: I = alpha * AB + I
|
| 182 |
+
intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real, intermediate.real);
|
| 183 |
+
intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag, intermediate.imag);
|
| 184 |
+
|
| 185 |
+
intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real);
|
| 186 |
+
intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag);
|
| 187 |
+
|
| 188 |
+
// Convert to destination numeric type
|
| 189 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 190 |
+
|
| 191 |
+
return FragmentOutput{
|
| 192 |
+
destination_converter(intermediate.real),
|
| 193 |
+
destination_converter(intermediate.imag)};
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 197 |
+
CUTLASS_HOST_DEVICE
|
| 198 |
+
FragmentOutput operator()(
|
| 199 |
+
FragmentAccumulator const &accumulator) const {
|
| 200 |
+
|
| 201 |
+
// Convert source to interal compute numeric type
|
| 202 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 203 |
+
|
| 204 |
+
ComputeFragment converted_accumulator{
|
| 205 |
+
accumulator_converter(accumulator.real),
|
| 206 |
+
accumulator_converter(accumulator.imag)};
|
| 207 |
+
|
| 208 |
+
// Perform binary operations
|
| 209 |
+
multiplies<Array<ElementCompute, kCount> > mul_op;
|
| 210 |
+
multiply_add<Array<ElementCompute, kCount> > mul_add_op;
|
| 211 |
+
|
| 212 |
+
// complex multiply-add: I = alpha * AB + I
|
| 213 |
+
ComputeFragment intermediate {
|
| 214 |
+
mul_op(alpha_.real(), converted_accumulator.real),
|
| 215 |
+
mul_op(alpha_.real(), converted_accumulator.imag)
|
| 216 |
+
};
|
| 217 |
+
|
| 218 |
+
intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real);
|
| 219 |
+
intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag);
|
| 220 |
+
|
| 221 |
+
// Convert to destination numeric type
|
| 222 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 223 |
+
|
| 224 |
+
return FragmentOutput{
|
| 225 |
+
destination_converter(intermediate.real),
|
| 226 |
+
destination_converter(intermediate.imag)};
|
| 227 |
+
}
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 231 |
+
|
| 232 |
+
} // namespace thread
|
| 233 |
+
} // namespace epilogue
|
| 234 |
+
} // namespace cutlass
|
| 235 |
+
|
| 236 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/half.h"
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 44 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace epilogue {
|
| 50 |
+
namespace thread {
|
| 51 |
+
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
namespace detail {
|
| 55 |
+
|
| 56 |
+
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
|
| 57 |
+
constexpr bool LinearCombinationReluIsHeavy() {
|
| 58 |
+
return false;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
/// Applies a linear combination operator to an array of elements.
|
| 66 |
+
///
|
| 67 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 68 |
+
///
|
| 69 |
+
template <
|
| 70 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 71 |
+
int Count, ///< Number of elements computed per operation
|
| 72 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 73 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 74 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 75 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 76 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 77 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 78 |
+
>
|
| 79 |
+
class LinearCombinationRelu {
|
| 80 |
+
public:
|
| 81 |
+
|
| 82 |
+
using ElementOutput = ElementOutput_;
|
| 83 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 84 |
+
using ElementCompute = ElementCompute_;
|
| 85 |
+
|
| 86 |
+
static int const kCount = Count;
|
| 87 |
+
static const ScaleType::Kind kScale = Scale;
|
| 88 |
+
|
| 89 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 90 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 91 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 92 |
+
using FragmentScaleBias = Array<ElementCompute, kCount>;
|
| 93 |
+
using FragmentSource = Array<ElementOutput, kCount>;
|
| 94 |
+
|
| 95 |
+
static FloatRoundStyle const kRound = Round;
|
| 96 |
+
|
| 97 |
+
static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy();
|
| 98 |
+
|
| 99 |
+
/// Host-constructable parameters structure
|
| 100 |
+
struct Params {
|
| 101 |
+
|
| 102 |
+
ElementCompute alpha; ///< scales accumulators
|
| 103 |
+
ElementCompute beta; ///< scales source tensor
|
| 104 |
+
ElementCompute threshold; ///< minimum value that is output
|
| 105 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 106 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 107 |
+
//
|
| 108 |
+
// Methods
|
| 109 |
+
//
|
| 110 |
+
|
| 111 |
+
CUTLASS_HOST_DEVICE
|
| 112 |
+
Params():
|
| 113 |
+
alpha(ElementCompute(1)),
|
| 114 |
+
beta(ElementCompute(0)),
|
| 115 |
+
threshold(ElementCompute(0)),
|
| 116 |
+
alpha_ptr(nullptr),
|
| 117 |
+
beta_ptr(nullptr) { }
|
| 118 |
+
|
| 119 |
+
CUTLASS_HOST_DEVICE
|
| 120 |
+
Params(
|
| 121 |
+
ElementCompute alpha,
|
| 122 |
+
ElementCompute beta = ElementCompute(0),
|
| 123 |
+
ElementCompute threshold = ElementCompute(0)
|
| 124 |
+
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 125 |
+
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
CUTLASS_HOST_DEVICE
|
| 129 |
+
Params(
|
| 130 |
+
ElementCompute const *alpha_ptr,
|
| 131 |
+
ElementCompute const *beta_ptr = nullptr,
|
| 132 |
+
ElementCompute threshold = ElementCompute(0)
|
| 133 |
+
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 134 |
+
|
| 135 |
+
}
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
private:
|
| 139 |
+
|
| 140 |
+
//
|
| 141 |
+
// Data members
|
| 142 |
+
//
|
| 143 |
+
|
| 144 |
+
ElementCompute alpha_;
|
| 145 |
+
ElementCompute beta_;
|
| 146 |
+
ElementCompute threshold_;
|
| 147 |
+
|
| 148 |
+
public:
|
| 149 |
+
|
| 150 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 151 |
+
CUTLASS_HOST_DEVICE
|
| 152 |
+
LinearCombinationRelu(Params const ¶ms) {
|
| 153 |
+
|
| 154 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 155 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 156 |
+
threshold_ = params.threshold;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/// Returns true if source is needed
|
| 160 |
+
CUTLASS_HOST_DEVICE
|
| 161 |
+
bool is_source_needed() const {
|
| 162 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 163 |
+
|
| 164 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 165 |
+
|
| 166 |
+
if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false;
|
| 167 |
+
|
| 168 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 169 |
+
|
| 170 |
+
return beta_ != ElementCompute(0);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
/// Functionally required for serial reduction in the epilogue
|
| 174 |
+
CUTLASS_HOST_DEVICE
|
| 175 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 176 |
+
if (k_partition) {
|
| 177 |
+
beta_ = ElementCompute(1);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
if (k_partition != k_partition_count - 1) {
|
| 181 |
+
// set to NaN to make ReLU no-op for all except last k partitions
|
| 182 |
+
int64_t allones = -1;
|
| 183 |
+
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 188 |
+
CUTLASS_HOST_DEVICE
|
| 189 |
+
FragmentOutput operator()(
|
| 190 |
+
FragmentAccumulator const &accumulator,
|
| 191 |
+
FragmentOutput const &source) const {
|
| 192 |
+
|
| 193 |
+
// Convert source to interal compute numeric type
|
| 194 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 195 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 196 |
+
|
| 197 |
+
FragmentCompute converted_source = source_converter(source);
|
| 198 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 199 |
+
|
| 200 |
+
// Perform binary operations
|
| 201 |
+
FragmentCompute intermediate;
|
| 202 |
+
|
| 203 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 204 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 205 |
+
ReLu<FragmentCompute> relu;
|
| 206 |
+
|
| 207 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 208 |
+
intermediate = converted_source;
|
| 209 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 210 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 211 |
+
intermediate = converted_accumulator;
|
| 212 |
+
} else {
|
| 213 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 214 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
// Compute threshold optionally
|
| 218 |
+
intermediate = relu(threshold_, intermediate);
|
| 219 |
+
|
| 220 |
+
// Convert to destination numeric type
|
| 221 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 222 |
+
|
| 223 |
+
return destination_converter(intermediate);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 227 |
+
CUTLASS_HOST_DEVICE
|
| 228 |
+
FragmentOutput operator()(
|
| 229 |
+
FragmentAccumulator const &accumulator) const {
|
| 230 |
+
|
| 231 |
+
// Convert source to interal compute numeric type
|
| 232 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 233 |
+
|
| 234 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 235 |
+
|
| 236 |
+
// Perform binary operations
|
| 237 |
+
FragmentCompute intermediate;
|
| 238 |
+
|
| 239 |
+
multiplies<FragmentCompute> mul_accumulator;
|
| 240 |
+
ReLu<FragmentCompute> relu;
|
| 241 |
+
|
| 242 |
+
if (Scale == ScaleType::Nothing) {
|
| 243 |
+
intermediate = converted_accumulator;
|
| 244 |
+
} else {
|
| 245 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
// Compute threshold optionally
|
| 249 |
+
intermediate = relu(threshold_, intermediate);
|
| 250 |
+
|
| 251 |
+
// Convert to destination numeric type
|
| 252 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 253 |
+
|
| 254 |
+
return destination_converter(intermediate);
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
/// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
|
| 258 |
+
/// Scale and Bias are from input Fragment
|
| 259 |
+
CUTLASS_HOST_DEVICE
|
| 260 |
+
FragmentOutput operator()(
|
| 261 |
+
FragmentAccumulator const &accumulator,
|
| 262 |
+
FragmentScaleBias const &scale,
|
| 263 |
+
FragmentScaleBias const &bias) const {
|
| 264 |
+
|
| 265 |
+
// Convert source to interal compute numeric type
|
| 266 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 267 |
+
|
| 268 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 269 |
+
|
| 270 |
+
// Perform per-channel scale and bias
|
| 271 |
+
FragmentCompute intermediate;
|
| 272 |
+
|
| 273 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 274 |
+
|
| 275 |
+
if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
|
| 276 |
+
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
|
| 277 |
+
else
|
| 278 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
|
| 279 |
+
|
| 280 |
+
ReLu<FragmentCompute> relu;
|
| 281 |
+
|
| 282 |
+
// Compute threshold optionally
|
| 283 |
+
intermediate = relu(threshold_, intermediate);
|
| 284 |
+
|
| 285 |
+
// Convert to destination numeric type
|
| 286 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 287 |
+
|
| 288 |
+
return destination_converter(intermediate);
|
| 289 |
+
}
|
| 290 |
+
};
|
| 291 |
+
|
| 292 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 293 |
+
|
| 294 |
+
// Conditional guards to enable partial specialization for packed integers
|
| 295 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
| 296 |
+
|
| 297 |
+
/// Applies a linear combination operator to an array of elements.
|
| 298 |
+
///
|
| 299 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 300 |
+
///
|
| 301 |
+
/// Special handling for int types
|
| 302 |
+
|
| 303 |
+
template <
|
| 304 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 305 |
+
int Count, ///< Number of elements computed per operation
|
| 306 |
+
ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
|
| 307 |
+
FloatRoundStyle Round
|
| 308 |
+
>
|
| 309 |
+
class LinearCombinationRelu <ElementOutput_, Count, int, float, Scale, Round> {
|
| 310 |
+
public:
|
| 311 |
+
|
| 312 |
+
using ElementOutput = ElementOutput_;
|
| 313 |
+
using ElementAccumulator = int;
|
| 314 |
+
using ElementCompute = float;
|
| 315 |
+
|
| 316 |
+
static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy();
|
| 317 |
+
|
| 318 |
+
static int const kCount = Count;
|
| 319 |
+
static const ScaleType::Kind kScale = Scale;
|
| 320 |
+
|
| 321 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 322 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 323 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 324 |
+
using FragmentScaleBias = Array<ElementCompute, kCount>;
|
| 325 |
+
using FragmentSource = Array<ElementOutput, kCount>;
|
| 326 |
+
|
| 327 |
+
static FloatRoundStyle const kRound = Round;
|
| 328 |
+
|
| 329 |
+
/// Host-constructable parameters structure
|
| 330 |
+
struct Params {
|
| 331 |
+
|
| 332 |
+
ElementCompute alpha; ///< scales accumulators
|
| 333 |
+
ElementCompute beta; ///< scales source tensor
|
| 334 |
+
ElementCompute threshold; ///< minimum value that is output
|
| 335 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 336 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 337 |
+
//
|
| 338 |
+
// Methods
|
| 339 |
+
//
|
| 340 |
+
|
| 341 |
+
CUTLASS_HOST_DEVICE
|
| 342 |
+
Params():
|
| 343 |
+
alpha(ElementCompute(1)),
|
| 344 |
+
beta(ElementCompute(0)),
|
| 345 |
+
threshold(ElementCompute(0)),
|
| 346 |
+
alpha_ptr(nullptr),
|
| 347 |
+
beta_ptr(nullptr) { }
|
| 348 |
+
|
| 349 |
+
CUTLASS_HOST_DEVICE
|
| 350 |
+
Params(
|
| 351 |
+
ElementCompute alpha,
|
| 352 |
+
ElementCompute beta = ElementCompute(0),
|
| 353 |
+
ElementCompute threshold = ElementCompute(0)
|
| 354 |
+
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 355 |
+
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
CUTLASS_HOST_DEVICE
|
| 359 |
+
Params(
|
| 360 |
+
ElementCompute const *alpha_ptr,
|
| 361 |
+
ElementCompute const *beta_ptr = nullptr,
|
| 362 |
+
ElementCompute threshold = ElementCompute(0)
|
| 363 |
+
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 364 |
+
|
| 365 |
+
}
|
| 366 |
+
};
|
| 367 |
+
|
| 368 |
+
private:
|
| 369 |
+
|
| 370 |
+
//
|
| 371 |
+
// Data members
|
| 372 |
+
//
|
| 373 |
+
|
| 374 |
+
ElementCompute alpha_;
|
| 375 |
+
ElementCompute beta_;
|
| 376 |
+
ElementCompute threshold_;
|
| 377 |
+
|
| 378 |
+
public:
|
| 379 |
+
|
| 380 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 381 |
+
CUTLASS_HOST_DEVICE
|
| 382 |
+
LinearCombinationRelu(Params const ¶ms) {
|
| 383 |
+
|
| 384 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 385 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 386 |
+
threshold_ = params.threshold;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
/// Returns true if source is needed
|
| 390 |
+
CUTLASS_HOST_DEVICE
|
| 391 |
+
bool is_source_needed() const {
|
| 392 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 393 |
+
|
| 394 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 395 |
+
|
| 396 |
+
if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false;
|
| 397 |
+
|
| 398 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 399 |
+
|
| 400 |
+
return beta_ != ElementCompute(0);
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
/// Functionally required for serial reduction in the epilogue
|
| 404 |
+
CUTLASS_HOST_DEVICE
|
| 405 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 406 |
+
if (k_partition) {
|
| 407 |
+
beta_ = ElementCompute(1);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
if (k_partition != k_partition_count - 1) {
|
| 411 |
+
// set to NaN to make ReLU no-op for all except last k partitions
|
| 412 |
+
int64_t allones = -1;
|
| 413 |
+
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
|
| 414 |
+
}
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 418 |
+
CUTLASS_HOST_DEVICE
|
| 419 |
+
FragmentOutput operator()(
|
| 420 |
+
FragmentAccumulator const &accumulator,
|
| 421 |
+
FragmentOutput const &source) const {
|
| 422 |
+
|
| 423 |
+
// Convert source to interal compute numeric type
|
| 424 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 425 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 426 |
+
|
| 427 |
+
FragmentCompute converted_source = source_converter(source);
|
| 428 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 429 |
+
|
| 430 |
+
// Perform binary operations
|
| 431 |
+
FragmentCompute intermediate;
|
| 432 |
+
|
| 433 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 434 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 435 |
+
ReLu<FragmentCompute> relu;
|
| 436 |
+
|
| 437 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 438 |
+
intermediate = converted_source;
|
| 439 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 440 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 441 |
+
intermediate = converted_accumulator;
|
| 442 |
+
} else {
|
| 443 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 444 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
// Compute threshold optionally
|
| 448 |
+
intermediate = relu(threshold_, intermediate);
|
| 449 |
+
|
| 450 |
+
if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
|
| 451 |
+
// Convert floats back to INT
|
| 452 |
+
FragmentAccumulator scaled_accumulator;
|
| 453 |
+
|
| 454 |
+
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
| 455 |
+
|
| 456 |
+
scaled_accumulator = compute_converter(intermediate);
|
| 457 |
+
|
| 458 |
+
// Convert to destination numeric type
|
| 459 |
+
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
| 460 |
+
destination_converter;
|
| 461 |
+
|
| 462 |
+
return destination_converter(scaled_accumulator);
|
| 463 |
+
} else {
|
| 464 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
| 465 |
+
destination_converter;
|
| 466 |
+
return destination_converter(intermediate);
|
| 467 |
+
}
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 471 |
+
CUTLASS_HOST_DEVICE
|
| 472 |
+
FragmentOutput operator()(
|
| 473 |
+
FragmentAccumulator const &accumulator) const {
|
| 474 |
+
|
| 475 |
+
// Convert source to interal compute numeric type
|
| 476 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 477 |
+
|
| 478 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 479 |
+
|
| 480 |
+
// Perform binary operations
|
| 481 |
+
FragmentCompute intermediate;
|
| 482 |
+
|
| 483 |
+
multiplies<FragmentCompute> mul_accumulator;
|
| 484 |
+
ReLu<FragmentCompute> relu;
|
| 485 |
+
|
| 486 |
+
if (Scale == ScaleType::Nothing) {
|
| 487 |
+
intermediate = converted_accumulator;
|
| 488 |
+
} else {
|
| 489 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
// Compute threshold optionally
|
| 493 |
+
intermediate = relu(threshold_, intermediate);
|
| 494 |
+
|
| 495 |
+
if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
|
| 496 |
+
// Convert floats back to INT
|
| 497 |
+
FragmentAccumulator scaled_accumulator;
|
| 498 |
+
|
| 499 |
+
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
| 500 |
+
|
| 501 |
+
scaled_accumulator = compute_converter(intermediate);
|
| 502 |
+
|
| 503 |
+
// Convert to destination numeric type
|
| 504 |
+
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
| 505 |
+
destination_converter;
|
| 506 |
+
|
| 507 |
+
return destination_converter(scaled_accumulator);
|
| 508 |
+
} else {
|
| 509 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
| 510 |
+
destination_converter;
|
| 511 |
+
return destination_converter(intermediate);
|
| 512 |
+
}
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
/// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
|
| 516 |
+
/// Scale and Bias are from input Fragment
|
| 517 |
+
CUTLASS_HOST_DEVICE
|
| 518 |
+
FragmentOutput operator()(
|
| 519 |
+
FragmentAccumulator const &accumulator,
|
| 520 |
+
FragmentScaleBias const &scale,
|
| 521 |
+
FragmentScaleBias const &bias) const {
|
| 522 |
+
|
| 523 |
+
// Convert source to interal compute numeric type
|
| 524 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 525 |
+
|
| 526 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 527 |
+
|
| 528 |
+
// Perform per-channel scale and bias
|
| 529 |
+
FragmentCompute intermediate;
|
| 530 |
+
|
| 531 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 532 |
+
|
| 533 |
+
if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
|
| 534 |
+
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
|
| 535 |
+
else
|
| 536 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
|
| 537 |
+
|
| 538 |
+
ReLu<FragmentCompute> relu;
|
| 539 |
+
|
| 540 |
+
// Compute threshold optionally
|
| 541 |
+
intermediate = relu(threshold_, intermediate);
|
| 542 |
+
|
| 543 |
+
if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
|
| 544 |
+
// Convert floats back to INT
|
| 545 |
+
FragmentAccumulator scaled_accumulator;
|
| 546 |
+
|
| 547 |
+
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
| 548 |
+
|
| 549 |
+
scaled_accumulator = compute_converter(intermediate);
|
| 550 |
+
|
| 551 |
+
// Convert to destination numeric type
|
| 552 |
+
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
| 553 |
+
destination_converter;
|
| 554 |
+
|
| 555 |
+
return destination_converter(scaled_accumulator);
|
| 556 |
+
} else {
|
| 557 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
| 558 |
+
destination_converter;
|
| 559 |
+
return destination_converter(intermediate);
|
| 560 |
+
}
|
| 561 |
+
}
|
| 562 |
+
};
|
| 563 |
+
|
| 564 |
+
#endif // Conditional guards to enable partial specialization for packed integers
|
| 565 |
+
|
| 566 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 567 |
+
|
| 568 |
+
} // namespace thread
|
| 569 |
+
} // namespace epilogue
|
| 570 |
+
} // namespace cutlass
|
| 571 |
+
|
| 572 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination with a relu operation used by epilogues.
|
| 33 |
+
This one only supports relu0 and tries to folding relu into other instructions. Thus,
|
| 34 |
+
serial splitk is not supported by this one. For example, relu can be folded into
|
| 35 |
+
hfma2/hmul2 for sm80+
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/half.h"
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
#include "cutlass/functional.h"
|
| 45 |
+
#include "cutlass/numeric_conversion.h"
|
| 46 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 47 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
namespace cutlass {
|
| 52 |
+
namespace epilogue {
|
| 53 |
+
namespace thread {
|
| 54 |
+
|
| 55 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
namespace detail {
|
| 58 |
+
|
| 59 |
+
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
|
| 60 |
+
constexpr bool LinearCombinationRelu0IsHeavy() {
|
| 61 |
+
return false;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 67 |
+
|
| 68 |
+
/// Applies a linear combination operator to an array of elements.
|
| 69 |
+
///
|
| 70 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 71 |
+
///
|
| 72 |
+
template <
|
| 73 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 74 |
+
int Count, ///< Number of elements computed per operation
|
| 75 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 76 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 77 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 78 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 79 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 80 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 81 |
+
>
|
| 82 |
+
class LinearCombinationRelu0 {
|
| 83 |
+
public:
|
| 84 |
+
|
| 85 |
+
using ElementOutput = ElementOutput_;
|
| 86 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 87 |
+
using ElementCompute = ElementCompute_;
|
| 88 |
+
|
| 89 |
+
static int const kCount = Count;
|
| 90 |
+
static const ScaleType::Kind kScale = Scale;
|
| 91 |
+
|
| 92 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 93 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 94 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 95 |
+
using FragmentScaleBias = Array<ElementCompute, kCount>;
|
| 96 |
+
using FragmentSource = Array<ElementOutput, kCount>;
|
| 97 |
+
|
| 98 |
+
static FloatRoundStyle const kRound = Round;
|
| 99 |
+
|
| 100 |
+
static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy();
|
| 101 |
+
|
| 102 |
+
/// Host-constructable parameters structure
|
| 103 |
+
struct Params {
|
| 104 |
+
|
| 105 |
+
ElementCompute alpha; ///< scales accumulators
|
| 106 |
+
ElementCompute beta; ///< scales source tensor
|
| 107 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 108 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 109 |
+
//
|
| 110 |
+
// Methods
|
| 111 |
+
//
|
| 112 |
+
|
| 113 |
+
CUTLASS_HOST_DEVICE
|
| 114 |
+
Params():
|
| 115 |
+
alpha(ElementCompute(1)),
|
| 116 |
+
beta(ElementCompute(0)),
|
| 117 |
+
alpha_ptr(nullptr),
|
| 118 |
+
beta_ptr(nullptr) { }
|
| 119 |
+
|
| 120 |
+
CUTLASS_HOST_DEVICE
|
| 121 |
+
Params(
|
| 122 |
+
ElementCompute alpha,
|
| 123 |
+
ElementCompute beta = ElementCompute(0)
|
| 124 |
+
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 125 |
+
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
CUTLASS_HOST_DEVICE
|
| 129 |
+
Params(
|
| 130 |
+
ElementCompute const *alpha_ptr,
|
| 131 |
+
ElementCompute const *beta_ptr = nullptr
|
| 132 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 133 |
+
|
| 134 |
+
}
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
private:
|
| 138 |
+
|
| 139 |
+
//
|
| 140 |
+
// Data members
|
| 141 |
+
//
|
| 142 |
+
|
| 143 |
+
ElementCompute alpha_;
|
| 144 |
+
ElementCompute beta_;
|
| 145 |
+
|
| 146 |
+
public:
|
| 147 |
+
|
| 148 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 149 |
+
CUTLASS_HOST_DEVICE
|
| 150 |
+
LinearCombinationRelu0(Params const ¶ms) {
|
| 151 |
+
|
| 152 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 153 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
/// Returns true if source is needed
|
| 157 |
+
CUTLASS_HOST_DEVICE
|
| 158 |
+
bool is_source_needed() const {
|
| 159 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 160 |
+
|
| 161 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 162 |
+
|
| 163 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 164 |
+
|
| 165 |
+
return beta_ != ElementCompute(0);
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
/// This is used for serial reduction which is not supported by Relu0
|
| 169 |
+
CUTLASS_HOST_DEVICE
|
| 170 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 171 |
+
assert(k_partition == 0);
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 175 |
+
CUTLASS_HOST_DEVICE
|
| 176 |
+
FragmentOutput operator()(
|
| 177 |
+
FragmentAccumulator const &accumulator,
|
| 178 |
+
FragmentOutput const &source) const {
|
| 179 |
+
|
| 180 |
+
// Convert source to interal compute numeric type
|
| 181 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 182 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 183 |
+
|
| 184 |
+
FragmentCompute converted_source = source_converter(source);
|
| 185 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 186 |
+
|
| 187 |
+
// Perform binary operations
|
| 188 |
+
FragmentCompute intermediate;
|
| 189 |
+
|
| 190 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 191 |
+
multiply_add_relu0<FragmentCompute> mul_add_relu0_accumulator;
|
| 192 |
+
ReLu<FragmentCompute> relu;
|
| 193 |
+
|
| 194 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 195 |
+
intermediate = converted_source;
|
| 196 |
+
intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 197 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 198 |
+
intermediate = converted_accumulator;
|
| 199 |
+
|
| 200 |
+
// Compute threshold optionally
|
| 201 |
+
intermediate = relu(intermediate);
|
| 202 |
+
} else {
|
| 203 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 204 |
+
intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// Convert to destination numeric type
|
| 208 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 209 |
+
|
| 210 |
+
return destination_converter(intermediate);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 214 |
+
CUTLASS_HOST_DEVICE
|
| 215 |
+
FragmentOutput operator()(
|
| 216 |
+
FragmentAccumulator const &accumulator) const {
|
| 217 |
+
|
| 218 |
+
// Convert source to interal compute numeric type
|
| 219 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 220 |
+
|
| 221 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 222 |
+
|
| 223 |
+
// Perform binary operations
|
| 224 |
+
FragmentCompute intermediate;
|
| 225 |
+
|
| 226 |
+
multiplies<FragmentCompute> mul_accumulator;
|
| 227 |
+
ReLu<FragmentCompute> relu;
|
| 228 |
+
|
| 229 |
+
if (Scale == ScaleType::Nothing) {
|
| 230 |
+
intermediate = converted_accumulator;
|
| 231 |
+
} else {
|
| 232 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// Compute threshold optionally
|
| 236 |
+
intermediate = relu(intermediate);
|
| 237 |
+
|
| 238 |
+
// Convert to destination numeric type
|
| 239 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 240 |
+
|
| 241 |
+
return destination_converter(intermediate);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
/// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
|
| 245 |
+
/// Scale and Bias are from input Fragment
|
| 246 |
+
CUTLASS_HOST_DEVICE
|
| 247 |
+
FragmentOutput operator()(
|
| 248 |
+
FragmentAccumulator const &accumulator,
|
| 249 |
+
FragmentScaleBias const &scale,
|
| 250 |
+
FragmentScaleBias const &bias) const {
|
| 251 |
+
|
| 252 |
+
// Convert source to interal compute numeric type
|
| 253 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 254 |
+
|
| 255 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 256 |
+
|
| 257 |
+
// Perform per-channel scale and bias
|
| 258 |
+
FragmentCompute intermediate;
|
| 259 |
+
|
| 260 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 261 |
+
|
| 262 |
+
if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
|
| 263 |
+
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
|
| 264 |
+
else
|
| 265 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
|
| 266 |
+
|
| 267 |
+
ReLu<FragmentCompute> relu;
|
| 268 |
+
|
| 269 |
+
// Compute threshold optionally
|
| 270 |
+
intermediate = relu(intermediate);
|
| 271 |
+
|
| 272 |
+
// Convert to destination numeric type
|
| 273 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
| 274 |
+
|
| 275 |
+
return destination_converter(intermediate);
|
| 276 |
+
}
|
| 277 |
+
};
|
| 278 |
+
|
| 279 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 280 |
+
|
| 281 |
+
// Conditional guards to enable partial specialization for packed integers
|
| 282 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
| 283 |
+
|
| 284 |
+
/// Applies a linear combination operator to an array of elements.
|
| 285 |
+
///
|
| 286 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 287 |
+
///
|
| 288 |
+
/// Special handling for int types
|
| 289 |
+
|
| 290 |
+
template <
|
| 291 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 292 |
+
int Count, ///< Number of elements computed per operation
|
| 293 |
+
ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
|
| 294 |
+
FloatRoundStyle Round
|
| 295 |
+
>
|
| 296 |
+
class LinearCombinationRelu0 <ElementOutput_, Count, int, float, Scale, Round> {
|
| 297 |
+
public:
|
| 298 |
+
|
| 299 |
+
using ElementOutput = ElementOutput_;
|
| 300 |
+
using ElementAccumulator = int;
|
| 301 |
+
using ElementCompute = float;
|
| 302 |
+
|
| 303 |
+
static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy();
|
| 304 |
+
|
| 305 |
+
static int const kCount = Count;
|
| 306 |
+
static const ScaleType::Kind kScale = Scale;
|
| 307 |
+
|
| 308 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 309 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 310 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 311 |
+
using FragmentScaleBias = Array<ElementCompute, kCount>;
|
| 312 |
+
using FragmentSource = Array<ElementOutput, kCount>;
|
| 313 |
+
|
| 314 |
+
static FloatRoundStyle const kRound = Round;
|
| 315 |
+
|
| 316 |
+
/// Host-constructable parameters structure
|
| 317 |
+
struct Params {
|
| 318 |
+
|
| 319 |
+
ElementCompute alpha; ///< scales accumulators
|
| 320 |
+
ElementCompute beta; ///< scales source tensor
|
| 321 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 322 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 323 |
+
//
|
| 324 |
+
// Methods
|
| 325 |
+
//
|
| 326 |
+
|
| 327 |
+
CUTLASS_HOST_DEVICE
|
| 328 |
+
Params():
|
| 329 |
+
alpha(ElementCompute(1)),
|
| 330 |
+
beta(ElementCompute(0)),
|
| 331 |
+
alpha_ptr(nullptr),
|
| 332 |
+
beta_ptr(nullptr) { }
|
| 333 |
+
|
| 334 |
+
CUTLASS_HOST_DEVICE
|
| 335 |
+
Params(
|
| 336 |
+
ElementCompute alpha,
|
| 337 |
+
ElementCompute beta = ElementCompute(0)
|
| 338 |
+
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 339 |
+
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
CUTLASS_HOST_DEVICE
|
| 343 |
+
Params(
|
| 344 |
+
ElementCompute const *alpha_ptr,
|
| 345 |
+
ElementCompute const *beta_ptr = nullptr
|
| 346 |
+
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 347 |
+
|
| 348 |
+
}
|
| 349 |
+
};
|
| 350 |
+
|
| 351 |
+
private:
|
| 352 |
+
|
| 353 |
+
//
|
| 354 |
+
// Data members
|
| 355 |
+
//
|
| 356 |
+
|
| 357 |
+
ElementCompute alpha_;
|
| 358 |
+
ElementCompute beta_;
|
| 359 |
+
|
| 360 |
+
public:
|
| 361 |
+
|
| 362 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 363 |
+
CUTLASS_HOST_DEVICE
|
| 364 |
+
LinearCombinationRelu0(Params const ¶ms) {
|
| 365 |
+
|
| 366 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 367 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
/// Returns true if source is needed
|
| 371 |
+
CUTLASS_HOST_DEVICE
|
| 372 |
+
bool is_source_needed() const {
|
| 373 |
+
if (Scale == ScaleType::NoBetaScaling) return true;
|
| 374 |
+
|
| 375 |
+
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 376 |
+
|
| 377 |
+
if (Scale == ScaleType::Nothing) return false;
|
| 378 |
+
|
| 379 |
+
return beta_ != ElementCompute(0);
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
/// This is used for serial reduction which is not supported by Relu0
|
| 383 |
+
CUTLASS_HOST_DEVICE
|
| 384 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 385 |
+
assert(k_partition == 0);
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 389 |
+
CUTLASS_HOST_DEVICE
|
| 390 |
+
FragmentOutput operator()(
|
| 391 |
+
FragmentAccumulator const &accumulator,
|
| 392 |
+
FragmentOutput const &source) const {
|
| 393 |
+
|
| 394 |
+
// Convert source to interal compute numeric type
|
| 395 |
+
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
| 396 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 397 |
+
|
| 398 |
+
FragmentCompute converted_source = source_converter(source);
|
| 399 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 400 |
+
|
| 401 |
+
// Perform binary operations
|
| 402 |
+
FragmentCompute intermediate;
|
| 403 |
+
|
| 404 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 405 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 406 |
+
ReLu<FragmentCompute> relu;
|
| 407 |
+
|
| 408 |
+
if (Scale == ScaleType::NoBetaScaling) {
|
| 409 |
+
intermediate = converted_source;
|
| 410 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 411 |
+
} else if (Scale == ScaleType::Nothing) {
|
| 412 |
+
intermediate = converted_accumulator;
|
| 413 |
+
} else {
|
| 414 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 415 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
// Compute threshold optionally
|
| 419 |
+
intermediate = relu(intermediate);
|
| 420 |
+
|
| 421 |
+
if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
|
| 422 |
+
// Convert floats back to INT
|
| 423 |
+
FragmentAccumulator scaled_accumulator;
|
| 424 |
+
|
| 425 |
+
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
| 426 |
+
|
| 427 |
+
scaled_accumulator = compute_converter(intermediate);
|
| 428 |
+
|
| 429 |
+
// Convert to destination numeric type
|
| 430 |
+
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
| 431 |
+
destination_converter;
|
| 432 |
+
|
| 433 |
+
return destination_converter(scaled_accumulator);
|
| 434 |
+
} else {
|
| 435 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
| 436 |
+
destination_converter;
|
| 437 |
+
return destination_converter(intermediate);
|
| 438 |
+
}
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 442 |
+
CUTLASS_HOST_DEVICE
|
| 443 |
+
FragmentOutput operator()(
|
| 444 |
+
FragmentAccumulator const &accumulator) const {
|
| 445 |
+
|
| 446 |
+
// Convert source to interal compute numeric type
|
| 447 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 448 |
+
|
| 449 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 450 |
+
|
| 451 |
+
// Perform binary operations
|
| 452 |
+
FragmentCompute intermediate;
|
| 453 |
+
|
| 454 |
+
multiplies<FragmentCompute> mul_accumulator;
|
| 455 |
+
ReLu<FragmentCompute> relu;
|
| 456 |
+
|
| 457 |
+
if (Scale == ScaleType::Nothing) {
|
| 458 |
+
intermediate = converted_accumulator;
|
| 459 |
+
} else {
|
| 460 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
// Compute threshold optionally
|
| 464 |
+
intermediate = relu(intermediate);
|
| 465 |
+
|
| 466 |
+
if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
|
| 467 |
+
// Convert floats back to INT
|
| 468 |
+
FragmentAccumulator scaled_accumulator;
|
| 469 |
+
|
| 470 |
+
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
| 471 |
+
|
| 472 |
+
scaled_accumulator = compute_converter(intermediate);
|
| 473 |
+
|
| 474 |
+
// Convert to destination numeric type
|
| 475 |
+
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
| 476 |
+
destination_converter;
|
| 477 |
+
|
| 478 |
+
return destination_converter(scaled_accumulator);
|
| 479 |
+
} else {
|
| 480 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
| 481 |
+
destination_converter;
|
| 482 |
+
return destination_converter(intermediate);
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
/// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
|
| 487 |
+
/// Scale and Bias are from input Fragment
|
| 488 |
+
CUTLASS_HOST_DEVICE
|
| 489 |
+
FragmentOutput operator()(
|
| 490 |
+
FragmentAccumulator const &accumulator,
|
| 491 |
+
FragmentScaleBias const &scale,
|
| 492 |
+
FragmentScaleBias const &bias) const {
|
| 493 |
+
|
| 494 |
+
// Convert source to interal compute numeric type
|
| 495 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 496 |
+
|
| 497 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 498 |
+
|
| 499 |
+
// Perform per-channel scale and bias
|
| 500 |
+
FragmentCompute intermediate;
|
| 501 |
+
|
| 502 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 503 |
+
|
| 504 |
+
if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
|
| 505 |
+
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
|
| 506 |
+
else
|
| 507 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
|
| 508 |
+
|
| 509 |
+
ReLu<FragmentCompute> relu;
|
| 510 |
+
|
| 511 |
+
// Compute threshold optionally
|
| 512 |
+
intermediate = relu(intermediate);
|
| 513 |
+
|
| 514 |
+
if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
|
| 515 |
+
// Convert floats back to INT
|
| 516 |
+
FragmentAccumulator scaled_accumulator;
|
| 517 |
+
|
| 518 |
+
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
| 519 |
+
|
| 520 |
+
scaled_accumulator = compute_converter(intermediate);
|
| 521 |
+
|
| 522 |
+
// Convert to destination numeric type
|
| 523 |
+
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
| 524 |
+
destination_converter;
|
| 525 |
+
|
| 526 |
+
return destination_converter(scaled_accumulator);
|
| 527 |
+
} else {
|
| 528 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
| 529 |
+
destination_converter;
|
| 530 |
+
return destination_converter(intermediate);
|
| 531 |
+
}
|
| 532 |
+
}
|
| 533 |
+
};
|
| 534 |
+
|
| 535 |
+
#endif // Conditional guards to enable partial specialization for packed integers
|
| 536 |
+
|
| 537 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 538 |
+
|
| 539 |
+
} // namespace thread
|
| 540 |
+
} // namespace epilogue
|
| 541 |
+
} // namespace cutlass
|
| 542 |
+
|
| 543 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Epilogue functor specialized for residual blocks in deep neural networks.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/functional.h"
|
| 40 |
+
#include "cutlass/numeric_conversion.h"
|
| 41 |
+
#include "cutlass/epilogue/thread/detail.hpp"
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace epilogue {
|
| 47 |
+
namespace thread {
|
| 48 |
+
|
| 49 |
+
/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
|
| 50 |
+
template <typename ElementOutput_, typename ElementAccumulator_,
|
| 51 |
+
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
| 52 |
+
template <typename T> class ActivationOp_,
|
| 53 |
+
template <typename T> class BinaryOp1_,
|
| 54 |
+
template <typename T> class UnaryOp_,
|
| 55 |
+
template <typename T> class BinaryOp2_ = detail::NoOp,
|
| 56 |
+
bool StoreT_ = false,
|
| 57 |
+
typename ElementVector_ = ElementC_>
|
| 58 |
+
class LinearCombinationResidualBlock {
|
| 59 |
+
public:
|
| 60 |
+
static bool const kIsSingleSource = false;
|
| 61 |
+
|
| 62 |
+
using ElementOutput = ElementC_;
|
| 63 |
+
using ElementC = ElementC_;
|
| 64 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 65 |
+
using ElementCompute = ElementCompute_;
|
| 66 |
+
using ElementVector = ElementVector_;
|
| 67 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 68 |
+
static int const kCount = kElementsPerAccess;
|
| 69 |
+
|
| 70 |
+
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
| 71 |
+
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
|
| 72 |
+
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
|
| 73 |
+
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
| 74 |
+
|
| 75 |
+
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
| 76 |
+
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
| 77 |
+
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
| 78 |
+
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
|
| 79 |
+
|
| 80 |
+
using ElementZ = ElementOutput_;
|
| 81 |
+
using ElementT = ElementZ;
|
| 82 |
+
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
| 83 |
+
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
| 84 |
+
|
| 85 |
+
static bool const kIsHeavy = true;
|
| 86 |
+
static bool const kStoreZ = true;
|
| 87 |
+
static bool const kStoreT = StoreT_;
|
| 88 |
+
|
| 89 |
+
/// Host-constructable parameters structure
|
| 90 |
+
struct Params {
|
| 91 |
+
|
| 92 |
+
ElementCompute alpha; ///< scales accumulators
|
| 93 |
+
ElementCompute beta; ///< scales residual input
|
| 94 |
+
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 95 |
+
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
|
| 96 |
+
|
| 97 |
+
CUTLASS_HOST_DEVICE
|
| 98 |
+
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
|
| 99 |
+
|
| 100 |
+
CUTLASS_HOST_DEVICE
|
| 101 |
+
Params(ElementCompute alpha, ElementCompute beta)
|
| 102 |
+
: alpha(alpha), beta(beta) {}
|
| 103 |
+
|
| 104 |
+
CUTLASS_HOST_DEVICE
|
| 105 |
+
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
| 106 |
+
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
private:
|
| 110 |
+
|
| 111 |
+
ElementCompute alpha_;
|
| 112 |
+
ElementCompute beta_;
|
| 113 |
+
bool skip_elementwise_;
|
| 114 |
+
|
| 115 |
+
public:
|
| 116 |
+
|
| 117 |
+
/// Constructor from Params
|
| 118 |
+
CUTLASS_HOST_DEVICE
|
| 119 |
+
LinearCombinationResidualBlock(Params const ¶ms) {
|
| 120 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 121 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 122 |
+
skip_elementwise_ = false;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
/// The "source" tensor corresponds to the residual input
|
| 126 |
+
CUTLASS_HOST_DEVICE
|
| 127 |
+
bool is_source_needed() const { return true; }
|
| 128 |
+
|
| 129 |
+
/// Functionally required for serial reduction in the epilogue
|
| 130 |
+
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
|
| 131 |
+
CUTLASS_HOST_DEVICE
|
| 132 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 133 |
+
if (k_partition) {
|
| 134 |
+
beta_ = ElementCompute(1);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
if (k_partition != k_partition_count - 1) {
|
| 138 |
+
skip_elementwise_ = true;
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
|
| 143 |
+
CUTLASS_HOST_DEVICE
|
| 144 |
+
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
| 145 |
+
FragmentC const &residual1, FragmentC const &residual2,
|
| 146 |
+
FragmentCompute const &bias) const {
|
| 147 |
+
UnaryOp unary_op;
|
| 148 |
+
BinaryOp1 binary_op1;
|
| 149 |
+
BinaryOp2 binary_op2;
|
| 150 |
+
ActivationOp activation;
|
| 151 |
+
|
| 152 |
+
FragmentCompute tmp_Accum =
|
| 153 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 154 |
+
FragmentCompute tmp_residual1 =
|
| 155 |
+
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
|
| 156 |
+
FragmentCompute tmp_residual2 =
|
| 157 |
+
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
|
| 158 |
+
|
| 159 |
+
FragmentCompute z =
|
| 160 |
+
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
|
| 161 |
+
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
| 162 |
+
|
| 163 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
| 164 |
+
frag_Z = convert_z(result_Z);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
/// Should never be called
|
| 168 |
+
CUTLASS_HOST_DEVICE
|
| 169 |
+
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
|
| 170 |
+
FragmentCompute const &) const {}
|
| 171 |
+
};
|
| 172 |
+
|
| 173 |
+
/// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
|
| 174 |
+
template <typename ElementOutput_, typename ElementAccumulator_,
|
| 175 |
+
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
| 176 |
+
template <typename T> class ActivationOp_,
|
| 177 |
+
template <typename T> class BinaryOp1_,
|
| 178 |
+
template <typename T> class UnaryOp_,
|
| 179 |
+
bool StoreT_,
|
| 180 |
+
typename ElementVector_>
|
| 181 |
+
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
| 182 |
+
ElementCompute_, ElementC_, ElementsPerAccess,
|
| 183 |
+
ActivationOp_, BinaryOp1_, UnaryOp_,
|
| 184 |
+
detail::NoOp, StoreT_, ElementVector_> {
|
| 185 |
+
public:
|
| 186 |
+
static bool const kIsSingleSource = true;
|
| 187 |
+
|
| 188 |
+
using ElementOutput = ElementC_;
|
| 189 |
+
using ElementC = ElementC_;
|
| 190 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 191 |
+
using ElementCompute = ElementCompute_;
|
| 192 |
+
using ElementVector = ElementVector_;
|
| 193 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 194 |
+
static int const kCount = kElementsPerAccess;
|
| 195 |
+
|
| 196 |
+
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
| 197 |
+
using BinaryOp = BinaryOp1_<Array<ElementCompute, kCount>>;
|
| 198 |
+
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
| 199 |
+
|
| 200 |
+
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
| 201 |
+
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
| 202 |
+
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
| 203 |
+
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
|
| 204 |
+
|
| 205 |
+
using ElementZ = ElementOutput_;
|
| 206 |
+
using ElementT = ElementZ;
|
| 207 |
+
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
| 208 |
+
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
| 209 |
+
|
| 210 |
+
static bool const kIsHeavy = true;
|
| 211 |
+
static bool const kStoreZ = true;
|
| 212 |
+
static bool const kStoreT = StoreT_;
|
| 213 |
+
|
| 214 |
+
/// Host-constructable parameters structure
|
| 215 |
+
struct Params {
|
| 216 |
+
|
| 217 |
+
ElementCompute alpha; ///< scales accumulators
|
| 218 |
+
ElementCompute beta; ///< scales residual input
|
| 219 |
+
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 220 |
+
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
|
| 221 |
+
|
| 222 |
+
CUTLASS_HOST_DEVICE
|
| 223 |
+
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
|
| 224 |
+
|
| 225 |
+
CUTLASS_HOST_DEVICE
|
| 226 |
+
Params(ElementCompute alpha, ElementCompute beta)
|
| 227 |
+
: alpha(alpha), beta(beta) {}
|
| 228 |
+
|
| 229 |
+
CUTLASS_HOST_DEVICE
|
| 230 |
+
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
| 231 |
+
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
| 232 |
+
};
|
| 233 |
+
|
| 234 |
+
private:
|
| 235 |
+
|
| 236 |
+
ElementCompute alpha_;
|
| 237 |
+
ElementCompute beta_;
|
| 238 |
+
bool skip_elementwise_;
|
| 239 |
+
|
| 240 |
+
public:
|
| 241 |
+
|
| 242 |
+
/// Constructor from Params
|
| 243 |
+
CUTLASS_HOST_DEVICE
|
| 244 |
+
LinearCombinationResidualBlock(Params const ¶ms) {
|
| 245 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 246 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 247 |
+
skip_elementwise_ = false;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
/// The "source" tensor corresponds to the residual input
|
| 251 |
+
CUTLASS_HOST_DEVICE
|
| 252 |
+
bool is_source_needed() const { return true; }
|
| 253 |
+
|
| 254 |
+
/// Functionally required for serial reduction in the epilogue
|
| 255 |
+
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
|
| 256 |
+
CUTLASS_HOST_DEVICE
|
| 257 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 258 |
+
if (k_partition) {
|
| 259 |
+
beta_ = ElementCompute(1);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
if (k_partition != k_partition_count - 1) {
|
| 263 |
+
skip_elementwise_ = true;
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
/// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual))
|
| 268 |
+
CUTLASS_HOST_DEVICE
|
| 269 |
+
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
| 270 |
+
FragmentC const &residual,
|
| 271 |
+
FragmentCompute const &bias) const {
|
| 272 |
+
UnaryOp unary_op;
|
| 273 |
+
BinaryOp binary_op;
|
| 274 |
+
ActivationOp activation;
|
| 275 |
+
|
| 276 |
+
FragmentCompute tmp_Accum =
|
| 277 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
| 278 |
+
FragmentCompute tmp_residual =
|
| 279 |
+
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual);
|
| 280 |
+
|
| 281 |
+
FragmentCompute z =
|
| 282 |
+
binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual);
|
| 283 |
+
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
| 284 |
+
|
| 285 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
| 286 |
+
frag_Z = convert_z(result_Z);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
/// Should never be called
|
| 290 |
+
CUTLASS_HOST_DEVICE
|
| 291 |
+
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
|
| 292 |
+
FragmentCompute const &) const {}
|
| 293 |
+
};
|
| 294 |
+
|
| 295 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 296 |
+
|
| 297 |
+
} // namespace thread
|
| 298 |
+
} // namespace epilogue
|
| 299 |
+
} // namespace cutlass
|
| 300 |
+
|
| 301 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination with Sigmoid operations used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 39 |
+
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
| 40 |
+
|
| 41 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace epilogue {
|
| 45 |
+
namespace thread {
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
/// Applies a linear combination operator followed by the Sigmoid activation, to an array of elements.
|
| 50 |
+
///
|
| 51 |
+
/// D = sigmoid(alpha * accumulator + beta * source + uniform)
|
| 52 |
+
///
|
| 53 |
+
template <
|
| 54 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 55 |
+
int Count, ///< Number of elements computed per operation
|
| 56 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 57 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 58 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 59 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 60 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 61 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 62 |
+
>
|
| 63 |
+
using LinearCombinationSigmoid = LinearCombinationGeneric<Sigmoid, ElementOutput_, Count, ElementAccumulator_,
|
| 64 |
+
ElementCompute_, Scale, Round, true>;
|
| 65 |
+
|
| 66 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 67 |
+
|
| 68 |
+
} // namespace thread
|
| 69 |
+
} // namespace epilogue
|
| 70 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination with SiLU operations used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 39 |
+
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
| 40 |
+
|
| 41 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace epilogue {
|
| 45 |
+
namespace thread {
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
/// Applies a linear combination operator folllowed by the SiLU activation to an array of elements.
|
| 50 |
+
///
|
| 51 |
+
/// D = silu(alpha * accumulator + beta * source + uniform)
|
| 52 |
+
///
|
| 53 |
+
template <
|
| 54 |
+
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 55 |
+
int Count, ///< Number of elements computed per operation
|
| 56 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 57 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 58 |
+
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 59 |
+
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 60 |
+
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
| 61 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 62 |
+
>
|
| 63 |
+
using LinearCombinationSilu = LinearCombinationGeneric<SiLu, ElementOutput_, Count, ElementAccumulator_,
|
| 64 |
+
ElementCompute_, Scale, Round, true>;
|
| 65 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 66 |
+
|
| 67 |
+
} // namespace thread
|
| 68 |
+
} // namespace epilogue
|
| 69 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Functor performing linear combination operation, bias addition, and tensor-tensor
|
| 34 |
+
elementwise operations
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/numeric_types.h"
|
| 44 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 45 |
+
#include "cutlass/epilogue/thread/detail.hpp"
|
| 46 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace epilogue {
|
| 52 |
+
namespace thread {
|
| 53 |
+
|
| 54 |
+
namespace detail {
|
| 55 |
+
|
| 56 |
+
/// Returns whether a source operand is needed for a combination of binary operation and scale
|
| 57 |
+
/// type. Simple specialized checks are made for cases in which 0 is an identity element of
|
| 58 |
+
/// the binary operation.
|
| 59 |
+
template <class BinaryOp, class ElementCompute, ScaleType::Kind Scale>
|
| 60 |
+
CUTLASS_HOST_DEVICE
|
| 61 |
+
bool is_binary_op_source_needed(ElementCompute scale) {
|
| 62 |
+
if constexpr (cute::is_same_v<BinaryOp, NoOp<ElementCompute>>) {
|
| 63 |
+
return false;
|
| 64 |
+
}
|
| 65 |
+
else if constexpr (cute::is_same_v<BinaryOp, plus<ElementCompute>> || cute::is_same_v<BinaryOp, minus<ElementCompute>>) {
|
| 66 |
+
// Cases for binary operators for which 0 is an identity element
|
| 67 |
+
if constexpr (Scale == ScaleType::NoBetaScaling) return true;
|
| 68 |
+
|
| 69 |
+
if constexpr (Scale == ScaleType::OnlyAlphaScaling) return false;
|
| 70 |
+
|
| 71 |
+
if constexpr (Scale == ScaleType::Nothing) return false;
|
| 72 |
+
|
| 73 |
+
return scale != ElementCompute(0);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return true;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
} // namespace detail
|
| 80 |
+
|
| 81 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 82 |
+
|
| 83 |
+
/** Compute a tensor-tensor broadcast epilogue.
|
| 84 |
+
*
|
| 85 |
+
* @param ElementOutput_ Data type used to load and store tensors
|
| 86 |
+
* @param ElementAccumulator_ Accumulator data type
|
| 87 |
+
* @param ElementCompute_ Data type used to compute linear combination
|
| 88 |
+
* @param ElementBias_ Data type of Bias elements
|
| 89 |
+
* @param ActivationFunctor_ Fused Activation
|
| 90 |
+
* @param BinaryOp0_ Binary operation to perform on O0 and C0. detail::NoOp means no operation
|
| 91 |
+
* @param BinaryOp1_ Binary operation to perform on O1 and C1. detail::NoOp means no operation
|
| 92 |
+
* @param UnaryOp_ Unary operation to perform on final result
|
| 93 |
+
* @param Scale Controls the type of Alpha and Beta scaling to perform
|
| 94 |
+
* @param Round How values should be rounded in conversions
|
| 95 |
+
* @param ElementSource_ Data type used for source operands
|
| 96 |
+
*
|
| 97 |
+
* Computes the following:
|
| 98 |
+
* O0 = alpha * accumulator + bias
|
| 99 |
+
* O1 = BinaryOp0(O0, beta * C0)
|
| 100 |
+
* O2 = BinaryOp1(O1, beta * C1)
|
| 101 |
+
* D = UnaryOp(O2)
|
| 102 |
+
*/
|
| 103 |
+
template <
|
| 104 |
+
class ElementOutput_,
|
| 105 |
+
class ElementAccumulator_ = ElementOutput_,
|
| 106 |
+
class ElementCompute_ = ElementOutput_,
|
| 107 |
+
class ElementBias_ = ElementCompute_,
|
| 108 |
+
template <class T> class ActivationFunctor_ = Identity,
|
| 109 |
+
template <class T> class BinaryOp0_ = plus,
|
| 110 |
+
template <class T> class BinaryOp1_ = detail::NoOp,
|
| 111 |
+
template <class T> class UnaryOp_ = Identity,
|
| 112 |
+
ScaleType::Kind Scale = ScaleType::Default,
|
| 113 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
| 114 |
+
class ElementSource_ = ElementOutput_
|
| 115 |
+
>
|
| 116 |
+
class LinearCombinationTensorBroadcast {
|
| 117 |
+
public:
|
| 118 |
+
|
| 119 |
+
using ElementOutput = ElementOutput_;
|
| 120 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 121 |
+
using ElementCompute = ElementCompute_;
|
| 122 |
+
using ElementScalar = ElementCompute;
|
| 123 |
+
using ElementBias = ElementBias_;
|
| 124 |
+
using ElementC = ElementSource_;
|
| 125 |
+
using ElementD = ElementOutput_;
|
| 126 |
+
using ElementScalingFactor = ElementAccumulator_;
|
| 127 |
+
|
| 128 |
+
using UnaryOp = UnaryOp_<ElementCompute>;
|
| 129 |
+
using BinaryOp0 = BinaryOp0_<ElementCompute>;
|
| 130 |
+
using BinaryOp1 = BinaryOp1_<ElementCompute>;
|
| 131 |
+
using ActivationFunctor = ActivationFunctor_<ElementCompute>;
|
| 132 |
+
|
| 133 |
+
static constexpr int kCount = 1;
|
| 134 |
+
static constexpr ScaleType::Kind kScale = Scale;
|
| 135 |
+
|
| 136 |
+
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 137 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 138 |
+
using ComputeFragment = Array<ElementCompute, kCount>;
|
| 139 |
+
using FragmentBias = Array<ElementBias, kCount>;
|
| 140 |
+
|
| 141 |
+
static constexpr FloatRoundStyle kRound = Round;
|
| 142 |
+
using NoOpType = detail::NoOp<ElementCompute>;
|
| 143 |
+
static constexpr bool IsBinaryOp0Enabled = !cute::is_same_v<BinaryOp0, NoOpType>;
|
| 144 |
+
static constexpr bool IsBinaryOp1Enabled = !cute::is_same_v<BinaryOp1, NoOpType>;
|
| 145 |
+
static constexpr bool IsUnaryOpEnabled = !cute::is_same_v<UnaryOp, NoOpType> && !cute::is_same_v<UnaryOp, Identity<ElementCompute>>;
|
| 146 |
+
|
| 147 |
+
/// Host-constructable parameters structure
|
| 148 |
+
struct Params {
|
| 149 |
+
|
| 150 |
+
ElementCompute alpha{}; ///< scales accumulators
|
| 151 |
+
ElementCompute beta{}; ///< scales source tensor
|
| 152 |
+
ElementCompute const* alpha_ptr = nullptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 153 |
+
ElementCompute const* beta_ptr = nullptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 154 |
+
|
| 155 |
+
//
|
| 156 |
+
// Methods
|
| 157 |
+
//
|
| 158 |
+
Params() = default;
|
| 159 |
+
|
| 160 |
+
CUTLASS_HOST_DEVICE
|
| 161 |
+
Params(ElementCompute const* alpha_ptr, ElementCompute const* beta_ptr)
|
| 162 |
+
: alpha_ptr(alpha_ptr),
|
| 163 |
+
beta_ptr(beta_ptr) {}
|
| 164 |
+
|
| 165 |
+
CUTLASS_HOST_DEVICE
|
| 166 |
+
Params(ElementCompute const* alpha_ptr)
|
| 167 |
+
: alpha_ptr(alpha_ptr) {}
|
| 168 |
+
|
| 169 |
+
CUTLASS_HOST_DEVICE
|
| 170 |
+
Params(ElementCompute alpha,
|
| 171 |
+
ElementCompute beta)
|
| 172 |
+
: alpha(alpha),
|
| 173 |
+
beta(beta) {}
|
| 174 |
+
};
|
| 175 |
+
|
| 176 |
+
private:
|
| 177 |
+
//
|
| 178 |
+
// Data members
|
| 179 |
+
//
|
| 180 |
+
|
| 181 |
+
ElementCompute alpha_;
|
| 182 |
+
ElementCompute beta_;
|
| 183 |
+
|
| 184 |
+
public:
|
| 185 |
+
|
| 186 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 187 |
+
CUTLASS_HOST_DEVICE
|
| 188 |
+
LinearCombinationTensorBroadcast(Params const& params)
|
| 189 |
+
: alpha_(params.alpha_ptr ? *params.alpha_ptr : params.alpha),
|
| 190 |
+
beta_(params.beta_ptr ? *params.beta_ptr : params.beta) {}
|
| 191 |
+
|
| 192 |
+
/// Returns true if source 0 is needed
|
| 193 |
+
CUTLASS_HOST_DEVICE
|
| 194 |
+
bool is_source0_needed() const {
|
| 195 |
+
return detail::is_binary_op_source_needed<BinaryOp0, ElementCompute, Scale>(beta_);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
/// Returns true if source 1 is needed
|
| 199 |
+
CUTLASS_HOST_DEVICE
|
| 200 |
+
bool is_source1_needed() const {
|
| 201 |
+
return detail::is_binary_op_source_needed<BinaryOp1, ElementCompute, Scale>(beta_);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
//
|
| 205 |
+
// Specialization for scalar
|
| 206 |
+
//
|
| 207 |
+
CUTLASS_HOST_DEVICE
|
| 208 |
+
ElementD operator()(ElementAccumulator const accumulator, ElementC const source0, ElementC source1, ElementBias const bias) {
|
| 209 |
+
// Convert everything to Compute type, do compute, and then store to output type
|
| 210 |
+
NumericConverter<ElementCompute, ElementAccumulator, Round> accumulator_converter;
|
| 211 |
+
NumericConverter<ElementCompute, ElementBias, Round> bias_converter;
|
| 212 |
+
NumericConverter<ElementCompute, ElementC, Round> source_converter;
|
| 213 |
+
NumericConverter<ElementD, ElementCompute, Round> destination_converter;
|
| 214 |
+
|
| 215 |
+
ActivationFunctor act;
|
| 216 |
+
multiplies<ElementCompute> mul;
|
| 217 |
+
multiply_add<ElementCompute> madd;
|
| 218 |
+
|
| 219 |
+
ElementCompute intermediate = accumulator_converter(accumulator);
|
| 220 |
+
intermediate = madd(alpha_, intermediate, bias_converter(bias));
|
| 221 |
+
intermediate = act(intermediate);
|
| 222 |
+
|
| 223 |
+
// Apply BinaryOp0, if needed
|
| 224 |
+
if constexpr (IsBinaryOp0Enabled) {
|
| 225 |
+
BinaryOp0 bin0;
|
| 226 |
+
ElementCompute converted_source = source_converter(source0);
|
| 227 |
+
intermediate = bin0(intermediate, mul(beta_, converted_source));
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
// Apply BinaryOp1, if needed
|
| 231 |
+
if constexpr (IsBinaryOp1Enabled) {
|
| 232 |
+
BinaryOp1 bin1;
|
| 233 |
+
ElementCompute converted_source = source_converter(source1);
|
| 234 |
+
intermediate = bin1(intermediate, mul(beta_, converted_source));
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
// Apply UnaryOp, if needed
|
| 238 |
+
if constexpr (IsUnaryOpEnabled) {
|
| 239 |
+
UnaryOp unary;
|
| 240 |
+
intermediate = unary(intermediate);
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
return destination_converter(intermediate);
|
| 244 |
+
}
|
| 245 |
+
};
|
| 246 |
+
|
| 247 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 248 |
+
|
| 249 |
+
} // namespace thread
|
| 250 |
+
} // namespace epilogue
|
| 251 |
+
} // namespace cutlass
|
| 252 |
+
|
| 253 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
|
| 33 |
+
\brief Functor performing linear combination with elementwise
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/half.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/array.h"
|
| 42 |
+
#include "cutlass/constants.h"
|
| 43 |
+
#include "cutlass/fast_math.h"
|
| 44 |
+
#include "cutlass/functional.h"
|
| 45 |
+
#include "cutlass/numeric_conversion.h"
|
| 46 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace epilogue {
|
| 52 |
+
namespace thread {
|
| 53 |
+
|
| 54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
/// Applies a linear combination operator to an array of elements.
|
| 57 |
+
///
|
| 58 |
+
/// D = alpha * accumulator + beta * source + uniform
|
| 59 |
+
///
|
| 60 |
+
template <
|
| 61 |
+
typename ElementCompute_, ///< Data type returned by this functor
|
| 62 |
+
typename ElementAccumulator_, ///< Data type of accumulators
|
| 63 |
+
typename ElementSource_, ///< Data type of source tensor
|
| 64 |
+
typename ElementTensor_, ///< Data type of additional tensor
|
| 65 |
+
int Count, ///< Number of elements computed per operation
|
| 66 |
+
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 67 |
+
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 68 |
+
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 69 |
+
>
|
| 70 |
+
class LinearCombinationWithElementwise {
|
| 71 |
+
public:
|
| 72 |
+
|
| 73 |
+
using ElementOutput = ElementSource_;
|
| 74 |
+
using ElementCompute = ElementCompute_;
|
| 75 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 76 |
+
using ElementSource = ElementSource_;
|
| 77 |
+
using ElementTensor = ElementTensor_;
|
| 78 |
+
|
| 79 |
+
static bool const kIsHeavy = true;
|
| 80 |
+
|
| 81 |
+
static int const kCount = Count;
|
| 82 |
+
|
| 83 |
+
using FragmentCompute = Array<ElementCompute, kCount>;
|
| 84 |
+
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 85 |
+
using FragmentSource = Array<ElementSource, kCount>;
|
| 86 |
+
using FragmentTensor = Array<ElementTensor, kCount>;
|
| 87 |
+
|
| 88 |
+
static FloatRoundStyle const kRound = Round;
|
| 89 |
+
|
| 90 |
+
/// Host-constructable parameters structure
|
| 91 |
+
struct Params {
|
| 92 |
+
|
| 93 |
+
ElementCompute alpha; ///< scales accumulators
|
| 94 |
+
ElementCompute beta; ///< scales source tensor
|
| 95 |
+
ElementCompute threshold; ///< minimum value that is output
|
| 96 |
+
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
| 97 |
+
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
| 98 |
+
//
|
| 99 |
+
// Methods
|
| 100 |
+
//
|
| 101 |
+
|
| 102 |
+
CUTLASS_HOST_DEVICE
|
| 103 |
+
Params():
|
| 104 |
+
alpha(ElementCompute(1)),
|
| 105 |
+
beta(ElementCompute(0)),
|
| 106 |
+
threshold(ElementCompute(0)),
|
| 107 |
+
alpha_ptr(nullptr),
|
| 108 |
+
beta_ptr(nullptr) { }
|
| 109 |
+
|
| 110 |
+
CUTLASS_HOST_DEVICE
|
| 111 |
+
Params(
|
| 112 |
+
ElementCompute alpha,
|
| 113 |
+
ElementCompute beta,
|
| 114 |
+
ElementCompute threshold = ElementCompute(0)
|
| 115 |
+
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
| 116 |
+
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
CUTLASS_HOST_DEVICE
|
| 120 |
+
Params(
|
| 121 |
+
ElementCompute const *alpha_ptr,
|
| 122 |
+
ElementCompute const *beta_ptr,
|
| 123 |
+
ElementCompute threshold = ElementCompute(0)
|
| 124 |
+
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
| 125 |
+
|
| 126 |
+
}
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
private:
|
| 130 |
+
|
| 131 |
+
//
|
| 132 |
+
// Data members
|
| 133 |
+
//
|
| 134 |
+
|
| 135 |
+
ElementCompute alpha_;
|
| 136 |
+
ElementCompute beta_;
|
| 137 |
+
ElementCompute threshold_;
|
| 138 |
+
bool participates_in_reduction_;
|
| 139 |
+
|
| 140 |
+
public:
|
| 141 |
+
|
| 142 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 143 |
+
CUTLASS_HOST_DEVICE
|
| 144 |
+
LinearCombinationWithElementwise(Params const ¶ms) {
|
| 145 |
+
|
| 146 |
+
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
| 147 |
+
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
| 148 |
+
threshold_ = params.threshold;
|
| 149 |
+
participates_in_reduction_ = true;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/// Returns true if source is needed
|
| 153 |
+
CUTLASS_HOST_DEVICE
|
| 154 |
+
bool is_source_needed() const {
|
| 155 |
+
return beta_ != ElementCompute(0);
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
/// Returns true if the threadblock computes the reduction
|
| 159 |
+
CUTLASS_HOST_DEVICE
|
| 160 |
+
bool participates_in_reduction() const {
|
| 161 |
+
return participates_in_reduction_;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/// Functionally required for serial reduction in the epilogue
|
| 165 |
+
CUTLASS_HOST_DEVICE
|
| 166 |
+
void set_k_partition(int k_partition, int k_partition_count) {
|
| 167 |
+
if (k_partition) {
|
| 168 |
+
beta_ = ElementCompute(1);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
if (k_partition != k_partition_count - 1) {
|
| 172 |
+
// set to NaN to make ReLU no-op for all except last k partitions
|
| 173 |
+
int64_t allones = -1;
|
| 174 |
+
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
|
| 175 |
+
// Avoid computing the reduction if this isn't the final Split-K slice
|
| 176 |
+
participates_in_reduction_ = false;
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 181 |
+
CUTLASS_HOST_DEVICE
|
| 182 |
+
FragmentCompute operator()(
|
| 183 |
+
FragmentAccumulator const &accumulator,
|
| 184 |
+
FragmentSource const &source,
|
| 185 |
+
FragmentTensor const &tensor) const {
|
| 186 |
+
|
| 187 |
+
// Convert source to interal compute numeric type
|
| 188 |
+
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
|
| 189 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 190 |
+
|
| 191 |
+
FragmentCompute converted_source = source_converter(source);
|
| 192 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 193 |
+
|
| 194 |
+
// Perform binary operations
|
| 195 |
+
FragmentCompute intermediate;
|
| 196 |
+
|
| 197 |
+
multiplies<FragmentCompute> mul_add_source;
|
| 198 |
+
multiply_add<FragmentCompute> mul_add_accumulator;
|
| 199 |
+
|
| 200 |
+
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
| 201 |
+
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
| 202 |
+
|
| 203 |
+
return intermediate;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
/// Computes linear scaling: D = alpha * accumulator
|
| 207 |
+
CUTLASS_HOST_DEVICE
|
| 208 |
+
FragmentCompute operator()(
|
| 209 |
+
FragmentAccumulator const &accumulator,
|
| 210 |
+
FragmentTensor const &tensor) const {
|
| 211 |
+
|
| 212 |
+
// Convert source to interal compute numeric type
|
| 213 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
| 214 |
+
|
| 215 |
+
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
| 216 |
+
|
| 217 |
+
// Perform binary operations
|
| 218 |
+
FragmentCompute intermediate;
|
| 219 |
+
|
| 220 |
+
multiplies<FragmentCompute> mul_accumulator;
|
| 221 |
+
|
| 222 |
+
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
| 223 |
+
|
| 224 |
+
return intermediate;
|
| 225 |
+
}
|
| 226 |
+
};
|
| 227 |
+
|
| 228 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 229 |
+
|
| 230 |
+
} // namespace thread
|
| 231 |
+
} // namespace epilogue
|
| 232 |
+
} // namespace cutlass
|
| 233 |
+
|
| 234 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/reduction_op.h
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing reduction operations used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/numeric_types.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/functional.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace epilogue {
|
| 47 |
+
namespace thread {
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
/// Applies a reduction sum to an array of elements.
|
| 52 |
+
///
|
| 53 |
+
///
|
| 54 |
+
template <
|
| 55 |
+
typename Element_, ///< Data type used to load and store tensors
|
| 56 |
+
int Count ///< Number of elements computed per operation
|
| 57 |
+
>
|
| 58 |
+
class ReductionOpPlus {
|
| 59 |
+
public:
|
| 60 |
+
|
| 61 |
+
using Element = Element_;
|
| 62 |
+
static int const kCount = Count;
|
| 63 |
+
|
| 64 |
+
using Fragment = Array<Element, kCount>;
|
| 65 |
+
using Operator = plus<Fragment>;
|
| 66 |
+
|
| 67 |
+
/// Host-constructable parameters structure
|
| 68 |
+
struct Params { };
|
| 69 |
+
|
| 70 |
+
private:
|
| 71 |
+
|
| 72 |
+
/// reduction operator
|
| 73 |
+
Operator operator_;
|
| 74 |
+
|
| 75 |
+
public:
|
| 76 |
+
|
| 77 |
+
/// Constructs the function object, possibly loading from pointers in host memory
|
| 78 |
+
CUTLASS_HOST_DEVICE
|
| 79 |
+
ReductionOpPlus(Params const ¶ms) {
|
| 80 |
+
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// Computes Compute =>
|
| 84 |
+
CUTLASS_HOST_DEVICE
|
| 85 |
+
Fragment operator()(
|
| 86 |
+
Fragment const &lhs,
|
| 87 |
+
Fragment const &rhs) const {
|
| 88 |
+
|
| 89 |
+
return operator_(lhs, rhs);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 94 |
+
|
| 95 |
+
} // namespace thread
|
| 96 |
+
} // namespace epilogue
|
| 97 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/scale_type.h
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Enum defines the behaviors of the epilogue.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
|
| 39 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
+
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace epilogue {
|
| 43 |
+
namespace thread {
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
/// Specifies internal data type for computation
|
| 48 |
+
/// Note :
|
| 49 |
+
/// 1. Scalar means alpha/beta is a single value from host(constant param) or device memory.
|
| 50 |
+
/// 2. Vector means alpha/beta is a vector always from device memory.
|
| 51 |
+
struct ScaleType {
|
| 52 |
+
enum Kind {
|
| 53 |
+
Default, // D = scalar_alpha x Acc + scalar_beta x C
|
| 54 |
+
NoBetaScaling, // D = scalar_alpha x Acc + C
|
| 55 |
+
OnlyAlphaScaling, // D = scalar_alpha x Acc
|
| 56 |
+
PerChannelScaling, // D = vector_alpha x Acc + vector_beta x C
|
| 57 |
+
OnlyAlphaPerChannelScaling, // D = vector_alpha x Acc
|
| 58 |
+
Nothing // D = Acc
|
| 59 |
+
};
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 63 |
+
|
| 64 |
+
} // namespace thread
|
| 65 |
+
} // namespace epilogue
|
| 66 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped complex GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/gemm/gemm.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 48 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 49 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
| 51 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 52 |
+
|
| 53 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 54 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 55 |
+
|
| 56 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 57 |
+
|
| 58 |
+
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
| 59 |
+
#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h"
|
| 60 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
| 61 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
| 62 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 63 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 64 |
+
|
| 65 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 66 |
+
|
| 67 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 68 |
+
|
| 69 |
+
namespace cutlass {
|
| 70 |
+
namespace epilogue {
|
| 71 |
+
namespace threadblock {
|
| 72 |
+
|
| 73 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 74 |
+
/// Specialization and defines sensible defaults for epilogues for complex*complex case
|
| 75 |
+
// 4 real-valued mma operations (Complex)
|
| 76 |
+
// A = (ar + j ai), B (br +j bi), D = AB
|
| 77 |
+
// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br)
|
| 78 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 79 |
+
template <
|
| 80 |
+
/// Epilogue Shape
|
| 81 |
+
typename Shape_,
|
| 82 |
+
/// Warp-level mma operator
|
| 83 |
+
typename WarpMmaTensorOp_,
|
| 84 |
+
/// Number of k partitions
|
| 85 |
+
int PartitionsK,
|
| 86 |
+
/// Epilogue output operator
|
| 87 |
+
typename OutputOp_,
|
| 88 |
+
/// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load()
|
| 89 |
+
int ElementsPerAccess,
|
| 90 |
+
/// Multiply-add operator
|
| 91 |
+
/// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
| 92 |
+
typename Operator_ = arch::OpMultiplyAddComplex
|
| 93 |
+
>
|
| 94 |
+
struct DefaultEpilogueComplexTensorOp {
|
| 95 |
+
|
| 96 |
+
using Shape = Shape_;
|
| 97 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 98 |
+
static int const kPartitionsK = PartitionsK;
|
| 99 |
+
using OutputOp = OutputOp_;
|
| 100 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 101 |
+
using Operator = Operator_;
|
| 102 |
+
|
| 103 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 104 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 105 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 106 |
+
|
| 107 |
+
//
|
| 108 |
+
// Thread map
|
| 109 |
+
//
|
| 110 |
+
|
| 111 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
| 112 |
+
Shape,
|
| 113 |
+
typename WarpMmaTensorOp::Shape,
|
| 114 |
+
kPartitionsK,
|
| 115 |
+
ElementOutput,
|
| 116 |
+
kElementsPerAccess
|
| 117 |
+
>::Type;
|
| 118 |
+
|
| 119 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 120 |
+
OutputTileThreadMap,
|
| 121 |
+
ElementOutput
|
| 122 |
+
>;
|
| 123 |
+
|
| 124 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
| 125 |
+
typename WarpMmaTensorOp::Shape,
|
| 126 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 127 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 128 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 129 |
+
LayoutC
|
| 130 |
+
>;
|
| 131 |
+
|
| 132 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 133 |
+
typename WarpMmaTensorOp::Shape,
|
| 134 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 135 |
+
ElementAccumulator,
|
| 136 |
+
LayoutC
|
| 137 |
+
>;
|
| 138 |
+
|
| 139 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 140 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 141 |
+
ElementAccumulator
|
| 142 |
+
>;
|
| 143 |
+
|
| 144 |
+
/// Hard-coded padding elements added
|
| 145 |
+
using Padding = cutlass::MatrixShape<0, 0>;
|
| 146 |
+
|
| 147 |
+
//
|
| 148 |
+
// Define the epilogue
|
| 149 |
+
//
|
| 150 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 151 |
+
Shape,
|
| 152 |
+
WarpMmaTensorOp,
|
| 153 |
+
kPartitionsK,
|
| 154 |
+
OutputTileIterator,
|
| 155 |
+
AccumulatorFragmentIterator,
|
| 156 |
+
WarpTileIterator,
|
| 157 |
+
SharedLoadIterator,
|
| 158 |
+
OutputOp,
|
| 159 |
+
Padding
|
| 160 |
+
>;
|
| 161 |
+
};
|
| 162 |
+
|
| 163 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 164 |
+
/// Partial specialization and defines sensible defaults for epilogues for complex*complex case
|
| 165 |
+
// 3 real-valued mma operations (Gaussian Complex)
|
| 166 |
+
// A = (ar + j ai), B = (br +j bi), D = AB
|
| 167 |
+
// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi)
|
| 168 |
+
// D = dr + j di = (P1 - P3) + j (P1 + P2)
|
| 169 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 170 |
+
template <
|
| 171 |
+
typename Shape_,
|
| 172 |
+
typename WarpMmaTensorOp_,
|
| 173 |
+
int PartitionsK,
|
| 174 |
+
typename OutputOp_,
|
| 175 |
+
int ElementsPerAccess
|
| 176 |
+
>
|
| 177 |
+
struct DefaultEpilogueComplexTensorOp <Shape_, WarpMmaTensorOp_, PartitionsK,
|
| 178 |
+
OutputOp_, ElementsPerAccess,
|
| 179 |
+
arch::OpMultiplyAddGaussianComplex
|
| 180 |
+
> {
|
| 181 |
+
|
| 182 |
+
using Shape = Shape_;
|
| 183 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 184 |
+
static int const kPartitionsK = PartitionsK;
|
| 185 |
+
using OutputOp = OutputOp_;
|
| 186 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 187 |
+
using Operator = arch::OpMultiplyAddGaussianComplex;
|
| 188 |
+
|
| 189 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 190 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 191 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 192 |
+
|
| 193 |
+
//
|
| 194 |
+
// Thread map
|
| 195 |
+
//
|
| 196 |
+
|
| 197 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
| 198 |
+
Shape,
|
| 199 |
+
typename WarpMmaTensorOp::Shape,
|
| 200 |
+
kPartitionsK,
|
| 201 |
+
ElementOutput,
|
| 202 |
+
kElementsPerAccess
|
| 203 |
+
>::Type;
|
| 204 |
+
|
| 205 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 206 |
+
OutputTileThreadMap,
|
| 207 |
+
ElementOutput
|
| 208 |
+
>;
|
| 209 |
+
|
| 210 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp<
|
| 211 |
+
typename WarpMmaTensorOp::Shape,
|
| 212 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 213 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 214 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 215 |
+
LayoutC
|
| 216 |
+
>;
|
| 217 |
+
|
| 218 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 219 |
+
typename WarpMmaTensorOp::Shape,
|
| 220 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 221 |
+
ElementAccumulator,
|
| 222 |
+
LayoutC
|
| 223 |
+
>;
|
| 224 |
+
|
| 225 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 226 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 227 |
+
ElementAccumulator
|
| 228 |
+
>;
|
| 229 |
+
|
| 230 |
+
/// Hard-coded padding elements added
|
| 231 |
+
using Padding = cutlass::MatrixShape<0, 0>;
|
| 232 |
+
|
| 233 |
+
//
|
| 234 |
+
// Define the epilogue
|
| 235 |
+
//
|
| 236 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 237 |
+
Shape,
|
| 238 |
+
WarpMmaTensorOp,
|
| 239 |
+
kPartitionsK,
|
| 240 |
+
OutputTileIterator,
|
| 241 |
+
AccumulatorFragmentIterator,
|
| 242 |
+
WarpTileIterator,
|
| 243 |
+
SharedLoadIterator,
|
| 244 |
+
OutputOp,
|
| 245 |
+
Padding
|
| 246 |
+
>;
|
| 247 |
+
};
|
| 248 |
+
|
| 249 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 250 |
+
|
| 251 |
+
} // namespace threadblock
|
| 252 |
+
} // namespace epilogue
|
| 253 |
+
} // namespace cutlass
|
| 254 |
+
|
| 255 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped complex GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
*/
|
| 39 |
+
|
| 40 |
+
#pragma once
|
| 41 |
+
|
| 42 |
+
#include "cutlass/cutlass.h"
|
| 43 |
+
#include "cutlass/numeric_types.h"
|
| 44 |
+
#include "cutlass/array.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/gemm/gemm.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 49 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
| 51 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
| 52 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 53 |
+
|
| 54 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 55 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 56 |
+
|
| 57 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 58 |
+
|
| 59 |
+
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
| 60 |
+
#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h"
|
| 61 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
| 62 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
| 63 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h"
|
| 64 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 65 |
+
|
| 66 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 67 |
+
|
| 68 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
|
| 70 |
+
namespace cutlass {
|
| 71 |
+
namespace epilogue {
|
| 72 |
+
namespace threadblock {
|
| 73 |
+
|
| 74 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 75 |
+
/// Specialization and defines sensible defaults for epilogues for complex*complex case
|
| 76 |
+
// 4 real-valued mma operations (Complex)
|
| 77 |
+
// A = (ar + j ai), B (br +j bi), D = AB
|
| 78 |
+
// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br)
|
| 79 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 80 |
+
template <
|
| 81 |
+
/// Epilogue Shape
|
| 82 |
+
typename Shape_,
|
| 83 |
+
/// Warp-level mma operator
|
| 84 |
+
typename WarpMmaTensorOp_,
|
| 85 |
+
/// Number of k partitions
|
| 86 |
+
int PartitionsK,
|
| 87 |
+
/// Epilogue output operator
|
| 88 |
+
typename OutputOp_,
|
| 89 |
+
/// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load()
|
| 90 |
+
int ElementsPerAccess,
|
| 91 |
+
/// Multiply-add operator
|
| 92 |
+
/// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
| 93 |
+
typename Operator_ = arch::OpMultiplyAddComplex,
|
| 94 |
+
/// Is for a symmetric kernel
|
| 95 |
+
BlasMode BlasMode_ = BlasMode::kGemm
|
| 96 |
+
>
|
| 97 |
+
struct DefaultEpilogueComplexTensorOpBlas3 {
|
| 98 |
+
|
| 99 |
+
using Shape = Shape_;
|
| 100 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 101 |
+
static int const kPartitionsK = PartitionsK;
|
| 102 |
+
using OutputOp = OutputOp_;
|
| 103 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 104 |
+
using Operator = Operator_;
|
| 105 |
+
static BlasMode const kBlasMode = BlasMode_;
|
| 106 |
+
|
| 107 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 108 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 109 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 110 |
+
|
| 111 |
+
//
|
| 112 |
+
// Thread map
|
| 113 |
+
//
|
| 114 |
+
|
| 115 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
| 116 |
+
Shape,
|
| 117 |
+
typename WarpMmaTensorOp::Shape,
|
| 118 |
+
kPartitionsK,
|
| 119 |
+
ElementOutput,
|
| 120 |
+
kElementsPerAccess
|
| 121 |
+
>::Type;
|
| 122 |
+
|
| 123 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3<
|
| 124 |
+
OutputTileThreadMap,
|
| 125 |
+
ElementOutput
|
| 126 |
+
, kBlasMode
|
| 127 |
+
>;
|
| 128 |
+
|
| 129 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
| 130 |
+
typename WarpMmaTensorOp::Shape,
|
| 131 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 132 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 133 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 134 |
+
LayoutC
|
| 135 |
+
>;
|
| 136 |
+
|
| 137 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 138 |
+
typename WarpMmaTensorOp::Shape,
|
| 139 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 140 |
+
ElementAccumulator,
|
| 141 |
+
LayoutC
|
| 142 |
+
>;
|
| 143 |
+
|
| 144 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 145 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 146 |
+
ElementAccumulator
|
| 147 |
+
>;
|
| 148 |
+
|
| 149 |
+
/// Hard-coded padding elements added
|
| 150 |
+
using Padding = cutlass::MatrixShape<0, 0>;
|
| 151 |
+
|
| 152 |
+
//
|
| 153 |
+
// Define the epilogue
|
| 154 |
+
//
|
| 155 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 156 |
+
Shape,
|
| 157 |
+
WarpMmaTensorOp,
|
| 158 |
+
kPartitionsK,
|
| 159 |
+
OutputTileIterator,
|
| 160 |
+
AccumulatorFragmentIterator,
|
| 161 |
+
WarpTileIterator,
|
| 162 |
+
SharedLoadIterator,
|
| 163 |
+
OutputOp,
|
| 164 |
+
Padding
|
| 165 |
+
>;
|
| 166 |
+
};
|
| 167 |
+
|
| 168 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 169 |
+
/// Partial specialization and defines sensible defaults for epilogues for complex*complex case
|
| 170 |
+
// 3 real-valued mma operations (Gaussian Complex)
|
| 171 |
+
// A = (ar + j ai), B = (br +j bi), D = AB
|
| 172 |
+
// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi)
|
| 173 |
+
// D = dr + j di = (P1 - P3) + j (P1 + P2)
|
| 174 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 175 |
+
template <
|
| 176 |
+
typename Shape_,
|
| 177 |
+
typename WarpMmaTensorOp_,
|
| 178 |
+
int PartitionsK,
|
| 179 |
+
typename OutputOp_,
|
| 180 |
+
int ElementsPerAccess,
|
| 181 |
+
BlasMode BlasMode_
|
| 182 |
+
>
|
| 183 |
+
struct DefaultEpilogueComplexTensorOpBlas3 <Shape_, WarpMmaTensorOp_, PartitionsK,
|
| 184 |
+
OutputOp_, ElementsPerAccess,
|
| 185 |
+
arch::OpMultiplyAddGaussianComplex
|
| 186 |
+
, BlasMode_
|
| 187 |
+
> {
|
| 188 |
+
|
| 189 |
+
using Shape = Shape_;
|
| 190 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 191 |
+
static int const kPartitionsK = PartitionsK;
|
| 192 |
+
using OutputOp = OutputOp_;
|
| 193 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 194 |
+
using Operator = arch::OpMultiplyAddGaussianComplex;
|
| 195 |
+
static BlasMode const kBlasMode = BlasMode_;
|
| 196 |
+
|
| 197 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 198 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 199 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 200 |
+
|
| 201 |
+
//
|
| 202 |
+
// Thread map
|
| 203 |
+
//
|
| 204 |
+
|
| 205 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
| 206 |
+
Shape,
|
| 207 |
+
typename WarpMmaTensorOp::Shape,
|
| 208 |
+
kPartitionsK,
|
| 209 |
+
ElementOutput,
|
| 210 |
+
kElementsPerAccess
|
| 211 |
+
>::Type;
|
| 212 |
+
|
| 213 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3<
|
| 214 |
+
OutputTileThreadMap,
|
| 215 |
+
ElementOutput,
|
| 216 |
+
kBlasMode
|
| 217 |
+
>;
|
| 218 |
+
|
| 219 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp<
|
| 220 |
+
typename WarpMmaTensorOp::Shape,
|
| 221 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 222 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 223 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 224 |
+
LayoutC
|
| 225 |
+
>;
|
| 226 |
+
|
| 227 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 228 |
+
typename WarpMmaTensorOp::Shape,
|
| 229 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 230 |
+
ElementAccumulator,
|
| 231 |
+
LayoutC
|
| 232 |
+
>;
|
| 233 |
+
|
| 234 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 235 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 236 |
+
ElementAccumulator
|
| 237 |
+
>;
|
| 238 |
+
|
| 239 |
+
/// Hard-coded padding elements added
|
| 240 |
+
using Padding = cutlass::MatrixShape<0, 0>;
|
| 241 |
+
|
| 242 |
+
//
|
| 243 |
+
// Define the epilogue
|
| 244 |
+
//
|
| 245 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 246 |
+
Shape,
|
| 247 |
+
WarpMmaTensorOp,
|
| 248 |
+
kPartitionsK,
|
| 249 |
+
OutputTileIterator,
|
| 250 |
+
AccumulatorFragmentIterator,
|
| 251 |
+
WarpTileIterator,
|
| 252 |
+
SharedLoadIterator,
|
| 253 |
+
OutputOp,
|
| 254 |
+
Padding
|
| 255 |
+
>;
|
| 256 |
+
};
|
| 257 |
+
|
| 258 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 259 |
+
|
| 260 |
+
} // namespace threadblock
|
| 261 |
+
} // namespace epilogue
|
| 262 |
+
} // namespace cutlass
|
| 263 |
+
|
| 264 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Direct store epilogue
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 38 |
+
|
| 39 |
+
#include "cutlass/epilogue/threadblock/epilogue_direct_store.h"
|
| 40 |
+
#include "cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h"
|
| 41 |
+
|
| 42 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
namespace epilogue {
|
| 46 |
+
namespace threadblock {
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
/// Given a properly constructed epilogue, returns a direct store epilogue
|
| 51 |
+
template <typename EpilogueTensorOp>
|
| 52 |
+
struct DefaultEpilogueDirectStore {
|
| 53 |
+
|
| 54 |
+
using OutputTileIterator = DirectStoreEpilogueIterator<typename EpilogueTensorOp::OutputTileIterator::Element>;
|
| 55 |
+
|
| 56 |
+
using Epilogue = EpilogueDirectStore<
|
| 57 |
+
typename EpilogueTensorOp::Shape,
|
| 58 |
+
typename EpilogueTensorOp::WarpMmaOperator,
|
| 59 |
+
EpilogueTensorOp::kPartitionsK,
|
| 60 |
+
OutputTileIterator,
|
| 61 |
+
typename EpilogueTensorOp::AccumulatorFragmentIterator,
|
| 62 |
+
typename EpilogueTensorOp::WarpTileIterator,
|
| 63 |
+
typename EpilogueTensorOp::SharedLoadIterator,
|
| 64 |
+
typename EpilogueTensorOp::OutputOp
|
| 65 |
+
>;
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
|
| 70 |
+
} // namespace threadblock
|
| 71 |
+
} // namespace epilogue
|
| 72 |
+
} // namespace cutlass
|
| 73 |
+
|
| 74 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Constructs a default epilogue for planar complex outputs.
|
| 33 |
+
|
| 34 |
+
This template reuses components for real-valued epilogues and applies them to planar complex
|
| 35 |
+
output matrices.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
#include "cutlass/array_planar_complex.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/arch/arch.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 49 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
| 50 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
| 51 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 52 |
+
|
| 53 |
+
#include "cutlass/epilogue/threadblock/epilogue_planar_complex.h"
|
| 54 |
+
|
| 55 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
namespace cutlass {
|
| 58 |
+
namespace epilogue {
|
| 59 |
+
namespace threadblock {
|
| 60 |
+
|
| 61 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
/// Defines sensible defaults for epilogues.
|
| 64 |
+
template <
|
| 65 |
+
typename ThreadblockShape_,
|
| 66 |
+
typename WarpMma_,
|
| 67 |
+
typename OpcodeClass_,
|
| 68 |
+
typename ArchTag_,
|
| 69 |
+
int PartitionsK,
|
| 70 |
+
typename OutputOp_,
|
| 71 |
+
int ElementsPerAccess
|
| 72 |
+
>
|
| 73 |
+
struct DefaultEpiloguePlanarComplex;
|
| 74 |
+
|
| 75 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 76 |
+
|
| 77 |
+
/// Defines sensible defaults for epilogues.
|
| 78 |
+
template <
|
| 79 |
+
typename ThreadblockShape_,
|
| 80 |
+
typename WarpMmaOperator_,
|
| 81 |
+
int PartitionsK,
|
| 82 |
+
typename OutputOp_,
|
| 83 |
+
int ElementsPerAccess
|
| 84 |
+
>
|
| 85 |
+
struct DefaultEpiloguePlanarComplex<
|
| 86 |
+
ThreadblockShape_,
|
| 87 |
+
WarpMmaOperator_,
|
| 88 |
+
arch::OpClassTensorOp,
|
| 89 |
+
arch::Sm70,
|
| 90 |
+
PartitionsK,
|
| 91 |
+
OutputOp_,
|
| 92 |
+
ElementsPerAccess> {
|
| 93 |
+
|
| 94 |
+
using RealEpilogue = DefaultEpilogueVoltaTensorOp<
|
| 95 |
+
ThreadblockShape_,
|
| 96 |
+
WarpMmaOperator_,
|
| 97 |
+
PartitionsK,
|
| 98 |
+
OutputOp_,
|
| 99 |
+
ElementsPerAccess
|
| 100 |
+
>;
|
| 101 |
+
|
| 102 |
+
using Epilogue = EpiloguePlanarComplex<
|
| 103 |
+
ThreadblockShape_,
|
| 104 |
+
WarpMmaOperator_,
|
| 105 |
+
PartitionsK,
|
| 106 |
+
typename RealEpilogue::OutputTileIterator,
|
| 107 |
+
typename RealEpilogue::AccumulatorFragmentIterator,
|
| 108 |
+
typename RealEpilogue::WarpTileIterator,
|
| 109 |
+
typename RealEpilogue::SharedLoadIterator,
|
| 110 |
+
OutputOp_,
|
| 111 |
+
typename RealEpilogue::Padding
|
| 112 |
+
>;
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 116 |
+
|
| 117 |
+
/// Defines sensible defaults for epilogues.
|
| 118 |
+
template <
|
| 119 |
+
typename ThreadblockShape_,
|
| 120 |
+
typename WarpMmaOperator_,
|
| 121 |
+
int PartitionsK,
|
| 122 |
+
typename OutputOp_,
|
| 123 |
+
int ElementsPerAccess
|
| 124 |
+
>
|
| 125 |
+
struct DefaultEpiloguePlanarComplex<
|
| 126 |
+
ThreadblockShape_,
|
| 127 |
+
WarpMmaOperator_,
|
| 128 |
+
arch::OpClassTensorOp,
|
| 129 |
+
arch::Sm75,
|
| 130 |
+
PartitionsK,
|
| 131 |
+
OutputOp_,
|
| 132 |
+
ElementsPerAccess> {
|
| 133 |
+
|
| 134 |
+
using RealEpilogue = DefaultEpilogueTensorOp<
|
| 135 |
+
ThreadblockShape_,
|
| 136 |
+
WarpMmaOperator_,
|
| 137 |
+
PartitionsK,
|
| 138 |
+
OutputOp_,
|
| 139 |
+
ElementsPerAccess
|
| 140 |
+
>;
|
| 141 |
+
|
| 142 |
+
using Epilogue = EpiloguePlanarComplex<
|
| 143 |
+
ThreadblockShape_,
|
| 144 |
+
WarpMmaOperator_,
|
| 145 |
+
PartitionsK,
|
| 146 |
+
typename RealEpilogue::OutputTileIterator,
|
| 147 |
+
typename RealEpilogue::AccumulatorFragmentIterator,
|
| 148 |
+
typename RealEpilogue::WarpTileIterator,
|
| 149 |
+
typename RealEpilogue::SharedLoadIterator,
|
| 150 |
+
OutputOp_,
|
| 151 |
+
typename RealEpilogue::Padding
|
| 152 |
+
>;
|
| 153 |
+
};
|
| 154 |
+
|
| 155 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 156 |
+
|
| 157 |
+
/// Defines sensible defaults for epilogues.
|
| 158 |
+
template <
|
| 159 |
+
typename ThreadblockShape_,
|
| 160 |
+
typename WarpMmaOperator_,
|
| 161 |
+
int PartitionsK,
|
| 162 |
+
typename OutputOp_,
|
| 163 |
+
int ElementsPerAccess
|
| 164 |
+
>
|
| 165 |
+
struct DefaultEpiloguePlanarComplex<
|
| 166 |
+
ThreadblockShape_,
|
| 167 |
+
WarpMmaOperator_,
|
| 168 |
+
arch::OpClassTensorOp,
|
| 169 |
+
arch::Sm80,
|
| 170 |
+
PartitionsK,
|
| 171 |
+
OutputOp_,
|
| 172 |
+
ElementsPerAccess> {
|
| 173 |
+
|
| 174 |
+
using RealEpilogue = DefaultEpilogueTensorOp<
|
| 175 |
+
ThreadblockShape_,
|
| 176 |
+
WarpMmaOperator_,
|
| 177 |
+
PartitionsK,
|
| 178 |
+
OutputOp_,
|
| 179 |
+
ElementsPerAccess
|
| 180 |
+
>;
|
| 181 |
+
|
| 182 |
+
using Epilogue = EpiloguePlanarComplex<
|
| 183 |
+
ThreadblockShape_,
|
| 184 |
+
WarpMmaOperator_,
|
| 185 |
+
PartitionsK,
|
| 186 |
+
typename RealEpilogue::OutputTileIterator,
|
| 187 |
+
typename RealEpilogue::AccumulatorFragmentIterator,
|
| 188 |
+
typename RealEpilogue::WarpTileIterator,
|
| 189 |
+
typename RealEpilogue::SharedLoadIterator,
|
| 190 |
+
OutputOp_,
|
| 191 |
+
typename RealEpilogue::Padding
|
| 192 |
+
>;
|
| 193 |
+
};
|
| 194 |
+
|
| 195 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 196 |
+
|
| 197 |
+
/// Defines sensible defaults for epilogues.
|
| 198 |
+
template <
|
| 199 |
+
typename ThreadblockShape_,
|
| 200 |
+
typename WarpMmaOperator_,
|
| 201 |
+
typename ArchTag_,
|
| 202 |
+
int PartitionsK,
|
| 203 |
+
typename OutputOp_,
|
| 204 |
+
int ElementsPerAccess
|
| 205 |
+
>
|
| 206 |
+
struct DefaultEpiloguePlanarComplex<
|
| 207 |
+
ThreadblockShape_,
|
| 208 |
+
WarpMmaOperator_,
|
| 209 |
+
arch::OpClassSimt,
|
| 210 |
+
ArchTag_,
|
| 211 |
+
PartitionsK,
|
| 212 |
+
OutputOp_,
|
| 213 |
+
ElementsPerAccess> {
|
| 214 |
+
|
| 215 |
+
using RealEpilogue = DefaultEpilogueSimt<
|
| 216 |
+
ThreadblockShape_,
|
| 217 |
+
WarpMmaOperator_,
|
| 218 |
+
OutputOp_,
|
| 219 |
+
ElementsPerAccess
|
| 220 |
+
>;
|
| 221 |
+
|
| 222 |
+
using Epilogue = EpiloguePlanarComplex<
|
| 223 |
+
ThreadblockShape_,
|
| 224 |
+
WarpMmaOperator_,
|
| 225 |
+
PartitionsK,
|
| 226 |
+
typename RealEpilogue::OutputTileIterator,
|
| 227 |
+
typename RealEpilogue::AccumulatorFragmentIterator,
|
| 228 |
+
typename RealEpilogue::WarpTileIterator,
|
| 229 |
+
typename RealEpilogue::SharedLoadIterator,
|
| 230 |
+
OutputOp_,
|
| 231 |
+
typename RealEpilogue::Padding
|
| 232 |
+
>;
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 236 |
+
|
| 237 |
+
} // namespace threadblock
|
| 238 |
+
} // namespace epilogue
|
| 239 |
+
} // namespace cutlass
|
| 240 |
+
|
| 241 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using SIMT.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/arch/mma.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/gemm/gemm.h"
|
| 48 |
+
#include "cutlass/gemm/warp/mma.h"
|
| 49 |
+
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 51 |
+
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
| 52 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 53 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
| 54 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
| 55 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 56 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 57 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 58 |
+
|
| 59 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 60 |
+
|
| 61 |
+
#include "cutlass/epilogue/warp/fragment_iterator_simt.h"
|
| 62 |
+
#include "cutlass/epilogue/warp/tile_iterator_simt.h"
|
| 63 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_simt.h"
|
| 64 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 65 |
+
|
| 66 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 67 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h"
|
| 68 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
| 69 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
| 70 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h"
|
| 71 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 72 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h"
|
| 73 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 74 |
+
#include "cutlass/epilogue/threadblock/epilogue_depthwise.h"
|
| 75 |
+
|
| 76 |
+
#include "cutlass/layout/permute.h"
|
| 77 |
+
|
| 78 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 79 |
+
|
| 80 |
+
namespace cutlass {
|
| 81 |
+
namespace epilogue {
|
| 82 |
+
namespace threadblock {
|
| 83 |
+
|
| 84 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 85 |
+
|
| 86 |
+
/// Defines sensible defaults for epilogues for SimtOps.
|
| 87 |
+
template <
|
| 88 |
+
typename Shape_,
|
| 89 |
+
typename WarpMmaSimt_,
|
| 90 |
+
typename OutputOp_,
|
| 91 |
+
int ElementsPerAccess,
|
| 92 |
+
bool ScatterD = false,
|
| 93 |
+
typename PermuteDLayout = layout::NoPermute,
|
| 94 |
+
conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity,
|
| 95 |
+
int Rank = 4
|
| 96 |
+
>
|
| 97 |
+
struct DefaultEpilogueSimt {
|
| 98 |
+
|
| 99 |
+
using Shape = Shape_;
|
| 100 |
+
using WarpMmaSimt = WarpMmaSimt_;
|
| 101 |
+
using OutputOp = OutputOp_;
|
| 102 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 103 |
+
static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
|
| 104 |
+
|
| 105 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 106 |
+
using LayoutC = typename WarpMmaSimt::LayoutC;
|
| 107 |
+
using ElementAccumulator = typename WarpMmaSimt::ElementC;
|
| 108 |
+
static conv::StrideSupport const kStrideSupport = StrideSupport;
|
| 109 |
+
static int const kRank = Rank;
|
| 110 |
+
|
| 111 |
+
//
|
| 112 |
+
// Thread map
|
| 113 |
+
//
|
| 114 |
+
|
| 115 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
|
| 116 |
+
Shape,
|
| 117 |
+
typename WarpMmaSimt::Shape,
|
| 118 |
+
typename WarpMmaSimt::Policy,
|
| 119 |
+
kPartitionsK,
|
| 120 |
+
ElementOutput,
|
| 121 |
+
kElementsPerAccess
|
| 122 |
+
>::Type;
|
| 123 |
+
|
| 124 |
+
static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
|
| 125 |
+
|
| 126 |
+
using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 127 |
+
OutputTileThreadMap,
|
| 128 |
+
ElementOutput,
|
| 129 |
+
ScatterD,
|
| 130 |
+
PermuteDLayout,
|
| 131 |
+
UseCUDAStore
|
| 132 |
+
>;
|
| 133 |
+
|
| 134 |
+
using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv<
|
| 135 |
+
OutputTileThreadMap,
|
| 136 |
+
ElementOutput,
|
| 137 |
+
ScatterD,
|
| 138 |
+
PermuteDLayout,
|
| 139 |
+
UseCUDAStore,
|
| 140 |
+
kRank
|
| 141 |
+
>;
|
| 142 |
+
|
| 143 |
+
using OutputTileIterator = typename platform::conditional<StrideSupport == cutlass::conv::StrideSupport::kUnity,
|
| 144 |
+
PackedOutputTileIterator,
|
| 145 |
+
StridedOutputTileIterator>::type;
|
| 146 |
+
|
| 147 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
| 148 |
+
typename WarpMmaSimt::Shape,
|
| 149 |
+
typename WarpMmaSimt::ThreadMma,
|
| 150 |
+
layout::RowMajor,
|
| 151 |
+
typename WarpMmaSimt::Policy
|
| 152 |
+
>;
|
| 153 |
+
|
| 154 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
|
| 155 |
+
typename WarpMmaSimt::Shape,
|
| 156 |
+
typename WarpMmaSimt::ThreadMma,
|
| 157 |
+
ElementAccumulator,
|
| 158 |
+
layout::RowMajor,
|
| 159 |
+
typename WarpMmaSimt::Policy
|
| 160 |
+
>;
|
| 161 |
+
|
| 162 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 163 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 164 |
+
ElementAccumulator
|
| 165 |
+
>;
|
| 166 |
+
|
| 167 |
+
/// Hard-coded padding elements added
|
| 168 |
+
using Padding = typename WarpTileIterator::Padding;
|
| 169 |
+
|
| 170 |
+
//
|
| 171 |
+
// Define the epilogue
|
| 172 |
+
//
|
| 173 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 174 |
+
Shape,
|
| 175 |
+
WarpMmaSimt,
|
| 176 |
+
kPartitionsK,
|
| 177 |
+
OutputTileIterator,
|
| 178 |
+
AccumulatorFragmentIterator,
|
| 179 |
+
WarpTileIterator,
|
| 180 |
+
SharedLoadIterator,
|
| 181 |
+
OutputOp,
|
| 182 |
+
Padding
|
| 183 |
+
>;
|
| 184 |
+
};
|
| 185 |
+
|
| 186 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 187 |
+
|
| 188 |
+
/// Defines sensible defaults for epilogues for SimtOps.
|
| 189 |
+
template <
|
| 190 |
+
typename Shape_,
|
| 191 |
+
typename WarpMmaSimt_,
|
| 192 |
+
typename OutputOp_,
|
| 193 |
+
int ElementsPerAccess
|
| 194 |
+
>
|
| 195 |
+
struct DefaultEpilogueSimtStridedDgrad {
|
| 196 |
+
|
| 197 |
+
using Shape = Shape_;
|
| 198 |
+
using WarpMmaSimt = WarpMmaSimt_;
|
| 199 |
+
using OutputOp = OutputOp_;
|
| 200 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 201 |
+
static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
|
| 202 |
+
|
| 203 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 204 |
+
using LayoutC = typename WarpMmaSimt::LayoutC;
|
| 205 |
+
using ElementAccumulator = typename WarpMmaSimt::ElementC;
|
| 206 |
+
|
| 207 |
+
//
|
| 208 |
+
// Thread map
|
| 209 |
+
//
|
| 210 |
+
|
| 211 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
|
| 212 |
+
Shape,
|
| 213 |
+
typename WarpMmaSimt::Shape,
|
| 214 |
+
typename WarpMmaSimt::Policy,
|
| 215 |
+
kPartitionsK,
|
| 216 |
+
ElementOutput,
|
| 217 |
+
kElementsPerAccess
|
| 218 |
+
>::Type;
|
| 219 |
+
|
| 220 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
|
| 221 |
+
OutputTileThreadMap,
|
| 222 |
+
ElementOutput
|
| 223 |
+
>;
|
| 224 |
+
|
| 225 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
| 226 |
+
typename WarpMmaSimt::Shape,
|
| 227 |
+
typename WarpMmaSimt::ThreadMma,
|
| 228 |
+
layout::RowMajor,
|
| 229 |
+
typename WarpMmaSimt::Policy
|
| 230 |
+
>;
|
| 231 |
+
|
| 232 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
|
| 233 |
+
typename WarpMmaSimt::Shape,
|
| 234 |
+
typename WarpMmaSimt::ThreadMma,
|
| 235 |
+
ElementAccumulator,
|
| 236 |
+
layout::RowMajor,
|
| 237 |
+
typename WarpMmaSimt::Policy
|
| 238 |
+
>;
|
| 239 |
+
|
| 240 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 241 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 242 |
+
ElementAccumulator
|
| 243 |
+
>;
|
| 244 |
+
|
| 245 |
+
/// Hard-coded padding elements added
|
| 246 |
+
using Padding = typename WarpTileIterator::Padding;
|
| 247 |
+
|
| 248 |
+
//
|
| 249 |
+
// Define the epilogue
|
| 250 |
+
//
|
| 251 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 252 |
+
Shape,
|
| 253 |
+
WarpMmaSimt,
|
| 254 |
+
kPartitionsK,
|
| 255 |
+
OutputTileIterator,
|
| 256 |
+
AccumulatorFragmentIterator,
|
| 257 |
+
WarpTileIterator,
|
| 258 |
+
SharedLoadIterator,
|
| 259 |
+
OutputOp,
|
| 260 |
+
Padding
|
| 261 |
+
>;
|
| 262 |
+
};
|
| 263 |
+
|
| 264 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 265 |
+
|
| 266 |
+
/// Defines sensible defaults for epilogues for SimtOps.
|
| 267 |
+
template <
|
| 268 |
+
int Rank,
|
| 269 |
+
typename Shape_,
|
| 270 |
+
typename WarpMmaSimt_,
|
| 271 |
+
typename OutputOp_,
|
| 272 |
+
int ElementsPerAccess
|
| 273 |
+
>
|
| 274 |
+
struct DefaultEpilogueSimtAffineRankN {
|
| 275 |
+
|
| 276 |
+
using Shape = Shape_;
|
| 277 |
+
using WarpMmaSimt = WarpMmaSimt_;
|
| 278 |
+
using OutputOp = OutputOp_;
|
| 279 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 280 |
+
static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
|
| 281 |
+
|
| 282 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 283 |
+
using LayoutC = typename WarpMmaSimt::LayoutC;
|
| 284 |
+
using ElementAccumulator = typename WarpMmaSimt::ElementC;
|
| 285 |
+
|
| 286 |
+
//
|
| 287 |
+
// Thread map
|
| 288 |
+
//
|
| 289 |
+
|
| 290 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
|
| 291 |
+
Shape,
|
| 292 |
+
typename WarpMmaSimt::Shape,
|
| 293 |
+
typename WarpMmaSimt::Policy,
|
| 294 |
+
kPartitionsK,
|
| 295 |
+
ElementOutput,
|
| 296 |
+
kElementsPerAccess
|
| 297 |
+
>::Type;
|
| 298 |
+
|
| 299 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
|
| 300 |
+
OutputTileThreadMap,
|
| 301 |
+
ElementOutput,
|
| 302 |
+
Rank
|
| 303 |
+
>;
|
| 304 |
+
|
| 305 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
| 306 |
+
typename WarpMmaSimt::Shape,
|
| 307 |
+
typename WarpMmaSimt::ThreadMma,
|
| 308 |
+
layout::RowMajor,
|
| 309 |
+
typename WarpMmaSimt::Policy
|
| 310 |
+
>;
|
| 311 |
+
|
| 312 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
|
| 313 |
+
typename WarpMmaSimt::Shape,
|
| 314 |
+
typename WarpMmaSimt::ThreadMma,
|
| 315 |
+
ElementAccumulator,
|
| 316 |
+
layout::RowMajor,
|
| 317 |
+
typename WarpMmaSimt::Policy
|
| 318 |
+
>;
|
| 319 |
+
|
| 320 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 321 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 322 |
+
ElementAccumulator
|
| 323 |
+
>;
|
| 324 |
+
|
| 325 |
+
/// Hard-coded padding elements added
|
| 326 |
+
using Padding = typename WarpTileIterator::Padding;
|
| 327 |
+
|
| 328 |
+
//
|
| 329 |
+
// Define the epilogue
|
| 330 |
+
//
|
| 331 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 332 |
+
Shape,
|
| 333 |
+
WarpMmaSimt,
|
| 334 |
+
kPartitionsK,
|
| 335 |
+
OutputTileIterator,
|
| 336 |
+
AccumulatorFragmentIterator,
|
| 337 |
+
WarpTileIterator,
|
| 338 |
+
SharedLoadIterator,
|
| 339 |
+
OutputOp,
|
| 340 |
+
Padding
|
| 341 |
+
>;
|
| 342 |
+
};
|
| 343 |
+
|
| 344 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 345 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 346 |
+
|
| 347 |
+
/// Defines sensible defaults for epilogues for SimtOps.
|
| 348 |
+
template <typename Shape_, // ThreadBlock Shape
|
| 349 |
+
typename WarpMmaSimt_, // mma_depthwise_simt
|
| 350 |
+
typename OutputOp_,
|
| 351 |
+
int ElementsPerAccess_,
|
| 352 |
+
typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>,
|
| 353 |
+
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> >
|
| 354 |
+
struct DefaultDirectConvEpilogueSimt {
|
| 355 |
+
using Shape = Shape_;
|
| 356 |
+
using WarpMmaSimt = WarpMmaSimt_;
|
| 357 |
+
using WarpShape = typename WarpMmaSimt::Shape;
|
| 358 |
+
using OutputOp = OutputOp_;
|
| 359 |
+
using ThreadOutputShape = ThreadOutputShape_;
|
| 360 |
+
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
| 361 |
+
static int const kElementsPerAccess = ElementsPerAccess_;
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 365 |
+
using LayoutC = typename WarpMmaSimt::LayoutC;
|
| 366 |
+
using ElementAccumulator = typename WarpMmaSimt::ElementC;
|
| 367 |
+
|
| 368 |
+
/// Number of threads total
|
| 369 |
+
using WarpCount = gemm::GemmShape<
|
| 370 |
+
Shape::kM / WarpShape::kM,
|
| 371 |
+
Shape::kN / WarpShape::kN
|
| 372 |
+
>;
|
| 373 |
+
|
| 374 |
+
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
|
| 375 |
+
|
| 376 |
+
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 377 |
+
|
| 378 |
+
//
|
| 379 |
+
// Thread map
|
| 380 |
+
//
|
| 381 |
+
|
| 382 |
+
using OutputTileThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
|
| 383 |
+
layout::PitchLinearShape<ThreadBlockOutputShape::kC, ThreadBlockOutputShape::kNHW>,
|
| 384 |
+
kThreads,
|
| 385 |
+
kElementsPerAccess
|
| 386 |
+
>;
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<
|
| 390 |
+
OutputTileThreadMap,
|
| 391 |
+
ElementOutput,
|
| 392 |
+
ThreadOutputShape,
|
| 393 |
+
ThreadBlockOutputShape
|
| 394 |
+
>;
|
| 395 |
+
|
| 396 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
| 397 |
+
typename WarpMmaSimt::Shape,
|
| 398 |
+
typename WarpMmaSimt::ThreadMma,
|
| 399 |
+
layout::RowMajor,
|
| 400 |
+
typename WarpMmaSimt::Policy
|
| 401 |
+
>;
|
| 402 |
+
|
| 403 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<
|
| 404 |
+
typename WarpMmaSimt::Shape,
|
| 405 |
+
ThreadOutputShape,
|
| 406 |
+
ThreadBlockOutputShape,
|
| 407 |
+
typename WarpMmaSimt::ThreadMma,
|
| 408 |
+
ElementAccumulator,
|
| 409 |
+
layout::RowMajor,
|
| 410 |
+
typename WarpMmaSimt::Policy
|
| 411 |
+
>;
|
| 412 |
+
|
| 413 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<
|
| 414 |
+
OutputTileThreadMap,
|
| 415 |
+
ElementAccumulator
|
| 416 |
+
>;
|
| 417 |
+
|
| 418 |
+
/// Hard-coded padding elements added
|
| 419 |
+
using Padding = typename WarpTileIterator::Padding;
|
| 420 |
+
//
|
| 421 |
+
// Define the epilogue
|
| 422 |
+
//
|
| 423 |
+
using Epilogue = cutlass::epilogue::threadblock::EpilogueDepthwise<
|
| 424 |
+
Shape,
|
| 425 |
+
ThreadOutputShape,
|
| 426 |
+
ThreadBlockOutputShape,
|
| 427 |
+
WarpMmaSimt,
|
| 428 |
+
OutputTileIterator,
|
| 429 |
+
AccumulatorFragmentIterator,
|
| 430 |
+
WarpTileIterator,
|
| 431 |
+
SharedLoadIterator,
|
| 432 |
+
OutputOp,
|
| 433 |
+
Padding
|
| 434 |
+
>;
|
| 435 |
+
};
|
| 436 |
+
|
| 437 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 438 |
+
|
| 439 |
+
} // namespace threadblock
|
| 440 |
+
} // namespace epilogue
|
| 441 |
+
} // namespace cutlass
|
| 442 |
+
|
| 443 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
ADDED
|
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/platform/platform.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/gemm/gemm.h"
|
| 48 |
+
|
| 49 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
| 51 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 52 |
+
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
|
| 53 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
| 54 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
| 55 |
+
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
|
| 56 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 57 |
+
|
| 58 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 59 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 60 |
+
|
| 61 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 62 |
+
|
| 63 |
+
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
| 64 |
+
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
| 65 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
| 66 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
| 67 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
| 68 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 69 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h"
|
| 70 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
| 71 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
| 72 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 73 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
| 74 |
+
|
| 75 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 76 |
+
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
| 77 |
+
|
| 78 |
+
#include "cutlass/layout/permute.h"
|
| 79 |
+
|
| 80 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 81 |
+
|
| 82 |
+
namespace cutlass {
|
| 83 |
+
namespace epilogue {
|
| 84 |
+
namespace threadblock {
|
| 85 |
+
|
| 86 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 87 |
+
|
| 88 |
+
namespace detail {
|
| 89 |
+
|
| 90 |
+
template <
|
| 91 |
+
typename ElementOutput,
|
| 92 |
+
typename ElementAccumulator,
|
| 93 |
+
int ElementsPerAccess,
|
| 94 |
+
typename ThreadblockShape,
|
| 95 |
+
typename WarpShape,
|
| 96 |
+
typename InstructionShape,
|
| 97 |
+
typename ThreadMap
|
| 98 |
+
>
|
| 99 |
+
struct DefaultIteratorsTensorOp {
|
| 100 |
+
|
| 101 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 102 |
+
WarpShape,
|
| 103 |
+
InstructionShape,
|
| 104 |
+
ElementAccumulator,
|
| 105 |
+
layout::RowMajor
|
| 106 |
+
>;
|
| 107 |
+
|
| 108 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 109 |
+
ThreadMap,
|
| 110 |
+
ElementAccumulator
|
| 111 |
+
>;
|
| 112 |
+
|
| 113 |
+
static int const kFragmentsPerIteration = 1;
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
/// Partial specialization for float <= float x 4
|
| 117 |
+
template <
|
| 118 |
+
typename ThreadblockShape,
|
| 119 |
+
typename WarpShape,
|
| 120 |
+
typename InstructionShape,
|
| 121 |
+
typename ThreadMap
|
| 122 |
+
>
|
| 123 |
+
struct DefaultIteratorsTensorOp<float, float, 4, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
|
| 124 |
+
|
| 125 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 126 |
+
WarpShape,
|
| 127 |
+
InstructionShape,
|
| 128 |
+
float,
|
| 129 |
+
layout::RowMajor
|
| 130 |
+
>;
|
| 131 |
+
|
| 132 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 133 |
+
ThreadMap,
|
| 134 |
+
float
|
| 135 |
+
>;
|
| 136 |
+
|
| 137 |
+
static int const kFragmentsPerIteration = 2;
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
/// Partial specialization for int32_t <= int32_t
|
| 141 |
+
template <
|
| 142 |
+
int ElementsPerAccess,
|
| 143 |
+
typename ThreadblockShape,
|
| 144 |
+
typename WarpShape,
|
| 145 |
+
typename InstructionShape,
|
| 146 |
+
typename ThreadMap
|
| 147 |
+
>
|
| 148 |
+
struct DefaultIteratorsTensorOp<int32_t, int32_t, ElementsPerAccess, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
|
| 149 |
+
|
| 150 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 151 |
+
WarpShape,
|
| 152 |
+
InstructionShape,
|
| 153 |
+
int32_t,
|
| 154 |
+
layout::RowMajor
|
| 155 |
+
>;
|
| 156 |
+
|
| 157 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 158 |
+
ThreadMap,
|
| 159 |
+
int32_t
|
| 160 |
+
>;
|
| 161 |
+
|
| 162 |
+
static int const kFragmentsPerIteration = 1;
|
| 163 |
+
};
|
| 164 |
+
|
| 165 |
+
/// Partial specialization for float <= int32_t
|
| 166 |
+
template <
|
| 167 |
+
int ElementsPerAccess,
|
| 168 |
+
typename ThreadblockShape,
|
| 169 |
+
typename WarpShape,
|
| 170 |
+
typename InstructionShape,
|
| 171 |
+
typename ThreadMap
|
| 172 |
+
>
|
| 173 |
+
struct DefaultIteratorsTensorOp<float, int32_t, ElementsPerAccess, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
|
| 174 |
+
|
| 175 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 176 |
+
WarpShape,
|
| 177 |
+
InstructionShape,
|
| 178 |
+
int32_t,
|
| 179 |
+
layout::RowMajor
|
| 180 |
+
>;
|
| 181 |
+
|
| 182 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 183 |
+
ThreadMap,
|
| 184 |
+
int32_t
|
| 185 |
+
>;
|
| 186 |
+
|
| 187 |
+
static int const kFragmentsPerIteration = 1;
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
/// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts.
|
| 191 |
+
template <
|
| 192 |
+
typename ThreadblockShape,
|
| 193 |
+
typename WarpShape,
|
| 194 |
+
typename InstructionShape,
|
| 195 |
+
typename ThreadMap
|
| 196 |
+
>
|
| 197 |
+
struct DefaultIteratorsTensorOp<
|
| 198 |
+
half_t,
|
| 199 |
+
float,
|
| 200 |
+
8,
|
| 201 |
+
ThreadblockShape,
|
| 202 |
+
WarpShape,
|
| 203 |
+
InstructionShape,
|
| 204 |
+
ThreadMap> {
|
| 205 |
+
|
| 206 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
| 207 |
+
WarpShape,
|
| 208 |
+
InstructionShape,
|
| 209 |
+
float,
|
| 210 |
+
32,
|
| 211 |
+
16,
|
| 212 |
+
8,
|
| 213 |
+
8
|
| 214 |
+
>;
|
| 215 |
+
|
| 216 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
| 217 |
+
ThreadMap,
|
| 218 |
+
float,
|
| 219 |
+
32,
|
| 220 |
+
16,
|
| 221 |
+
8,
|
| 222 |
+
8
|
| 223 |
+
>;
|
| 224 |
+
|
| 225 |
+
static int const kFragmentsPerIteration = 2;
|
| 226 |
+
};
|
| 227 |
+
|
| 228 |
+
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
| 229 |
+
template <
|
| 230 |
+
typename ThreadblockShape,
|
| 231 |
+
typename WarpShape,
|
| 232 |
+
typename InstructionShape,
|
| 233 |
+
typename ThreadMap
|
| 234 |
+
>
|
| 235 |
+
struct DefaultIteratorsTensorOp<
|
| 236 |
+
bfloat16_t,
|
| 237 |
+
int32_t,
|
| 238 |
+
8,
|
| 239 |
+
ThreadblockShape,
|
| 240 |
+
WarpShape,
|
| 241 |
+
InstructionShape,
|
| 242 |
+
ThreadMap> {
|
| 243 |
+
|
| 244 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
| 245 |
+
WarpShape,
|
| 246 |
+
InstructionShape,
|
| 247 |
+
int32_t,
|
| 248 |
+
32,
|
| 249 |
+
16,
|
| 250 |
+
8,
|
| 251 |
+
8
|
| 252 |
+
>;
|
| 253 |
+
|
| 254 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
| 255 |
+
ThreadMap,
|
| 256 |
+
int32_t,
|
| 257 |
+
32,
|
| 258 |
+
16,
|
| 259 |
+
8,
|
| 260 |
+
8
|
| 261 |
+
>;
|
| 262 |
+
|
| 263 |
+
static int const kFragmentsPerIteration = 2;
|
| 264 |
+
};
|
| 265 |
+
|
| 266 |
+
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
| 267 |
+
template <
|
| 268 |
+
typename ThreadblockShape,
|
| 269 |
+
typename WarpShape,
|
| 270 |
+
typename InstructionShape,
|
| 271 |
+
typename ThreadMap
|
| 272 |
+
>
|
| 273 |
+
struct DefaultIteratorsTensorOp<
|
| 274 |
+
half_t,
|
| 275 |
+
int32_t,
|
| 276 |
+
8,
|
| 277 |
+
ThreadblockShape,
|
| 278 |
+
WarpShape,
|
| 279 |
+
InstructionShape,
|
| 280 |
+
ThreadMap> {
|
| 281 |
+
|
| 282 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
| 283 |
+
WarpShape,
|
| 284 |
+
InstructionShape,
|
| 285 |
+
int32_t,
|
| 286 |
+
32,
|
| 287 |
+
16,
|
| 288 |
+
8,
|
| 289 |
+
8
|
| 290 |
+
>;
|
| 291 |
+
|
| 292 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
| 293 |
+
ThreadMap,
|
| 294 |
+
int32_t,
|
| 295 |
+
32,
|
| 296 |
+
16,
|
| 297 |
+
8,
|
| 298 |
+
8
|
| 299 |
+
>;
|
| 300 |
+
|
| 301 |
+
static int const kFragmentsPerIteration = 2;
|
| 302 |
+
};
|
| 303 |
+
|
| 304 |
+
/// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
|
| 305 |
+
/// Threadblock::kN = 256 still has bank conflicts.
|
| 306 |
+
template <
|
| 307 |
+
typename ElementOutput,
|
| 308 |
+
int ElementsPerAccess,
|
| 309 |
+
typename ThreadblockShape,
|
| 310 |
+
typename WarpShape,
|
| 311 |
+
typename InstructionShape,
|
| 312 |
+
typename ThreadMap
|
| 313 |
+
>
|
| 314 |
+
struct DefaultIteratorsTensorOp<
|
| 315 |
+
ElementOutput,
|
| 316 |
+
int32_t,
|
| 317 |
+
ElementsPerAccess,
|
| 318 |
+
ThreadblockShape,
|
| 319 |
+
WarpShape,
|
| 320 |
+
InstructionShape,
|
| 321 |
+
ThreadMap> {
|
| 322 |
+
|
| 323 |
+
static_assert(platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
| 324 |
+
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
| 325 |
+
platform::is_same<ElementOutput, int8_t>::value ||
|
| 326 |
+
platform::is_same<ElementOutput, uint8_t>::value,
|
| 327 |
+
"ElementOutput needs to be 4 or 8 bit (unsigned) int.");
|
| 328 |
+
|
| 329 |
+
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8 || ElementsPerAccess == 4),
|
| 330 |
+
"ElementsPerAccess needs to be 16 or 8.");
|
| 331 |
+
|
| 332 |
+
using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
| 333 |
+
WarpShape,
|
| 334 |
+
InstructionShape,
|
| 335 |
+
int32_t,
|
| 336 |
+
32,
|
| 337 |
+
cutlass::sizeof_bits<ElementOutput>::value,
|
| 338 |
+
ElementsPerAccess,
|
| 339 |
+
8
|
| 340 |
+
>;
|
| 341 |
+
|
| 342 |
+
using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 343 |
+
WarpShape,
|
| 344 |
+
InstructionShape,
|
| 345 |
+
int32_t,
|
| 346 |
+
layout::RowMajor
|
| 347 |
+
>;
|
| 348 |
+
|
| 349 |
+
using WarpTileIterator = typename platform::conditional<
|
| 350 |
+
(ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
|
| 351 |
+
WarpTileIteratorNotMixed,
|
| 352 |
+
WarpTileIteratorMixed>::type;
|
| 353 |
+
|
| 354 |
+
using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
| 355 |
+
ThreadMap,
|
| 356 |
+
int32_t,
|
| 357 |
+
32,
|
| 358 |
+
cutlass::sizeof_bits<ElementOutput>::value,
|
| 359 |
+
ElementsPerAccess,
|
| 360 |
+
8
|
| 361 |
+
>;
|
| 362 |
+
|
| 363 |
+
using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 364 |
+
ThreadMap,
|
| 365 |
+
int32_t
|
| 366 |
+
>;
|
| 367 |
+
|
| 368 |
+
using SharedLoadIterator = typename platform::conditional<
|
| 369 |
+
(ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
|
| 370 |
+
SharedLoadIteratorNotMixed,
|
| 371 |
+
SharedLoadIteratorMixed>::type;
|
| 372 |
+
|
| 373 |
+
static int const kFragmentsPerIteration = 1;
|
| 374 |
+
};
|
| 375 |
+
|
| 376 |
+
/// Partial specialization for float_e4m3_t <= float x 16/8 epilogues avoids shared memory bank conflicts.
|
| 377 |
+
/// Threadblock::kN = 256 still has bank conflicts.
|
| 378 |
+
template <
|
| 379 |
+
int ElementsPerAccess,
|
| 380 |
+
typename ThreadblockShape,
|
| 381 |
+
typename WarpShape,
|
| 382 |
+
typename InstructionShape,
|
| 383 |
+
typename ThreadMap
|
| 384 |
+
>
|
| 385 |
+
struct DefaultIteratorsTensorOp<
|
| 386 |
+
cutlass::float_e4m3_t,
|
| 387 |
+
float,
|
| 388 |
+
ElementsPerAccess,
|
| 389 |
+
ThreadblockShape,
|
| 390 |
+
WarpShape,
|
| 391 |
+
InstructionShape,
|
| 392 |
+
ThreadMap> {
|
| 393 |
+
|
| 394 |
+
using ElementOutput = cutlass::float_e4m3_t;
|
| 395 |
+
|
| 396 |
+
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8 || ElementsPerAccess == 4),
|
| 397 |
+
"ElementsPerAccess needs to be 16 or 8.");
|
| 398 |
+
|
| 399 |
+
using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
| 400 |
+
WarpShape,
|
| 401 |
+
InstructionShape,
|
| 402 |
+
float,
|
| 403 |
+
32,
|
| 404 |
+
cutlass::sizeof_bits<ElementOutput>::value,
|
| 405 |
+
ElementsPerAccess,
|
| 406 |
+
8
|
| 407 |
+
>;
|
| 408 |
+
|
| 409 |
+
using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 410 |
+
WarpShape,
|
| 411 |
+
InstructionShape,
|
| 412 |
+
float,
|
| 413 |
+
layout::RowMajor
|
| 414 |
+
>;
|
| 415 |
+
|
| 416 |
+
using WarpTileIterator = typename platform::conditional<
|
| 417 |
+
(ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
|
| 418 |
+
WarpTileIteratorNotMixed,
|
| 419 |
+
WarpTileIteratorMixed>::type;
|
| 420 |
+
|
| 421 |
+
using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
| 422 |
+
ThreadMap,
|
| 423 |
+
float,
|
| 424 |
+
32,
|
| 425 |
+
cutlass::sizeof_bits<ElementOutput>::value,
|
| 426 |
+
ElementsPerAccess,
|
| 427 |
+
8
|
| 428 |
+
>;
|
| 429 |
+
|
| 430 |
+
using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 431 |
+
ThreadMap,
|
| 432 |
+
float
|
| 433 |
+
>;
|
| 434 |
+
|
| 435 |
+
using SharedLoadIterator = typename platform::conditional<
|
| 436 |
+
(ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
|
| 437 |
+
SharedLoadIteratorNotMixed,
|
| 438 |
+
SharedLoadIteratorMixed>::type;
|
| 439 |
+
|
| 440 |
+
static int const kFragmentsPerIteration = 1;
|
| 441 |
+
};
|
| 442 |
+
|
| 443 |
+
/// Partial specialization for float_e5m2_t <= float x 16/8 epilogues avoids shared memory bank conflicts.
|
| 444 |
+
/// Threadblock::kN = 256 still has bank conflicts.
|
| 445 |
+
template <
|
| 446 |
+
int ElementsPerAccess,
|
| 447 |
+
typename ThreadblockShape,
|
| 448 |
+
typename WarpShape,
|
| 449 |
+
typename InstructionShape,
|
| 450 |
+
typename ThreadMap
|
| 451 |
+
>
|
| 452 |
+
struct DefaultIteratorsTensorOp<
|
| 453 |
+
cutlass::float_e5m2_t,
|
| 454 |
+
float,
|
| 455 |
+
ElementsPerAccess,
|
| 456 |
+
ThreadblockShape,
|
| 457 |
+
WarpShape,
|
| 458 |
+
InstructionShape,
|
| 459 |
+
ThreadMap> {
|
| 460 |
+
|
| 461 |
+
using ElementOutput = cutlass::float_e5m2_t;
|
| 462 |
+
|
| 463 |
+
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8 || ElementsPerAccess == 4),
|
| 464 |
+
"ElementsPerAccess needs to be 16 or 8.");
|
| 465 |
+
|
| 466 |
+
using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
| 467 |
+
WarpShape,
|
| 468 |
+
InstructionShape,
|
| 469 |
+
float,
|
| 470 |
+
32,
|
| 471 |
+
cutlass::sizeof_bits<ElementOutput>::value,
|
| 472 |
+
ElementsPerAccess,
|
| 473 |
+
8
|
| 474 |
+
>;
|
| 475 |
+
|
| 476 |
+
using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 477 |
+
WarpShape,
|
| 478 |
+
InstructionShape,
|
| 479 |
+
float,
|
| 480 |
+
layout::RowMajor
|
| 481 |
+
>;
|
| 482 |
+
|
| 483 |
+
using WarpTileIterator = typename platform::conditional<
|
| 484 |
+
(ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
|
| 485 |
+
WarpTileIteratorNotMixed,
|
| 486 |
+
WarpTileIteratorMixed>::type;
|
| 487 |
+
|
| 488 |
+
using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
| 489 |
+
ThreadMap,
|
| 490 |
+
float,
|
| 491 |
+
32,
|
| 492 |
+
cutlass::sizeof_bits<ElementOutput>::value,
|
| 493 |
+
ElementsPerAccess,
|
| 494 |
+
8
|
| 495 |
+
>;
|
| 496 |
+
|
| 497 |
+
using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 498 |
+
ThreadMap,
|
| 499 |
+
float
|
| 500 |
+
>;
|
| 501 |
+
|
| 502 |
+
using SharedLoadIterator = typename platform::conditional<
|
| 503 |
+
(ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
|
| 504 |
+
SharedLoadIteratorNotMixed,
|
| 505 |
+
SharedLoadIteratorMixed>::type;
|
| 506 |
+
|
| 507 |
+
static int const kFragmentsPerIteration = 1;
|
| 508 |
+
};
|
| 509 |
+
|
| 510 |
+
} // namespace detail
|
| 511 |
+
|
| 512 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 513 |
+
|
| 514 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 515 |
+
template <
|
| 516 |
+
typename Shape_,
|
| 517 |
+
typename WarpMmaTensorOp_,
|
| 518 |
+
int PartitionsK,
|
| 519 |
+
typename OutputOp_,
|
| 520 |
+
int ElementsPerAccess,
|
| 521 |
+
bool ScatterD = false,
|
| 522 |
+
typename PermuteDLayout = layout::NoPermute,
|
| 523 |
+
conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity,
|
| 524 |
+
int Rank = 4
|
| 525 |
+
>
|
| 526 |
+
struct DefaultEpilogueTensorOp {
|
| 527 |
+
|
| 528 |
+
using Shape = Shape_;
|
| 529 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 530 |
+
static int const kPartitionsK = PartitionsK;
|
| 531 |
+
using OutputOp = OutputOp_;
|
| 532 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 533 |
+
|
| 534 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 535 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 536 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 537 |
+
static conv::StrideSupport const kStrideSupport = StrideSupport;
|
| 538 |
+
static int const kRank = Rank;
|
| 539 |
+
|
| 540 |
+
//
|
| 541 |
+
// Thread map
|
| 542 |
+
//
|
| 543 |
+
|
| 544 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
| 545 |
+
Shape,
|
| 546 |
+
typename WarpMmaTensorOp::Shape,
|
| 547 |
+
kPartitionsK,
|
| 548 |
+
ElementOutput,
|
| 549 |
+
kElementsPerAccess
|
| 550 |
+
>::Type;
|
| 551 |
+
|
| 552 |
+
static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
|
| 553 |
+
|
| 554 |
+
using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 555 |
+
OutputTileThreadMap,
|
| 556 |
+
ElementOutput,
|
| 557 |
+
ScatterD,
|
| 558 |
+
PermuteDLayout,
|
| 559 |
+
UseCUDAStore
|
| 560 |
+
>;
|
| 561 |
+
|
| 562 |
+
using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv<
|
| 563 |
+
OutputTileThreadMap,
|
| 564 |
+
ElementOutput,
|
| 565 |
+
ScatterD,
|
| 566 |
+
PermuteDLayout,
|
| 567 |
+
UseCUDAStore,
|
| 568 |
+
kRank
|
| 569 |
+
>;
|
| 570 |
+
|
| 571 |
+
using OutputTileIterator = typename platform::conditional<StrideSupport == cutlass::conv::StrideSupport::kUnity,
|
| 572 |
+
PackedOutputTileIterator,
|
| 573 |
+
StridedOutputTileIterator>::type;
|
| 574 |
+
|
| 575 |
+
using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
|
| 576 |
+
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
| 577 |
+
typename WarpMmaTensorOp::Shape,
|
| 578 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 579 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 580 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 581 |
+
LayoutC>,
|
| 582 |
+
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 583 |
+
typename WarpMmaTensorOp::Shape,
|
| 584 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 585 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 586 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 587 |
+
LayoutC> >::type;
|
| 588 |
+
|
| 589 |
+
/// Support several implementations depending on structure of epilogue
|
| 590 |
+
using DefaultIterators = detail::DefaultIteratorsTensorOp<
|
| 591 |
+
ElementOutput,
|
| 592 |
+
ElementAccumulator,
|
| 593 |
+
kElementsPerAccess,
|
| 594 |
+
Shape,
|
| 595 |
+
typename WarpMmaTensorOp::Shape,
|
| 596 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 597 |
+
typename OutputTileThreadMap::CompactedThreadMap
|
| 598 |
+
>;
|
| 599 |
+
|
| 600 |
+
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
| 601 |
+
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
| 602 |
+
|
| 603 |
+
/// Hard-coded padding elements added
|
| 604 |
+
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
|
| 605 |
+
|
| 606 |
+
static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
|
| 607 |
+
|
| 608 |
+
//
|
| 609 |
+
// Define the epilogue
|
| 610 |
+
//
|
| 611 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 612 |
+
Shape,
|
| 613 |
+
WarpMmaTensorOp,
|
| 614 |
+
kPartitionsK,
|
| 615 |
+
OutputTileIterator,
|
| 616 |
+
AccumulatorFragmentIterator,
|
| 617 |
+
WarpTileIterator,
|
| 618 |
+
SharedLoadIterator,
|
| 619 |
+
OutputOp,
|
| 620 |
+
Padding,
|
| 621 |
+
kFragmentsPerIteration
|
| 622 |
+
>;
|
| 623 |
+
};
|
| 624 |
+
|
| 625 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 626 |
+
|
| 627 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 628 |
+
template <
|
| 629 |
+
typename Shape_,
|
| 630 |
+
typename WarpMmaTensorOp_,
|
| 631 |
+
int PartitionsK,
|
| 632 |
+
typename OutputOp_,
|
| 633 |
+
int ElementsPerAccess
|
| 634 |
+
>
|
| 635 |
+
struct DefaultEpilogueTensorOpStridedDgrad {
|
| 636 |
+
|
| 637 |
+
using Shape = Shape_;
|
| 638 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 639 |
+
static int const kPartitionsK = PartitionsK;
|
| 640 |
+
using OutputOp = OutputOp_;
|
| 641 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 642 |
+
|
| 643 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 644 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 645 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 646 |
+
|
| 647 |
+
//
|
| 648 |
+
// Thread map
|
| 649 |
+
//
|
| 650 |
+
|
| 651 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
| 652 |
+
Shape,
|
| 653 |
+
typename WarpMmaTensorOp::Shape,
|
| 654 |
+
kPartitionsK,
|
| 655 |
+
ElementOutput,
|
| 656 |
+
kElementsPerAccess
|
| 657 |
+
>::Type;
|
| 658 |
+
|
| 659 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
|
| 660 |
+
OutputTileThreadMap,
|
| 661 |
+
ElementOutput
|
| 662 |
+
>;
|
| 663 |
+
|
| 664 |
+
using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
|
| 665 |
+
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
| 666 |
+
typename WarpMmaTensorOp::Shape,
|
| 667 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 668 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 669 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 670 |
+
LayoutC>,
|
| 671 |
+
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 672 |
+
typename WarpMmaTensorOp::Shape,
|
| 673 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 674 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 675 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 676 |
+
LayoutC> >::type;
|
| 677 |
+
|
| 678 |
+
/// Support several implementations depending on structure of epilogue
|
| 679 |
+
using DefaultIterators = detail::DefaultIteratorsTensorOp<
|
| 680 |
+
ElementOutput,
|
| 681 |
+
ElementAccumulator,
|
| 682 |
+
kElementsPerAccess,
|
| 683 |
+
Shape,
|
| 684 |
+
typename WarpMmaTensorOp::Shape,
|
| 685 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 686 |
+
typename OutputTileThreadMap::CompactedThreadMap
|
| 687 |
+
>;
|
| 688 |
+
|
| 689 |
+
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
| 690 |
+
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
| 691 |
+
|
| 692 |
+
/// Hard-coded padding elements added
|
| 693 |
+
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
|
| 694 |
+
|
| 695 |
+
static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
|
| 696 |
+
|
| 697 |
+
//
|
| 698 |
+
// Define the epilogue
|
| 699 |
+
//
|
| 700 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 701 |
+
Shape,
|
| 702 |
+
WarpMmaTensorOp,
|
| 703 |
+
kPartitionsK,
|
| 704 |
+
OutputTileIterator,
|
| 705 |
+
AccumulatorFragmentIterator,
|
| 706 |
+
WarpTileIterator,
|
| 707 |
+
SharedLoadIterator,
|
| 708 |
+
OutputOp,
|
| 709 |
+
Padding,
|
| 710 |
+
kFragmentsPerIteration
|
| 711 |
+
>;
|
| 712 |
+
};
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 716 |
+
|
| 717 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 718 |
+
template <
|
| 719 |
+
int Rank,
|
| 720 |
+
typename Shape_,
|
| 721 |
+
typename WarpMmaTensorOp_,
|
| 722 |
+
int PartitionsK,
|
| 723 |
+
typename OutputOp_,
|
| 724 |
+
int ElementsPerAccess
|
| 725 |
+
>
|
| 726 |
+
struct DefaultEpilogueTensorOpAffineRankN {
|
| 727 |
+
|
| 728 |
+
using Shape = Shape_;
|
| 729 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 730 |
+
static int const kPartitionsK = PartitionsK;
|
| 731 |
+
using OutputOp = OutputOp_;
|
| 732 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 733 |
+
|
| 734 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 735 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 736 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 737 |
+
|
| 738 |
+
//
|
| 739 |
+
// Thread map
|
| 740 |
+
//
|
| 741 |
+
|
| 742 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
| 743 |
+
Shape,
|
| 744 |
+
typename WarpMmaTensorOp::Shape,
|
| 745 |
+
kPartitionsK,
|
| 746 |
+
ElementOutput,
|
| 747 |
+
kElementsPerAccess
|
| 748 |
+
>::Type;
|
| 749 |
+
|
| 750 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
|
| 751 |
+
OutputTileThreadMap,
|
| 752 |
+
ElementOutput,
|
| 753 |
+
Rank
|
| 754 |
+
>;
|
| 755 |
+
|
| 756 |
+
// Map to the row major iterator since the iterator selection for affineN is the same.
|
| 757 |
+
using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
|
| 758 |
+
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
| 759 |
+
typename WarpMmaTensorOp::Shape,
|
| 760 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 761 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 762 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 763 |
+
layout::RowMajor>,
|
| 764 |
+
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 765 |
+
typename WarpMmaTensorOp::Shape,
|
| 766 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 767 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 768 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 769 |
+
layout::RowMajor> >::type;
|
| 770 |
+
|
| 771 |
+
/// Support several implementations depending on structure of epilogue
|
| 772 |
+
using DefaultIterators = detail::DefaultIteratorsTensorOp<
|
| 773 |
+
ElementOutput,
|
| 774 |
+
ElementAccumulator,
|
| 775 |
+
kElementsPerAccess,
|
| 776 |
+
Shape,
|
| 777 |
+
typename WarpMmaTensorOp::Shape,
|
| 778 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 779 |
+
typename OutputTileThreadMap::CompactedThreadMap
|
| 780 |
+
>;
|
| 781 |
+
|
| 782 |
+
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
| 783 |
+
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
| 784 |
+
|
| 785 |
+
/// Hard-coded padding elements added
|
| 786 |
+
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
|
| 787 |
+
|
| 788 |
+
static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
|
| 789 |
+
|
| 790 |
+
//
|
| 791 |
+
// Define the epilogue
|
| 792 |
+
//
|
| 793 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 794 |
+
Shape,
|
| 795 |
+
WarpMmaTensorOp,
|
| 796 |
+
kPartitionsK,
|
| 797 |
+
OutputTileIterator,
|
| 798 |
+
AccumulatorFragmentIterator,
|
| 799 |
+
WarpTileIterator,
|
| 800 |
+
SharedLoadIterator,
|
| 801 |
+
OutputOp,
|
| 802 |
+
Padding,
|
| 803 |
+
kFragmentsPerIteration
|
| 804 |
+
>;
|
| 805 |
+
};
|
| 806 |
+
|
| 807 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 808 |
+
/// Defines sensible defaults for epilogues for TensorOps which uses
|
| 809 |
+
/// intereleaved output layout. For this case, shared memory is not needed.
|
| 810 |
+
template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
|
| 811 |
+
typename OutputOp_, int ElementsPerAccess, int InterleavedK,
|
| 812 |
+
bool isSplitK = false>
|
| 813 |
+
struct DefaultInterleavedEpilogueTensorOp {
|
| 814 |
+
using Shape = Shape_;
|
| 815 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 816 |
+
static int const kPartitionsK = PartitionsK;
|
| 817 |
+
using OutputOp = OutputOp_;
|
| 818 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 819 |
+
|
| 820 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 821 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 822 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 823 |
+
|
| 824 |
+
//
|
| 825 |
+
// Thread map
|
| 826 |
+
//
|
| 827 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::
|
| 828 |
+
DefaultInterleavedThreadMapTensorOp<
|
| 829 |
+
Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput,
|
| 830 |
+
kElementsPerAccess, InterleavedK>::Type;
|
| 831 |
+
|
| 832 |
+
using OutputTileIterator =
|
| 833 |
+
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator<
|
| 834 |
+
OutputTileThreadMap, ElementOutput, InterleavedK>;
|
| 835 |
+
|
| 836 |
+
using AccumulatorFragmentIterator =
|
| 837 |
+
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 838 |
+
typename WarpMmaTensorOp::Shape,
|
| 839 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 840 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 841 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 842 |
+
LayoutC>;
|
| 843 |
+
|
| 844 |
+
//
|
| 845 |
+
// Define the epilogue
|
| 846 |
+
//
|
| 847 |
+
using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue<
|
| 848 |
+
Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator,
|
| 849 |
+
AccumulatorFragmentIterator, OutputOp, InterleavedK>;
|
| 850 |
+
};
|
| 851 |
+
|
| 852 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 853 |
+
|
| 854 |
+
/// Defines sensible defaults for epilogues for TensorOps which uses
|
| 855 |
+
/// intereleaved output layout. For this case, shared memory is not needed.
|
| 856 |
+
template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
|
| 857 |
+
typename OutputOp_, int ElementsPerAccess, int InterleavedK,
|
| 858 |
+
bool isSplitK = false>
|
| 859 |
+
struct DefaultInterleavedConvEpilogue {
|
| 860 |
+
using Shape = Shape_;
|
| 861 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 862 |
+
static int const kPartitionsK = PartitionsK;
|
| 863 |
+
using OutputOp = OutputOp_;
|
| 864 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 865 |
+
|
| 866 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 867 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 868 |
+
|
| 869 |
+
//
|
| 870 |
+
// Thread map
|
| 871 |
+
//
|
| 872 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::
|
| 873 |
+
DefaultInterleavedConvThreadMapTensorOp<
|
| 874 |
+
Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput,
|
| 875 |
+
kElementsPerAccess, InterleavedK>::Type;
|
| 876 |
+
|
| 877 |
+
using OutputTileIterator =
|
| 878 |
+
cutlass::epilogue::threadblock::InterleavedConvPredicatedTileIterator<
|
| 879 |
+
OutputTileThreadMap, ElementOutput, InterleavedK>;
|
| 880 |
+
|
| 881 |
+
using AccumulatorFragmentIterator =
|
| 882 |
+
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 883 |
+
typename WarpMmaTensorOp::Shape,
|
| 884 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 885 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 886 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 887 |
+
// can reuse the gemm version here to do element selection
|
| 888 |
+
layout::ColumnMajorInterleaved<InterleavedK>>;
|
| 889 |
+
|
| 890 |
+
//
|
| 891 |
+
// Define the epilogue
|
| 892 |
+
//
|
| 893 |
+
using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue<
|
| 894 |
+
Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator,
|
| 895 |
+
AccumulatorFragmentIterator, OutputOp, InterleavedK>;
|
| 896 |
+
};
|
| 897 |
+
|
| 898 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 899 |
+
|
| 900 |
+
} // namespace threadblock
|
| 901 |
+
} // namespace epilogue
|
| 902 |
+
} // namespace cutlass
|
| 903 |
+
|
| 904 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
*/
|
| 39 |
+
|
| 40 |
+
#pragma once
|
| 41 |
+
|
| 42 |
+
#include "cutlass/cutlass.h"
|
| 43 |
+
#include "cutlass/numeric_types.h"
|
| 44 |
+
#include "cutlass/array.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/gemm/gemm.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 49 |
+
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 51 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
| 52 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
| 53 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 54 |
+
|
| 55 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 56 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 57 |
+
|
| 58 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 59 |
+
|
| 60 |
+
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
| 61 |
+
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
| 62 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
| 63 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
| 64 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
| 65 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h"
|
| 66 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 67 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
| 68 |
+
|
| 69 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 70 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 71 |
+
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
| 72 |
+
|
| 73 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 74 |
+
|
| 75 |
+
namespace cutlass {
|
| 76 |
+
namespace epilogue {
|
| 77 |
+
namespace threadblock {
|
| 78 |
+
|
| 79 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 80 |
+
|
| 81 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 82 |
+
template <
|
| 83 |
+
typename Shape_,
|
| 84 |
+
typename WarpMmaTensorOp_,
|
| 85 |
+
int PartitionsK,
|
| 86 |
+
typename OutputOp_,
|
| 87 |
+
int ElementsPerAccess,
|
| 88 |
+
/// Is for a symmetric kernel
|
| 89 |
+
BlasMode BlasMode_ = BlasMode::kGemm
|
| 90 |
+
>
|
| 91 |
+
struct DefaultEpilogueTensorOpBlas3 {
|
| 92 |
+
|
| 93 |
+
using Shape = Shape_;
|
| 94 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 95 |
+
static int const kPartitionsK = PartitionsK;
|
| 96 |
+
using OutputOp = OutputOp_;
|
| 97 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 98 |
+
static BlasMode const kBlasMode = BlasMode_;
|
| 99 |
+
|
| 100 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 101 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 102 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 103 |
+
|
| 104 |
+
//
|
| 105 |
+
// Thread map
|
| 106 |
+
//
|
| 107 |
+
|
| 108 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
| 109 |
+
Shape,
|
| 110 |
+
typename WarpMmaTensorOp::Shape,
|
| 111 |
+
kPartitionsK,
|
| 112 |
+
ElementOutput,
|
| 113 |
+
kElementsPerAccess
|
| 114 |
+
>::Type;
|
| 115 |
+
|
| 116 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3<
|
| 117 |
+
OutputTileThreadMap,
|
| 118 |
+
ElementOutput,
|
| 119 |
+
kBlasMode
|
| 120 |
+
>;
|
| 121 |
+
|
| 122 |
+
using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
|
| 123 |
+
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
| 124 |
+
typename WarpMmaTensorOp::Shape,
|
| 125 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 126 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 127 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 128 |
+
LayoutC>,
|
| 129 |
+
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 130 |
+
typename WarpMmaTensorOp::Shape,
|
| 131 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 132 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 133 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 134 |
+
LayoutC> >::type;
|
| 135 |
+
|
| 136 |
+
/// Support several implementations depending on structure of epilogue
|
| 137 |
+
using DefaultIterators = detail::DefaultIteratorsTensorOp<
|
| 138 |
+
ElementOutput,
|
| 139 |
+
ElementAccumulator,
|
| 140 |
+
kElementsPerAccess,
|
| 141 |
+
Shape,
|
| 142 |
+
typename WarpMmaTensorOp::Shape,
|
| 143 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 144 |
+
typename OutputTileThreadMap::CompactedThreadMap
|
| 145 |
+
>;
|
| 146 |
+
|
| 147 |
+
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
| 148 |
+
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
| 149 |
+
|
| 150 |
+
/// Hard-coded padding elements added
|
| 151 |
+
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
|
| 152 |
+
|
| 153 |
+
//
|
| 154 |
+
// Define the epilogue
|
| 155 |
+
//
|
| 156 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 157 |
+
Shape,
|
| 158 |
+
WarpMmaTensorOp,
|
| 159 |
+
kPartitionsK,
|
| 160 |
+
OutputTileIterator,
|
| 161 |
+
AccumulatorFragmentIterator,
|
| 162 |
+
WarpTileIterator,
|
| 163 |
+
SharedLoadIterator,
|
| 164 |
+
OutputOp,
|
| 165 |
+
Padding
|
| 166 |
+
>;
|
| 167 |
+
};
|
| 168 |
+
|
| 169 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 170 |
+
|
| 171 |
+
} // namespace threadblock
|
| 172 |
+
} // namespace epilogue
|
| 173 |
+
} // namespace cutlass
|
| 174 |
+
|
| 175 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops on Volta.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/gemm/gemm.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 48 |
+
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
| 49 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
| 51 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
| 52 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 53 |
+
|
| 54 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 55 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 56 |
+
|
| 57 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 58 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
| 59 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 60 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
| 61 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 62 |
+
|
| 63 |
+
#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h"
|
| 64 |
+
#include "cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h"
|
| 65 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h"
|
| 66 |
+
|
| 67 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 68 |
+
|
| 69 |
+
#include "cutlass/layout/permute.h"
|
| 70 |
+
|
| 71 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 72 |
+
|
| 73 |
+
namespace cutlass {
|
| 74 |
+
namespace epilogue {
|
| 75 |
+
namespace threadblock {
|
| 76 |
+
|
| 77 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 78 |
+
|
| 79 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 80 |
+
template <
|
| 81 |
+
typename Shape_,
|
| 82 |
+
typename WarpMmaTensorOp_,
|
| 83 |
+
int PartitionsK,
|
| 84 |
+
typename OutputOp_,
|
| 85 |
+
int ElementsPerAccess,
|
| 86 |
+
bool ScatterD = false,
|
| 87 |
+
typename PermuteDLayout = layout::NoPermute
|
| 88 |
+
>
|
| 89 |
+
struct DefaultEpilogueVoltaTensorOp {
|
| 90 |
+
|
| 91 |
+
using Shape = Shape_;
|
| 92 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 93 |
+
static int const kPartitionsK = PartitionsK;
|
| 94 |
+
using OutputOp = OutputOp_;
|
| 95 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 96 |
+
|
| 97 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 98 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 99 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 100 |
+
|
| 101 |
+
//
|
| 102 |
+
// Thread map
|
| 103 |
+
//
|
| 104 |
+
|
| 105 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp<
|
| 106 |
+
Shape,
|
| 107 |
+
typename WarpMmaTensorOp::Shape,
|
| 108 |
+
kPartitionsK,
|
| 109 |
+
ElementOutput,
|
| 110 |
+
kElementsPerAccess,
|
| 111 |
+
ElementAccumulator
|
| 112 |
+
>::Type;
|
| 113 |
+
|
| 114 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 115 |
+
OutputTileThreadMap,
|
| 116 |
+
ElementOutput,
|
| 117 |
+
ScatterD,
|
| 118 |
+
PermuteDLayout
|
| 119 |
+
>;
|
| 120 |
+
|
| 121 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
|
| 122 |
+
typename WarpMmaTensorOp::Shape,
|
| 123 |
+
gemm::GemmShape<32, 32, 4>,
|
| 124 |
+
ElementAccumulator,
|
| 125 |
+
LayoutC
|
| 126 |
+
>;
|
| 127 |
+
|
| 128 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
|
| 129 |
+
typename WarpMmaTensorOp::Shape,
|
| 130 |
+
gemm::GemmShape<32, 32, 4>,
|
| 131 |
+
ElementAccumulator,
|
| 132 |
+
LayoutC
|
| 133 |
+
>;
|
| 134 |
+
|
| 135 |
+
static int const kSharedMemAlignment = sizeof_bits<ElementAccumulator>::value * WarpTileIterator::kElementsPerAccess / 8;
|
| 136 |
+
|
| 137 |
+
static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B");
|
| 138 |
+
|
| 139 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 140 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 141 |
+
ElementAccumulator,
|
| 142 |
+
kSharedMemAlignment
|
| 143 |
+
>;
|
| 144 |
+
|
| 145 |
+
/// Hard-coded padding elements added
|
| 146 |
+
using Padding = typename WarpTileIterator::Padding;
|
| 147 |
+
|
| 148 |
+
//
|
| 149 |
+
// Define the epilogue
|
| 150 |
+
//
|
| 151 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 152 |
+
Shape,
|
| 153 |
+
WarpMmaTensorOp,
|
| 154 |
+
kPartitionsK,
|
| 155 |
+
OutputTileIterator,
|
| 156 |
+
AccumulatorFragmentIterator,
|
| 157 |
+
WarpTileIterator,
|
| 158 |
+
SharedLoadIterator,
|
| 159 |
+
OutputOp,
|
| 160 |
+
Padding
|
| 161 |
+
>;
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 165 |
+
|
| 166 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 167 |
+
template <
|
| 168 |
+
typename Shape_,
|
| 169 |
+
typename WarpMmaTensorOp_,
|
| 170 |
+
int PartitionsK,
|
| 171 |
+
typename OutputOp_,
|
| 172 |
+
int ElementsPerAccess
|
| 173 |
+
>
|
| 174 |
+
struct DefaultEpilogueVoltaTensorOpStridedDgrad {
|
| 175 |
+
|
| 176 |
+
using Shape = Shape_;
|
| 177 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 178 |
+
static int const kPartitionsK = PartitionsK;
|
| 179 |
+
using OutputOp = OutputOp_;
|
| 180 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 181 |
+
|
| 182 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 183 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 184 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 185 |
+
|
| 186 |
+
//
|
| 187 |
+
// Thread map
|
| 188 |
+
//
|
| 189 |
+
|
| 190 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp<
|
| 191 |
+
Shape,
|
| 192 |
+
typename WarpMmaTensorOp::Shape,
|
| 193 |
+
kPartitionsK,
|
| 194 |
+
ElementOutput,
|
| 195 |
+
kElementsPerAccess,
|
| 196 |
+
ElementAccumulator
|
| 197 |
+
>::Type;
|
| 198 |
+
|
| 199 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
|
| 200 |
+
OutputTileThreadMap,
|
| 201 |
+
ElementOutput
|
| 202 |
+
>;
|
| 203 |
+
|
| 204 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
|
| 205 |
+
typename WarpMmaTensorOp::Shape,
|
| 206 |
+
gemm::GemmShape<32, 32, 4>,
|
| 207 |
+
ElementAccumulator,
|
| 208 |
+
LayoutC
|
| 209 |
+
>;
|
| 210 |
+
|
| 211 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
|
| 212 |
+
typename WarpMmaTensorOp::Shape,
|
| 213 |
+
gemm::GemmShape<32, 32, 4>,
|
| 214 |
+
ElementAccumulator,
|
| 215 |
+
LayoutC
|
| 216 |
+
>;
|
| 217 |
+
|
| 218 |
+
static int const kSharedMemAlignment = sizeof_bits<ElementAccumulator>::value * WarpTileIterator::kElementsPerAccess / 8;
|
| 219 |
+
|
| 220 |
+
static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B");
|
| 221 |
+
|
| 222 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 223 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 224 |
+
ElementAccumulator,
|
| 225 |
+
kSharedMemAlignment
|
| 226 |
+
>;
|
| 227 |
+
|
| 228 |
+
/// Hard-coded padding elements added
|
| 229 |
+
using Padding = typename WarpTileIterator::Padding;
|
| 230 |
+
|
| 231 |
+
//
|
| 232 |
+
// Define the epilogue
|
| 233 |
+
//
|
| 234 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 235 |
+
Shape,
|
| 236 |
+
WarpMmaTensorOp,
|
| 237 |
+
kPartitionsK,
|
| 238 |
+
OutputTileIterator,
|
| 239 |
+
AccumulatorFragmentIterator,
|
| 240 |
+
WarpTileIterator,
|
| 241 |
+
SharedLoadIterator,
|
| 242 |
+
OutputOp,
|
| 243 |
+
Padding
|
| 244 |
+
>;
|
| 245 |
+
};
|
| 246 |
+
|
| 247 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 248 |
+
|
| 249 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 250 |
+
template <
|
| 251 |
+
int Rank,
|
| 252 |
+
typename Shape_,
|
| 253 |
+
typename WarpMmaTensorOp_,
|
| 254 |
+
int PartitionsK,
|
| 255 |
+
typename OutputOp_,
|
| 256 |
+
int ElementsPerAccess
|
| 257 |
+
>
|
| 258 |
+
struct DefaultEpilogueVoltaTensorOpAffineRankN {
|
| 259 |
+
|
| 260 |
+
using Shape = Shape_;
|
| 261 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 262 |
+
static int const kPartitionsK = PartitionsK;
|
| 263 |
+
using OutputOp = OutputOp_;
|
| 264 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 265 |
+
|
| 266 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 267 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 268 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 269 |
+
|
| 270 |
+
//
|
| 271 |
+
// Thread map
|
| 272 |
+
//
|
| 273 |
+
|
| 274 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp<
|
| 275 |
+
Shape,
|
| 276 |
+
typename WarpMmaTensorOp::Shape,
|
| 277 |
+
kPartitionsK,
|
| 278 |
+
ElementOutput,
|
| 279 |
+
kElementsPerAccess,
|
| 280 |
+
ElementAccumulator
|
| 281 |
+
>::Type;
|
| 282 |
+
|
| 283 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
|
| 284 |
+
OutputTileThreadMap,
|
| 285 |
+
ElementOutput,
|
| 286 |
+
Rank
|
| 287 |
+
>;
|
| 288 |
+
|
| 289 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
|
| 290 |
+
typename WarpMmaTensorOp::Shape,
|
| 291 |
+
gemm::GemmShape<32, 32, 4>,
|
| 292 |
+
ElementAccumulator,
|
| 293 |
+
LayoutC
|
| 294 |
+
>;
|
| 295 |
+
|
| 296 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
|
| 297 |
+
typename WarpMmaTensorOp::Shape,
|
| 298 |
+
gemm::GemmShape<32, 32, 4>,
|
| 299 |
+
ElementAccumulator,
|
| 300 |
+
LayoutC
|
| 301 |
+
>;
|
| 302 |
+
|
| 303 |
+
static int const kSharedMemAlignment = sizeof_bits<ElementAccumulator>::value * WarpTileIterator::kElementsPerAccess / 8;
|
| 304 |
+
|
| 305 |
+
static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B");
|
| 306 |
+
|
| 307 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 308 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 309 |
+
ElementAccumulator,
|
| 310 |
+
kSharedMemAlignment
|
| 311 |
+
>;
|
| 312 |
+
|
| 313 |
+
/// Hard-coded padding elements added
|
| 314 |
+
using Padding = typename WarpTileIterator::Padding;
|
| 315 |
+
|
| 316 |
+
//
|
| 317 |
+
// Define the epilogue
|
| 318 |
+
//
|
| 319 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 320 |
+
Shape,
|
| 321 |
+
WarpMmaTensorOp,
|
| 322 |
+
kPartitionsK,
|
| 323 |
+
OutputTileIterator,
|
| 324 |
+
AccumulatorFragmentIterator,
|
| 325 |
+
WarpTileIterator,
|
| 326 |
+
SharedLoadIterator,
|
| 327 |
+
OutputOp,
|
| 328 |
+
Padding
|
| 329 |
+
>;
|
| 330 |
+
};
|
| 331 |
+
|
| 332 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 333 |
+
} // namespace threadblock
|
| 334 |
+
} // namespace epilogue
|
| 335 |
+
} // namespace cutlass
|
| 336 |
+
|
| 337 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_absmax.h
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Default configuration for epilogue computing absolute maximum of output and auxiliary outputs.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
|
| 42 |
+
#include "cutlass/gemm/gemm.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 45 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
| 46 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 47 |
+
#include "cutlass/epilogue/threadblock/epilogue_with_absmax.h"
|
| 48 |
+
|
| 49 |
+
#include "cutlass/layout/permute.h"
|
| 50 |
+
|
| 51 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
namespace cutlass {
|
| 54 |
+
namespace epilogue {
|
| 55 |
+
namespace threadblock {
|
| 56 |
+
|
| 57 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
|
| 59 |
+
/// Defines sensible defaults for absolute-maximum-computing epilogues with TensorOps
|
| 60 |
+
template <
|
| 61 |
+
typename Shape,
|
| 62 |
+
typename WarpMmaTensorOp,
|
| 63 |
+
int PartitionsK,
|
| 64 |
+
typename ElementOutput,
|
| 65 |
+
typename ElementAuxOutput,
|
| 66 |
+
typename ElementVector,
|
| 67 |
+
typename OutputOp,
|
| 68 |
+
int ElementsPerAccess,
|
| 69 |
+
bool ScatterD = false,
|
| 70 |
+
typename PermuteDLayout = layout::NoPermute
|
| 71 |
+
>
|
| 72 |
+
struct DefaultEpilogueWithAbsMax {
|
| 73 |
+
|
| 74 |
+
/// Use defaults related to the existing epilogue
|
| 75 |
+
using Base = DefaultEpilogueTensorOp<
|
| 76 |
+
Shape,
|
| 77 |
+
WarpMmaTensorOp,
|
| 78 |
+
PartitionsK,
|
| 79 |
+
OutputOp,
|
| 80 |
+
ElementsPerAccess
|
| 81 |
+
>;
|
| 82 |
+
|
| 83 |
+
//
|
| 84 |
+
// Stores the output
|
| 85 |
+
//
|
| 86 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 87 |
+
typename Base::OutputTileThreadMap,
|
| 88 |
+
ElementOutput,
|
| 89 |
+
ScatterD,
|
| 90 |
+
PermuteDLayout
|
| 91 |
+
>;
|
| 92 |
+
|
| 93 |
+
//
|
| 94 |
+
// Stores the auxiliary output
|
| 95 |
+
//
|
| 96 |
+
using AuxOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 97 |
+
typename Base::OutputTileThreadMap,
|
| 98 |
+
ElementAuxOutput,
|
| 99 |
+
ScatterD,
|
| 100 |
+
PermuteDLayout
|
| 101 |
+
>;
|
| 102 |
+
|
| 103 |
+
/// Define the epilogue
|
| 104 |
+
using Epilogue = EpilogueWithAbsMax<
|
| 105 |
+
Shape,
|
| 106 |
+
WarpMmaTensorOp,
|
| 107 |
+
PartitionsK,
|
| 108 |
+
OutputTileIterator,
|
| 109 |
+
AuxOutputTileIterator,
|
| 110 |
+
ElementVector,
|
| 111 |
+
typename Base::AccumulatorFragmentIterator,
|
| 112 |
+
typename Base::WarpTileIterator,
|
| 113 |
+
typename Base::SharedLoadIterator,
|
| 114 |
+
OutputOp,
|
| 115 |
+
typename Base::Padding,
|
| 116 |
+
Base::kFragmentsPerIteration
|
| 117 |
+
>;
|
| 118 |
+
};
|
| 119 |
+
|
| 120 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 121 |
+
|
| 122 |
+
} // namespace threadblock
|
| 123 |
+
} // namespace epilogue
|
| 124 |
+
} // namespace cutlass
|
| 125 |
+
|
| 126 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/gemm/gemm.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 48 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
| 49 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 50 |
+
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
|
| 51 |
+
#include "cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h"
|
| 52 |
+
|
| 53 |
+
#include "cutlass/layout/permute.h"
|
| 54 |
+
|
| 55 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
namespace cutlass {
|
| 58 |
+
namespace epilogue {
|
| 59 |
+
namespace threadblock {
|
| 60 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 61 |
+
|
| 62 |
+
/// Defines sensible defaults for epilogues for SimtOps.
|
| 63 |
+
template <
|
| 64 |
+
typename Shape,
|
| 65 |
+
typename WarpMmaSimt,
|
| 66 |
+
typename ElementOutput,
|
| 67 |
+
typename ElementTensor,
|
| 68 |
+
typename ElementVector,
|
| 69 |
+
typename OutputOp,
|
| 70 |
+
int ElementsPerAccess,
|
| 71 |
+
bool ScatterD = false,
|
| 72 |
+
typename PermuteDLayout = layout::NoPermute,
|
| 73 |
+
conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity,
|
| 74 |
+
int Rank = 4
|
| 75 |
+
>
|
| 76 |
+
struct DefaultEpilogueWithBroadcastSimt {
|
| 77 |
+
|
| 78 |
+
static conv::StrideSupport const kStrideSupport = StrideSupport;
|
| 79 |
+
static int const kRank = Rank;
|
| 80 |
+
|
| 81 |
+
static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
|
| 82 |
+
|
| 83 |
+
/// Use defaults related to the existing epilogue
|
| 84 |
+
using Base = DefaultEpilogueSimt<
|
| 85 |
+
Shape,
|
| 86 |
+
WarpMmaSimt,
|
| 87 |
+
OutputOp,
|
| 88 |
+
ElementsPerAccess
|
| 89 |
+
>;
|
| 90 |
+
|
| 91 |
+
using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 92 |
+
typename Base::OutputTileThreadMap,
|
| 93 |
+
ElementOutput,
|
| 94 |
+
ScatterD,
|
| 95 |
+
PermuteDLayout,
|
| 96 |
+
UseCUDAStore
|
| 97 |
+
>;
|
| 98 |
+
|
| 99 |
+
using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv<
|
| 100 |
+
typename Base::OutputTileThreadMap,
|
| 101 |
+
ElementOutput,
|
| 102 |
+
ScatterD,
|
| 103 |
+
PermuteDLayout,
|
| 104 |
+
UseCUDAStore,
|
| 105 |
+
kRank
|
| 106 |
+
>;
|
| 107 |
+
|
| 108 |
+
//
|
| 109 |
+
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
| 110 |
+
//
|
| 111 |
+
using OutputTileIterator = typename platform::conditional<StrideSupport == cutlass::conv::StrideSupport::kUnity,
|
| 112 |
+
PackedOutputTileIterator,
|
| 113 |
+
StridedOutputTileIterator>::type;
|
| 114 |
+
|
| 115 |
+
//
|
| 116 |
+
// Additional tensor tile iterator - stores t = Elementwise(z)
|
| 117 |
+
//
|
| 118 |
+
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 119 |
+
typename Base::OutputTileThreadMap,
|
| 120 |
+
ElementTensor
|
| 121 |
+
>;
|
| 122 |
+
/// Define the epilogue
|
| 123 |
+
using Epilogue = EpilogueWithBroadcast<
|
| 124 |
+
Shape,
|
| 125 |
+
WarpMmaSimt,
|
| 126 |
+
Base::kPartitionsK,
|
| 127 |
+
OutputTileIterator,
|
| 128 |
+
TensorTileIterator,
|
| 129 |
+
ElementVector,
|
| 130 |
+
typename Base::AccumulatorFragmentIterator,
|
| 131 |
+
typename Base::WarpTileIterator,
|
| 132 |
+
typename Base::SharedLoadIterator,
|
| 133 |
+
OutputOp,
|
| 134 |
+
typename Base::Padding
|
| 135 |
+
>;
|
| 136 |
+
};
|
| 137 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 138 |
+
|
| 139 |
+
/// Defines sensible defaults for strided dgrad epilogues for SimtOps.
|
| 140 |
+
template <
|
| 141 |
+
typename Shape,
|
| 142 |
+
typename WarpMmaSimt,
|
| 143 |
+
typename ElementOutput,
|
| 144 |
+
typename ElementTensor,
|
| 145 |
+
typename ElementVector,
|
| 146 |
+
typename OutputOp,
|
| 147 |
+
int ElementsPerAccess,
|
| 148 |
+
bool ScatterD = false,
|
| 149 |
+
typename PermuteDLayout = layout::NoPermute
|
| 150 |
+
>
|
| 151 |
+
struct DefaultEpilogueWithBroadcastSimtStridedDgrad {
|
| 152 |
+
|
| 153 |
+
/// Use defaults related to the existing epilogue
|
| 154 |
+
using Base = DefaultEpilogueSimtStridedDgrad<
|
| 155 |
+
Shape,
|
| 156 |
+
WarpMmaSimt,
|
| 157 |
+
OutputOp,
|
| 158 |
+
ElementsPerAccess
|
| 159 |
+
>;
|
| 160 |
+
|
| 161 |
+
//
|
| 162 |
+
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
| 163 |
+
//
|
| 164 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
|
| 165 |
+
typename Base::OutputTileThreadMap,
|
| 166 |
+
ElementOutput
|
| 167 |
+
>;
|
| 168 |
+
|
| 169 |
+
//
|
| 170 |
+
// Additional tensor tile iterator - stores t = Elementwise(z)
|
| 171 |
+
//
|
| 172 |
+
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
|
| 173 |
+
typename Base::OutputTileThreadMap,
|
| 174 |
+
ElementTensor
|
| 175 |
+
>;
|
| 176 |
+
|
| 177 |
+
/// Define the epilogue
|
| 178 |
+
using Epilogue = EpilogueWithBroadcast<
|
| 179 |
+
Shape,
|
| 180 |
+
WarpMmaSimt,
|
| 181 |
+
Base::kPartitionsK,
|
| 182 |
+
OutputTileIterator,
|
| 183 |
+
TensorTileIterator,
|
| 184 |
+
ElementVector,
|
| 185 |
+
typename Base::AccumulatorFragmentIterator,
|
| 186 |
+
typename Base::WarpTileIterator,
|
| 187 |
+
typename Base::SharedLoadIterator,
|
| 188 |
+
OutputOp,
|
| 189 |
+
typename Base::Padding
|
| 190 |
+
>;
|
| 191 |
+
};
|
| 192 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 193 |
+
|
| 194 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 195 |
+
template <
|
| 196 |
+
typename Shape,
|
| 197 |
+
typename WarpMmaTensorOp,
|
| 198 |
+
int PartitionsK,
|
| 199 |
+
typename ElementOutput,
|
| 200 |
+
typename ElementTensor,
|
| 201 |
+
typename ElementVector,
|
| 202 |
+
typename OutputOp,
|
| 203 |
+
int ElementsPerAccess,
|
| 204 |
+
bool ScatterD = false,
|
| 205 |
+
typename PermuteDLayout = layout::NoPermute
|
| 206 |
+
>
|
| 207 |
+
struct DefaultEpilogueWithBroadcastTensorOp {
|
| 208 |
+
|
| 209 |
+
/// Use defaults related to the existing epilogue
|
| 210 |
+
using Base = DefaultEpilogueTensorOp<
|
| 211 |
+
Shape,
|
| 212 |
+
WarpMmaTensorOp,
|
| 213 |
+
PartitionsK,
|
| 214 |
+
OutputOp,
|
| 215 |
+
ElementsPerAccess
|
| 216 |
+
>;
|
| 217 |
+
|
| 218 |
+
//
|
| 219 |
+
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
| 220 |
+
//
|
| 221 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 222 |
+
typename Base::OutputTileThreadMap,
|
| 223 |
+
ElementOutput,
|
| 224 |
+
ScatterD,
|
| 225 |
+
PermuteDLayout
|
| 226 |
+
>;
|
| 227 |
+
|
| 228 |
+
//
|
| 229 |
+
// Additional tensor tile iterator - stores t = Elementwise(z)
|
| 230 |
+
//
|
| 231 |
+
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 232 |
+
typename Base::OutputTileThreadMap,
|
| 233 |
+
ElementTensor
|
| 234 |
+
>;
|
| 235 |
+
|
| 236 |
+
/// Define the epilogue
|
| 237 |
+
using Epilogue = EpilogueWithBroadcast<
|
| 238 |
+
Shape,
|
| 239 |
+
WarpMmaTensorOp,
|
| 240 |
+
PartitionsK,
|
| 241 |
+
OutputTileIterator,
|
| 242 |
+
TensorTileIterator,
|
| 243 |
+
ElementVector,
|
| 244 |
+
typename Base::AccumulatorFragmentIterator,
|
| 245 |
+
typename Base::WarpTileIterator,
|
| 246 |
+
typename Base::SharedLoadIterator,
|
| 247 |
+
OutputOp,
|
| 248 |
+
typename Base::Padding,
|
| 249 |
+
Base::kFragmentsPerIteration
|
| 250 |
+
>;
|
| 251 |
+
};
|
| 252 |
+
|
| 253 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 254 |
+
|
| 255 |
+
/// Defines sensible defaults for streamk epilogues for TensorOps.
|
| 256 |
+
template <
|
| 257 |
+
typename Shape,
|
| 258 |
+
typename WarpMmaTensorOp,
|
| 259 |
+
int PartitionsK,
|
| 260 |
+
typename ElementOutput,
|
| 261 |
+
typename ElementTensor,
|
| 262 |
+
typename ElementVector,
|
| 263 |
+
typename OutputOp,
|
| 264 |
+
int ElementsPerAccess,
|
| 265 |
+
bool ScatterD = false,
|
| 266 |
+
typename PermuteDLayout = layout::NoPermute
|
| 267 |
+
>
|
| 268 |
+
struct DefaultStreamkEpilogueWithBroadcastTensorOp {
|
| 269 |
+
|
| 270 |
+
/// Use defaults related to the existing epilogue
|
| 271 |
+
using Base = DefaultEpilogueTensorOp<
|
| 272 |
+
Shape,
|
| 273 |
+
WarpMmaTensorOp,
|
| 274 |
+
PartitionsK,
|
| 275 |
+
OutputOp,
|
| 276 |
+
ElementsPerAccess
|
| 277 |
+
>;
|
| 278 |
+
|
| 279 |
+
//
|
| 280 |
+
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
| 281 |
+
//
|
| 282 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 283 |
+
typename Base::OutputTileThreadMap,
|
| 284 |
+
ElementOutput,
|
| 285 |
+
ScatterD,
|
| 286 |
+
PermuteDLayout
|
| 287 |
+
>;
|
| 288 |
+
|
| 289 |
+
//
|
| 290 |
+
// Additional tensor tile iterator - stores t = Elementwise(z)
|
| 291 |
+
//
|
| 292 |
+
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 293 |
+
typename Base::OutputTileThreadMap,
|
| 294 |
+
ElementTensor
|
| 295 |
+
>;
|
| 296 |
+
|
| 297 |
+
/// Define the epilogue
|
| 298 |
+
using Epilogue = EpilogueStreamkWithBroadcast<
|
| 299 |
+
Shape,
|
| 300 |
+
WarpMmaTensorOp,
|
| 301 |
+
PartitionsK,
|
| 302 |
+
OutputTileIterator,
|
| 303 |
+
TensorTileIterator,
|
| 304 |
+
ElementVector,
|
| 305 |
+
typename Base::AccumulatorFragmentIterator,
|
| 306 |
+
typename Base::WarpTileIterator,
|
| 307 |
+
typename Base::SharedLoadIterator,
|
| 308 |
+
OutputOp,
|
| 309 |
+
typename Base::Padding,
|
| 310 |
+
Base::kFragmentsPerIteration
|
| 311 |
+
>;
|
| 312 |
+
};
|
| 313 |
+
|
| 314 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 315 |
+
|
| 316 |
+
/// Defines sensible defaults for epilogues for VoltaTensorOps.
|
| 317 |
+
template <
|
| 318 |
+
typename Shape,
|
| 319 |
+
typename WarpMmaTensorOp,
|
| 320 |
+
int PartitionsK,
|
| 321 |
+
typename ElementOutput,
|
| 322 |
+
typename ElementTensor,
|
| 323 |
+
typename ElementVector,
|
| 324 |
+
typename OutputOp,
|
| 325 |
+
int ElementsPerAccess
|
| 326 |
+
>
|
| 327 |
+
struct DefaultEpilogueWithBroadcastVoltaTensorOp {
|
| 328 |
+
|
| 329 |
+
/// Use defaults related to the existing epilogue
|
| 330 |
+
using Base = DefaultEpilogueVoltaTensorOp<
|
| 331 |
+
Shape,
|
| 332 |
+
WarpMmaTensorOp,
|
| 333 |
+
PartitionsK,
|
| 334 |
+
OutputOp,
|
| 335 |
+
ElementsPerAccess
|
| 336 |
+
>;
|
| 337 |
+
|
| 338 |
+
//
|
| 339 |
+
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
| 340 |
+
//
|
| 341 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 342 |
+
typename Base::OutputTileThreadMap,
|
| 343 |
+
ElementOutput
|
| 344 |
+
>;
|
| 345 |
+
|
| 346 |
+
//
|
| 347 |
+
// Additional tensor tile iterator - stores t = Elementwise(z)
|
| 348 |
+
//
|
| 349 |
+
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 350 |
+
typename Base::OutputTileThreadMap,
|
| 351 |
+
ElementTensor
|
| 352 |
+
>;
|
| 353 |
+
|
| 354 |
+
/// Define the epilogue
|
| 355 |
+
using Epilogue = EpilogueWithBroadcast<
|
| 356 |
+
Shape,
|
| 357 |
+
WarpMmaTensorOp,
|
| 358 |
+
PartitionsK,
|
| 359 |
+
OutputTileIterator,
|
| 360 |
+
TensorTileIterator,
|
| 361 |
+
ElementVector,
|
| 362 |
+
typename Base::AccumulatorFragmentIterator,
|
| 363 |
+
typename Base::WarpTileIterator,
|
| 364 |
+
typename Base::SharedLoadIterator,
|
| 365 |
+
OutputOp,
|
| 366 |
+
typename Base::Padding
|
| 367 |
+
>;
|
| 368 |
+
};
|
| 369 |
+
|
| 370 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 371 |
+
|
| 372 |
+
} // namespace threadblock
|
| 373 |
+
} // namespace epilogue
|
| 374 |
+
} // namespace cutlass
|
| 375 |
+
|
| 376 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
|
| 33 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 34 |
+
|
| 35 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 36 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 37 |
+
|
| 38 |
+
*/
|
| 39 |
+
|
| 40 |
+
#pragma once
|
| 41 |
+
|
| 42 |
+
#include "cutlass/cutlass.h"
|
| 43 |
+
#include "cutlass/numeric_types.h"
|
| 44 |
+
#include "cutlass/array.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/gemm/gemm.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 49 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
| 50 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 51 |
+
#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h"
|
| 52 |
+
|
| 53 |
+
#include "cutlass/layout/permute.h"
|
| 54 |
+
|
| 55 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
namespace cutlass {
|
| 58 |
+
namespace epilogue {
|
| 59 |
+
namespace threadblock {
|
| 60 |
+
|
| 61 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 64 |
+
template <
|
| 65 |
+
typename Shape,
|
| 66 |
+
typename WarpMmaTensorOp,
|
| 67 |
+
int PartitionsK,
|
| 68 |
+
typename ElementOutput,
|
| 69 |
+
typename OutputOp,
|
| 70 |
+
typename ReductionOp,
|
| 71 |
+
int ElementsPerAccess,
|
| 72 |
+
bool ScatterD = false,
|
| 73 |
+
typename PermuteDLayout = layout::NoPermute
|
| 74 |
+
>
|
| 75 |
+
struct DefaultEpilogueWithReductionTensorOp {
|
| 76 |
+
|
| 77 |
+
/// Use defaults related to the existing epilogue
|
| 78 |
+
using Base = DefaultEpilogueTensorOp<
|
| 79 |
+
Shape,
|
| 80 |
+
WarpMmaTensorOp,
|
| 81 |
+
PartitionsK,
|
| 82 |
+
OutputOp,
|
| 83 |
+
ElementsPerAccess
|
| 84 |
+
>;
|
| 85 |
+
|
| 86 |
+
/// Additional tensor tile iterator
|
| 87 |
+
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 88 |
+
typename Base::OutputTileThreadMap,
|
| 89 |
+
typename OutputOp::ElementTensor
|
| 90 |
+
>;
|
| 91 |
+
|
| 92 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 93 |
+
typename Base::OutputTileThreadMap,
|
| 94 |
+
ElementOutput,
|
| 95 |
+
ScatterD,
|
| 96 |
+
PermuteDLayout
|
| 97 |
+
>;
|
| 98 |
+
|
| 99 |
+
/// Define the epilogue
|
| 100 |
+
using Epilogue = EpilogueWithReduction<
|
| 101 |
+
Shape,
|
| 102 |
+
WarpMmaTensorOp,
|
| 103 |
+
PartitionsK,
|
| 104 |
+
OutputTileIterator,
|
| 105 |
+
TensorTileIterator,
|
| 106 |
+
typename WarpMmaTensorOp::ElementC,
|
| 107 |
+
typename Base::AccumulatorFragmentIterator,
|
| 108 |
+
typename Base::WarpTileIterator,
|
| 109 |
+
typename Base::SharedLoadIterator,
|
| 110 |
+
typename Base::OutputOp,
|
| 111 |
+
ReductionOp,
|
| 112 |
+
typename Base::Padding
|
| 113 |
+
>;
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 117 |
+
|
| 118 |
+
/// Defines sensible defaults for epilogues for TensorOps.
|
| 119 |
+
template <
|
| 120 |
+
typename Shape,
|
| 121 |
+
typename WarpMmaTensorOp,
|
| 122 |
+
int PartitionsK,
|
| 123 |
+
typename ElementOutput,
|
| 124 |
+
typename OutputOp,
|
| 125 |
+
typename ReductionOp,
|
| 126 |
+
int ElementsPerAccess,
|
| 127 |
+
bool ScatterD = false,
|
| 128 |
+
typename PermuteDLayout = layout::NoPermute
|
| 129 |
+
>
|
| 130 |
+
struct DefaultEpilogueWithReductionVoltaTensorOp {
|
| 131 |
+
|
| 132 |
+
/// Use defaults related to the existing epilogue
|
| 133 |
+
using Base = DefaultEpilogueVoltaTensorOp<
|
| 134 |
+
Shape,
|
| 135 |
+
WarpMmaTensorOp,
|
| 136 |
+
PartitionsK,
|
| 137 |
+
OutputOp,
|
| 138 |
+
ElementsPerAccess
|
| 139 |
+
>;
|
| 140 |
+
|
| 141 |
+
/// Additional tensor tile iterator
|
| 142 |
+
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 143 |
+
typename Base::OutputTileThreadMap,
|
| 144 |
+
typename OutputOp::ElementTensor
|
| 145 |
+
>;
|
| 146 |
+
|
| 147 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 148 |
+
typename Base::OutputTileThreadMap,
|
| 149 |
+
ElementOutput,
|
| 150 |
+
ScatterD,
|
| 151 |
+
PermuteDLayout
|
| 152 |
+
>;
|
| 153 |
+
|
| 154 |
+
/// Define the epilogue
|
| 155 |
+
using Epilogue = EpilogueWithReduction<
|
| 156 |
+
Shape,
|
| 157 |
+
WarpMmaTensorOp,
|
| 158 |
+
PartitionsK,
|
| 159 |
+
OutputTileIterator,
|
| 160 |
+
TensorTileIterator,
|
| 161 |
+
typename WarpMmaTensorOp::ElementC,
|
| 162 |
+
typename Base::AccumulatorFragmentIterator,
|
| 163 |
+
typename Base::WarpTileIterator,
|
| 164 |
+
typename Base::SharedLoadIterator,
|
| 165 |
+
typename Base::OutputOp,
|
| 166 |
+
ReductionOp,
|
| 167 |
+
typename Base::Padding
|
| 168 |
+
>;
|
| 169 |
+
};
|
| 170 |
+
|
| 171 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 172 |
+
|
| 173 |
+
} // namespace threadblock
|
| 174 |
+
} // namespace epilogue
|
| 175 |
+
} // namespace cutlass
|
| 176 |
+
|
| 177 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using WMMA.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/gemm/gemm.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 48 |
+
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
| 49 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
| 51 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
| 52 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 53 |
+
|
| 54 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 55 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 56 |
+
|
| 57 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 58 |
+
|
| 59 |
+
#include "cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h"
|
| 60 |
+
#include "cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h"
|
| 61 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h"
|
| 62 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 63 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 64 |
+
|
| 65 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 66 |
+
|
| 67 |
+
#include "cutlass/layout/permute.h"
|
| 68 |
+
|
| 69 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 70 |
+
|
| 71 |
+
namespace cutlass {
|
| 72 |
+
namespace epilogue {
|
| 73 |
+
namespace threadblock {
|
| 74 |
+
|
| 75 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 76 |
+
|
| 77 |
+
/// Defines sensible defaults for epilogues for WMMA TensorOps.
|
| 78 |
+
template <
|
| 79 |
+
typename Shape_,
|
| 80 |
+
typename WarpMmaTensorOp_,
|
| 81 |
+
int PartitionsK,
|
| 82 |
+
typename OutputOp_,
|
| 83 |
+
int ElementsPerAccess,
|
| 84 |
+
bool ScatterD = false,
|
| 85 |
+
typename PermuteDLayout = layout::NoPermute
|
| 86 |
+
>
|
| 87 |
+
struct DefaultEpilogueWmmaTensorOp {
|
| 88 |
+
|
| 89 |
+
using Shape = Shape_;
|
| 90 |
+
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 91 |
+
static int const kPartitionsK = PartitionsK;
|
| 92 |
+
using OutputOp = OutputOp_;
|
| 93 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 94 |
+
|
| 95 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 96 |
+
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 97 |
+
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 98 |
+
|
| 99 |
+
//
|
| 100 |
+
// Thread map
|
| 101 |
+
//
|
| 102 |
+
|
| 103 |
+
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapWmmaTensorOp<
|
| 104 |
+
Shape,
|
| 105 |
+
typename WarpMmaTensorOp::Shape,
|
| 106 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 107 |
+
kPartitionsK,
|
| 108 |
+
ElementOutput,
|
| 109 |
+
kElementsPerAccess
|
| 110 |
+
>::Type;
|
| 111 |
+
|
| 112 |
+
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 113 |
+
OutputTileThreadMap,
|
| 114 |
+
ElementOutput,
|
| 115 |
+
ScatterD,
|
| 116 |
+
PermuteDLayout
|
| 117 |
+
>;
|
| 118 |
+
|
| 119 |
+
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp<
|
| 120 |
+
typename WarpMmaTensorOp::Shape,
|
| 121 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 122 |
+
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 123 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 124 |
+
LayoutC
|
| 125 |
+
>;
|
| 126 |
+
|
| 127 |
+
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorWmmaTensorOp<
|
| 128 |
+
typename WarpMmaTensorOp::Shape,
|
| 129 |
+
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 130 |
+
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 131 |
+
LayoutC
|
| 132 |
+
>;
|
| 133 |
+
|
| 134 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
| 135 |
+
typename OutputTileThreadMap::CompactedThreadMap,
|
| 136 |
+
ElementAccumulator
|
| 137 |
+
>;
|
| 138 |
+
|
| 139 |
+
/// Hard-coded padding elements added
|
| 140 |
+
using Padding = typename WarpTileIterator::Padding;
|
| 141 |
+
|
| 142 |
+
//
|
| 143 |
+
// Define the epilogue
|
| 144 |
+
//
|
| 145 |
+
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
| 146 |
+
Shape,
|
| 147 |
+
WarpMmaTensorOp,
|
| 148 |
+
kPartitionsK,
|
| 149 |
+
OutputTileIterator,
|
| 150 |
+
AccumulatorFragmentIterator,
|
| 151 |
+
WarpTileIterator,
|
| 152 |
+
SharedLoadIterator,
|
| 153 |
+
OutputOp,
|
| 154 |
+
Padding
|
| 155 |
+
>;
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 160 |
+
|
| 161 |
+
} // namespace threadblock
|
| 162 |
+
} // namespace epilogue
|
| 163 |
+
} // namespace cutlass
|
| 164 |
+
|
| 165 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief
|
| 33 |
+
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 39 |
+
#include "cutlass/gemm/gemm.h"
|
| 40 |
+
|
| 41 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace epilogue {
|
| 45 |
+
namespace threadblock {
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
/// Defines the optimal thread map for SIMT accumulator layouts
|
| 50 |
+
template <
|
| 51 |
+
typename ThreadblockShape_,
|
| 52 |
+
typename WarpShape_,
|
| 53 |
+
typename MmaSimtPolicy_,
|
| 54 |
+
int PartitionsK,
|
| 55 |
+
typename Element_,
|
| 56 |
+
int ElementsPerAccess
|
| 57 |
+
>
|
| 58 |
+
struct DefaultThreadMapSimt {
|
| 59 |
+
|
| 60 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 61 |
+
using WarpShape = WarpShape_;
|
| 62 |
+
using MmaSimtPolicy = MmaSimtPolicy_;
|
| 63 |
+
static int const kPartitionsK = PartitionsK;
|
| 64 |
+
using Element = Element_;
|
| 65 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 66 |
+
|
| 67 |
+
//
|
| 68 |
+
// Definitions
|
| 69 |
+
//
|
| 70 |
+
|
| 71 |
+
struct Detail {
|
| 72 |
+
|
| 73 |
+
static int const kWarpSize = 32;
|
| 74 |
+
|
| 75 |
+
static_assert(
|
| 76 |
+
!(ThreadblockShape::kM % WarpShape::kM) &&
|
| 77 |
+
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
| 78 |
+
|
| 79 |
+
/// Number of warps
|
| 80 |
+
using WarpCount = gemm::GemmShape<
|
| 81 |
+
ThreadblockShape::kM / WarpShape::kM,
|
| 82 |
+
ThreadblockShape::kN / WarpShape::kN,
|
| 83 |
+
kPartitionsK
|
| 84 |
+
>;
|
| 85 |
+
|
| 86 |
+
/// Computes number of thread-level matrix multiplies are needed to span a warp
|
| 87 |
+
static int const kGroupCount =
|
| 88 |
+
WarpShape::kM / (MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM);
|
| 89 |
+
|
| 90 |
+
/// Number of participating threads
|
| 91 |
+
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 92 |
+
|
| 93 |
+
/// Number of iterations
|
| 94 |
+
static int const kIterations = MmaSimtPolicy::LaneMmaShape::kM * kGroupCount;
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
//
|
| 98 |
+
// ThreadMap
|
| 99 |
+
//
|
| 100 |
+
|
| 101 |
+
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
|
| 102 |
+
using Type = OutputTileOptimalThreadMap<
|
| 103 |
+
OutputTileShape< // Shape
|
| 104 |
+
ThreadblockShape::kN,
|
| 105 |
+
1,
|
| 106 |
+
MmaSimtPolicy::WarpShape::kRow,
|
| 107 |
+
Detail::WarpCount::kM,
|
| 108 |
+
1>,
|
| 109 |
+
OutputTileShape< // Count
|
| 110 |
+
1,
|
| 111 |
+
MmaSimtPolicy::LaneMmaShape::kM,
|
| 112 |
+
Detail::kGroupCount,
|
| 113 |
+
1,
|
| 114 |
+
Detail::kIterations>,
|
| 115 |
+
Detail::kThreads,
|
| 116 |
+
kElementsPerAccess,
|
| 117 |
+
sizeof_bits<Element>::value
|
| 118 |
+
>;
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 122 |
+
|
| 123 |
+
} // namespace threadblock
|
| 124 |
+
} // namespace epilogue
|
| 125 |
+
} // namespace cutlass
|
| 126 |
+
|
| 127 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief
|
| 33 |
+
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 39 |
+
#include "cutlass/gemm/gemm.h"
|
| 40 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 41 |
+
|
| 42 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
namespace epilogue {
|
| 46 |
+
namespace threadblock {
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
/// Defines the optimal thread map for TensorOp accumulator layouts
|
| 51 |
+
template <
|
| 52 |
+
typename ThreadblockShape_,
|
| 53 |
+
typename WarpShape_,
|
| 54 |
+
int PartitionsK,
|
| 55 |
+
typename Element_,
|
| 56 |
+
int ElementsPerAccess
|
| 57 |
+
>
|
| 58 |
+
struct DefaultThreadMapTensorOp {
|
| 59 |
+
|
| 60 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 61 |
+
using WarpShape = WarpShape_;
|
| 62 |
+
static int const kPartitionsK = PartitionsK;
|
| 63 |
+
using Element = Element_;
|
| 64 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 65 |
+
|
| 66 |
+
//
|
| 67 |
+
// Definitions
|
| 68 |
+
//
|
| 69 |
+
|
| 70 |
+
struct Detail {
|
| 71 |
+
|
| 72 |
+
/// Tensor Operations fundamentally perform operations on 8 rows
|
| 73 |
+
static int const kTensorOpRows = 8;
|
| 74 |
+
static int const kWarpSize = 32;
|
| 75 |
+
|
| 76 |
+
static_assert(
|
| 77 |
+
!(ThreadblockShape::kM % WarpShape::kM) &&
|
| 78 |
+
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
| 79 |
+
|
| 80 |
+
/// Number of warps
|
| 81 |
+
using WarpCount = gemm::GemmShape<
|
| 82 |
+
ThreadblockShape::kM / WarpShape::kM,
|
| 83 |
+
ThreadblockShape::kN / WarpShape::kN,
|
| 84 |
+
kPartitionsK
|
| 85 |
+
>;
|
| 86 |
+
|
| 87 |
+
/// Number of participating threads
|
| 88 |
+
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
//
|
| 92 |
+
// ThreadMap
|
| 93 |
+
//
|
| 94 |
+
|
| 95 |
+
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
|
| 96 |
+
using Type = OutputTileOptimalThreadMap <
|
| 97 |
+
OutputTileShape<ThreadblockShape::kN, Detail::kTensorOpRows, Detail::WarpCount::kM, 1, 1>,
|
| 98 |
+
OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>,
|
| 99 |
+
Detail::kThreads,
|
| 100 |
+
kElementsPerAccess,
|
| 101 |
+
sizeof_bits<Element>::value
|
| 102 |
+
>;
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 106 |
+
|
| 107 |
+
/// Defines the optimal thread map for TensorOp accumulator layouts
|
| 108 |
+
template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
|
| 109 |
+
typename Element_, int ElementsPerAccess, int InterleavedK>
|
| 110 |
+
struct DefaultInterleavedThreadMapTensorOp {
|
| 111 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 112 |
+
using WarpShape = WarpShape_;
|
| 113 |
+
static int const kPartitionsK = PartitionsK;
|
| 114 |
+
using Element = Element_;
|
| 115 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 116 |
+
static int const kInterleavedK = InterleavedK;
|
| 117 |
+
|
| 118 |
+
//
|
| 119 |
+
// Definitions
|
| 120 |
+
//
|
| 121 |
+
|
| 122 |
+
struct Detail {
|
| 123 |
+
/// Tensor Operations fundamentally perform operations on 8 rows
|
| 124 |
+
static int const kTensorOpRows = 8;
|
| 125 |
+
static int const kWarpSize = 32;
|
| 126 |
+
|
| 127 |
+
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
|
| 128 |
+
!(ThreadblockShape::kN % WarpShape::kN),
|
| 129 |
+
"Divisibility");
|
| 130 |
+
|
| 131 |
+
/// Number of warps
|
| 132 |
+
using WarpCount =
|
| 133 |
+
gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
|
| 134 |
+
ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
|
| 135 |
+
|
| 136 |
+
/// Number of participating threads
|
| 137 |
+
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
//
|
| 141 |
+
// ThreadMap
|
| 142 |
+
//
|
| 143 |
+
|
| 144 |
+
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept
|
| 145 |
+
/// InterleavedOutputTileThreadMap
|
| 146 |
+
using Type = InterleavedOutputTileThreadMap<
|
| 147 |
+
layout::PitchLinearShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
|
| 148 |
+
layout::PitchLinearShape<WarpShape::kM / Detail::kTensorOpRows,
|
| 149 |
+
WarpShape::kN / InterleavedK>,
|
| 150 |
+
Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 155 |
+
|
| 156 |
+
/// Defines the optimal thread map for TensorOp accumulator layouts
|
| 157 |
+
template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
|
| 158 |
+
typename Element_, int ElementsPerAccess, int InterleavedK>
|
| 159 |
+
struct DefaultInterleavedConvThreadMapTensorOp {
|
| 160 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 161 |
+
using WarpShape = WarpShape_;
|
| 162 |
+
static int const kPartitionsK = PartitionsK;
|
| 163 |
+
using Element = Element_;
|
| 164 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 165 |
+
static int const kInterleavedK = InterleavedK;
|
| 166 |
+
|
| 167 |
+
//
|
| 168 |
+
// Definitions
|
| 169 |
+
//
|
| 170 |
+
|
| 171 |
+
struct Detail {
|
| 172 |
+
/// Tensor Operations fundamentally perform operations on 8 rows
|
| 173 |
+
static int const kTensorOpRows = 8;
|
| 174 |
+
static int const kWarpSize = 32;
|
| 175 |
+
|
| 176 |
+
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
|
| 177 |
+
!(ThreadblockShape::kN % WarpShape::kN),
|
| 178 |
+
"Divisibility");
|
| 179 |
+
|
| 180 |
+
/// Number of warps
|
| 181 |
+
using WarpCount =
|
| 182 |
+
gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
|
| 183 |
+
ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
|
| 184 |
+
|
| 185 |
+
/// Number of participating threads
|
| 186 |
+
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 187 |
+
};
|
| 188 |
+
|
| 189 |
+
//
|
| 190 |
+
// ThreadMap
|
| 191 |
+
//
|
| 192 |
+
|
| 193 |
+
/// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept
|
| 194 |
+
/// InterleavedOutputTileThreadMap
|
| 195 |
+
using Type = InterleavedConvOutputTileThreadMap<
|
| 196 |
+
MatrixShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
|
| 197 |
+
MatrixShape<WarpShape::kM / Detail::kTensorOpRows,
|
| 198 |
+
WarpShape::kN / InterleavedK>,
|
| 199 |
+
Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
|
| 200 |
+
};
|
| 201 |
+
|
| 202 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 203 |
+
|
| 204 |
+
} // namespace threadblock
|
| 205 |
+
} // namespace epilogue
|
| 206 |
+
} // namespace cutlass
|
| 207 |
+
|
| 208 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief
|
| 33 |
+
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 39 |
+
#include "cutlass/gemm/gemm.h"
|
| 40 |
+
|
| 41 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace epilogue {
|
| 45 |
+
namespace threadblock {
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
/// Defines the optimal thread map for TensorOp accumulator layouts
|
| 50 |
+
template <
|
| 51 |
+
typename ThreadblockShape,
|
| 52 |
+
typename WarpShape,
|
| 53 |
+
int PartitionsK,
|
| 54 |
+
typename ElementOutput,
|
| 55 |
+
int ElementsPerAccess,
|
| 56 |
+
typename ElementAccumulator
|
| 57 |
+
>
|
| 58 |
+
struct DefaultThreadMapVoltaTensorOp;
|
| 59 |
+
|
| 60 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 61 |
+
|
| 62 |
+
/// Defines the optimal thread map for TensorOp accumulator layouts
|
| 63 |
+
template <
|
| 64 |
+
typename ThreadblockShape_,
|
| 65 |
+
typename WarpShape_,
|
| 66 |
+
int PartitionsK,
|
| 67 |
+
typename ElementOutput_,
|
| 68 |
+
int ElementsPerAccess
|
| 69 |
+
>
|
| 70 |
+
struct DefaultThreadMapVoltaTensorOp<
|
| 71 |
+
ThreadblockShape_,
|
| 72 |
+
WarpShape_,
|
| 73 |
+
PartitionsK,
|
| 74 |
+
ElementOutput_,
|
| 75 |
+
ElementsPerAccess,
|
| 76 |
+
half_t> {
|
| 77 |
+
|
| 78 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 79 |
+
using WarpShape = WarpShape_;
|
| 80 |
+
static int const kPartitionsK = PartitionsK;
|
| 81 |
+
using ElementOutput = ElementOutput_;
|
| 82 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 83 |
+
using ElementAccumulator = half_t;
|
| 84 |
+
|
| 85 |
+
//
|
| 86 |
+
// Definitions
|
| 87 |
+
//
|
| 88 |
+
|
| 89 |
+
struct Detail {
|
| 90 |
+
|
| 91 |
+
static int const kTensorOpRows = 16;
|
| 92 |
+
static int const kWarpSize = 32;
|
| 93 |
+
static int const kInterleavedTilesM = WarpShape::kM / 32;
|
| 94 |
+
|
| 95 |
+
static_assert(
|
| 96 |
+
!(ThreadblockShape::kM % WarpShape::kM) &&
|
| 97 |
+
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
| 98 |
+
|
| 99 |
+
/// Number of warps
|
| 100 |
+
using WarpCount = gemm::GemmShape<
|
| 101 |
+
ThreadblockShape::kM / WarpShape::kM,
|
| 102 |
+
ThreadblockShape::kN / WarpShape::kN,
|
| 103 |
+
kPartitionsK
|
| 104 |
+
>;
|
| 105 |
+
|
| 106 |
+
/// Number of participating threads
|
| 107 |
+
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 108 |
+
|
| 109 |
+
using Shape = cutlass::epilogue::threadblock::OutputTileShape<
|
| 110 |
+
ThreadblockShape::kN, // column
|
| 111 |
+
4, // row
|
| 112 |
+
4, // group
|
| 113 |
+
WarpCount::kM, // cluster
|
| 114 |
+
1 // tile
|
| 115 |
+
>;
|
| 116 |
+
|
| 117 |
+
/// Number of iterations per subspace
|
| 118 |
+
using Count = cutlass::epilogue::threadblock::OutputTileShape<
|
| 119 |
+
1, // column
|
| 120 |
+
2, // row
|
| 121 |
+
kInterleavedTilesM, // group
|
| 122 |
+
1, // cluster
|
| 123 |
+
WarpShape::kM / kTensorOpRows // iterations
|
| 124 |
+
>;
|
| 125 |
+
};
|
| 126 |
+
|
| 127 |
+
//
|
| 128 |
+
// ThreadMap
|
| 129 |
+
//
|
| 130 |
+
|
| 131 |
+
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
|
| 132 |
+
using Type = OutputTileOptimalThreadMap <
|
| 133 |
+
typename Detail::Shape,
|
| 134 |
+
typename Detail::Count,
|
| 135 |
+
Detail::kThreads,
|
| 136 |
+
kElementsPerAccess,
|
| 137 |
+
sizeof_bits<ElementOutput>::value
|
| 138 |
+
>;
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 142 |
+
|
| 143 |
+
/// Defines the optimal thread map for TensorOp accumulator layouts
|
| 144 |
+
template <
|
| 145 |
+
typename ThreadblockShape_,
|
| 146 |
+
typename WarpShape_,
|
| 147 |
+
int PartitionsK,
|
| 148 |
+
typename ElementOutput_,
|
| 149 |
+
int ElementsPerAccess
|
| 150 |
+
>
|
| 151 |
+
struct DefaultThreadMapVoltaTensorOp<
|
| 152 |
+
ThreadblockShape_,
|
| 153 |
+
WarpShape_,
|
| 154 |
+
PartitionsK,
|
| 155 |
+
ElementOutput_,
|
| 156 |
+
ElementsPerAccess,
|
| 157 |
+
float> {
|
| 158 |
+
|
| 159 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 160 |
+
using WarpShape = WarpShape_;
|
| 161 |
+
static int const kPartitionsK = PartitionsK;
|
| 162 |
+
using ElementOutput = ElementOutput_;
|
| 163 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 164 |
+
using ElementAccumulator = float;
|
| 165 |
+
|
| 166 |
+
//
|
| 167 |
+
// Definitions
|
| 168 |
+
//
|
| 169 |
+
|
| 170 |
+
struct Detail {
|
| 171 |
+
|
| 172 |
+
static int const kTensorOpRows = 16;
|
| 173 |
+
static int const kWarpSize = 32;
|
| 174 |
+
static int const kInterleavedTilesM = WarpShape::kM / 32;
|
| 175 |
+
|
| 176 |
+
static_assert(
|
| 177 |
+
!(ThreadblockShape::kM % WarpShape::kM) &&
|
| 178 |
+
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
| 179 |
+
|
| 180 |
+
/// Number of warps
|
| 181 |
+
using WarpCount = gemm::GemmShape<
|
| 182 |
+
ThreadblockShape::kM / WarpShape::kM,
|
| 183 |
+
ThreadblockShape::kN / WarpShape::kN,
|
| 184 |
+
kPartitionsK
|
| 185 |
+
>;
|
| 186 |
+
|
| 187 |
+
/// Number of participating threads
|
| 188 |
+
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 189 |
+
|
| 190 |
+
using Shape = cutlass::epilogue::threadblock::OutputTileShape<
|
| 191 |
+
ThreadblockShape::kN, // column
|
| 192 |
+
4, // row
|
| 193 |
+
4, // group
|
| 194 |
+
WarpCount::kM, // cluster
|
| 195 |
+
1 // tile
|
| 196 |
+
>;
|
| 197 |
+
|
| 198 |
+
/// Number of iterations per subspace
|
| 199 |
+
using Count = cutlass::epilogue::threadblock::OutputTileShape<
|
| 200 |
+
1, // column
|
| 201 |
+
2, // row
|
| 202 |
+
kInterleavedTilesM, // group
|
| 203 |
+
1, // cluster
|
| 204 |
+
WarpShape::kM / kTensorOpRows // iterations
|
| 205 |
+
>;
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
//
|
| 209 |
+
// ThreadMap
|
| 210 |
+
//
|
| 211 |
+
|
| 212 |
+
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
|
| 213 |
+
using Type = OutputTileOptimalThreadMap <
|
| 214 |
+
typename Detail::Shape,
|
| 215 |
+
typename Detail::Count,
|
| 216 |
+
Detail::kThreads,
|
| 217 |
+
kElementsPerAccess,
|
| 218 |
+
sizeof_bits<ElementOutput>::value
|
| 219 |
+
>;
|
| 220 |
+
};
|
| 221 |
+
|
| 222 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 223 |
+
|
| 224 |
+
} // namespace threadblock
|
| 225 |
+
} // namespace epilogue
|
| 226 |
+
} // namespace cutlass
|
| 227 |
+
|
| 228 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief
|
| 33 |
+
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 39 |
+
#include "cutlass/gemm/gemm.h"
|
| 40 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 41 |
+
|
| 42 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
namespace epilogue {
|
| 46 |
+
namespace threadblock {
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
/// Defines the optimal thread map for Wmma TensorOp accumulator layouts
|
| 51 |
+
template <
|
| 52 |
+
typename ThreadblockShape_,
|
| 53 |
+
typename WarpShape_,
|
| 54 |
+
typename InstructionShape_,
|
| 55 |
+
int PartitionsK,
|
| 56 |
+
typename Element_,
|
| 57 |
+
int ElementsPerAccess
|
| 58 |
+
>
|
| 59 |
+
struct DefaultThreadMapWmmaTensorOp {
|
| 60 |
+
|
| 61 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 62 |
+
using WarpShape = WarpShape_;
|
| 63 |
+
using InstructionShape = InstructionShape_;
|
| 64 |
+
static int const kPartitionsK = PartitionsK;
|
| 65 |
+
using Element = Element_;
|
| 66 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 67 |
+
|
| 68 |
+
//
|
| 69 |
+
// Definitions
|
| 70 |
+
//
|
| 71 |
+
|
| 72 |
+
struct Detail {
|
| 73 |
+
|
| 74 |
+
/// Wmma Tensor Operations fundamentally perform operations on InstructionShape::kM rows
|
| 75 |
+
static int const kTensorOpRows = InstructionShape::kM;
|
| 76 |
+
static int const kWarpSize = 32;
|
| 77 |
+
|
| 78 |
+
static_assert(
|
| 79 |
+
!(ThreadblockShape::kM % WarpShape::kM) &&
|
| 80 |
+
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
| 81 |
+
|
| 82 |
+
/// Number of warps
|
| 83 |
+
using WarpCount = gemm::GemmShape<
|
| 84 |
+
ThreadblockShape::kM / WarpShape::kM,
|
| 85 |
+
ThreadblockShape::kN / WarpShape::kN,
|
| 86 |
+
kPartitionsK
|
| 87 |
+
>;
|
| 88 |
+
|
| 89 |
+
/// Number of participating threads
|
| 90 |
+
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
//
|
| 94 |
+
// ThreadMap
|
| 95 |
+
//
|
| 96 |
+
|
| 97 |
+
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
|
| 98 |
+
using Type = OutputTileOptimalThreadMap <
|
| 99 |
+
OutputTileShape<ThreadblockShape::kN, Detail::kTensorOpRows, Detail::WarpCount::kM, 1, 1>,
|
| 100 |
+
OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>,
|
| 101 |
+
Detail::kThreads,
|
| 102 |
+
kElementsPerAccess,
|
| 103 |
+
sizeof_bits<Element>::value
|
| 104 |
+
>;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 108 |
+
|
| 109 |
+
} // namespace threadblock
|
| 110 |
+
} // namespace epilogue
|
| 111 |
+
} // namespace cutlass
|
| 112 |
+
|
| 113 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
#include "cutlass/layout/matrix.h"
|
| 45 |
+
#include "cutlass/layout/tensor.h"
|
| 46 |
+
#include "cutlass/matrix_shape.h"
|
| 47 |
+
#include "cutlass/tensor_ref.h"
|
| 48 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 49 |
+
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
| 50 |
+
#include "cutlass/arch/arch.h"
|
| 51 |
+
#include "cutlass/arch/memory.h"
|
| 52 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
|
| 53 |
+
|
| 54 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
namespace cutlass {
|
| 57 |
+
|
| 58 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace epilogue {
|
| 61 |
+
namespace threadblock {
|
| 62 |
+
|
| 63 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
template <typename Element_>
|
| 66 |
+
class DirectStoreEpilogueIterator {
|
| 67 |
+
public:
|
| 68 |
+
|
| 69 |
+
using Element = Element_;
|
| 70 |
+
|
| 71 |
+
using Layout = layout::RowMajor;
|
| 72 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 73 |
+
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
| 74 |
+
|
| 75 |
+
using Index = typename Layout::Index;
|
| 76 |
+
using LongIndex = typename Layout::LongIndex;
|
| 77 |
+
using TensorCoord = MatrixCoord;
|
| 78 |
+
|
| 79 |
+
static int const kElementsPerAccess = 1;
|
| 80 |
+
|
| 81 |
+
/// Uses a non-template class
|
| 82 |
+
struct Params : PredicatedTileIteratorParams {
|
| 83 |
+
using Base = PredicatedTileIteratorParams;
|
| 84 |
+
|
| 85 |
+
CUTLASS_HOST_DEVICE
|
| 86 |
+
Params() { }
|
| 87 |
+
|
| 88 |
+
CUTLASS_HOST_DEVICE
|
| 89 |
+
Params(Layout const &layout) {
|
| 90 |
+
stride = layout.stride(0) * sizeof(Element);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
CUTLASS_HOST_DEVICE
|
| 94 |
+
Params(Base const &base) :
|
| 95 |
+
Base(base) { }
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
public:
|
| 99 |
+
|
| 100 |
+
//
|
| 101 |
+
// Data members
|
| 102 |
+
//
|
| 103 |
+
|
| 104 |
+
Element *pointer; // pointer to the output matrix
|
| 105 |
+
|
| 106 |
+
LongIndex stride; // stride in elements between rows
|
| 107 |
+
|
| 108 |
+
TensorCoord extent; // extent of output matrix
|
| 109 |
+
|
| 110 |
+
int thread_idx; // thread index
|
| 111 |
+
|
| 112 |
+
TensorCoord threadblock_offset;
|
| 113 |
+
|
| 114 |
+
public:
|
| 115 |
+
|
| 116 |
+
/// Constructor
|
| 117 |
+
CUTLASS_DEVICE
|
| 118 |
+
DirectStoreEpilogueIterator(
|
| 119 |
+
PredicatedTileIteratorParams const & params,
|
| 120 |
+
Element *pointer_,
|
| 121 |
+
TensorCoord extent_,
|
| 122 |
+
int thread_idx_,
|
| 123 |
+
TensorCoord threadblock_offset_ = TensorCoord(),
|
| 124 |
+
int const * indices = nullptr
|
| 125 |
+
):
|
| 126 |
+
pointer(pointer_),
|
| 127 |
+
stride(params.stride / sizeof(Element)),
|
| 128 |
+
extent(extent_),
|
| 129 |
+
thread_idx(thread_idx_),
|
| 130 |
+
threadblock_offset(threadblock_offset_)
|
| 131 |
+
{
|
| 132 |
+
|
| 133 |
+
}
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 137 |
+
|
| 138 |
+
} // namespace threadblock
|
| 139 |
+
} // namespace epilogue
|
| 140 |
+
} // namespace cutlass
|
| 141 |
+
|
| 142 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue.h
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
The shared memory resource is time-sliced across warps.
|
| 38 |
+
*/
|
| 39 |
+
|
| 40 |
+
#pragma once
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include CUDA_STD_HEADER(cassert)
|
| 43 |
+
|
| 44 |
+
#include "cutlass/numeric_types.h"
|
| 45 |
+
#include "cutlass/array.h"
|
| 46 |
+
#include "cutlass/layout/vector.h"
|
| 47 |
+
#include "cutlass/layout/tensor.h"
|
| 48 |
+
#include "cutlass/tensor_coord.h"
|
| 49 |
+
#include "cutlass/aligned_buffer.h"
|
| 50 |
+
#include "cutlass/functional.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/gemm/gemm.h"
|
| 53 |
+
|
| 54 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 55 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 56 |
+
|
| 57 |
+
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
| 58 |
+
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
|
| 59 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 60 |
+
|
| 61 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
namespace cutlass {
|
| 64 |
+
namespace epilogue {
|
| 65 |
+
namespace threadblock {
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
|
| 70 |
+
/// Epilogue operator
|
| 71 |
+
template <
|
| 72 |
+
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 73 |
+
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 74 |
+
int PartitionsK, ///< Number of partitions of the K dimension
|
| 75 |
+
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
| 76 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 77 |
+
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 78 |
+
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
| 79 |
+
typename OutputOp_, ///< Output operator
|
| 80 |
+
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
| 81 |
+
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
| 82 |
+
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
| 83 |
+
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
| 84 |
+
>
|
| 85 |
+
class Epilogue :
|
| 86 |
+
public EpilogueBase<
|
| 87 |
+
Shape_,
|
| 88 |
+
typename WarpMmaOperator_::Shape,
|
| 89 |
+
PartitionsK,
|
| 90 |
+
AccumulatorFragmentIterator_,
|
| 91 |
+
WarpTileIterator_,
|
| 92 |
+
Padding_,
|
| 93 |
+
FragmentsPerPartition>,
|
| 94 |
+
public EpilogueBaseStreamK<
|
| 95 |
+
Shape_,
|
| 96 |
+
PartitionsK,
|
| 97 |
+
WarpMmaOperator_,
|
| 98 |
+
AccumulatorFragmentIterator_>
|
| 99 |
+
{
|
| 100 |
+
|
| 101 |
+
public:
|
| 102 |
+
|
| 103 |
+
using Base = EpilogueBase<
|
| 104 |
+
Shape_,
|
| 105 |
+
typename WarpMmaOperator_::Shape,
|
| 106 |
+
PartitionsK,
|
| 107 |
+
AccumulatorFragmentIterator_,
|
| 108 |
+
WarpTileIterator_,
|
| 109 |
+
Padding_,
|
| 110 |
+
FragmentsPerPartition>;
|
| 111 |
+
|
| 112 |
+
using BaseStreamK = EpilogueBaseStreamK<
|
| 113 |
+
Shape_,
|
| 114 |
+
PartitionsK,
|
| 115 |
+
WarpMmaOperator_,
|
| 116 |
+
AccumulatorFragmentIterator_>;
|
| 117 |
+
|
| 118 |
+
using Shape = Shape_;
|
| 119 |
+
using WarpMmaOperator = WarpMmaOperator_;
|
| 120 |
+
static int const kPartitionsK = PartitionsK;
|
| 121 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 122 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 123 |
+
using WarpTileIterator = WarpTileIterator_;
|
| 124 |
+
using SharedLoadIterator = SharedLoadIterator_;
|
| 125 |
+
using OutputOp = OutputOp_;
|
| 126 |
+
using Padding = Padding_;
|
| 127 |
+
using Layout = layout::RowMajor;
|
| 128 |
+
using LongIndex = typename Layout::LongIndex;
|
| 129 |
+
|
| 130 |
+
/// Number of warps per block
|
| 131 |
+
using WarpCount = typename Base::WarpCount;
|
| 132 |
+
|
| 133 |
+
/// Number of threads per block
|
| 134 |
+
static int const kBlockThreads = 32 * WarpCount::kCount;
|
| 135 |
+
|
| 136 |
+
/// Per-thread accumulator tile type
|
| 137 |
+
using AccumulatorTile = typename Base::AccumulatorTile;
|
| 138 |
+
|
| 139 |
+
/// Numerical accumulation element type
|
| 140 |
+
using ElementAccumulator = typename WarpMmaOperator::ElementC;
|
| 141 |
+
|
| 142 |
+
/// Fragment type used by the accumulator tile's fragment iterator
|
| 143 |
+
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
|
| 144 |
+
|
| 145 |
+
/// Output element
|
| 146 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 147 |
+
|
| 148 |
+
/// Output access size
|
| 149 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 150 |
+
|
| 151 |
+
/// Tensor reference to destination tensor
|
| 152 |
+
using TensorRef = typename OutputTileIterator::TensorRef;
|
| 153 |
+
|
| 154 |
+
/// Tensor reference to sync tensor
|
| 155 |
+
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
| 156 |
+
|
| 157 |
+
/// Const tensor reference to source tensor
|
| 158 |
+
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
| 159 |
+
|
| 160 |
+
/// Vector type used by the global output iterator
|
| 161 |
+
using OutputAccessType = Array<
|
| 162 |
+
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 163 |
+
|
| 164 |
+
/// Vector type used by the shared output iterator
|
| 165 |
+
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 166 |
+
|
| 167 |
+
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
| 168 |
+
|
| 169 |
+
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
public:
|
| 173 |
+
|
| 174 |
+
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
| 175 |
+
"Mismatch between shared load iterator and output tile iterator.");
|
| 176 |
+
|
| 177 |
+
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
| 178 |
+
|
| 179 |
+
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
| 180 |
+
"Divisibility");
|
| 181 |
+
|
| 182 |
+
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
public:
|
| 186 |
+
|
| 187 |
+
/// Aspect for when epilogue source is not needed
|
| 188 |
+
struct SourceAspectNotNeeded
|
| 189 |
+
{
|
| 190 |
+
/// Constructor
|
| 191 |
+
CUTLASS_DEVICE
|
| 192 |
+
SourceAspectNotNeeded()
|
| 193 |
+
{}
|
| 194 |
+
|
| 195 |
+
// No-op
|
| 196 |
+
CUTLASS_DEVICE
|
| 197 |
+
void load() { }
|
| 198 |
+
|
| 199 |
+
/// Invoke the output functor over each vector of output
|
| 200 |
+
CUTLASS_DEVICE
|
| 201 |
+
void apply_output_operator(
|
| 202 |
+
typename OutputTileIterator::Fragment &output_fragment,
|
| 203 |
+
OutputOp const &output_op,
|
| 204 |
+
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
|
| 205 |
+
{
|
| 206 |
+
OutputAccessType *output_frag_ptr =
|
| 207 |
+
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
| 208 |
+
|
| 209 |
+
AccumulatorAccessType const *compute_frag_ptr =
|
| 210 |
+
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
| 211 |
+
|
| 212 |
+
int const kOutputOpIterations =
|
| 213 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 214 |
+
|
| 215 |
+
CUTLASS_PRAGMA_UNROLL
|
| 216 |
+
for (int i = 0; i < kOutputOpIterations; ++i)
|
| 217 |
+
{
|
| 218 |
+
// Call the output operator
|
| 219 |
+
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
};
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
/// Aspect for when epilogue source is needed
|
| 226 |
+
struct SourceAspectNeeded
|
| 227 |
+
{
|
| 228 |
+
OutputTileIterator source_iterator;
|
| 229 |
+
|
| 230 |
+
typename OutputTileIterator::Fragment source_fragment;
|
| 231 |
+
|
| 232 |
+
/// Invoke the output functor over each vector of output
|
| 233 |
+
CUTLASS_DEVICE
|
| 234 |
+
static void apply_output_operator(
|
| 235 |
+
typename OutputTileIterator::Fragment &output_fragment,
|
| 236 |
+
OutputOp const &output_op,
|
| 237 |
+
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
|
| 238 |
+
typename OutputTileIterator::Fragment const &source_fragment)
|
| 239 |
+
{
|
| 240 |
+
OutputAccessType *output_frag_ptr =
|
| 241 |
+
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
| 242 |
+
|
| 243 |
+
AccumulatorAccessType const *compute_frag_ptr =
|
| 244 |
+
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
| 245 |
+
|
| 246 |
+
OutputAccessType const *source_frag_ptr =
|
| 247 |
+
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
| 248 |
+
|
| 249 |
+
int const kOutputOpIterations =
|
| 250 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 251 |
+
|
| 252 |
+
CUTLASS_PRAGMA_UNROLL
|
| 253 |
+
for (int i = 0; i < kOutputOpIterations; ++i)
|
| 254 |
+
{
|
| 255 |
+
// Call the output operator
|
| 256 |
+
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
/// Constructor
|
| 261 |
+
CUTLASS_DEVICE
|
| 262 |
+
SourceAspectNeeded(OutputTileIterator source_iterator) :
|
| 263 |
+
source_iterator(source_iterator)
|
| 264 |
+
{
|
| 265 |
+
source_fragment.clear();
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// Load addend source fragment from global memory
|
| 269 |
+
CUTLASS_DEVICE
|
| 270 |
+
void load() {
|
| 271 |
+
source_iterator.load(source_fragment);
|
| 272 |
+
++source_iterator;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
/// Invoke the output functor over each vector of output
|
| 276 |
+
CUTLASS_DEVICE
|
| 277 |
+
void apply_output_operator(
|
| 278 |
+
typename OutputTileIterator::Fragment &output_fragment,
|
| 279 |
+
OutputOp const &output_op,
|
| 280 |
+
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
|
| 281 |
+
{
|
| 282 |
+
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
| 283 |
+
}
|
| 284 |
+
};
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
private:
|
| 288 |
+
|
| 289 |
+
/// Loads fragment from shared memory aligned with output tensor
|
| 290 |
+
SharedLoadIterator shared_load_iterator_;
|
| 291 |
+
|
| 292 |
+
/// Thread index in the threadblock
|
| 293 |
+
int thread_idx;
|
| 294 |
+
|
| 295 |
+
/// Warp index in the threadblock
|
| 296 |
+
int warp_idx;
|
| 297 |
+
|
| 298 |
+
public:
|
| 299 |
+
|
| 300 |
+
/// Constructor
|
| 301 |
+
CUTLASS_DEVICE
|
| 302 |
+
Epilogue(
|
| 303 |
+
typename Base::SharedStorage &shared_storage, ///< Shared storage object
|
| 304 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 305 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 306 |
+
int lane_idx) ///< Id of thread within warp
|
| 307 |
+
:
|
| 308 |
+
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
| 309 |
+
BaseStreamK(thread_idx),
|
| 310 |
+
shared_load_iterator_(shared_storage.reference(), thread_idx),
|
| 311 |
+
thread_idx(thread_idx),
|
| 312 |
+
warp_idx(warp_idx)
|
| 313 |
+
{}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
|
| 317 |
+
/// performing epilogue computations, writing to output
|
| 318 |
+
CUTLASS_DEVICE
|
| 319 |
+
void reduce(
|
| 320 |
+
int peer_idx_begin,
|
| 321 |
+
int peer_idx_end,
|
| 322 |
+
int reduce_fragment_idx,
|
| 323 |
+
void *element_workspace,
|
| 324 |
+
OutputOp const &output_op, ///< Output operator
|
| 325 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 326 |
+
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 327 |
+
{
|
| 328 |
+
// Reduce peer accumulator fragments into one fragment
|
| 329 |
+
AccumulatorFragment accum_fragment;
|
| 330 |
+
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
|
| 331 |
+
|
| 332 |
+
// Store fragment to shared memory
|
| 333 |
+
this->warp_tile_iterator_.store(accum_fragment);
|
| 334 |
+
|
| 335 |
+
__syncthreads();
|
| 336 |
+
|
| 337 |
+
// Initialize/load source-fragment data
|
| 338 |
+
typename OutputTileIterator::Fragment source_fragment;
|
| 339 |
+
source_fragment.clear();
|
| 340 |
+
|
| 341 |
+
if (output_op.is_source_needed())
|
| 342 |
+
{
|
| 343 |
+
source_iterator += reduce_fragment_idx;
|
| 344 |
+
source_iterator.load(source_fragment);
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// Load fragment from shared memory
|
| 348 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment;
|
| 349 |
+
shared_load_iterator_.load(aligned_accum_fragment);
|
| 350 |
+
|
| 351 |
+
// Add fragments shared by other k partitions
|
| 352 |
+
if (kPartitionsK > 1)
|
| 353 |
+
{
|
| 354 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 355 |
+
|
| 356 |
+
CUTLASS_PRAGMA_UNROLL
|
| 357 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 358 |
+
typename SharedLoadIterator::Fragment aligned_addend_fragment;
|
| 359 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 360 |
+
shared_load_iterator_.load(aligned_addend_fragment);
|
| 361 |
+
aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment);
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
// Compute the output result
|
| 366 |
+
typename OutputTileIterator::Fragment output_fragment;
|
| 367 |
+
|
| 368 |
+
// Apply the output operator
|
| 369 |
+
SourceAspectNeeded::apply_output_operator(
|
| 370 |
+
output_fragment,
|
| 371 |
+
output_op,
|
| 372 |
+
aligned_accum_fragment,
|
| 373 |
+
source_fragment);
|
| 374 |
+
|
| 375 |
+
// Store the final result
|
| 376 |
+
destination_iterator += reduce_fragment_idx;
|
| 377 |
+
destination_iterator.store(output_fragment);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
/// Perform the epilogue computations and stream the result to global memory.
|
| 382 |
+
CUTLASS_DEVICE
|
| 383 |
+
void operator()(
|
| 384 |
+
OutputOp const &output_op, ///< Output operator
|
| 385 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 386 |
+
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
|
| 387 |
+
{
|
| 388 |
+
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
/// Perform the epilogue computations and stream the result to global memory. Implements
|
| 393 |
+
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
|
| 394 |
+
CUTLASS_DEVICE
|
| 395 |
+
void operator()(
|
| 396 |
+
OutputOp const &output_op, ///< Output operator
|
| 397 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 398 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 399 |
+
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
| 400 |
+
{
|
| 401 |
+
if (output_op.is_source_needed())
|
| 402 |
+
{
|
| 403 |
+
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
| 404 |
+
}
|
| 405 |
+
else
|
| 406 |
+
{
|
| 407 |
+
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
/// Perform the epilogue computations and stream the result to global memory. Implements a
|
| 413 |
+
/// single codepath, regardless of whether the output op requires addend data to be loaded
|
| 414 |
+
CUTLASS_DEVICE
|
| 415 |
+
void unified(
|
| 416 |
+
OutputOp const &output_op, ///< Output operator
|
| 417 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 418 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 419 |
+
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
| 420 |
+
{
|
| 421 |
+
if (!output_op.is_source_needed())
|
| 422 |
+
{
|
| 423 |
+
source_iterator.clear_mask();
|
| 424 |
+
__syncthreads(); // Dummy (CUDA 11.0)
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
template<class Seq>
|
| 431 |
+
struct acc2smem;
|
| 432 |
+
|
| 433 |
+
template <size_t... Seq>
|
| 434 |
+
struct acc2smem<cutlass::index_sequence<Seq...>> {
|
| 435 |
+
template<int Advance>
|
| 436 |
+
CUTLASS_DEVICE
|
| 437 |
+
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
| 438 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 439 |
+
CUTLASS_PRAGMA_UNROLL
|
| 440 |
+
for (int i = 0; i < Advance; i++) {
|
| 441 |
+
++accum_fragment_iterator;
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 445 |
+
|
| 446 |
+
accum_fragment_iterator.load(accum_fragment);
|
| 447 |
+
++accum_fragment_iterator;
|
| 448 |
+
warp_tile_iterator.store(accum_fragment);
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
CUTLASS_DEVICE
|
| 452 |
+
static void push(size_t pos,
|
| 453 |
+
AccumulatorFragmentIterator const &iterator_begin,
|
| 454 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 455 |
+
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
| 456 |
+
}
|
| 457 |
+
};
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
/// Streams the result to global memory
|
| 461 |
+
template <typename SourceAspect>
|
| 462 |
+
CUTLASS_DEVICE
|
| 463 |
+
void operator()(
|
| 464 |
+
OutputOp const &output_op, ///< Output operator
|
| 465 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 466 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 467 |
+
SourceAspect source)
|
| 468 |
+
{
|
| 469 |
+
// Iterator over warp-level accumulator fragment
|
| 470 |
+
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 471 |
+
|
| 472 |
+
//
|
| 473 |
+
// Iterate over accumulator tile
|
| 474 |
+
//
|
| 475 |
+
|
| 476 |
+
#ifdef __clang__
|
| 477 |
+
#pragma clang diagnostic push
|
| 478 |
+
#pragma clang diagnostic ignored "-Wcuda-compat"
|
| 479 |
+
// Turn off clangs warning about loop unroll argument using parens.
|
| 480 |
+
#endif
|
| 481 |
+
|
| 482 |
+
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
| 483 |
+
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter)
|
| 484 |
+
{
|
| 485 |
+
//
|
| 486 |
+
// Load the source
|
| 487 |
+
//
|
| 488 |
+
|
| 489 |
+
source.load();
|
| 490 |
+
//
|
| 491 |
+
// Convert and store fragment
|
| 492 |
+
//
|
| 493 |
+
|
| 494 |
+
__syncthreads();
|
| 495 |
+
|
| 496 |
+
acc2smem<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
| 497 |
+
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
| 498 |
+
|
| 499 |
+
__syncthreads();
|
| 500 |
+
|
| 501 |
+
//
|
| 502 |
+
// Load fragments from shared memory
|
| 503 |
+
//
|
| 504 |
+
|
| 505 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 506 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 507 |
+
|
| 508 |
+
if (kPartitionsK > 1) {
|
| 509 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 510 |
+
|
| 511 |
+
CUTLASS_PRAGMA_UNROLL
|
| 512 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 513 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 514 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 515 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
//
|
| 522 |
+
// Compute the output result
|
| 523 |
+
//
|
| 524 |
+
|
| 525 |
+
typename OutputTileIterator::Fragment output_fragment;
|
| 526 |
+
source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0]);
|
| 527 |
+
|
| 528 |
+
//
|
| 529 |
+
// Store the final result
|
| 530 |
+
//
|
| 531 |
+
|
| 532 |
+
destination_iterator.store(output_fragment);
|
| 533 |
+
++destination_iterator;
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
#ifdef __clang__
|
| 537 |
+
#pragma clang diagnostic pop
|
| 538 |
+
#endif
|
| 539 |
+
}
|
| 540 |
+
};
|
| 541 |
+
|
| 542 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 543 |
+
|
| 544 |
+
} // namespace threadblock
|
| 545 |
+
} // namespace epilogue
|
| 546 |
+
} // namespace cutlass
|
| 547 |
+
|
| 548 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#if !defined(__CUDACC_RTC__)
|
| 42 |
+
#include <type_traits>
|
| 43 |
+
#include <utility>
|
| 44 |
+
#endif
|
| 45 |
+
#include CUDA_STD_HEADER(cassert)
|
| 46 |
+
|
| 47 |
+
#include "cutlass/matrix_shape.h"
|
| 48 |
+
#include "cutlass/numeric_types.h"
|
| 49 |
+
#include "cutlass/array.h"
|
| 50 |
+
#include "cutlass/layout/vector.h"
|
| 51 |
+
#include "cutlass/layout/tensor.h"
|
| 52 |
+
#include "cutlass/tensor_coord.h"
|
| 53 |
+
#include "cutlass/aligned_buffer.h"
|
| 54 |
+
|
| 55 |
+
#include "cutlass/gemm/gemm.h"
|
| 56 |
+
|
| 57 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 58 |
+
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
namespace cutlass {
|
| 62 |
+
namespace epilogue {
|
| 63 |
+
namespace threadblock {
|
| 64 |
+
|
| 65 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 66 |
+
|
| 67 |
+
//
|
| 68 |
+
// This is used for metaprogramming epilogue functors. If they define
|
| 69 |
+
// `static bool const kIsHeavy = true;`, then the epilogue functor itself is
|
| 70 |
+
// not inlined. This results in smaller code and is advantageous if the epilogue
|
| 71 |
+
// functor consists of many instructions.
|
| 72 |
+
//
|
| 73 |
+
// If the epilogue functor does not define `kIsHeavy` or if it is `false`, then
|
| 74 |
+
// the behavior from CUTLASS 2.5 and before is retained. The epilogue is fully
|
| 75 |
+
// unrolled and inlined.
|
| 76 |
+
//
|
| 77 |
+
|
| 78 |
+
template<class>
|
| 79 |
+
struct TypeSink { typedef void type; };
|
| 80 |
+
|
| 81 |
+
template<class T> using TypeSinkT = typename TypeSink<T>::type;
|
| 82 |
+
|
| 83 |
+
template<class T, class=void> struct IsEpilogueFunctorHeavy {
|
| 84 |
+
static bool const value = false;
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
template<class T> struct IsEpilogueFunctorHeavy<T, TypeSinkT< decltype( T::kIsHeavy ) > > {
|
| 88 |
+
static bool const value = T::kIsHeavy;
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 92 |
+
|
| 93 |
+
/// Base class for epilogues defining warp-level
|
| 94 |
+
template <
|
| 95 |
+
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 96 |
+
typename WarpShape_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 97 |
+
int PartitionsK, ///< Number of partitions of the K dimension
|
| 98 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 99 |
+
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 100 |
+
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
| 101 |
+
int FragmentsPerIteration = 1
|
| 102 |
+
>
|
| 103 |
+
class EpilogueBase {
|
| 104 |
+
public:
|
| 105 |
+
|
| 106 |
+
using Shape = Shape_;
|
| 107 |
+
using WarpShape = WarpShape_;
|
| 108 |
+
static int const kPartitionsK = PartitionsK;
|
| 109 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 110 |
+
using WarpTileIterator = WarpTileIterator_;
|
| 111 |
+
using Padding = Padding_;
|
| 112 |
+
|
| 113 |
+
/// Output layout is always row-major
|
| 114 |
+
using Layout = layout::RowMajor;
|
| 115 |
+
|
| 116 |
+
/// The complete warp-level accumulator tile
|
| 117 |
+
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
| 118 |
+
|
| 119 |
+
/// Accumulator element
|
| 120 |
+
using ElementAccumulator = typename AccumulatorTile::Element;
|
| 121 |
+
|
| 122 |
+
/// Number of warps
|
| 123 |
+
using WarpCount = gemm::GemmShape<
|
| 124 |
+
Shape::kM / WarpShape::kM,
|
| 125 |
+
Shape::kN / WarpShape::kN,
|
| 126 |
+
kPartitionsK
|
| 127 |
+
>;
|
| 128 |
+
|
| 129 |
+
/// Use this to control the granularity of one epilogue 'iteration'
|
| 130 |
+
static int const kFragmentsPerIteration = FragmentsPerIteration;
|
| 131 |
+
|
| 132 |
+
public:
|
| 133 |
+
|
| 134 |
+
/// Shared storage allocation needed by the epilogue
|
| 135 |
+
struct SharedStorage {
|
| 136 |
+
|
| 137 |
+
//
|
| 138 |
+
// Type definitions
|
| 139 |
+
//
|
| 140 |
+
|
| 141 |
+
/// Element type of shared memory
|
| 142 |
+
using Element = typename WarpTileIterator::Element;
|
| 143 |
+
|
| 144 |
+
/// Tensor reference to shared memory allocation
|
| 145 |
+
using TensorRef = typename WarpTileIterator::TensorRef;
|
| 146 |
+
|
| 147 |
+
/// Layout of shared memory allocation
|
| 148 |
+
using Layout = typename WarpTileIterator::Layout;
|
| 149 |
+
|
| 150 |
+
/// Logical shape of the shared memory tile written to by all warps.
|
| 151 |
+
using Shape = MatrixShape<
|
| 152 |
+
WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,
|
| 153 |
+
WarpCount::kN * WarpTileIterator::Shape::kColumn
|
| 154 |
+
>;
|
| 155 |
+
|
| 156 |
+
/// Shape of the shared memory allocation for the epilogue
|
| 157 |
+
using StorageShape = MatrixShape<
|
| 158 |
+
(Shape::kRow + Padding::kRow) * kFragmentsPerIteration,
|
| 159 |
+
Shape::kColumn + Padding::kColumn
|
| 160 |
+
>;
|
| 161 |
+
|
| 162 |
+
//
|
| 163 |
+
// Data members
|
| 164 |
+
//
|
| 165 |
+
|
| 166 |
+
AlignedBuffer<Element, StorageShape::kCount> storage;
|
| 167 |
+
|
| 168 |
+
//
|
| 169 |
+
// Methods
|
| 170 |
+
//
|
| 171 |
+
|
| 172 |
+
/// Returns a pointer to the shared memory buffer
|
| 173 |
+
CUTLASS_DEVICE
|
| 174 |
+
Element *data() {
|
| 175 |
+
return storage.data();
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
/// Returns a tensor reference to the shared memory buffer
|
| 179 |
+
CUTLASS_DEVICE
|
| 180 |
+
TensorRef reference() {
|
| 181 |
+
return TensorRef(
|
| 182 |
+
storage.data(),
|
| 183 |
+
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
|
| 184 |
+
}
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
protected:
|
| 188 |
+
|
| 189 |
+
//
|
| 190 |
+
// Data members
|
| 191 |
+
//
|
| 192 |
+
|
| 193 |
+
SharedStorage &shared_storage_;
|
| 194 |
+
|
| 195 |
+
/// Stores a warp's fragment of accumulators to SMEM
|
| 196 |
+
WarpTileIterator warp_tile_iterator_;
|
| 197 |
+
|
| 198 |
+
public:
|
| 199 |
+
|
| 200 |
+
/// Constructor
|
| 201 |
+
CUTLASS_DEVICE
|
| 202 |
+
EpilogueBase(
|
| 203 |
+
SharedStorage &shared_storage, ///< Shared storage object
|
| 204 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 205 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 206 |
+
int lane_idx ///< Id of thread within warp
|
| 207 |
+
):
|
| 208 |
+
shared_storage_(shared_storage),
|
| 209 |
+
warp_tile_iterator_(shared_storage.reference(), lane_idx) {
|
| 210 |
+
|
| 211 |
+
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
|
| 212 |
+
//
|
| 213 |
+
// _m: the warp's position within the threadblock along the M dimension
|
| 214 |
+
// _n: the warp's position within the threadblock along the N dimension
|
| 215 |
+
// _k: the warp's position within the threadblock along the K dimension
|
| 216 |
+
|
| 217 |
+
int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
|
| 218 |
+
int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
|
| 219 |
+
int warp_m = warp_mn % WarpCount::kM;
|
| 220 |
+
int warp_n = warp_mn / WarpCount::kM;
|
| 221 |
+
|
| 222 |
+
MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
|
| 223 |
+
|
| 224 |
+
warp_tile_iterator_.add_tile_offset(warp_offset);
|
| 225 |
+
}
|
| 226 |
+
};
|
| 227 |
+
|
| 228 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 229 |
+
|
| 230 |
+
} // namespace threadblock
|
| 231 |
+
} // namespace epilogue
|
| 232 |
+
} // namespace cutlass
|
| 233 |
+
|
| 234 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Basic subset of epilogue functionality for supporting StreamK decompositions
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/functional.h"
|
| 40 |
+
#include "cutlass/block_striped.h"
|
| 41 |
+
|
| 42 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
namespace epilogue {
|
| 46 |
+
namespace threadblock {
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
/// StreamK epilogue functionality for cross-block accumulator fragment reduction
|
| 52 |
+
template <
|
| 53 |
+
typename Shape, ///< Shape of threadblock tile (concept: GemmShape)
|
| 54 |
+
int PartitionsK,
|
| 55 |
+
typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 56 |
+
typename AccumulatorFragmentIterator> ///< Iterator for enumerating fragments within the per-thread tile of raw accumulators
|
| 57 |
+
class EpilogueBaseStreamK
|
| 58 |
+
{
|
| 59 |
+
|
| 60 |
+
protected:
|
| 61 |
+
|
| 62 |
+
/// The per-thread tile of raw accumulators
|
| 63 |
+
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
| 64 |
+
|
| 65 |
+
/// Number of warps
|
| 66 |
+
using WarpCount = gemm::GemmShape<
|
| 67 |
+
Shape::kM / WarpMmaOperator::Shape::kM,
|
| 68 |
+
Shape::kN / WarpMmaOperator::Shape::kN,
|
| 69 |
+
PartitionsK>;
|
| 70 |
+
|
| 71 |
+
/// Number of threads per block
|
| 72 |
+
static int const kBlockThreads = 32 * WarpCount::kCount;
|
| 73 |
+
|
| 74 |
+
/// Numerical accumulation element type
|
| 75 |
+
using ElementAccumulator = typename WarpMmaOperator::ElementC;
|
| 76 |
+
|
| 77 |
+
/// Fragment type used by the accumulator tile's fragment iterator
|
| 78 |
+
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
|
| 79 |
+
|
| 80 |
+
public:
|
| 81 |
+
|
| 82 |
+
/// Number of AccumulatorTile fragments per thread
|
| 83 |
+
static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations;
|
| 84 |
+
|
| 85 |
+
protected:
|
| 86 |
+
|
| 87 |
+
/// Number of AccumulatorTile fragments per block output tile
|
| 88 |
+
static int const kOutputTileFragments = kBlockThreads * kAccumulatorFragments;
|
| 89 |
+
|
| 90 |
+
/// Block-striped transfer utility for sharing AccumulatorFragment
|
| 91 |
+
using BlockStripedT = BlockStriped<kBlockThreads, AccumulatorFragment>;
|
| 92 |
+
|
| 93 |
+
/// AccumulatorFragment stride in the shared workspace between different peer blocks (each thread block can share accumulators for up to two block output tiles)
|
| 94 |
+
static const int kPeerFragmentStride = kOutputTileFragments * 2;
|
| 95 |
+
|
| 96 |
+
public:
|
| 97 |
+
|
| 98 |
+
/// Workspace bytes per thread block
|
| 99 |
+
static size_t const kWorkspaceBytesPerBlock =sizeof(AccumulatorFragment) * kPeerFragmentStride;
|
| 100 |
+
|
| 101 |
+
public:
|
| 102 |
+
|
| 103 |
+
/// Thread index in the threadblock
|
| 104 |
+
int thread_idx;
|
| 105 |
+
|
| 106 |
+
public:
|
| 107 |
+
|
| 108 |
+
/// Constructor
|
| 109 |
+
CUTLASS_DEVICE
|
| 110 |
+
EpilogueBaseStreamK(
|
| 111 |
+
int thread_idx) ///< ID of a thread within the threadblock
|
| 112 |
+
:
|
| 113 |
+
thread_idx(thread_idx)
|
| 114 |
+
{}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
/// Aggregates the accumulator sets shared by peer blocks in the global workspace
|
| 118 |
+
CUTLASS_DEVICE
|
| 119 |
+
void reduce(
|
| 120 |
+
AccumulatorFragment &accum_fragment, ///< [out] sum of all shared accumulator fragments for these peer partials
|
| 121 |
+
int peer_idx_begin,
|
| 122 |
+
int peer_idx_end,
|
| 123 |
+
int reduce_fragment_idx,
|
| 124 |
+
void *workspace_ptr)
|
| 125 |
+
{
|
| 126 |
+
plus<AccumulatorFragment> add_fragments;
|
| 127 |
+
|
| 128 |
+
AccumulatorFragment *fragment_workspace = reinterpret_cast<AccumulatorFragment *>(workspace_ptr);
|
| 129 |
+
|
| 130 |
+
int fragment_offset = (peer_idx_begin * kPeerFragmentStride) + (reduce_fragment_idx * kBlockThreads);
|
| 131 |
+
|
| 132 |
+
// Load first peer fragment
|
| 133 |
+
BlockStripedT::load(accum_fragment, fragment_workspace + fragment_offset, this->thread_idx);
|
| 134 |
+
|
| 135 |
+
fragment_offset += kPeerFragmentStride; // Move to next peer
|
| 136 |
+
fragment_offset += kOutputTileFragments; // Move to the set of fragments for this peer's "non-started" output tile
|
| 137 |
+
|
| 138 |
+
// Reduce fragments from additional peers
|
| 139 |
+
#pragma unroll 2
|
| 140 |
+
for (; fragment_offset < peer_idx_end * kPeerFragmentStride; fragment_offset += kPeerFragmentStride)
|
| 141 |
+
{
|
| 142 |
+
// Load peer fragment
|
| 143 |
+
AccumulatorFragment addend_fragment;
|
| 144 |
+
BlockStripedT::load(addend_fragment, fragment_workspace + fragment_offset, this->thread_idx);
|
| 145 |
+
|
| 146 |
+
// Add peer fragment
|
| 147 |
+
accum_fragment = add_fragments(accum_fragment, addend_fragment);
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
/// Shares the accumulator set with peers in the global workspace
|
| 153 |
+
CUTLASS_DEVICE
|
| 154 |
+
void share(
|
| 155 |
+
int peer_idx,
|
| 156 |
+
void *workspace_ptr,
|
| 157 |
+
AccumulatorTile const &accumulators,
|
| 158 |
+
bool started_tile) ///< Whether this thread block computed the first work volume for the current output tile
|
| 159 |
+
{
|
| 160 |
+
AccumulatorFragment *fragment_workspace = reinterpret_cast<AccumulatorFragment *>(workspace_ptr);
|
| 161 |
+
|
| 162 |
+
int fragment_offset = peer_idx * kPeerFragmentStride;
|
| 163 |
+
|
| 164 |
+
if (!started_tile) {
|
| 165 |
+
// Move to the set of fragments for the "non-started" output tile
|
| 166 |
+
fragment_offset += kOutputTileFragments;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 170 |
+
|
| 171 |
+
// Convert raw accumulator tile to fragments and store
|
| 172 |
+
CUTLASS_PRAGMA_UNROLL
|
| 173 |
+
for (int iter = 0; iter < kAccumulatorFragments; ++iter)
|
| 174 |
+
{
|
| 175 |
+
// Acquire reordered accumulator fragment
|
| 176 |
+
AccumulatorFragment accum_fragment;
|
| 177 |
+
accum_fragment_iterator.load(accum_fragment);
|
| 178 |
+
++accum_fragment_iterator;
|
| 179 |
+
|
| 180 |
+
// Store accumulator fragment
|
| 181 |
+
BlockStripedT::store(fragment_workspace + fragment_offset, accum_fragment, this->thread_idx);
|
| 182 |
+
|
| 183 |
+
fragment_offset += kBlockThreads;
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
};
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 192 |
+
|
| 193 |
+
} // namespace threadblock
|
| 194 |
+
} // namespace epilogue
|
| 195 |
+
} // namespace cutlass
|
| 196 |
+
|
| 197 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for Depthwise convoltuion
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/array.h"
|
| 42 |
+
#include "cutlass/cutlass.h"
|
| 43 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 44 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 45 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 46 |
+
#include "cutlass/gemm/gemm.h"
|
| 47 |
+
#include "cutlass/numeric_types.h"
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
namespace cutlass {
|
| 52 |
+
namespace epilogue {
|
| 53 |
+
namespace threadblock {
|
| 54 |
+
|
| 55 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
/// Epilogue operator
|
| 58 |
+
template <typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 59 |
+
typename ThreadOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
|
| 60 |
+
typename ThreadBlockOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
|
| 61 |
+
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
|
| 62 |
+
///< gemm::warp::MmaTensorOp)
|
| 63 |
+
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
| 64 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 65 |
+
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 66 |
+
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
| 67 |
+
typename OutputOp_, ///< Output operator
|
| 68 |
+
typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept:
|
| 69 |
+
///< MatrixShape)
|
| 70 |
+
>
|
| 71 |
+
class EpilogueDepthwise {
|
| 72 |
+
public:
|
| 73 |
+
using Shape = Shape_;
|
| 74 |
+
using WarpShape = typename WarpMmaOperator_::Shape;
|
| 75 |
+
using ThreadOutputShape = ThreadOutputShape_;
|
| 76 |
+
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
| 77 |
+
using WarpMmaOperator = WarpMmaOperator_;
|
| 78 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 79 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 80 |
+
using WarpTileIterator = WarpTileIterator_;
|
| 81 |
+
using SharedLoadIterator = SharedLoadIterator_;
|
| 82 |
+
using OutputOp = OutputOp_;
|
| 83 |
+
using Padding = Padding_;
|
| 84 |
+
|
| 85 |
+
using Layout = layout::RowMajor;
|
| 86 |
+
using LongIndex = typename Layout::LongIndex;
|
| 87 |
+
|
| 88 |
+
/// The complete warp-level accumulator tile
|
| 89 |
+
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
| 90 |
+
|
| 91 |
+
/// Accumulator element
|
| 92 |
+
using ElementAccumulator = typename WarpTileIterator::Element;
|
| 93 |
+
|
| 94 |
+
/// Output element
|
| 95 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 96 |
+
|
| 97 |
+
/// Output access size
|
| 98 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 99 |
+
|
| 100 |
+
/// Tensor reference to destination tensor
|
| 101 |
+
using TensorRef = typename OutputTileIterator::TensorRef;
|
| 102 |
+
|
| 103 |
+
/// Tensor reference to sync tensor
|
| 104 |
+
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
| 105 |
+
|
| 106 |
+
/// Const tensor reference to source tensor
|
| 107 |
+
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
| 108 |
+
|
| 109 |
+
/// Array type used to output
|
| 110 |
+
using OutputAccessType =
|
| 111 |
+
Array<typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 112 |
+
|
| 113 |
+
/// Array type used by output functor
|
| 114 |
+
using AccumulatorAccessType =
|
| 115 |
+
Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 116 |
+
|
| 117 |
+
/// Number of warps
|
| 118 |
+
using WarpCount =
|
| 119 |
+
gemm::GemmShape<Shape::kM / WarpShape::kM, Shape::kN / WarpShape::kN>;
|
| 120 |
+
|
| 121 |
+
public:
|
| 122 |
+
static_assert(SharedLoadIterator::Fragment::kElements ==
|
| 123 |
+
OutputTileIterator::Fragment::kElements,
|
| 124 |
+
"Mismatch between shared load iterator and output tile iterator.");
|
| 125 |
+
|
| 126 |
+
static_assert(OutputTileIterator::kElementsPerAccess,
|
| 127 |
+
"OutputTileIterator::kElementsPerAccess must not be zero.");
|
| 128 |
+
|
| 129 |
+
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
| 130 |
+
"Divisibility");
|
| 131 |
+
|
| 132 |
+
/// Shared storage allocation needed by the epilogue
|
| 133 |
+
struct SharedStorage {
|
| 134 |
+
//
|
| 135 |
+
// Type definitions
|
| 136 |
+
//
|
| 137 |
+
|
| 138 |
+
/// Element type of shared memory
|
| 139 |
+
using Element = typename WarpTileIterator::Element;
|
| 140 |
+
|
| 141 |
+
/// Tensor reference to shared memory allocation
|
| 142 |
+
using TensorRef = typename WarpTileIterator::TensorRef;
|
| 143 |
+
|
| 144 |
+
/// Layout of shared memory allocation
|
| 145 |
+
using Layout = typename WarpTileIterator::Layout;
|
| 146 |
+
|
| 147 |
+
/// Logical shape of the shared memory tile written to by all warps.
|
| 148 |
+
using Shape = MatrixShape<ThreadBlockOutputShape::kNHW, ThreadBlockOutputShape::kC>;
|
| 149 |
+
|
| 150 |
+
/// Shape of the shared memory allocation for the epilogue
|
| 151 |
+
using StorageShape = MatrixShape<Shape::kRow, Shape::kColumn>;
|
| 152 |
+
|
| 153 |
+
//
|
| 154 |
+
// Data members
|
| 155 |
+
//
|
| 156 |
+
|
| 157 |
+
AlignedBuffer<Element, StorageShape::kCount> storage;
|
| 158 |
+
|
| 159 |
+
//
|
| 160 |
+
// Methods
|
| 161 |
+
//
|
| 162 |
+
|
| 163 |
+
/// Returns a pointer to the shared memory buffer
|
| 164 |
+
CUTLASS_DEVICE
|
| 165 |
+
Element *data() { return storage.data(); }
|
| 166 |
+
|
| 167 |
+
/// Returns a tensor reference to the shared memory buffer
|
| 168 |
+
CUTLASS_DEVICE
|
| 169 |
+
TensorRef reference() {
|
| 170 |
+
return TensorRef(storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
|
| 171 |
+
}
|
| 172 |
+
};
|
| 173 |
+
|
| 174 |
+
private:
|
| 175 |
+
/// Loads fragment from shared memory aligned with output tensor
|
| 176 |
+
SharedLoadIterator shared_load_iterator_;
|
| 177 |
+
|
| 178 |
+
/// Stores a warp's fragment of accumulators to SMEM
|
| 179 |
+
WarpTileIterator warp_tile_iterator_;
|
| 180 |
+
|
| 181 |
+
LongIndex warp_offset;
|
| 182 |
+
int thread_idx;
|
| 183 |
+
int warp_idx;
|
| 184 |
+
int lane_idx;
|
| 185 |
+
int warp_m, warp_n; // warp coordinates within a cta
|
| 186 |
+
int tid_m, tid_n; // thread coordinates within a warp
|
| 187 |
+
|
| 188 |
+
public:
|
| 189 |
+
/// Constructor
|
| 190 |
+
CUTLASS_DEVICE
|
| 191 |
+
EpilogueDepthwise(SharedStorage &shared_storage, ///< Shared storage object
|
| 192 |
+
int thread_idx_, ///< ID of a thread within the threadblock
|
| 193 |
+
int warp_idx_, ///< ID of warp within threadblock
|
| 194 |
+
int lane_idx_ ///< Id of thread within warp
|
| 195 |
+
)
|
| 196 |
+
: thread_idx(thread_idx_),
|
| 197 |
+
warp_idx(warp_idx_),
|
| 198 |
+
lane_idx(lane_idx_),
|
| 199 |
+
shared_load_iterator_(shared_storage.reference(), thread_idx_),
|
| 200 |
+
warp_tile_iterator_(shared_storage.reference(), thread_idx_, lane_idx_) {}
|
| 201 |
+
|
| 202 |
+
/// Streams the result to global memory
|
| 203 |
+
CUTLASS_DEVICE
|
| 204 |
+
void operator()(OutputOp const &output_op, ///< Output operator
|
| 205 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 206 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 207 |
+
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in
|
| 208 |
+
///< units of threadblock tiles)
|
| 209 |
+
const int smem_base_offset) { ///< SMEM base offset for epilogue operation
|
| 210 |
+
// initiate the smem base offset for different output tile.
|
| 211 |
+
warp_tile_iterator_.set_smem_base_address(smem_base_offset);
|
| 212 |
+
|
| 213 |
+
shared_load_iterator_.set_smem_base_address(smem_base_offset);
|
| 214 |
+
|
| 215 |
+
if (!output_op.is_source_needed()) {
|
| 216 |
+
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
| 217 |
+
} else {
|
| 218 |
+
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
private:
|
| 223 |
+
/// Streams the result to global memory
|
| 224 |
+
CUTLASS_DEVICE
|
| 225 |
+
void compute_source_needed_(
|
| 226 |
+
OutputOp const &output_op, ///< Output operator
|
| 227 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 228 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 229 |
+
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 230 |
+
|
| 231 |
+
typename OutputTileIterator::Fragment source_fragment;
|
| 232 |
+
|
| 233 |
+
source_fragment.clear();
|
| 234 |
+
|
| 235 |
+
source_iterator.load(source_fragment);
|
| 236 |
+
|
| 237 |
+
// store to smem
|
| 238 |
+
warp_tile_iterator_.store(accumulators);
|
| 239 |
+
|
| 240 |
+
__syncthreads();
|
| 241 |
+
|
| 242 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment;
|
| 243 |
+
|
| 244 |
+
// load from smem
|
| 245 |
+
shared_load_iterator_.load(aligned_accum_fragment);
|
| 246 |
+
|
| 247 |
+
typename OutputTileIterator::Fragment output_fragment;
|
| 248 |
+
|
| 249 |
+
apply_output_operator_(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
| 250 |
+
|
| 251 |
+
// Store to GMEM
|
| 252 |
+
destination_iterator.store(output_fragment);
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
/// Streams the result to global memory
|
| 256 |
+
CUTLASS_DEVICE
|
| 257 |
+
void compute_source_not_needed_(
|
| 258 |
+
OutputOp const &output_op, ///< Output operator
|
| 259 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 260 |
+
AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 261 |
+
|
| 262 |
+
// store to smem
|
| 263 |
+
warp_tile_iterator_.store(accumulators);
|
| 264 |
+
|
| 265 |
+
__syncthreads();
|
| 266 |
+
|
| 267 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment;
|
| 268 |
+
|
| 269 |
+
// load from smem
|
| 270 |
+
shared_load_iterator_.load(aligned_accum_fragment);
|
| 271 |
+
|
| 272 |
+
typename OutputTileIterator::Fragment output_fragment;
|
| 273 |
+
|
| 274 |
+
apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment);
|
| 275 |
+
|
| 276 |
+
// Store to GMEM
|
| 277 |
+
destination_iterator.store(output_fragment);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
/// Helper to invoke the output functor over each vector of output
|
| 281 |
+
CUTLASS_DEVICE
|
| 282 |
+
void apply_output_operator_(
|
| 283 |
+
typename OutputTileIterator::Fragment &output_fragment,
|
| 284 |
+
OutputOp const &output_op, ///< Output operator
|
| 285 |
+
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
|
| 286 |
+
typename OutputTileIterator::Fragment const &source_fragment) {
|
| 287 |
+
|
| 288 |
+
OutputAccessType *output_frag_ptr =
|
| 289 |
+
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
| 290 |
+
|
| 291 |
+
AccumulatorAccessType const *compute_frag_ptr =
|
| 292 |
+
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
| 293 |
+
|
| 294 |
+
OutputAccessType const *source_frag_ptr =
|
| 295 |
+
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
| 296 |
+
|
| 297 |
+
int const kOutputOpIterations =
|
| 298 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 299 |
+
|
| 300 |
+
CUTLASS_PRAGMA_UNROLL
|
| 301 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 302 |
+
// Call the output operator
|
| 303 |
+
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
/// Helper to invoke the output functor over each vector of output
|
| 308 |
+
CUTLASS_DEVICE
|
| 309 |
+
void apply_output_operator_source_not_needed_(
|
| 310 |
+
typename OutputTileIterator::Fragment &output_fragment,
|
| 311 |
+
OutputOp const &output_op, ///< Output operator
|
| 312 |
+
typename SharedLoadIterator::Fragment const &aligned_accum_fragment) {
|
| 313 |
+
OutputAccessType *output_frag_ptr = reinterpret_cast<OutputAccessType *>(&output_fragment);
|
| 314 |
+
|
| 315 |
+
AccumulatorAccessType const *compute_frag_ptr =
|
| 316 |
+
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
| 317 |
+
|
| 318 |
+
int const kOutputOpIterations =
|
| 319 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 320 |
+
|
| 321 |
+
CUTLASS_PRAGMA_UNROLL
|
| 322 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 323 |
+
// Call the output operator
|
| 324 |
+
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
};
|
| 328 |
+
|
| 329 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 330 |
+
|
| 331 |
+
} // namespace threadblock
|
| 332 |
+
} // namespace epilogue
|
| 333 |
+
} // namespace cutlass
|
| 334 |
+
|
| 335 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs and convolution using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/gemm/gemm.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 48 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 49 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 50 |
+
|
| 51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
namespace cutlass {
|
| 54 |
+
namespace epilogue {
|
| 55 |
+
namespace threadblock {
|
| 56 |
+
|
| 57 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
|
| 59 |
+
/// Epilogue operator
|
| 60 |
+
template <
|
| 61 |
+
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 62 |
+
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 63 |
+
int PartitionsK, ///< Number of partitions of the K dimension
|
| 64 |
+
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
| 65 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 66 |
+
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 67 |
+
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
| 68 |
+
typename OutputOp_ ///< Output operator
|
| 69 |
+
>
|
| 70 |
+
class EpilogueDirectStore {
|
| 71 |
+
public:
|
| 72 |
+
|
| 73 |
+
using Shape = Shape_;
|
| 74 |
+
using WarpMmaOperator = WarpMmaOperator_;
|
| 75 |
+
using WarpShape = typename WarpMmaOperator_::Shape;
|
| 76 |
+
static int const kPartitionsK = PartitionsK;
|
| 77 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 78 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 79 |
+
using WarpTileIterator = WarpTileIterator_;
|
| 80 |
+
using OutputOp = OutputOp_;
|
| 81 |
+
using Padding = MatrixShape<0, 0>;
|
| 82 |
+
|
| 83 |
+
using Layout = layout::RowMajor;
|
| 84 |
+
using LongIndex = typename Layout::LongIndex;
|
| 85 |
+
|
| 86 |
+
/// The complete warp-level accumulator tile
|
| 87 |
+
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
| 88 |
+
|
| 89 |
+
/// Accumulator element
|
| 90 |
+
using ElementAccumulator = typename WarpTileIterator::Element;
|
| 91 |
+
|
| 92 |
+
/// Output element
|
| 93 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 94 |
+
|
| 95 |
+
/// Output access size
|
| 96 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 97 |
+
|
| 98 |
+
/// Tensor reference to destination tensor
|
| 99 |
+
using TensorRef = typename OutputTileIterator::TensorRef;
|
| 100 |
+
|
| 101 |
+
/// Tensor reference to sync tensor
|
| 102 |
+
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
| 103 |
+
|
| 104 |
+
/// Const tensor reference to source tensor
|
| 105 |
+
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
| 106 |
+
|
| 107 |
+
/// Array type used to output
|
| 108 |
+
using OutputAccessType = Array<
|
| 109 |
+
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 110 |
+
|
| 111 |
+
/// Array type used by output functor
|
| 112 |
+
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 113 |
+
|
| 114 |
+
/// Number of warps
|
| 115 |
+
using WarpCount = gemm::GemmShape<
|
| 116 |
+
Shape::kM / WarpShape::kM,
|
| 117 |
+
Shape::kN / WarpShape::kN,
|
| 118 |
+
kPartitionsK
|
| 119 |
+
>;
|
| 120 |
+
|
| 121 |
+
/// Use this to control the granularity of one epilogue 'iteration'
|
| 122 |
+
static int const kFragmentsPerIteration = 1;
|
| 123 |
+
|
| 124 |
+
static int constexpr kSmemTiles = 1;
|
| 125 |
+
static int constexpr kSmemPointerOffset = 0;
|
| 126 |
+
|
| 127 |
+
/// Shared storage allocation needed by the epilogue
|
| 128 |
+
struct SharedStorage { } ;
|
| 129 |
+
|
| 130 |
+
private:
|
| 131 |
+
|
| 132 |
+
// Assume accumulator tile is multipile interleaved 32x32 tile.
|
| 133 |
+
static int const kElementsPerPartial = 4;
|
| 134 |
+
using EleShapePerPatial = typename platform::conditional<
|
| 135 |
+
platform::is_same<ElementAccumulator, float>::value,
|
| 136 |
+
MatrixShape<2, 2>,
|
| 137 |
+
MatrixShape<1, 4> >::type;
|
| 138 |
+
static int const kElementsPerMma = 8;
|
| 139 |
+
static int const kAccumulatorPatials = 2;
|
| 140 |
+
using QuadShapePerPatialMma = MatrixShape<4, 4>;
|
| 141 |
+
|
| 142 |
+
static_assert(OutputOp::kCount >= 2,
|
| 143 |
+
"The direct store epilogue for Tensor Ops requires the output functor have kCount >= 2.");
|
| 144 |
+
|
| 145 |
+
private:
|
| 146 |
+
|
| 147 |
+
LongIndex warp_offset;
|
| 148 |
+
int thread_idx;
|
| 149 |
+
int warp_idx;
|
| 150 |
+
int lane_idx;
|
| 151 |
+
int warp_m, warp_n; // warp coordinates within a cta
|
| 152 |
+
int tid_m, tid_n; // thread coordinates within a warp
|
| 153 |
+
|
| 154 |
+
public:
|
| 155 |
+
|
| 156 |
+
/// Constructor
|
| 157 |
+
CUTLASS_DEVICE
|
| 158 |
+
EpilogueDirectStore(
|
| 159 |
+
SharedStorage &shared_storage, ///< Shared storage object
|
| 160 |
+
int thread_idx_, ///< ID of a thread within the threadblock
|
| 161 |
+
int warp_idx_, ///< ID of warp within threadblock
|
| 162 |
+
int lane_idx_ ///< Id of thread within warp
|
| 163 |
+
):
|
| 164 |
+
thread_idx(thread_idx_),
|
| 165 |
+
warp_idx(warp_idx_),
|
| 166 |
+
lane_idx(lane_idx_)
|
| 167 |
+
{
|
| 168 |
+
|
| 169 |
+
// warp offsetting calculations
|
| 170 |
+
warp_offset = warp_idx * WarpShape::kM * WarpShape::kN;
|
| 171 |
+
int warp_id_mn = warp_idx % (WarpCount::kM * WarpShape::kN);
|
| 172 |
+
warp_m = warp_id_mn % WarpCount::kM;
|
| 173 |
+
warp_n = warp_id_mn / WarpCount::kM;
|
| 174 |
+
MatrixCoord warp_offset_coord(warp_m*WarpShape::kM, warp_n*WarpShape::kN);
|
| 175 |
+
|
| 176 |
+
// thread offsetting calculations
|
| 177 |
+
int quad = (lane_idx >> 2);
|
| 178 |
+
int lane_in_quad = (lane_idx & 3);
|
| 179 |
+
|
| 180 |
+
// this seems to be te correct layout
|
| 181 |
+
tid_m = quad;
|
| 182 |
+
tid_n = 2 * lane_in_quad;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
/// Streams the result to global memory
|
| 186 |
+
CUTLASS_DEVICE
|
| 187 |
+
void operator()(
|
| 188 |
+
OutputOp const &output_op, ///< Output operator
|
| 189 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 190 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 191 |
+
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 192 |
+
|
| 193 |
+
if (!output_op.is_source_needed()) {
|
| 194 |
+
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
| 195 |
+
}
|
| 196 |
+
else {
|
| 197 |
+
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
private:
|
| 202 |
+
|
| 203 |
+
/// Streams the result to global memory
|
| 204 |
+
CUTLASS_DEVICE
|
| 205 |
+
void compute_source_needed_(
|
| 206 |
+
OutputOp const &output_op, ///< Output operator
|
| 207 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 208 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 209 |
+
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 210 |
+
|
| 211 |
+
const int kAccumBlockN = 2;
|
| 212 |
+
const int kThreadsM = 8;
|
| 213 |
+
const int kThreadsN = 4;
|
| 214 |
+
const int kBlockM = WarpShape::kM / kThreadsM;
|
| 215 |
+
|
| 216 |
+
/// Array type used to output
|
| 217 |
+
using OutputAccessType = AlignedArray<ElementOutput, kAccumBlockN>;
|
| 218 |
+
|
| 219 |
+
/// Array type passed to the output operator - unused elements are optimized away
|
| 220 |
+
using OutputFragmentType = Array<ElementOutput, OutputOp::kCount>;
|
| 221 |
+
|
| 222 |
+
/// Array type used by output functor
|
| 223 |
+
using AccumulatorAccessType = Array<ElementAccumulator, kAccumBlockN>;
|
| 224 |
+
|
| 225 |
+
/// Array type used by output functor
|
| 226 |
+
using AccumulatorFragmentType = Array<ElementAccumulator, OutputOp::kCount>;
|
| 227 |
+
|
| 228 |
+
AccumulatorAccessType const *accumulator_pair = reinterpret_cast<AccumulatorAccessType const *>(&accumulators);
|
| 229 |
+
|
| 230 |
+
CUTLASS_PRAGMA_UNROLL
|
| 231 |
+
for (int accum_m_idx = 0; accum_m_idx < WarpShape::kM / kThreadsM; accum_m_idx++) {
|
| 232 |
+
|
| 233 |
+
int accum_m = kThreadsM * accum_m_idx;
|
| 234 |
+
int mL = destination_iterator.threadblock_offset.row() + WarpShape::kM * warp_m + tid_m + accum_m;
|
| 235 |
+
int nL_base = destination_iterator.threadblock_offset.column() + WarpShape::kN * warp_n + tid_n;
|
| 236 |
+
|
| 237 |
+
ElementOutput *output_ptr = destination_iterator.pointer + mL * destination_iterator.stride;
|
| 238 |
+
ElementOutput *source_ptr = source_iterator.pointer + mL * source_iterator.stride;
|
| 239 |
+
|
| 240 |
+
int const kIterationsN = WarpShape::kN / kThreadsN / kAccumBlockN;
|
| 241 |
+
|
| 242 |
+
CUTLASS_PRAGMA_UNROLL
|
| 243 |
+
for (int accum_n_idx = 0; accum_n_idx < kIterationsN; accum_n_idx++) {
|
| 244 |
+
|
| 245 |
+
int accum_idx = accum_m_idx + kBlockM * accum_n_idx;
|
| 246 |
+
int accum_n = kThreadsM * accum_n_idx;
|
| 247 |
+
|
| 248 |
+
// mL and nL are logical coordinate in 2D mapping of epilogue's 4D output
|
| 249 |
+
int nL = nL_base + accum_n;
|
| 250 |
+
|
| 251 |
+
bool guard = (mL < destination_iterator.extent.row()) && (nL < destination_iterator.extent.column());
|
| 252 |
+
|
| 253 |
+
AccumulatorFragmentType accum_fragment;
|
| 254 |
+
reinterpret_cast<AccumulatorAccessType &>(accum_fragment) = accumulator_pair[accum_idx];
|
| 255 |
+
|
| 256 |
+
OutputFragmentType output_fragment;
|
| 257 |
+
|
| 258 |
+
if(guard) {
|
| 259 |
+
reinterpret_cast<OutputAccessType &>(output_fragment) =
|
| 260 |
+
*reinterpret_cast<OutputAccessType const *>(source_ptr + nL);
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
// Perform output operator
|
| 264 |
+
output_fragment = output_op(accum_fragment, output_fragment);
|
| 265 |
+
|
| 266 |
+
if(guard) {
|
| 267 |
+
// Store
|
| 268 |
+
*reinterpret_cast<OutputAccessType *>(output_ptr + nL) = reinterpret_cast<OutputAccessType const &>(output_fragment);
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
/// Streams the result to global memory
|
| 275 |
+
CUTLASS_DEVICE
|
| 276 |
+
void compute_source_not_needed_(
|
| 277 |
+
OutputOp const &output_op, ///< Output operator
|
| 278 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 279 |
+
AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 280 |
+
|
| 281 |
+
const int kAccumBlockN = 2;
|
| 282 |
+
const int kThreadsM = 8;
|
| 283 |
+
const int kThreadsN = 4;
|
| 284 |
+
const int kBlockM = WarpShape::kM / kThreadsM;
|
| 285 |
+
|
| 286 |
+
/// Array type used to output
|
| 287 |
+
using OutputAccessType = AlignedArray<ElementOutput, kAccumBlockN>;
|
| 288 |
+
|
| 289 |
+
/// Array type passed to the output operator - unused elements are optimized away
|
| 290 |
+
using OutputFragmentType = Array<ElementOutput, OutputOp::kCount>;
|
| 291 |
+
|
| 292 |
+
/// Array type used by output functor
|
| 293 |
+
using AccumulatorAccessType = Array<ElementAccumulator, kAccumBlockN>;
|
| 294 |
+
|
| 295 |
+
/// Array type used by output functor
|
| 296 |
+
using AccumulatorFragmentType = Array<ElementAccumulator, OutputOp::kCount>;
|
| 297 |
+
|
| 298 |
+
AccumulatorAccessType const *accumulator_pair = reinterpret_cast<AccumulatorAccessType const *>(&accumulators);
|
| 299 |
+
|
| 300 |
+
CUTLASS_PRAGMA_UNROLL
|
| 301 |
+
for (int accum_m_idx = 0; accum_m_idx < WarpShape::kM / kThreadsM; accum_m_idx++) {
|
| 302 |
+
|
| 303 |
+
int accum_m = kThreadsM * accum_m_idx;
|
| 304 |
+
int mL = destination_iterator.threadblock_offset.row() + WarpShape::kM * warp_m + tid_m + accum_m;
|
| 305 |
+
int nL_base = destination_iterator.threadblock_offset.column() + WarpShape::kN * warp_n + tid_n;
|
| 306 |
+
|
| 307 |
+
ElementOutput *output_ptr = destination_iterator.pointer + mL * destination_iterator.stride;
|
| 308 |
+
|
| 309 |
+
int const kIterationsN = WarpShape::kN / kThreadsN / kAccumBlockN;
|
| 310 |
+
|
| 311 |
+
CUTLASS_PRAGMA_UNROLL
|
| 312 |
+
for (int accum_n_idx = 0; accum_n_idx < kIterationsN; accum_n_idx++) {
|
| 313 |
+
|
| 314 |
+
int accum_idx = accum_m_idx + kBlockM * accum_n_idx;
|
| 315 |
+
int accum_n = kThreadsM * accum_n_idx;
|
| 316 |
+
|
| 317 |
+
// mL and nL are logical coordinate in 2D mapping of epilogue's 4D output
|
| 318 |
+
int nL = nL_base + accum_n;
|
| 319 |
+
|
| 320 |
+
bool guard = (mL < destination_iterator.extent.row()) && (nL < destination_iterator.extent.column());
|
| 321 |
+
|
| 322 |
+
AccumulatorFragmentType accum_fragment;
|
| 323 |
+
reinterpret_cast<AccumulatorAccessType &>(accum_fragment) = accumulator_pair[accum_idx];
|
| 324 |
+
|
| 325 |
+
OutputFragmentType output_fragment;
|
| 326 |
+
|
| 327 |
+
// Perform output operator
|
| 328 |
+
output_fragment = output_op(accum_fragment);
|
| 329 |
+
|
| 330 |
+
if(guard) {
|
| 331 |
+
|
| 332 |
+
// Store
|
| 333 |
+
*reinterpret_cast<OutputAccessType *>(output_ptr + nL) =
|
| 334 |
+
reinterpret_cast<OutputAccessType const &>(output_fragment);
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
};
|
| 340 |
+
|
| 341 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 342 |
+
|
| 343 |
+
} // namespace threadblock
|
| 344 |
+
} // namespace epilogue
|
| 345 |
+
} // namespace cutlass
|
| 346 |
+
|
| 347 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include CUDA_STD_HEADER(cassert)
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
#include "cutlass/layout/vector.h"
|
| 45 |
+
#include "cutlass/layout/tensor.h"
|
| 46 |
+
#include "cutlass/tensor_coord.h"
|
| 47 |
+
#include "cutlass/aligned_buffer.h"
|
| 48 |
+
#include "cutlass/functional.h"
|
| 49 |
+
|
| 50 |
+
#include "cutlass/gemm/gemm.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 53 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 54 |
+
|
| 55 |
+
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
| 56 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 57 |
+
#include "cutlass/numeric_types.h"
|
| 58 |
+
|
| 59 |
+
namespace cutlass {
|
| 60 |
+
namespace epilogue {
|
| 61 |
+
namespace threadblock {
|
| 62 |
+
|
| 63 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
/// Epilogue operator
|
| 66 |
+
template <
|
| 67 |
+
typename ElementAccumulator_,
|
| 68 |
+
typename ElementOutput_,
|
| 69 |
+
typename ThreadBlockShape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 70 |
+
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 71 |
+
bool ReduceKForA_
|
| 72 |
+
>
|
| 73 |
+
class EpilogueGemmKReduction {
|
| 74 |
+
|
| 75 |
+
public:
|
| 76 |
+
|
| 77 |
+
using ThreadBlockShape = ThreadBlockShape_;
|
| 78 |
+
using WarpMmaOperator = WarpMmaOperator_;
|
| 79 |
+
using WarpShape = typename WarpMmaOperator::Shape;
|
| 80 |
+
using Layout = layout::RowMajor;
|
| 81 |
+
using LongIndex = typename Layout::LongIndex;
|
| 82 |
+
|
| 83 |
+
/// Accumulator element
|
| 84 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 85 |
+
|
| 86 |
+
/// Output element
|
| 87 |
+
using ElementOutput = ElementOutput_;
|
| 88 |
+
|
| 89 |
+
/// Output access size
|
| 90 |
+
static int const kElementsPerAccess = 1;
|
| 91 |
+
|
| 92 |
+
static bool const kReduceKForA = ReduceKForA_;
|
| 93 |
+
|
| 94 |
+
static int const kThreadBlockSize = kReduceKForA ? ThreadBlockShape::kM : ThreadBlockShape::kN;
|
| 95 |
+
|
| 96 |
+
static int const kWarpSize = kReduceKForA ? WarpShape::kM : WarpShape::kN;
|
| 97 |
+
|
| 98 |
+
static int const kIterations = kWarpSize / 8;
|
| 99 |
+
|
| 100 |
+
using FragmentAccumulator = Array<ElementAccumulator, kIterations>;
|
| 101 |
+
|
| 102 |
+
private:
|
| 103 |
+
|
| 104 |
+
int thread_offset_;
|
| 105 |
+
ElementOutput* pointer_;
|
| 106 |
+
int col_;
|
| 107 |
+
public:
|
| 108 |
+
|
| 109 |
+
/// Constructor
|
| 110 |
+
CUTLASS_DEVICE
|
| 111 |
+
EpilogueGemmKReduction(
|
| 112 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 113 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 114 |
+
int lane_idx, ///< Id of thread within warp
|
| 115 |
+
int threadblock_offset,
|
| 116 |
+
ElementOutput* pointer
|
| 117 |
+
)
|
| 118 |
+
{
|
| 119 |
+
col_ = lane_idx % 4;
|
| 120 |
+
thread_offset_ = threadblock_offset * kThreadBlockSize
|
| 121 |
+
+ warp_idx * kWarpSize
|
| 122 |
+
+ lane_idx / 4 + col_ * 8;
|
| 123 |
+
|
| 124 |
+
pointer_ = pointer + LongIndex(thread_offset_);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
/// Streams the result to global memory
|
| 128 |
+
CUTLASS_DEVICE
|
| 129 |
+
void operator()(
|
| 130 |
+
int size,
|
| 131 |
+
FragmentAccumulator &gemm_k_with_reduction_accumulation,
|
| 132 |
+
bool LoadForSerialSplitK
|
| 133 |
+
) {
|
| 134 |
+
bool guard[kIterations / 4];
|
| 135 |
+
|
| 136 |
+
CUTLASS_PRAGMA_UNROLL
|
| 137 |
+
for (int i = 0; i < kIterations / 4; ++i) {
|
| 138 |
+
guard[i] = ((thread_offset_ + i * 32) < size);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
Array<ElementOutput, kIterations / 4> source;
|
| 142 |
+
source.clear();
|
| 143 |
+
|
| 144 |
+
CUTLASS_PRAGMA_UNROLL
|
| 145 |
+
for (int i = 0; i < kIterations / 4; ++i) {
|
| 146 |
+
ElementOutput *source_ptr = reinterpret_cast<ElementOutput *>(&source);
|
| 147 |
+
cutlass::arch::global_load<ElementOutput, sizeof(ElementOutput)>(
|
| 148 |
+
source_ptr[i],
|
| 149 |
+
(void *)(pointer_ + i * 32),
|
| 150 |
+
guard[i] && LoadForSerialSplitK);
|
| 151 |
+
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
FragmentAccumulator sum = gemm_k_with_reduction_accumulation;
|
| 155 |
+
|
| 156 |
+
CUTLASS_PRAGMA_UNROLL
|
| 157 |
+
for (int i = 0; i < kIterations; ++i) {
|
| 158 |
+
sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 1);
|
| 159 |
+
sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 2);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
Array<ElementAccumulator, kIterations / 4> intermediate;
|
| 163 |
+
|
| 164 |
+
CUTLASS_PRAGMA_UNROLL
|
| 165 |
+
for (int i = 0; i < kIterations / 4; ++i) {
|
| 166 |
+
if (col_ == 0) {
|
| 167 |
+
intermediate[i] = sum[0 + i * 4];
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
if (col_ == 1) {
|
| 171 |
+
intermediate[i] = sum[1 + i * 4];
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
if (col_ == 2) {
|
| 175 |
+
intermediate[i] = sum[2 + i * 4];
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if (col_ == 3) {
|
| 179 |
+
intermediate[i] = sum[3 + i * 4];
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
NumericArrayConverter<ElementAccumulator, ElementOutput, kIterations / 4> source_converter;
|
| 184 |
+
Array<ElementAccumulator, kIterations / 4> converted_source = source_converter(source);
|
| 185 |
+
|
| 186 |
+
plus<Array<ElementAccumulator, kIterations / 4>> plus_source;
|
| 187 |
+
intermediate = plus_source(intermediate, converted_source);
|
| 188 |
+
|
| 189 |
+
NumericArrayConverter<ElementOutput, ElementAccumulator, kIterations / 4> converter;
|
| 190 |
+
Array<ElementOutput, kIterations / 4> result = converter(intermediate);
|
| 191 |
+
|
| 192 |
+
CUTLASS_PRAGMA_UNROLL
|
| 193 |
+
for (int i = 0; i < kIterations / 4; ++i) {
|
| 194 |
+
cutlass::arch::global_store<ElementOutput, sizeof(ElementOutput)>(result[i],
|
| 195 |
+
(void *)(pointer_ + i * 32), guard[i]);
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
};
|
| 199 |
+
|
| 200 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 201 |
+
|
| 202 |
+
} // namespace threadblock
|
| 203 |
+
} // namespace epilogue
|
| 204 |
+
} // namespace cutlass
|
| 205 |
+
|
| 206 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
#include "cutlass/array_planar_complex.h"
|
| 45 |
+
#include "cutlass/layout/vector.h"
|
| 46 |
+
#include "cutlass/layout/tensor.h"
|
| 47 |
+
#include "cutlass/tensor_coord.h"
|
| 48 |
+
#include "cutlass/aligned_buffer.h"
|
| 49 |
+
#include "cutlass/functional.h"
|
| 50 |
+
|
| 51 |
+
#include "cutlass/gemm/gemm.h"
|
| 52 |
+
|
| 53 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 54 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 55 |
+
|
| 56 |
+
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
| 57 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 58 |
+
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
namespace cutlass {
|
| 62 |
+
namespace epilogue {
|
| 63 |
+
namespace threadblock {
|
| 64 |
+
|
| 65 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 66 |
+
|
| 67 |
+
/// Epilogue operator for planar-complex output representations.
|
| 68 |
+
///
|
| 69 |
+
/// Note, as with most CUTLASS components for planar complex, the template arguments describe
|
| 70 |
+
/// the underlying real data type.
|
| 71 |
+
template <
|
| 72 |
+
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 73 |
+
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 74 |
+
int PartitionsK, ///< Number of partitions of the K dimension
|
| 75 |
+
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
| 76 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 77 |
+
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 78 |
+
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
| 79 |
+
typename OutputOp_, ///< Output operator
|
| 80 |
+
typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
| 81 |
+
>
|
| 82 |
+
class EpiloguePlanarComplex {
|
| 83 |
+
public:
|
| 84 |
+
|
| 85 |
+
using Shape = Shape_;
|
| 86 |
+
using WarpMmaOperator = WarpMmaOperator_;
|
| 87 |
+
static int const kPartitionsK = PartitionsK;
|
| 88 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 89 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 90 |
+
using WarpTileIterator = WarpTileIterator_;
|
| 91 |
+
using SharedLoadIterator = SharedLoadIterator_;
|
| 92 |
+
using OutputOp = OutputOp_;
|
| 93 |
+
using Padding = Padding_;
|
| 94 |
+
|
| 95 |
+
/// Output layout is always row-major
|
| 96 |
+
using Layout = layout::RowMajor;
|
| 97 |
+
using LongIndex = typename Layout::LongIndex;
|
| 98 |
+
|
| 99 |
+
/// The complete warp-level accumulator tile
|
| 100 |
+
using AccumulatorTile = ArrayPlanarComplex<
|
| 101 |
+
typename WarpMmaOperator::FragmentC::Element,
|
| 102 |
+
WarpMmaOperator::FragmentC::kElements
|
| 103 |
+
>;
|
| 104 |
+
|
| 105 |
+
/// Accumulator element
|
| 106 |
+
using ElementAccumulator = typename WarpTileIterator::Element;
|
| 107 |
+
|
| 108 |
+
/// Output element
|
| 109 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 110 |
+
|
| 111 |
+
/// Output access size
|
| 112 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 113 |
+
|
| 114 |
+
/// Tensor reference to destination tensor
|
| 115 |
+
using TensorRef = typename OutputTileIterator::TensorRef;
|
| 116 |
+
|
| 117 |
+
/// Tensor reference to sync tensor
|
| 118 |
+
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
| 119 |
+
|
| 120 |
+
/// Const tensor reference to source tensor
|
| 121 |
+
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
| 122 |
+
|
| 123 |
+
/// Array type used to output
|
| 124 |
+
using OutputAccessType = Array<
|
| 125 |
+
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 126 |
+
|
| 127 |
+
/// Array type used by output functor
|
| 128 |
+
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 129 |
+
|
| 130 |
+
/// Shape of each warp-level operation
|
| 131 |
+
using WarpShape = typename WarpMmaOperator::Shape;
|
| 132 |
+
|
| 133 |
+
/// Number of warps
|
| 134 |
+
using WarpCount = gemm::GemmShape<
|
| 135 |
+
Shape::kM / WarpShape::kM,
|
| 136 |
+
Shape::kN / WarpShape::kN,
|
| 137 |
+
kPartitionsK
|
| 138 |
+
>;
|
| 139 |
+
|
| 140 |
+
/// Shared memory allocation
|
| 141 |
+
struct SharedStorage {
|
| 142 |
+
|
| 143 |
+
//
|
| 144 |
+
// Type definitions
|
| 145 |
+
//
|
| 146 |
+
|
| 147 |
+
/// Element type of shared memory
|
| 148 |
+
using Element = typename WarpTileIterator::Element;
|
| 149 |
+
|
| 150 |
+
/// Tensor reference to shared memory allocation
|
| 151 |
+
using TensorRef = typename WarpTileIterator::TensorRef;
|
| 152 |
+
|
| 153 |
+
/// Layout of shared memory allocation
|
| 154 |
+
using Layout = typename WarpTileIterator::Layout;
|
| 155 |
+
|
| 156 |
+
/// Logical shape of the shared memory tile written to by all warps.
|
| 157 |
+
using Shape = MatrixShape<
|
| 158 |
+
WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,
|
| 159 |
+
WarpCount::kN * WarpTileIterator::Shape::kColumn
|
| 160 |
+
>;
|
| 161 |
+
|
| 162 |
+
/// Shape of the shared memory allocation for the epilogue
|
| 163 |
+
using StorageShape = MatrixShape<
|
| 164 |
+
Shape::kRow + Padding::kRow,
|
| 165 |
+
Shape::kColumn + Padding::kColumn
|
| 166 |
+
>;
|
| 167 |
+
|
| 168 |
+
static int const kImaginaryStride = StorageShape::kCount;
|
| 169 |
+
|
| 170 |
+
//
|
| 171 |
+
// Data members
|
| 172 |
+
//
|
| 173 |
+
|
| 174 |
+
AlignedBuffer<Element, kImaginaryStride * 2> storage;
|
| 175 |
+
|
| 176 |
+
//
|
| 177 |
+
// Methods
|
| 178 |
+
//
|
| 179 |
+
|
| 180 |
+
/// Returns a pointer to the shared memory buffer
|
| 181 |
+
CUTLASS_DEVICE
|
| 182 |
+
Element *data() {
|
| 183 |
+
return storage.data();
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
/// Returns a tensor reference to the shared memory buffer
|
| 187 |
+
CUTLASS_DEVICE
|
| 188 |
+
TensorRef reference() {
|
| 189 |
+
return TensorRef(
|
| 190 |
+
storage.data(),
|
| 191 |
+
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
|
| 192 |
+
}
|
| 193 |
+
};
|
| 194 |
+
|
| 195 |
+
private:
|
| 196 |
+
|
| 197 |
+
//
|
| 198 |
+
// Data members
|
| 199 |
+
//
|
| 200 |
+
|
| 201 |
+
SharedStorage &shared_storage_;
|
| 202 |
+
|
| 203 |
+
/// Loads fragment from shared memory aligned with output tensor
|
| 204 |
+
SharedLoadIterator shared_load_iterator_;
|
| 205 |
+
|
| 206 |
+
/// Stores a warp's fragment of accumulators to SMEM
|
| 207 |
+
WarpTileIterator warp_tile_iterator_;
|
| 208 |
+
|
| 209 |
+
public:
|
| 210 |
+
|
| 211 |
+
/// Constructor
|
| 212 |
+
CUTLASS_DEVICE
|
| 213 |
+
EpiloguePlanarComplex(
|
| 214 |
+
SharedStorage &shared_storage, ///< Shared storage object
|
| 215 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 216 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 217 |
+
int lane_idx ///< Id of thread within warp
|
| 218 |
+
):
|
| 219 |
+
shared_storage_(shared_storage),
|
| 220 |
+
shared_load_iterator_(shared_storage.reference(), thread_idx),
|
| 221 |
+
warp_tile_iterator_(shared_storage.reference(), lane_idx) {
|
| 222 |
+
|
| 223 |
+
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
|
| 224 |
+
//
|
| 225 |
+
// _m: the warp's position within the threadblock along the M dimension
|
| 226 |
+
// _n: the warp's position within the threadblock along the N dimension
|
| 227 |
+
// _k: the warp's position within the threadblock along the K dimension
|
| 228 |
+
|
| 229 |
+
int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
|
| 230 |
+
int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
|
| 231 |
+
int warp_m = warp_mn % WarpCount::kM;
|
| 232 |
+
int warp_n = warp_mn / WarpCount::kM;
|
| 233 |
+
|
| 234 |
+
MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
|
| 235 |
+
|
| 236 |
+
warp_tile_iterator_.add_tile_offset(warp_offset);
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
/// Streams the result to global memory
|
| 240 |
+
CUTLASS_DEVICE
|
| 241 |
+
void operator()(
|
| 242 |
+
OutputOp const &output_op, ///< Output operator
|
| 243 |
+
OutputTileIterator destination_iterator_real, ///< Tile iterator for destination
|
| 244 |
+
OutputTileIterator destination_iterator_imag, ///< Tile iterator for destination
|
| 245 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 246 |
+
OutputTileIterator source_iterator_real, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 247 |
+
OutputTileIterator source_iterator_imag) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 248 |
+
|
| 249 |
+
typename OutputTileIterator::Fragment source_fragment_real;
|
| 250 |
+
typename OutputTileIterator::Fragment source_fragment_imag;
|
| 251 |
+
|
| 252 |
+
if (!output_op.is_source_needed()) {
|
| 253 |
+
source_iterator_real.clear_mask();
|
| 254 |
+
source_iterator_imag.clear_mask();
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
source_fragment_real.clear();
|
| 258 |
+
source_fragment_imag.clear();
|
| 259 |
+
|
| 260 |
+
//
|
| 261 |
+
// Iterator over warp-level accumulator fragment
|
| 262 |
+
//
|
| 263 |
+
|
| 264 |
+
AccumulatorFragmentIterator accum_fragment_iterator_real(accumulators.real);
|
| 265 |
+
AccumulatorFragmentIterator accum_fragment_iterator_imag(accumulators.imag);
|
| 266 |
+
|
| 267 |
+
//
|
| 268 |
+
// Iterate over accumulator tile
|
| 269 |
+
//
|
| 270 |
+
|
| 271 |
+
CUTLASS_PRAGMA_UNROLL
|
| 272 |
+
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
| 273 |
+
|
| 274 |
+
//
|
| 275 |
+
// Load the source
|
| 276 |
+
//
|
| 277 |
+
|
| 278 |
+
source_iterator_real.load(source_fragment_real);
|
| 279 |
+
source_iterator_imag.load(source_fragment_imag);
|
| 280 |
+
|
| 281 |
+
++source_iterator_real;
|
| 282 |
+
++source_iterator_imag;
|
| 283 |
+
|
| 284 |
+
//
|
| 285 |
+
// Convert and store fragment
|
| 286 |
+
//
|
| 287 |
+
|
| 288 |
+
__syncthreads();
|
| 289 |
+
|
| 290 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment_real;
|
| 291 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment_imag;
|
| 292 |
+
|
| 293 |
+
accum_fragment_iterator_real.load(accum_fragment_real);
|
| 294 |
+
accum_fragment_iterator_imag.load(accum_fragment_imag);
|
| 295 |
+
|
| 296 |
+
++accum_fragment_iterator_real;
|
| 297 |
+
++accum_fragment_iterator_imag;
|
| 298 |
+
|
| 299 |
+
this->warp_tile_iterator_.store(accum_fragment_real);
|
| 300 |
+
this->warp_tile_iterator_.store_with_pointer_offset(accum_fragment_imag, SharedStorage::kImaginaryStride);
|
| 301 |
+
|
| 302 |
+
__syncthreads();
|
| 303 |
+
|
| 304 |
+
//
|
| 305 |
+
// Load fragments from shared memory
|
| 306 |
+
//
|
| 307 |
+
|
| 308 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment_real[kPartitionsK];
|
| 309 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment_imag[kPartitionsK];
|
| 310 |
+
|
| 311 |
+
shared_load_iterator_.load(aligned_accum_fragment_real[0]);
|
| 312 |
+
shared_load_iterator_.load_with_pointer_offset(aligned_accum_fragment_imag[0], SharedStorage::kImaginaryStride);
|
| 313 |
+
|
| 314 |
+
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
| 315 |
+
static_assert(kPartitionsK == 1, "Sliced-K not supported for planar complex at this time");
|
| 316 |
+
|
| 317 |
+
//
|
| 318 |
+
// Compute the output result
|
| 319 |
+
//
|
| 320 |
+
|
| 321 |
+
typename OutputTileIterator::Fragment output_fragment_real;
|
| 322 |
+
typename OutputTileIterator::Fragment output_fragment_imag;
|
| 323 |
+
|
| 324 |
+
apply_output_operator_(
|
| 325 |
+
output_fragment_real,
|
| 326 |
+
output_fragment_imag,
|
| 327 |
+
output_op,
|
| 328 |
+
aligned_accum_fragment_real[0],
|
| 329 |
+
aligned_accum_fragment_imag[0],
|
| 330 |
+
source_fragment_real,
|
| 331 |
+
source_fragment_imag);
|
| 332 |
+
|
| 333 |
+
//
|
| 334 |
+
// Store the final result
|
| 335 |
+
//
|
| 336 |
+
|
| 337 |
+
destination_iterator_real.store(output_fragment_real);
|
| 338 |
+
destination_iterator_imag.store(output_fragment_imag);
|
| 339 |
+
|
| 340 |
+
++destination_iterator_real;
|
| 341 |
+
++destination_iterator_imag;
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
private:
|
| 346 |
+
|
| 347 |
+
/// Helper to invoke the output functor over each vector of output
|
| 348 |
+
CUTLASS_DEVICE
|
| 349 |
+
void apply_output_operator_(
|
| 350 |
+
typename OutputTileIterator::Fragment &output_fragment_real,
|
| 351 |
+
typename OutputTileIterator::Fragment &output_fragment_imag,
|
| 352 |
+
OutputOp const &output_op, ///< Output operator
|
| 353 |
+
typename SharedLoadIterator::Fragment const &aligned_accum_fragment_real,
|
| 354 |
+
typename SharedLoadIterator::Fragment const &aligned_accum_fragment_imag,
|
| 355 |
+
typename OutputTileIterator::Fragment const &source_fragment_real,
|
| 356 |
+
typename OutputTileIterator::Fragment const &source_fragment_imag) {
|
| 357 |
+
|
| 358 |
+
OutputAccessType *output_frag_real_ptr =
|
| 359 |
+
reinterpret_cast<OutputAccessType *>(&output_fragment_real);
|
| 360 |
+
|
| 361 |
+
OutputAccessType *output_frag_imag_ptr =
|
| 362 |
+
reinterpret_cast<OutputAccessType *>(&output_fragment_imag);
|
| 363 |
+
|
| 364 |
+
AccumulatorAccessType const *compute_frag_real_ptr =
|
| 365 |
+
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment_real);
|
| 366 |
+
|
| 367 |
+
AccumulatorAccessType const *compute_frag_imag_ptr =
|
| 368 |
+
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment_imag);
|
| 369 |
+
|
| 370 |
+
OutputAccessType const *source_frag_real_ptr =
|
| 371 |
+
reinterpret_cast<OutputAccessType const *>(&source_fragment_real);
|
| 372 |
+
|
| 373 |
+
OutputAccessType const *source_frag_imag_ptr =
|
| 374 |
+
reinterpret_cast<OutputAccessType const *>(&source_fragment_imag);
|
| 375 |
+
|
| 376 |
+
int const kOutputOpIterations =
|
| 377 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 378 |
+
|
| 379 |
+
CUTLASS_PRAGMA_UNROLL
|
| 380 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 381 |
+
|
| 382 |
+
// Call the output operator
|
| 383 |
+
auto result_fragment = output_op(
|
| 384 |
+
make_ArrayPlanarComplex(compute_frag_real_ptr[i], compute_frag_imag_ptr[i]),
|
| 385 |
+
make_ArrayPlanarComplex(source_frag_real_ptr[i], source_frag_imag_ptr[i])
|
| 386 |
+
);
|
| 387 |
+
|
| 388 |
+
output_frag_real_ptr[i] = result_fragment.real;
|
| 389 |
+
output_frag_imag_ptr[i] = result_fragment.imag;
|
| 390 |
+
}
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
};
|
| 394 |
+
|
| 395 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 396 |
+
|
| 397 |
+
} // namespace threadblock
|
| 398 |
+
} // namespace epilogue
|
| 399 |
+
} // namespace cutlass
|
| 400 |
+
|
| 401 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMM/CONV to store accumulator in shared memory after
|
| 33 |
+
applying scale, bias loaded from global memory and element-wise operations.
|
| 34 |
+
|
| 35 |
+
This Epilogue is typically used in fused GEMM/CONV to stage the intermediate accumulator.
|
| 36 |
+
|
| 37 |
+
*/
|
| 38 |
+
|
| 39 |
+
#pragma once
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include CUDA_STD_HEADER(cassert)
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
#include "cutlass/layout/vector.h"
|
| 45 |
+
#include "cutlass/layout/tensor.h"
|
| 46 |
+
#include "cutlass/tensor_coord.h"
|
| 47 |
+
#include "cutlass/aligned_buffer.h"
|
| 48 |
+
#include "cutlass/functional.h"
|
| 49 |
+
|
| 50 |
+
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
| 51 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
namespace cutlass {
|
| 55 |
+
namespace epilogue {
|
| 56 |
+
namespace threadblock {
|
| 57 |
+
|
| 58 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
/// Epilogue operator
|
| 61 |
+
template <
|
| 62 |
+
typename SmemTileIterator_, ///< Shared memory Tile iterator to output to shared memory
|
| 63 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 64 |
+
typename ScaleBiasIterator_, ///< Iterator to load scale and bias from global memory
|
| 65 |
+
typename OutputOp_ ///< Output operator
|
| 66 |
+
>
|
| 67 |
+
class EpilogueSmemAccumulator {
|
| 68 |
+
|
| 69 |
+
public:
|
| 70 |
+
|
| 71 |
+
using SmemTileIterator = SmemTileIterator_;
|
| 72 |
+
|
| 73 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 74 |
+
|
| 75 |
+
using ScaleBiasIterator = ScaleBiasIterator_;
|
| 76 |
+
|
| 77 |
+
using OutputOp = OutputOp_;
|
| 78 |
+
|
| 79 |
+
/// Fragment of accumulator tile
|
| 80 |
+
using FragmentAccumulator = typename AccumulatorFragmentIterator::Fragment;
|
| 81 |
+
|
| 82 |
+
/// The complete warp-level accumulator tile
|
| 83 |
+
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
| 84 |
+
|
| 85 |
+
/// Fragment of Scale and Bias loaded from global memory
|
| 86 |
+
using FragmentScaleBias = typename ScaleBiasIterator::Fragment;
|
| 87 |
+
|
| 88 |
+
static const bool PerChannelScale = (OutputOp::kScale ==
|
| 89 |
+
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
| 90 |
+
|
| 91 |
+
/// Constructor
|
| 92 |
+
CUTLASS_DEVICE
|
| 93 |
+
EpilogueSmemAccumulator() {}
|
| 94 |
+
|
| 95 |
+
/// Streams the result to shared memory
|
| 96 |
+
CUTLASS_DEVICE
|
| 97 |
+
void operator()(
|
| 98 |
+
OutputOp const &output_op, ///< Output operator
|
| 99 |
+
SmemTileIterator smem_iterator, ///< Tile iterator for destination in shared memory
|
| 100 |
+
AccumulatorTile const &accumulator, ///< Complete warp-level accumulator tile
|
| 101 |
+
ScaleBiasIterator scale_iterator, ///< iterator for scale vector in global memory
|
| 102 |
+
ScaleBiasIterator bias_iterator) { ///< iterator for bias vector in global memory
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
// Fragment to load scale bias from global memory
|
| 106 |
+
FragmentScaleBias tb_frag_scale;
|
| 107 |
+
FragmentScaleBias tb_frag_bias;
|
| 108 |
+
|
| 109 |
+
/// Fragment Iterator to load slice of accumulator tile
|
| 110 |
+
AccumulatorFragmentIterator frag_iterator_accum(accumulator);
|
| 111 |
+
FragmentAccumulator tb_frag_accum;
|
| 112 |
+
|
| 113 |
+
/// Epilogue output fragment
|
| 114 |
+
typename SmemTileIterator::Fragment tb_frag_smem;
|
| 115 |
+
|
| 116 |
+
/// Load scale and bias from global memory
|
| 117 |
+
|
| 118 |
+
if(PerChannelScale)
|
| 119 |
+
scale_iterator.load(tb_frag_scale);
|
| 120 |
+
|
| 121 |
+
bias_iterator.load(tb_frag_bias);
|
| 122 |
+
|
| 123 |
+
/// Iterate over the accumulator tile and store to shared memory
|
| 124 |
+
CUTLASS_PRAGMA_UNROLL
|
| 125 |
+
for (int rid = 0; rid < AccumulatorFragmentIterator::TileIterations::kRow; ++rid) {
|
| 126 |
+
|
| 127 |
+
CUTLASS_PRAGMA_UNROLL
|
| 128 |
+
for (int cid = 0; cid < AccumulatorFragmentIterator::TileIterations::kColumn; ++cid) {
|
| 129 |
+
|
| 130 |
+
using AccumulatorAccessType = typename OutputOp::FragmentAccumulator;
|
| 131 |
+
using ScaleBiasAccessType = typename OutputOp::FragmentScaleBias;
|
| 132 |
+
using FragmentSmemAccessType = typename OutputOp::FragmentOutput;
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
ScaleBiasAccessType const * scale_frag_ptr =
|
| 136 |
+
reinterpret_cast<ScaleBiasAccessType const *>(&tb_frag_scale);
|
| 137 |
+
ScaleBiasAccessType const * bias_frag_ptr =
|
| 138 |
+
reinterpret_cast<ScaleBiasAccessType const *>(&tb_frag_bias);
|
| 139 |
+
|
| 140 |
+
FragmentSmemAccessType * smem_frag_ptr =
|
| 141 |
+
reinterpret_cast<FragmentSmemAccessType *>(&tb_frag_smem);
|
| 142 |
+
|
| 143 |
+
CUTLASS_PRAGMA_UNROLL
|
| 144 |
+
for (int idx = 0; idx < AccumulatorFragmentIterator::kIterationsPerTile; ++idx) {
|
| 145 |
+
frag_iterator_accum.load(tb_frag_accum);
|
| 146 |
+
++frag_iterator_accum;
|
| 147 |
+
|
| 148 |
+
AccumulatorAccessType const * accumulator_frag_ptr =
|
| 149 |
+
reinterpret_cast<AccumulatorAccessType const *>(&tb_frag_accum);
|
| 150 |
+
const int kOutputIterations = FragmentAccumulator::kElements / OutputOp::kCount;
|
| 151 |
+
|
| 152 |
+
CUTLASS_PRAGMA_UNROLL
|
| 153 |
+
for (int it = 0; it < kOutputIterations; it++) {
|
| 154 |
+
smem_frag_ptr[idx * kOutputIterations + it] = output_op(accumulator_frag_ptr[it],
|
| 155 |
+
scale_frag_ptr[cid * kOutputIterations + it], bias_frag_ptr[cid * kOutputIterations + it]);
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
smem_iterator.store(tb_frag_smem);
|
| 160 |
+
++smem_iterator;
|
| 161 |
+
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Streams the result to shared memory
|
| 167 |
+
CUTLASS_DEVICE
|
| 168 |
+
void operator()(
|
| 169 |
+
OutputOp const &output_op, ///< Output operator
|
| 170 |
+
SmemTileIterator smem_iterator, ///< Tile iterator for destination in shared memory
|
| 171 |
+
AccumulatorTile const &accumulator) { ///< Complete warp-level accumulator tile
|
| 172 |
+
|
| 173 |
+
/// Fragment Iterator to load slice of accumulator tile
|
| 174 |
+
AccumulatorFragmentIterator frag_iterator_accum(accumulator);
|
| 175 |
+
FragmentAccumulator tb_frag_accum;
|
| 176 |
+
|
| 177 |
+
/// Epilogue output fragment
|
| 178 |
+
typename SmemTileIterator::Fragment tb_frag_smem;
|
| 179 |
+
|
| 180 |
+
/// Iterate over the accumulator tile and store to shared memory
|
| 181 |
+
CUTLASS_PRAGMA_UNROLL
|
| 182 |
+
for (int rid = 0; rid < AccumulatorFragmentIterator::TileIterations::kRow; ++rid) {
|
| 183 |
+
|
| 184 |
+
CUTLASS_PRAGMA_UNROLL
|
| 185 |
+
for (int cid = 0; cid < AccumulatorFragmentIterator::TileIterations::kColumn; ++cid) {
|
| 186 |
+
|
| 187 |
+
using AccumulatorAccessType = typename OutputOp::FragmentAccumulator;
|
| 188 |
+
using FragmentSmemAccessType = typename OutputOp::FragmentOutput;
|
| 189 |
+
|
| 190 |
+
FragmentSmemAccessType * smem_frag_ptr =
|
| 191 |
+
reinterpret_cast<FragmentSmemAccessType *>(&tb_frag_smem);
|
| 192 |
+
|
| 193 |
+
CUTLASS_PRAGMA_UNROLL
|
| 194 |
+
for (int idx = 0; idx < AccumulatorFragmentIterator::kIterationsPerTile; ++idx) {
|
| 195 |
+
frag_iterator_accum.load(tb_frag_accum);
|
| 196 |
+
++frag_iterator_accum;
|
| 197 |
+
|
| 198 |
+
AccumulatorAccessType const * accumulator_frag_ptr =
|
| 199 |
+
reinterpret_cast<AccumulatorAccessType const *>(&tb_frag_accum);
|
| 200 |
+
const int kOutputIterations = FragmentAccumulator::kElements / OutputOp::kCount;
|
| 201 |
+
|
| 202 |
+
CUTLASS_PRAGMA_UNROLL
|
| 203 |
+
for (int it = 0; it < kOutputIterations; it++) {
|
| 204 |
+
smem_frag_ptr[idx * kOutputIterations + it] = output_op(accumulator_frag_ptr[it]);
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
smem_iterator.store(tb_frag_smem);
|
| 209 |
+
++smem_iterator;
|
| 210 |
+
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
};
|
| 216 |
+
|
| 217 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 218 |
+
|
| 219 |
+
} // namespace threadblock
|
| 220 |
+
} // namespace epilogue
|
| 221 |
+
} // namespace cutlass
|
| 222 |
+
|
| 223 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 224 |
+
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
|
| 33 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 34 |
+
|
| 35 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 36 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 37 |
+
|
| 38 |
+
*/
|
| 39 |
+
|
| 40 |
+
#pragma once
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
|
| 43 |
+
#include CUDA_STD_HEADER(cassert)
|
| 44 |
+
|
| 45 |
+
#if defined(__CUDACC_RTC__)
|
| 46 |
+
#include CUDA_STD_HEADER(utility)
|
| 47 |
+
#else
|
| 48 |
+
#include <utility>
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
#include "cutlass/array.h"
|
| 52 |
+
#include "cutlass/numeric_types.h"
|
| 53 |
+
#include "cutlass/numeric_conversion.h"
|
| 54 |
+
#include "cutlass/tensor_coord.h"
|
| 55 |
+
#include "cutlass/aligned_buffer.h"
|
| 56 |
+
#include "cutlass/functional.h"
|
| 57 |
+
#include "cutlass/fast_math.h"
|
| 58 |
+
#include "cutlass/layout/vector.h"
|
| 59 |
+
#include "cutlass/layout/tensor.h"
|
| 60 |
+
|
| 61 |
+
#include "cutlass/gemm/gemm.h"
|
| 62 |
+
|
| 63 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 64 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 65 |
+
|
| 66 |
+
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
| 67 |
+
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
|
| 68 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 69 |
+
|
| 70 |
+
#include "cutlass/numeric_types.h"
|
| 71 |
+
|
| 72 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 73 |
+
|
| 74 |
+
namespace cutlass {
|
| 75 |
+
namespace epilogue {
|
| 76 |
+
namespace threadblock {
|
| 77 |
+
|
| 78 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 79 |
+
|
| 80 |
+
/// This base class is meant to define the concept required of the
|
| 81 |
+
/// EpilogueStreamkWithBroadcast::OutputOp
|
| 82 |
+
template <
|
| 83 |
+
typename ElementC_,
|
| 84 |
+
typename ElementAccumulator_,
|
| 85 |
+
typename ElementCompute_,
|
| 86 |
+
typename ElementZ_,
|
| 87 |
+
typename ElementT_,
|
| 88 |
+
int ElementsPerAccess,
|
| 89 |
+
bool StoreZ = true,
|
| 90 |
+
bool StoreT = true
|
| 91 |
+
>
|
| 92 |
+
struct EpilogueStreamkWithBroadcastOpBase : EpilogueWithBroadcastOpBase<
|
| 93 |
+
ElementC_,
|
| 94 |
+
ElementAccumulator_,
|
| 95 |
+
ElementCompute_,
|
| 96 |
+
ElementZ_,
|
| 97 |
+
ElementT_,
|
| 98 |
+
ElementsPerAccess,
|
| 99 |
+
StoreZ,
|
| 100 |
+
StoreT
|
| 101 |
+
>
|
| 102 |
+
{
|
| 103 |
+
|
| 104 |
+
/// Parameters structure - required
|
| 105 |
+
struct Params { };
|
| 106 |
+
|
| 107 |
+
//
|
| 108 |
+
// Methods
|
| 109 |
+
//
|
| 110 |
+
|
| 111 |
+
/// Constructor from Params
|
| 112 |
+
EpilogueStreamkWithBroadcastOpBase(Params const ¶ms_) { }
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 116 |
+
|
| 117 |
+
/// Epilogue operator with bias vector broadcast over columns.
|
| 118 |
+
///
|
| 119 |
+
/// Computes the following:
|
| 120 |
+
///
|
| 121 |
+
///
|
| 122 |
+
/// Z, T = OutputOp(AB, C, Broadcast)
|
| 123 |
+
///
|
| 124 |
+
/// if (ElementwiseOp::kStoreZ) {
|
| 125 |
+
/// store(converted_u);
|
| 126 |
+
/// }
|
| 127 |
+
///
|
| 128 |
+
/// if (ElementwiseOp::kStoreT) {
|
| 129 |
+
/// store(v);
|
| 130 |
+
/// }
|
| 131 |
+
///
|
| 132 |
+
template <
|
| 133 |
+
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 134 |
+
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 135 |
+
int PartitionsK, ///< Number of partitions of the K dimension
|
| 136 |
+
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
|
| 137 |
+
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
|
| 138 |
+
typename ElementVector_, ///< Pointer to broadcast vector
|
| 139 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 140 |
+
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 141 |
+
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
| 142 |
+
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
|
| 143 |
+
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
| 144 |
+
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
| 145 |
+
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
| 146 |
+
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
|
| 147 |
+
bool IsSingleSource = OutputOp_::kIsSingleSource
|
| 148 |
+
>
|
| 149 |
+
class EpilogueStreamkWithBroadcast;
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 153 |
+
|
| 154 |
+
/// EpilogueStreamkWithBroadcast: Two sources
|
| 155 |
+
|
| 156 |
+
template <
|
| 157 |
+
typename Shape_,
|
| 158 |
+
typename WarpMmaOperator_,
|
| 159 |
+
int PartitionsK,
|
| 160 |
+
typename OutputTileIterator_,
|
| 161 |
+
typename TensorTileIterator_,
|
| 162 |
+
typename ElementVector_,
|
| 163 |
+
typename AccumulatorFragmentIterator_,
|
| 164 |
+
typename WarpTileIterator_,
|
| 165 |
+
typename SharedLoadIterator_,
|
| 166 |
+
typename OutputOp_,
|
| 167 |
+
typename Padding_,
|
| 168 |
+
int FragmentsPerPartition,
|
| 169 |
+
int IterationsUnroll
|
| 170 |
+
>
|
| 171 |
+
class EpilogueStreamkWithBroadcast<
|
| 172 |
+
Shape_,
|
| 173 |
+
WarpMmaOperator_,
|
| 174 |
+
PartitionsK,
|
| 175 |
+
OutputTileIterator_,
|
| 176 |
+
TensorTileIterator_,
|
| 177 |
+
ElementVector_,
|
| 178 |
+
AccumulatorFragmentIterator_,
|
| 179 |
+
WarpTileIterator_,
|
| 180 |
+
SharedLoadIterator_,
|
| 181 |
+
OutputOp_,
|
| 182 |
+
Padding_,
|
| 183 |
+
FragmentsPerPartition,
|
| 184 |
+
IterationsUnroll,
|
| 185 |
+
false
|
| 186 |
+
> :
|
| 187 |
+
public EpilogueWithBroadcast<
|
| 188 |
+
Shape_,
|
| 189 |
+
WarpMmaOperator_,
|
| 190 |
+
PartitionsK,
|
| 191 |
+
OutputTileIterator_,
|
| 192 |
+
TensorTileIterator_,
|
| 193 |
+
ElementVector_,
|
| 194 |
+
AccumulatorFragmentIterator_,
|
| 195 |
+
WarpTileIterator_,
|
| 196 |
+
SharedLoadIterator_,
|
| 197 |
+
OutputOp_,
|
| 198 |
+
Padding_,
|
| 199 |
+
FragmentsPerPartition,
|
| 200 |
+
IterationsUnroll,
|
| 201 |
+
false>,
|
| 202 |
+
public EpilogueBaseStreamK<
|
| 203 |
+
Shape_,
|
| 204 |
+
PartitionsK,
|
| 205 |
+
WarpMmaOperator_,
|
| 206 |
+
AccumulatorFragmentIterator_>
|
| 207 |
+
{
|
| 208 |
+
|
| 209 |
+
public:
|
| 210 |
+
|
| 211 |
+
using Base = EpilogueWithBroadcast<
|
| 212 |
+
Shape_,
|
| 213 |
+
WarpMmaOperator_,
|
| 214 |
+
PartitionsK,
|
| 215 |
+
OutputTileIterator_,
|
| 216 |
+
TensorTileIterator_,
|
| 217 |
+
ElementVector_,
|
| 218 |
+
AccumulatorFragmentIterator_,
|
| 219 |
+
WarpTileIterator_,
|
| 220 |
+
SharedLoadIterator_,
|
| 221 |
+
OutputOp_,
|
| 222 |
+
Padding_,
|
| 223 |
+
FragmentsPerPartition,
|
| 224 |
+
IterationsUnroll,
|
| 225 |
+
false>;
|
| 226 |
+
|
| 227 |
+
using BaseStreamK = EpilogueBaseStreamK<
|
| 228 |
+
Shape_,
|
| 229 |
+
PartitionsK,
|
| 230 |
+
WarpMmaOperator_,
|
| 231 |
+
AccumulatorFragmentIterator_>;
|
| 232 |
+
|
| 233 |
+
using Shape = Shape_;
|
| 234 |
+
static int const kPartitionsK = PartitionsK;
|
| 235 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 236 |
+
using TensorTileIterator = TensorTileIterator_;
|
| 237 |
+
using ElementVector = ElementVector_;
|
| 238 |
+
using SharedLoadIterator = SharedLoadIterator_;
|
| 239 |
+
using OutputOp = OutputOp_;
|
| 240 |
+
|
| 241 |
+
/// Fragment type used by the accumulator tile's fragment iterator
|
| 242 |
+
using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment;
|
| 243 |
+
|
| 244 |
+
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
| 245 |
+
using SharedStorage = typename Base::SharedStorage;
|
| 246 |
+
|
| 247 |
+
public:
|
| 248 |
+
|
| 249 |
+
/// Constructor
|
| 250 |
+
CUTLASS_DEVICE
|
| 251 |
+
EpilogueStreamkWithBroadcast(
|
| 252 |
+
SharedStorage &shared_storage, ///< Shared storage object
|
| 253 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 254 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 255 |
+
int lane_idx ///< Id of thread within warp
|
| 256 |
+
):
|
| 257 |
+
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
| 258 |
+
BaseStreamK(thread_idx)
|
| 259 |
+
{ }
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
|
| 263 |
+
/// performing epilogue computations, writing to output
|
| 264 |
+
CUTLASS_DEVICE
|
| 265 |
+
void reduce(
|
| 266 |
+
int peer_idx_begin,
|
| 267 |
+
int peer_idx_end,
|
| 268 |
+
int reduce_fragment_idx,
|
| 269 |
+
void *element_workspace,
|
| 270 |
+
OutputOp const &output_op, ///< Output operator
|
| 271 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 272 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 273 |
+
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
|
| 274 |
+
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
|
| 275 |
+
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
| 276 |
+
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
| 277 |
+
MatrixCoord(Shape::kM, Shape::kN),
|
| 278 |
+
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
| 279 |
+
MatrixCoord())
|
| 280 |
+
{
|
| 281 |
+
// Reduce peer accumulator fragments into one fragment
|
| 282 |
+
AccumulatorFragment accum_fragment;
|
| 283 |
+
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
|
| 284 |
+
|
| 285 |
+
// Store fragment to shared memory
|
| 286 |
+
this->warp_tile_iterator_.store(accum_fragment);
|
| 287 |
+
|
| 288 |
+
__syncthreads();
|
| 289 |
+
|
| 290 |
+
Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator1, source_iterator2, tensor_iterator, problem_size, threadblock_offset);
|
| 291 |
+
|
| 292 |
+
}
|
| 293 |
+
};
|
| 294 |
+
|
| 295 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 296 |
+
|
| 297 |
+
/// EpilogueStreamkWithBroadcast: Single source
|
| 298 |
+
|
| 299 |
+
template <
|
| 300 |
+
typename Shape_,
|
| 301 |
+
typename WarpMmaOperator_,
|
| 302 |
+
int PartitionsK,
|
| 303 |
+
typename OutputTileIterator_,
|
| 304 |
+
typename TensorTileIterator_,
|
| 305 |
+
typename ElementVector_,
|
| 306 |
+
typename AccumulatorFragmentIterator_,
|
| 307 |
+
typename WarpTileIterator_,
|
| 308 |
+
typename SharedLoadIterator_,
|
| 309 |
+
typename OutputOp_,
|
| 310 |
+
typename Padding_,
|
| 311 |
+
int FragmentsPerPartition,
|
| 312 |
+
int IterationsUnroll
|
| 313 |
+
>
|
| 314 |
+
class EpilogueStreamkWithBroadcast<
|
| 315 |
+
Shape_,
|
| 316 |
+
WarpMmaOperator_,
|
| 317 |
+
PartitionsK,
|
| 318 |
+
OutputTileIterator_,
|
| 319 |
+
TensorTileIterator_,
|
| 320 |
+
ElementVector_,
|
| 321 |
+
AccumulatorFragmentIterator_,
|
| 322 |
+
WarpTileIterator_,
|
| 323 |
+
SharedLoadIterator_,
|
| 324 |
+
OutputOp_,
|
| 325 |
+
Padding_,
|
| 326 |
+
FragmentsPerPartition,
|
| 327 |
+
IterationsUnroll,
|
| 328 |
+
true
|
| 329 |
+
> :
|
| 330 |
+
public EpilogueWithBroadcast<
|
| 331 |
+
Shape_,
|
| 332 |
+
WarpMmaOperator_,
|
| 333 |
+
PartitionsK,
|
| 334 |
+
OutputTileIterator_,
|
| 335 |
+
TensorTileIterator_,
|
| 336 |
+
ElementVector_,
|
| 337 |
+
AccumulatorFragmentIterator_,
|
| 338 |
+
WarpTileIterator_,
|
| 339 |
+
SharedLoadIterator_,
|
| 340 |
+
OutputOp_,
|
| 341 |
+
Padding_,
|
| 342 |
+
FragmentsPerPartition,
|
| 343 |
+
IterationsUnroll,
|
| 344 |
+
true>,
|
| 345 |
+
public EpilogueBaseStreamK<
|
| 346 |
+
Shape_,
|
| 347 |
+
PartitionsK,
|
| 348 |
+
WarpMmaOperator_,
|
| 349 |
+
AccumulatorFragmentIterator_>
|
| 350 |
+
{
|
| 351 |
+
|
| 352 |
+
public:
|
| 353 |
+
|
| 354 |
+
using Base = EpilogueWithBroadcast<
|
| 355 |
+
Shape_,
|
| 356 |
+
WarpMmaOperator_,
|
| 357 |
+
PartitionsK,
|
| 358 |
+
OutputTileIterator_,
|
| 359 |
+
TensorTileIterator_,
|
| 360 |
+
ElementVector_,
|
| 361 |
+
AccumulatorFragmentIterator_,
|
| 362 |
+
WarpTileIterator_,
|
| 363 |
+
SharedLoadIterator_,
|
| 364 |
+
OutputOp_,
|
| 365 |
+
Padding_,
|
| 366 |
+
FragmentsPerPartition,
|
| 367 |
+
IterationsUnroll,
|
| 368 |
+
true>;
|
| 369 |
+
|
| 370 |
+
using BaseStreamK = EpilogueBaseStreamK<
|
| 371 |
+
Shape_,
|
| 372 |
+
PartitionsK,
|
| 373 |
+
WarpMmaOperator_,
|
| 374 |
+
AccumulatorFragmentIterator_>;
|
| 375 |
+
|
| 376 |
+
using Shape = Shape_;
|
| 377 |
+
static int const kPartitionsK = PartitionsK;
|
| 378 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 379 |
+
using TensorTileIterator = TensorTileIterator_;
|
| 380 |
+
using ElementVector = ElementVector_;
|
| 381 |
+
using SharedLoadIterator = SharedLoadIterator_;
|
| 382 |
+
using OutputOp = OutputOp_;
|
| 383 |
+
|
| 384 |
+
/// Fragment type used by the accumulator tile's fragment iterator
|
| 385 |
+
using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment;
|
| 386 |
+
|
| 387 |
+
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
| 388 |
+
using SharedStorage = typename Base::SharedStorage;
|
| 389 |
+
|
| 390 |
+
public:
|
| 391 |
+
|
| 392 |
+
/// Constructor
|
| 393 |
+
CUTLASS_DEVICE
|
| 394 |
+
EpilogueStreamkWithBroadcast(
|
| 395 |
+
SharedStorage &shared_storage, ///< Shared storage object
|
| 396 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 397 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 398 |
+
int lane_idx ///< Id of thread within warp
|
| 399 |
+
):
|
| 400 |
+
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
| 401 |
+
BaseStreamK(thread_idx)
|
| 402 |
+
{ }
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
|
| 406 |
+
/// performing epilogue computations, writing to output
|
| 407 |
+
CUTLASS_DEVICE
|
| 408 |
+
void reduce(
|
| 409 |
+
int peer_idx_begin,
|
| 410 |
+
int peer_idx_end,
|
| 411 |
+
int reduce_fragment_idx,
|
| 412 |
+
void *element_workspace,
|
| 413 |
+
OutputOp const &output_op, ///< Output operator
|
| 414 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 415 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 416 |
+
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 417 |
+
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
| 418 |
+
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
| 419 |
+
MatrixCoord(Shape::kM, Shape::kN),
|
| 420 |
+
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
| 421 |
+
MatrixCoord())
|
| 422 |
+
{
|
| 423 |
+
// Reduce peer accumulator fragments into one fragment
|
| 424 |
+
AccumulatorFragment accum_fragment;
|
| 425 |
+
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
|
| 426 |
+
|
| 427 |
+
// Store fragment to shared memory
|
| 428 |
+
this->warp_tile_iterator_.store(accum_fragment);
|
| 429 |
+
|
| 430 |
+
__syncthreads();
|
| 431 |
+
|
| 432 |
+
Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator, tensor_iterator, problem_size, threadblock_offset);
|
| 433 |
+
|
| 434 |
+
}
|
| 435 |
+
};
|
| 436 |
+
|
| 437 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 438 |
+
|
| 439 |
+
} // namespace threadblock
|
| 440 |
+
} // namespace epilogue
|
| 441 |
+
} // namespace cutlass
|
| 442 |
+
|
| 443 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue.
|
| 33 |
+
|
| 34 |
+
The epilogue finds max values in each row of the row-major output matrix and stores them.
|
| 35 |
+
The max values are also used for a further round of threadblock scoped reduction operation, where
|
| 36 |
+
the partial reduction results are stored in a pre-allocated array and used for further full reduction.
|
| 37 |
+
|
| 38 |
+
*/
|
| 39 |
+
|
| 40 |
+
#pragma once
|
| 41 |
+
|
| 42 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
#include "cutlass/cutlass.h"
|
| 45 |
+
#include "cutlass/arch/memory.h"
|
| 46 |
+
#include "cutlass/arch/memory_sm75.h"
|
| 47 |
+
#include "cutlass/numeric_conversion.h"
|
| 48 |
+
#include "cutlass/fast_math.h"
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace epilogue {
|
| 52 |
+
namespace threadblock {
|
| 53 |
+
|
| 54 |
+
template <
|
| 55 |
+
typename ThreadblockShape_,
|
| 56 |
+
int ThreadCount,
|
| 57 |
+
typename OutputTileIterator_,
|
| 58 |
+
typename ElementAccumulator_,
|
| 59 |
+
typename ElementNorm_,
|
| 60 |
+
typename ElementSum_,
|
| 61 |
+
typename ElementSoftmaxCompute_,
|
| 62 |
+
typename ElementwiseFunctor_,
|
| 63 |
+
bool UseMasking_ = false
|
| 64 |
+
>
|
| 65 |
+
class EpilogueVisitorSoftmax {
|
| 66 |
+
public:
|
| 67 |
+
|
| 68 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 69 |
+
static int const kThreadCount = ThreadCount;
|
| 70 |
+
|
| 71 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 72 |
+
using ElementwiseFunctor = ElementwiseFunctor_;
|
| 73 |
+
|
| 74 |
+
static int const kIterations = OutputTileIterator::kIterations;
|
| 75 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 76 |
+
|
| 77 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 78 |
+
using LayoutOutput = cutlass::layout::RowMajor;
|
| 79 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 80 |
+
|
| 81 |
+
using ElementNorm = ElementNorm_;
|
| 82 |
+
using ElementSum = ElementSum_;
|
| 83 |
+
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
|
| 84 |
+
|
| 85 |
+
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
| 86 |
+
using SoftmaxFragment = Array<ElementSoftmaxCompute, kElementsPerAccess>;
|
| 87 |
+
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
| 88 |
+
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
|
| 89 |
+
|
| 90 |
+
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
| 91 |
+
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
| 92 |
+
static bool const kUseMasking = UseMasking_;
|
| 93 |
+
|
| 94 |
+
/// Argument structure
|
| 95 |
+
struct Arguments {
|
| 96 |
+
|
| 97 |
+
typename ElementwiseFunctor::Params elementwise;
|
| 98 |
+
int64_t batch_stride_C;
|
| 99 |
+
int64_t batch_stride_D;
|
| 100 |
+
int64_t batch_stride_Max;
|
| 101 |
+
int64_t batch_stride_Sum;
|
| 102 |
+
|
| 103 |
+
//
|
| 104 |
+
// Methods
|
| 105 |
+
//
|
| 106 |
+
Arguments():
|
| 107 |
+
batch_stride_C(0),
|
| 108 |
+
batch_stride_D(0),
|
| 109 |
+
batch_stride_Max(0),
|
| 110 |
+
batch_stride_Sum(0)
|
| 111 |
+
{
|
| 112 |
+
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
Arguments(
|
| 116 |
+
typename ElementwiseFunctor::Params elementwise_
|
| 117 |
+
):
|
| 118 |
+
elementwise(elementwise_),
|
| 119 |
+
batch_stride_C(0),
|
| 120 |
+
batch_stride_D(0),
|
| 121 |
+
batch_stride_Max(0),
|
| 122 |
+
batch_stride_Sum(0)
|
| 123 |
+
{
|
| 124 |
+
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
Arguments(
|
| 128 |
+
typename ElementwiseFunctor::Params elementwise_,
|
| 129 |
+
int64_t batch_stride_C_,
|
| 130 |
+
int64_t batch_stride_D_,
|
| 131 |
+
int64_t batch_stride_Max_,
|
| 132 |
+
int64_t batch_stride_Sum_
|
| 133 |
+
):
|
| 134 |
+
elementwise(elementwise_),
|
| 135 |
+
batch_stride_C(batch_stride_C_),
|
| 136 |
+
batch_stride_D(batch_stride_D_),
|
| 137 |
+
batch_stride_Max(batch_stride_Max_),
|
| 138 |
+
batch_stride_Sum(batch_stride_Sum_)
|
| 139 |
+
{
|
| 140 |
+
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
struct Params {
|
| 146 |
+
|
| 147 |
+
typename ElementwiseFunctor::Params elementwise;
|
| 148 |
+
int64_t batch_stride_C;
|
| 149 |
+
int64_t batch_stride_D;
|
| 150 |
+
int64_t batch_stride_Max;
|
| 151 |
+
int64_t batch_stride_Sum;
|
| 152 |
+
//
|
| 153 |
+
// Methods
|
| 154 |
+
//
|
| 155 |
+
CUTLASS_HOST_DEVICE
|
| 156 |
+
Params()
|
| 157 |
+
{
|
| 158 |
+
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
CUTLASS_HOST_DEVICE
|
| 162 |
+
Params(Arguments const &args):
|
| 163 |
+
elementwise(args.elementwise),
|
| 164 |
+
batch_stride_C(args.batch_stride_C),
|
| 165 |
+
batch_stride_D(args.batch_stride_D),
|
| 166 |
+
batch_stride_Max(args.batch_stride_Max),
|
| 167 |
+
batch_stride_Sum(args.batch_stride_Sum)
|
| 168 |
+
{
|
| 169 |
+
|
| 170 |
+
}
|
| 171 |
+
};
|
| 172 |
+
|
| 173 |
+
/// Shared storage
|
| 174 |
+
struct SharedStorage {
|
| 175 |
+
|
| 176 |
+
};
|
| 177 |
+
|
| 178 |
+
private:
|
| 179 |
+
|
| 180 |
+
Params const & params_;
|
| 181 |
+
SharedStorage & shared_storage_;
|
| 182 |
+
MatrixCoord extent_;
|
| 183 |
+
MatrixCoord extent_real_;
|
| 184 |
+
ElementwiseFunctor elementwise_;
|
| 185 |
+
|
| 186 |
+
OutputTileIterator iterator_C_;
|
| 187 |
+
OutputTileIterator iterator_D_;
|
| 188 |
+
typename OutputTileIterator::Fragment fragment_C_;
|
| 189 |
+
typename OutputTileIterator::Fragment fragment_D_;
|
| 190 |
+
|
| 191 |
+
ElementAccumulator alpha_;
|
| 192 |
+
ElementAccumulator beta_;
|
| 193 |
+
|
| 194 |
+
ElementNorm *ptr_Max_;
|
| 195 |
+
ElementSum *ptr_Sum_;
|
| 196 |
+
|
| 197 |
+
int column_offset_;
|
| 198 |
+
|
| 199 |
+
ElementSoftmaxCompute accum_max_;
|
| 200 |
+
ElementSoftmaxCompute accum_sum_;
|
| 201 |
+
|
| 202 |
+
MatrixCoord thread_offset_;
|
| 203 |
+
|
| 204 |
+
float infinity_;
|
| 205 |
+
|
| 206 |
+
public:
|
| 207 |
+
|
| 208 |
+
CUTLASS_DEVICE
|
| 209 |
+
EpilogueVisitorSoftmax(
|
| 210 |
+
Params const ¶ms,
|
| 211 |
+
SharedStorage &shared_storage,
|
| 212 |
+
cutlass::MatrixCoord const &problem_size,
|
| 213 |
+
int thread_idx,
|
| 214 |
+
int warp_idx,
|
| 215 |
+
int lane_idx,
|
| 216 |
+
typename OutputTileIterator::Params params_C,
|
| 217 |
+
typename OutputTileIterator::Params params_D,
|
| 218 |
+
typename OutputTileIterator::Element *ptr_C,
|
| 219 |
+
typename OutputTileIterator::Element *ptr_D,
|
| 220 |
+
ElementNorm *ptr_Max = nullptr,
|
| 221 |
+
ElementSum *ptr_Sum = nullptr,
|
| 222 |
+
cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0),
|
| 223 |
+
int column_offset = 0,
|
| 224 |
+
cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0),
|
| 225 |
+
float infinity = 10000.0f
|
| 226 |
+
):
|
| 227 |
+
params_(params),
|
| 228 |
+
shared_storage_(shared_storage),
|
| 229 |
+
extent_(problem_size),
|
| 230 |
+
elementwise_(params.elementwise),
|
| 231 |
+
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
|
| 232 |
+
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
|
| 233 |
+
ptr_Max_(ptr_Max),
|
| 234 |
+
ptr_Sum_(ptr_Sum),
|
| 235 |
+
column_offset_(column_offset),
|
| 236 |
+
extent_real_(problem_size_real),
|
| 237 |
+
infinity_(infinity)
|
| 238 |
+
{
|
| 239 |
+
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
|
| 240 |
+
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
| 241 |
+
|
| 242 |
+
if (beta_ == ElementAccumulator()) {
|
| 243 |
+
iterator_C_.clear_mask();
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/// Helper to indicate split-K behavior
|
| 248 |
+
CUTLASS_DEVICE
|
| 249 |
+
void set_k_partition(
|
| 250 |
+
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
| 251 |
+
int split_k_slices) { ///< Total number of split-K slices
|
| 252 |
+
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
/// Called to set the batch index
|
| 256 |
+
CUTLASS_DEVICE
|
| 257 |
+
void set_batch_index(int batch_idx) {
|
| 258 |
+
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
| 259 |
+
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
/// Called at the start of the epilogue just before iterating over accumulator slices
|
| 263 |
+
CUTLASS_DEVICE
|
| 264 |
+
void begin_epilogue() {
|
| 265 |
+
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
/// Called at the start of one step before starting accumulator exchange
|
| 269 |
+
CUTLASS_DEVICE
|
| 270 |
+
void begin_step(int step_idx) {
|
| 271 |
+
fragment_D_.clear();
|
| 272 |
+
fragment_C_.clear();
|
| 273 |
+
|
| 274 |
+
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
| 275 |
+
iterator_C_.load(fragment_C_);
|
| 276 |
+
++iterator_C_;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
/// Called at the start of a row
|
| 282 |
+
CUTLASS_DEVICE
|
| 283 |
+
void begin_row(int row_idx) {
|
| 284 |
+
// Clear accumulators for max and sum when starting a whole row
|
| 285 |
+
clear_accum_();
|
| 286 |
+
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
/// Called after accumulators have been exchanged for each accumulator vector
|
| 290 |
+
CUTLASS_DEVICE
|
| 291 |
+
void visit(
|
| 292 |
+
int iter_idx,
|
| 293 |
+
int row_idx,
|
| 294 |
+
int column_idx,
|
| 295 |
+
int frag_idx,
|
| 296 |
+
AccumulatorFragment const &accum) {
|
| 297 |
+
|
| 298 |
+
using Mul = cutlass::multiplies<SoftmaxFragment>;
|
| 299 |
+
using Minus = cutlass::minus<SoftmaxFragment>;
|
| 300 |
+
using Exp = cutlass::fast_exp_op<SoftmaxFragment>;
|
| 301 |
+
|
| 302 |
+
Minus minus;
|
| 303 |
+
Exp exponential;
|
| 304 |
+
|
| 305 |
+
SoftmaxFragment result;
|
| 306 |
+
|
| 307 |
+
NumericArrayConverter<ElementSoftmaxCompute, ElementOutput, kElementsPerAccess> source_converter;
|
| 308 |
+
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
|
| 309 |
+
|
| 310 |
+
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
| 311 |
+
result = source_converter(elementwise_(accum));
|
| 312 |
+
}else{
|
| 313 |
+
result = source_converter(elementwise_(accum, source_vector));
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
thread_offset_ =
|
| 317 |
+
iterator_D_.thread_start() +
|
| 318 |
+
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
| 319 |
+
|
| 320 |
+
bool column_guard = (thread_offset_.column() < extent_.column());
|
| 321 |
+
|
| 322 |
+
if (kUseMasking) {
|
| 323 |
+
int elements_in_boundary = extent_real_.column() - thread_offset_.column();
|
| 324 |
+
elements_in_boundary = (elements_in_boundary > kElementsPerAccess) ? kElementsPerAccess : elements_in_boundary;
|
| 325 |
+
elementwise_padding_(result, elements_in_boundary);
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
ElementSoftmaxCompute accum_max_prev = accum_max_;
|
| 329 |
+
|
| 330 |
+
// Compute the maximum within one row
|
| 331 |
+
if (!column_idx) {
|
| 332 |
+
// This is the first fragment in a new row
|
| 333 |
+
if (column_guard) {
|
| 334 |
+
accum_max_ = maximum_accumulator_(result);
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
else {
|
| 338 |
+
// This is an additional fragment in the same row
|
| 339 |
+
if (column_guard) {
|
| 340 |
+
accum_max_ = maximum_accumulator_(result, accum_max_);
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
// proactively compute max in warps
|
| 345 |
+
accum_max_ = warp_reduce_max_(accum_max_);
|
| 346 |
+
|
| 347 |
+
ElementSoftmaxCompute updater = fast_exp(accum_max_prev - accum_max_);
|
| 348 |
+
|
| 349 |
+
SoftmaxFragment intermediate = exponential(minus(result, accum_max_));
|
| 350 |
+
|
| 351 |
+
if (kHasMultiStepsInRow) {
|
| 352 |
+
if (!column_idx) {
|
| 353 |
+
accum_sum_ = (column_guard) ? \
|
| 354 |
+
sum_accumulator_(intermediate) : ElementSoftmaxCompute(0);
|
| 355 |
+
} else {
|
| 356 |
+
// Algorithm in $3.1, https://arxiv.org/pdf/2205.14135v1.pdf
|
| 357 |
+
// S* = S* x updater + sum_row(P'), where updater = exp(M* - M_row)
|
| 358 |
+
accum_sum_ = (column_guard) ? \
|
| 359 |
+
sum_accumulator_(intermediate, accum_sum_ * updater) : accum_sum_ * updater;
|
| 360 |
+
}
|
| 361 |
+
} else {
|
| 362 |
+
accum_sum_ = (column_guard) ? sum_accumulator_(intermediate, accum_sum_) : ElementSoftmaxCompute(0);
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
// Convert to the output
|
| 366 |
+
NumericArrayConverter<ElementOutput, ElementSoftmaxCompute, kElementsPerAccess> output_converter;
|
| 367 |
+
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
|
| 368 |
+
output = output_converter(result);
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
/// Called at the end of a row
|
| 372 |
+
CUTLASS_DEVICE
|
| 373 |
+
void end_row(int row_idx) {
|
| 374 |
+
|
| 375 |
+
using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>;
|
| 376 |
+
using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>;
|
| 377 |
+
|
| 378 |
+
ConvertSumOutput convert_sum_output;
|
| 379 |
+
ConvertNormOutput convert_norm_output;
|
| 380 |
+
|
| 381 |
+
// Compute accumulate sum only in the last step
|
| 382 |
+
accum_sum_ = warp_reduce_sum_(accum_sum_);
|
| 383 |
+
|
| 384 |
+
bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0);
|
| 385 |
+
bool row_guard = thread_offset_.row() < extent_.row();
|
| 386 |
+
bool is_write_thread = row_guard && is_first_thread_in_tile;
|
| 387 |
+
|
| 388 |
+
int block_batch = blockIdx.z;
|
| 389 |
+
|
| 390 |
+
ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Max;
|
| 391 |
+
ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Sum;
|
| 392 |
+
|
| 393 |
+
arch::global_store<ElementNorm, sizeof(ElementNorm)>(
|
| 394 |
+
convert_norm_output(accum_max_),
|
| 395 |
+
(void *)curr_ptr_max,
|
| 396 |
+
is_write_thread);
|
| 397 |
+
|
| 398 |
+
arch::global_store<ElementSum, sizeof(ElementSum)>(
|
| 399 |
+
convert_sum_output(accum_sum_),
|
| 400 |
+
(void *)curr_ptr_sum,
|
| 401 |
+
is_write_thread);
|
| 402 |
+
|
| 403 |
+
// Clear accumulators for max and sum when finishing a whole row
|
| 404 |
+
clear_accum_();
|
| 405 |
+
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
/// Called after all accumulator elements have been visited
|
| 409 |
+
CUTLASS_DEVICE
|
| 410 |
+
void end_step(int step_idx) {
|
| 411 |
+
|
| 412 |
+
iterator_D_.store(fragment_D_);
|
| 413 |
+
++iterator_D_;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
/// Called after all steps have been completed
|
| 417 |
+
CUTLASS_DEVICE
|
| 418 |
+
void end_epilogue() {
|
| 419 |
+
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
private:
|
| 423 |
+
|
| 424 |
+
CUTLASS_DEVICE
|
| 425 |
+
void elementwise_padding_(SoftmaxFragment &result, int elements_in_boundary) {
|
| 426 |
+
CUTLASS_PRAGMA_UNROLL
|
| 427 |
+
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
| 428 |
+
result[i] = (i < elements_in_boundary) ? result[i] : ElementSoftmaxCompute(-infinity_);
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
CUTLASS_DEVICE
|
| 433 |
+
ElementSoftmaxCompute warp_reduce_sum_(ElementSoftmaxCompute sum_) {
|
| 434 |
+
int half_thread_in_row = (kThreadsPerRow >> 1);
|
| 435 |
+
CUTLASS_PRAGMA_UNROLL
|
| 436 |
+
for (int i = half_thread_in_row; i > 0; i >>= 1) {
|
| 437 |
+
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, sum_, i);
|
| 438 |
+
sum_ += tmp;
|
| 439 |
+
}
|
| 440 |
+
return sum_;
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
CUTLASS_DEVICE
|
| 444 |
+
ElementSoftmaxCompute warp_reduce_max_(ElementSoftmaxCompute max_) {
|
| 445 |
+
int half_thread_in_row = (kThreadsPerRow >> 1);
|
| 446 |
+
CUTLASS_PRAGMA_UNROLL
|
| 447 |
+
for (int i = half_thread_in_row; i > 0; i >>= 1) {
|
| 448 |
+
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, max_, i);
|
| 449 |
+
max_ = fast_max(max_, tmp);
|
| 450 |
+
}
|
| 451 |
+
return max_;
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
CUTLASS_DEVICE
|
| 455 |
+
void clear_accum_() {
|
| 456 |
+
|
| 457 |
+
uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX
|
| 458 |
+
float min_float = reinterpret_cast<float const &>(float_max_bits);
|
| 459 |
+
accum_max_ = ElementSoftmaxCompute(min_float);
|
| 460 |
+
accum_sum_ = ElementSoftmaxCompute(0);
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
CUTLASS_DEVICE
|
| 464 |
+
ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) {
|
| 465 |
+
ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
|
| 466 |
+
|
| 467 |
+
CUTLASS_PRAGMA_UNROLL
|
| 468 |
+
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
| 469 |
+
sum_ += ElementSoftmaxCompute(accum[i]);
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
return sum_;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
CUTLASS_DEVICE
|
| 476 |
+
ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute sum_) {
|
| 477 |
+
// ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
|
| 478 |
+
|
| 479 |
+
CUTLASS_PRAGMA_UNROLL
|
| 480 |
+
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
| 481 |
+
sum_ += ElementSoftmaxCompute(accum[i]);
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
return sum_;
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
CUTLASS_DEVICE
|
| 488 |
+
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) {
|
| 489 |
+
ElementSoftmaxCompute max_ = accum[0];
|
| 490 |
+
|
| 491 |
+
CUTLASS_PRAGMA_UNROLL
|
| 492 |
+
for (int i = 1; i < SoftmaxFragment::kElements; ++i) {
|
| 493 |
+
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
return max_;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
CUTLASS_DEVICE
|
| 500 |
+
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) {
|
| 501 |
+
|
| 502 |
+
CUTLASS_PRAGMA_UNROLL
|
| 503 |
+
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
| 504 |
+
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
return max_;
|
| 508 |
+
}
|
| 509 |
+
};
|
| 510 |
+
|
| 511 |
+
} // namespace threadblock
|
| 512 |
+
} // namespace epilogue
|
| 513 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h
ADDED
|
@@ -0,0 +1,922 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
|
| 34 |
+
\brief Threadblock-level epilogue computing:
|
| 35 |
+
Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias
|
| 36 |
+
D = activation(Aux)
|
| 37 |
+
|
| 38 |
+
if Aux is fp8 type:
|
| 39 |
+
abs_max_output = max( abs(aux) | (for every aux in Aux))
|
| 40 |
+
Aux = scale_aux * Aux
|
| 41 |
+
endif
|
| 42 |
+
|
| 43 |
+
if D is fp8 type:
|
| 44 |
+
abs_max_output = max( abs(d) | (for every d in D))
|
| 45 |
+
D = scale_d * D
|
| 46 |
+
endif
|
| 47 |
+
|
| 48 |
+
Parameter Aux is optionally stored to global memory
|
| 49 |
+
*/
|
| 50 |
+
|
| 51 |
+
#pragma once
|
| 52 |
+
#include "cutlass/cutlass.h"
|
| 53 |
+
#include CUDA_STD_HEADER(cassert)
|
| 54 |
+
|
| 55 |
+
#if defined(__CUDACC_RTC__)
|
| 56 |
+
#include CUDA_STD_HEADER(utility)
|
| 57 |
+
#else
|
| 58 |
+
#include <utility>
|
| 59 |
+
#endif
|
| 60 |
+
|
| 61 |
+
#include "cutlass/array.h"
|
| 62 |
+
#include "cutlass/numeric_types.h"
|
| 63 |
+
#include "cutlass/numeric_conversion.h"
|
| 64 |
+
#include "cutlass/tensor_coord.h"
|
| 65 |
+
#include "cutlass/aligned_buffer.h"
|
| 66 |
+
#include "cutlass/functional.h"
|
| 67 |
+
#include "cutlass/fast_math.h"
|
| 68 |
+
#include "cutlass/layout/vector.h"
|
| 69 |
+
#include "cutlass/layout/tensor.h"
|
| 70 |
+
|
| 71 |
+
#include "cutlass/gemm/gemm.h"
|
| 72 |
+
|
| 73 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 74 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 75 |
+
|
| 76 |
+
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
| 77 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 78 |
+
|
| 79 |
+
#include "cutlass/numeric_types.h"
|
| 80 |
+
|
| 81 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 82 |
+
|
| 83 |
+
namespace cutlass {
|
| 84 |
+
namespace epilogue {
|
| 85 |
+
namespace threadblock {
|
| 86 |
+
|
| 87 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 88 |
+
|
| 89 |
+
namespace detail {
|
| 90 |
+
|
| 91 |
+
/// Helper class for keeping track of absolute maximums and performing scaling
|
| 92 |
+
template <
|
| 93 |
+
typename Iterator, // Iterator type used for storing the data for which absolute maximum and scaling
|
| 94 |
+
// will be computed. This type is used for predicating absolute maximum calculations.
|
| 95 |
+
typename Fragment, // Type of input to be computed on
|
| 96 |
+
bool ScalingAndAmaxNeeded // Whether to perform absolute maximum and scaling operations
|
| 97 |
+
>
|
| 98 |
+
struct ScalingAndAmaxHelper;
|
| 99 |
+
|
| 100 |
+
/// Partial specialization that does not perform scaling or calculate an absolute maximum
|
| 101 |
+
template <typename Iterator, typename Fragment>
|
| 102 |
+
struct ScalingAndAmaxHelper<Iterator, Fragment, false> {
|
| 103 |
+
using Element = typename Fragment::Element;
|
| 104 |
+
|
| 105 |
+
CUTLASS_HOST_DEVICE
|
| 106 |
+
ScalingAndAmaxHelper(Element scale) { }
|
| 107 |
+
|
| 108 |
+
CUTLASS_DEVICE
|
| 109 |
+
Fragment operator()(const Iterator& iterator, const Fragment& inp) {
|
| 110 |
+
return inp;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
CUTLASS_HOST_DEVICE
|
| 114 |
+
Element get_abs_max() const {
|
| 115 |
+
return Element(0.);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
CUTLASS_HOST_DEVICE
|
| 119 |
+
void set_scaling_factor(Element scale_) { }
|
| 120 |
+
};
|
| 121 |
+
|
| 122 |
+
/// Partial specialization that keeps track of an absolute maximum value of inputs seen
|
| 123 |
+
/// and scales inputs
|
| 124 |
+
template <typename Iterator, typename Fragment>
|
| 125 |
+
struct ScalingAndAmaxHelper<Iterator, Fragment, true> {
|
| 126 |
+
using Element = typename Fragment::Element;
|
| 127 |
+
using AccessType = typename Iterator::AccessType;
|
| 128 |
+
using ThreadMap = typename Iterator::ThreadMap;
|
| 129 |
+
|
| 130 |
+
Element abs_max;
|
| 131 |
+
Element scale;
|
| 132 |
+
|
| 133 |
+
// Operators
|
| 134 |
+
maximum_with_nan_propogation<Element> max_op;
|
| 135 |
+
absolute_value_op<Element> abs_op;
|
| 136 |
+
multiplies<Fragment> multiply;
|
| 137 |
+
|
| 138 |
+
CUTLASS_HOST_DEVICE
|
| 139 |
+
ScalingAndAmaxHelper(Element scale_) : abs_max(0.), scale(scale_) { }
|
| 140 |
+
|
| 141 |
+
// Compute the absolute maximum value between `abs_max` and the entries
|
| 142 |
+
// of `frag` for predicated-on entries of `iterator`. Return a scaled
|
| 143 |
+
// version of `inp`.
|
| 144 |
+
CUTLASS_DEVICE
|
| 145 |
+
Fragment operator()(const Iterator& iterator, const Fragment& frag) {
|
| 146 |
+
using PredicateGroup = Array<Element, Iterator::ThreadMap::kElementsPerAccess>;
|
| 147 |
+
PredicateGroup const *frag_ptr = reinterpret_cast<PredicateGroup const *>(&frag);
|
| 148 |
+
|
| 149 |
+
typename Iterator::Mask mask;
|
| 150 |
+
iterator.get_mask(mask);
|
| 151 |
+
|
| 152 |
+
CUTLASS_PRAGMA_UNROLL
|
| 153 |
+
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
|
| 154 |
+
|
| 155 |
+
CUTLASS_PRAGMA_UNROLL
|
| 156 |
+
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
| 157 |
+
|
| 158 |
+
CUTLASS_PRAGMA_UNROLL
|
| 159 |
+
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
| 160 |
+
int frag_row_idx =
|
| 161 |
+
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
| 162 |
+
|
| 163 |
+
int row_offset = row * ThreadMap::Delta::kRow
|
| 164 |
+
+ group * ThreadMap::Delta::kGroup
|
| 165 |
+
+ cluster * ThreadMap::Delta::kCluster;
|
| 166 |
+
|
| 167 |
+
bool row_guard = ((row_offset + iterator.thread_start_row()) < iterator.extent_row());
|
| 168 |
+
|
| 169 |
+
CUTLASS_PRAGMA_UNROLL
|
| 170 |
+
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
| 171 |
+
bool guard = row_guard && mask.predicates[column];
|
| 172 |
+
|
| 173 |
+
if (guard) {
|
| 174 |
+
int access_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
|
| 175 |
+
CUTLASS_PRAGMA_UNROLL
|
| 176 |
+
for (int i = 0; i < PredicateGroup::kElements; ++i) {
|
| 177 |
+
abs_max = max_op(abs_max, abs_op(frag_ptr[access_idx][i]));
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
// Perform scaling
|
| 186 |
+
return multiply(scale, frag);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
CUTLASS_HOST_DEVICE
|
| 190 |
+
Element get_abs_max() const {
|
| 191 |
+
return abs_max;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
CUTLASS_HOST_DEVICE
|
| 195 |
+
void set_scaling_factor(Element scale_) {
|
| 196 |
+
scale = scale_;
|
| 197 |
+
}
|
| 198 |
+
};
|
| 199 |
+
|
| 200 |
+
} // namespace detail
|
| 201 |
+
|
| 202 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 203 |
+
|
| 204 |
+
template <
|
| 205 |
+
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 206 |
+
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 207 |
+
int PartitionsK, ///< Number of partitions of the K dimension
|
| 208 |
+
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
| 209 |
+
typename AuxOutputTileIterator_, ///< Tile iterator writing auxiliary output tensors
|
| 210 |
+
typename ElementVector_, ///< Data type of bias vector
|
| 211 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 212 |
+
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 213 |
+
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
| 214 |
+
typename OutputOp_, ///< Output operator
|
| 215 |
+
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
| 216 |
+
int FragmentsPerPartition = 1, ///< Used to coarsen the epilogue granularity
|
| 217 |
+
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
| 218 |
+
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
| 219 |
+
>
|
| 220 |
+
class EpilogueWithAbsMax :
|
| 221 |
+
public EpilogueBase<
|
| 222 |
+
Shape_,
|
| 223 |
+
typename WarpMmaOperator_::Shape,
|
| 224 |
+
PartitionsK,
|
| 225 |
+
AccumulatorFragmentIterator_,
|
| 226 |
+
WarpTileIterator_,
|
| 227 |
+
Padding_,
|
| 228 |
+
FragmentsPerPartition> {
|
| 229 |
+
|
| 230 |
+
public:
|
| 231 |
+
|
| 232 |
+
using Base = EpilogueBase<
|
| 233 |
+
Shape_,
|
| 234 |
+
typename WarpMmaOperator_::Shape,
|
| 235 |
+
PartitionsK,
|
| 236 |
+
AccumulatorFragmentIterator_,
|
| 237 |
+
WarpTileIterator_,
|
| 238 |
+
Padding_,
|
| 239 |
+
FragmentsPerPartition>;
|
| 240 |
+
|
| 241 |
+
static bool const kIsSingleSource = true;
|
| 242 |
+
using Shape = Shape_;
|
| 243 |
+
using WarpMmaOperator = WarpMmaOperator_;
|
| 244 |
+
static int const kPartitionsK = PartitionsK;
|
| 245 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 246 |
+
using AuxOutputTileIterator = AuxOutputTileIterator_;
|
| 247 |
+
using ElementVector = ElementVector_;
|
| 248 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 249 |
+
using WarpTileIterator = WarpTileIterator_;
|
| 250 |
+
using SharedLoadIterator = SharedLoadIterator_;
|
| 251 |
+
using OutputOp = OutputOp_;
|
| 252 |
+
using Padding = Padding_;
|
| 253 |
+
|
| 254 |
+
using Layout = layout::RowMajor;
|
| 255 |
+
using LongIndex = typename Layout::LongIndex;
|
| 256 |
+
|
| 257 |
+
/// The complete warp-level accumulator tile
|
| 258 |
+
using AccumulatorTile = typename Base::AccumulatorTile;
|
| 259 |
+
|
| 260 |
+
/// Accumulator element
|
| 261 |
+
using ElementAccumulator = typename WarpTileIterator::Element;
|
| 262 |
+
|
| 263 |
+
/// Data type used for absolute maximum value
|
| 264 |
+
using ElementAbsmax = typename OutputOp::ElementAbsmax;
|
| 265 |
+
|
| 266 |
+
/// Compute data type produced by the output op
|
| 267 |
+
using ElementCompute = typename OutputOp::ElementCompute;
|
| 268 |
+
|
| 269 |
+
/// Compute fragment
|
| 270 |
+
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
|
| 271 |
+
|
| 272 |
+
/// Helpers for (optionally) computing absolute maximums and scaling output and auxiliary output
|
| 273 |
+
using OutputScaler = detail::ScalingAndAmaxHelper<OutputTileIterator,
|
| 274 |
+
FragmentCompute,
|
| 275 |
+
OutputOp::kIsScalingAndAmaxOutputNeeded>;
|
| 276 |
+
|
| 277 |
+
using AuxOutputScaler = detail::ScalingAndAmaxHelper<AuxOutputTileIterator,
|
| 278 |
+
FragmentCompute,
|
| 279 |
+
OutputOp::kIsScalingAndAmaxAuxOutputNeeded>;
|
| 280 |
+
|
| 281 |
+
/// Thread map used by output tile iterators
|
| 282 |
+
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
| 283 |
+
|
| 284 |
+
/// Fragment object used to store the broadcast values
|
| 285 |
+
using BroadcastFragment = Array<
|
| 286 |
+
ElementCompute,
|
| 287 |
+
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
| 288 |
+
|
| 289 |
+
/// Output element
|
| 290 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 291 |
+
|
| 292 |
+
/// Data type of auxiliary output
|
| 293 |
+
using ElementAuxOutput = typename AuxOutputTileIterator::Element;
|
| 294 |
+
|
| 295 |
+
/// Output access size
|
| 296 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 297 |
+
|
| 298 |
+
/// Tensor reference to destination tensor
|
| 299 |
+
using TensorRef = typename OutputTileIterator::TensorRef;
|
| 300 |
+
|
| 301 |
+
/// Tensor reference to sync tensor
|
| 302 |
+
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
| 303 |
+
|
| 304 |
+
/// Const tensor reference to source tensor
|
| 305 |
+
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
| 306 |
+
|
| 307 |
+
/// Array type used to output
|
| 308 |
+
using OutputAccessType = Array<
|
| 309 |
+
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 310 |
+
|
| 311 |
+
/// Array type used by output functor
|
| 312 |
+
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 313 |
+
|
| 314 |
+
/// Array type used by output functor
|
| 315 |
+
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
|
| 316 |
+
|
| 317 |
+
/// Auxiliary output access type
|
| 318 |
+
using AuxAccessType = Array<ElementAuxOutput, OutputTileIterator::kElementsPerAccess>;
|
| 319 |
+
|
| 320 |
+
/// Number of warps
|
| 321 |
+
using WarpCount = typename Base::WarpCount;
|
| 322 |
+
|
| 323 |
+
/// Shared memory allocation from epilogue base class
|
| 324 |
+
using BaseSharedStorage = typename Base::SharedStorage;
|
| 325 |
+
|
| 326 |
+
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
| 327 |
+
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
| 328 |
+
|
| 329 |
+
/// Used for the broadcast
|
| 330 |
+
struct BroadcastDetail {
|
| 331 |
+
|
| 332 |
+
/// Number of threads per warp
|
| 333 |
+
static int const kWarpSize = 32;
|
| 334 |
+
|
| 335 |
+
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
| 336 |
+
|
| 337 |
+
/// Number of distinct scalar column indices handled by each thread
|
| 338 |
+
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
| 339 |
+
|
| 340 |
+
/// Number of distinct scalar row indices handled by each thread
|
| 341 |
+
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
| 342 |
+
|
| 343 |
+
/// Number of threads per threadblock
|
| 344 |
+
static int const kThreadCount = kWarpSize * WarpCount::kCount;
|
| 345 |
+
|
| 346 |
+
/// Number of distinct threads per row of output tile
|
| 347 |
+
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
|
| 348 |
+
|
| 349 |
+
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
| 350 |
+
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
| 351 |
+
|
| 352 |
+
/// I'm not sure what I meant here.
|
| 353 |
+
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
| 354 |
+
|
| 355 |
+
/// Shape of the shared memory allocation for the epilogue
|
| 356 |
+
using StorageShape = MatrixShape<
|
| 357 |
+
kThreadRows,
|
| 358 |
+
Shape::kN
|
| 359 |
+
>;
|
| 360 |
+
|
| 361 |
+
/// Debug printing
|
| 362 |
+
CUTLASS_DEVICE
|
| 363 |
+
static void print() {
|
| 364 |
+
#if 0
|
| 365 |
+
printf("BroadcastDetail {\n");
|
| 366 |
+
printf(
|
| 367 |
+
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
| 368 |
+
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
|
| 369 |
+
kColumnsPerThread,
|
| 370 |
+
kRowsPerThread,
|
| 371 |
+
kThreadCount,
|
| 372 |
+
kThreadsPerRow,
|
| 373 |
+
kThreadRows,
|
| 374 |
+
kThreadAccessesPerRow,
|
| 375 |
+
StorageShape::kRow,
|
| 376 |
+
StorageShape::kColumn,
|
| 377 |
+
StorageShape::kCount
|
| 378 |
+
);
|
| 379 |
+
printf("};\n");
|
| 380 |
+
#endif
|
| 381 |
+
}
|
| 382 |
+
};
|
| 383 |
+
|
| 384 |
+
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
| 385 |
+
struct SharedStorage {
|
| 386 |
+
union {
|
| 387 |
+
BaseSharedStorage base;
|
| 388 |
+
};
|
| 389 |
+
|
| 390 |
+
CUTLASS_HOST_DEVICE
|
| 391 |
+
SharedStorage() { }
|
| 392 |
+
};
|
| 393 |
+
|
| 394 |
+
public:
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
| 398 |
+
"Mismatch between shared load iterator and output tile iterator.");
|
| 399 |
+
|
| 400 |
+
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
| 401 |
+
|
| 402 |
+
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
| 403 |
+
"Divisibility");
|
| 404 |
+
|
| 405 |
+
private:
|
| 406 |
+
|
| 407 |
+
/// Loads fragment from shared memory aligned with output tensor
|
| 408 |
+
SharedLoadIterator shared_load_iterator_;
|
| 409 |
+
|
| 410 |
+
/// Thread index within the threadblock
|
| 411 |
+
int thread_idx_;
|
| 412 |
+
|
| 413 |
+
public:
|
| 414 |
+
|
| 415 |
+
/// Constructor
|
| 416 |
+
CUTLASS_DEVICE
|
| 417 |
+
EpilogueWithAbsMax(
|
| 418 |
+
SharedStorage &shared_storage, ///< Shared storage object
|
| 419 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 420 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 421 |
+
int lane_idx ///< Id of thread within warp
|
| 422 |
+
):
|
| 423 |
+
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
|
| 424 |
+
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
|
| 425 |
+
thread_idx_(thread_idx)
|
| 426 |
+
{
|
| 427 |
+
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
/// Streams the result to global memory
|
| 431 |
+
CUTLASS_DEVICE
|
| 432 |
+
void operator()(
|
| 433 |
+
OutputOp &output_op, ///< Output operator
|
| 434 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 435 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 436 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 437 |
+
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
| 438 |
+
AuxOutputTileIterator aux_iterator, ///< Tile iterator for destination auxiliary output
|
| 439 |
+
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
| 440 |
+
MatrixCoord(Shape::kM, Shape::kN),
|
| 441 |
+
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
| 442 |
+
MatrixCoord()) {
|
| 443 |
+
|
| 444 |
+
BroadcastFragment broadcast_fragment;
|
| 445 |
+
|
| 446 |
+
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
| 447 |
+
|
| 448 |
+
OutputScaler output_scaler(output_op.get_scale_d());
|
| 449 |
+
|
| 450 |
+
AuxOutputScaler aux_scaler(output_op.get_scale_aux());
|
| 451 |
+
|
| 452 |
+
if (!output_op.is_source_needed()) {
|
| 453 |
+
compute_source_not_needed_(
|
| 454 |
+
output_op,
|
| 455 |
+
broadcast_fragment,
|
| 456 |
+
destination_iterator,
|
| 457 |
+
accumulators,
|
| 458 |
+
aux_iterator,
|
| 459 |
+
output_scaler,
|
| 460 |
+
aux_scaler);
|
| 461 |
+
}
|
| 462 |
+
else {
|
| 463 |
+
compute_source_needed_(
|
| 464 |
+
output_op,
|
| 465 |
+
broadcast_fragment,
|
| 466 |
+
destination_iterator,
|
| 467 |
+
accumulators,
|
| 468 |
+
source_iterator,
|
| 469 |
+
aux_iterator,
|
| 470 |
+
output_scaler,
|
| 471 |
+
aux_scaler);
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
// Store the absolute maximum values of the output and auxiliar tensors, if needed.
|
| 475 |
+
if (output_op.get_ptr_output_abs_max() != nullptr) {
|
| 476 |
+
ElementAbsmax local_abs_max =
|
| 477 |
+
NumericConverter<ElementAbsmax, ElementCompute, OutputOp::kRound>{}(output_scaler.get_abs_max());
|
| 478 |
+
atomic_maximum<ElementAbsmax>{}(
|
| 479 |
+
output_op.get_ptr_output_abs_max(), local_abs_max);
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
if (output_op.get_ptr_aux_output_abs_max() != nullptr) {
|
| 483 |
+
ElementAbsmax local_abs_max =
|
| 484 |
+
NumericConverter<ElementAbsmax, ElementCompute, OutputOp::kRound>{}(aux_scaler.get_abs_max());
|
| 485 |
+
atomic_maximum<ElementAbsmax>{}(
|
| 486 |
+
output_op.get_ptr_aux_output_abs_max(), local_abs_max);
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
private:
|
| 491 |
+
|
| 492 |
+
CUTLASS_DEVICE
|
| 493 |
+
void load_broadcast_fragment_(
|
| 494 |
+
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 495 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 496 |
+
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
|
| 497 |
+
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
|
| 498 |
+
) {
|
| 499 |
+
|
| 500 |
+
broadcast_fragment.clear();
|
| 501 |
+
|
| 502 |
+
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
| 503 |
+
if (!broadcast_ptr) {
|
| 504 |
+
return;
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
| 508 |
+
|
| 509 |
+
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
|
| 510 |
+
broadcast_ptr += thread_initial_column;
|
| 511 |
+
|
| 512 |
+
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
| 513 |
+
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
| 514 |
+
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
|
| 515 |
+
|
| 516 |
+
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
|
| 517 |
+
|
| 518 |
+
CUTLASS_PRAGMA_UNROLL
|
| 519 |
+
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
| 520 |
+
|
| 521 |
+
AccessType loaded;
|
| 522 |
+
|
| 523 |
+
loaded.clear();
|
| 524 |
+
|
| 525 |
+
if (thread_column_idx < problem_size.column()) {
|
| 526 |
+
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
ComputeFragmentType cvt = converter(loaded);
|
| 530 |
+
frag_ptr[j] = cvt;
|
| 531 |
+
|
| 532 |
+
thread_column_idx += ThreadMap::Delta::kColumn;
|
| 533 |
+
broadcast_ptr += ThreadMap::Delta::kColumn;
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
template <class Seq>
|
| 538 |
+
struct acc2smem_source_not_needed;
|
| 539 |
+
|
| 540 |
+
template <size_t... Seq>
|
| 541 |
+
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
| 542 |
+
template <int Advance>
|
| 543 |
+
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
| 544 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 545 |
+
CUTLASS_PRAGMA_UNROLL
|
| 546 |
+
for (int i = 0; i < Advance; i++) {
|
| 547 |
+
++accum_fragment_iterator;
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
CUTLASS_PRAGMA_UNROLL
|
| 551 |
+
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
| 552 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 553 |
+
|
| 554 |
+
accum_fragment_iterator.load(accum_fragment);
|
| 555 |
+
++accum_fragment_iterator;
|
| 556 |
+
|
| 557 |
+
warp_tile_iterator.store(accum_fragment);
|
| 558 |
+
if (p < Base::kFragmentsPerIteration - 1) {
|
| 559 |
+
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
| 560 |
+
}
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
if (Base::kFragmentsPerIteration > 1) {
|
| 564 |
+
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
| 565 |
+
(1 - Base::kFragmentsPerIteration));
|
| 566 |
+
}
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
CUTLASS_DEVICE
|
| 570 |
+
static void push(size_t pos,
|
| 571 |
+
AccumulatorFragmentIterator const &iterator_begin,
|
| 572 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 573 |
+
int dummy[] = {
|
| 574 |
+
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
| 575 |
+
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
| 576 |
+
|
| 577 |
+
CUTLASS_UNUSED(dummy[0]);
|
| 578 |
+
}
|
| 579 |
+
};
|
| 580 |
+
|
| 581 |
+
/// Streams the result to global memory
|
| 582 |
+
CUTLASS_DEVICE
|
| 583 |
+
void compute_source_not_needed_(
|
| 584 |
+
OutputOp &output_op, ///< Output operator
|
| 585 |
+
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 586 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 587 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 588 |
+
AuxOutputTileIterator aux_iterator, ///< Tile iterator for destination auxiliary output
|
| 589 |
+
OutputScaler& output_scaler, ///< Helper for (optionally) computing the absolute maximum and scaling output
|
| 590 |
+
AuxOutputScaler& aux_scaler ///< Helper for (optionally) computing the absolute maximum and scaling the auxiliary output
|
| 591 |
+
) {
|
| 592 |
+
|
| 593 |
+
//
|
| 594 |
+
// Iterator over warp-level accumulator fragment
|
| 595 |
+
//
|
| 596 |
+
|
| 597 |
+
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 598 |
+
|
| 599 |
+
//
|
| 600 |
+
// Iterate over accumulator tile
|
| 601 |
+
//
|
| 602 |
+
|
| 603 |
+
// CUTLASS_PRAGMA_UNROLL
|
| 604 |
+
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
| 605 |
+
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
| 606 |
+
|
| 607 |
+
//
|
| 608 |
+
// Convert and store fragment
|
| 609 |
+
//
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
__syncthreads();
|
| 613 |
+
|
| 614 |
+
acc2smem_source_not_needed<
|
| 615 |
+
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
| 616 |
+
Base::kFragmentsPerIteration>>::push(iter,
|
| 617 |
+
accum_fragment_iterator,
|
| 618 |
+
this->warp_tile_iterator_);
|
| 619 |
+
|
| 620 |
+
__syncthreads();
|
| 621 |
+
|
| 622 |
+
//
|
| 623 |
+
// Load fragments from shared memory
|
| 624 |
+
//
|
| 625 |
+
|
| 626 |
+
CUTLASS_PRAGMA_UNROLL
|
| 627 |
+
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 631 |
+
|
| 632 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 633 |
+
|
| 634 |
+
if (p < Base::kFragmentsPerIteration - 1) {
|
| 635 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 636 |
+
}
|
| 637 |
+
else if (kPartitionsK > 1) {
|
| 638 |
+
|
| 639 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 640 |
+
|
| 641 |
+
CUTLASS_PRAGMA_UNROLL
|
| 642 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 643 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 644 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 645 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
//
|
| 652 |
+
// Apply output operation
|
| 653 |
+
//
|
| 654 |
+
|
| 655 |
+
FragmentCompute frag_Z_compute;
|
| 656 |
+
FragmentCompute frag_Aux_compute;
|
| 657 |
+
|
| 658 |
+
apply_output_operator_source_not_needed_(
|
| 659 |
+
frag_Z_compute,
|
| 660 |
+
frag_Aux_compute,
|
| 661 |
+
output_op,
|
| 662 |
+
aligned_accum_fragment[0],
|
| 663 |
+
broadcast_fragment);
|
| 664 |
+
|
| 665 |
+
//
|
| 666 |
+
// Conditionally store fragments
|
| 667 |
+
//
|
| 668 |
+
|
| 669 |
+
// (Optionally) compute the absolute maximum of frag_Z and scale frag_Z
|
| 670 |
+
frag_Z_compute = output_scaler(destination_iterator, frag_Z_compute);
|
| 671 |
+
NumericArrayConverter<typename OutputTileIterator::Fragment::Element, ElementCompute,
|
| 672 |
+
OutputTileIterator::Fragment::kElements> cvt_to_dst;
|
| 673 |
+
typename OutputTileIterator::Fragment frag_Z = cvt_to_dst(frag_Z_compute);
|
| 674 |
+
|
| 675 |
+
// Always store the output
|
| 676 |
+
destination_iterator.store(frag_Z);
|
| 677 |
+
++destination_iterator;
|
| 678 |
+
|
| 679 |
+
// Only store the auxiliary output if scaling and absolute-maximum calculation were needed
|
| 680 |
+
if (OutputOp::kIsScalingAndAmaxAuxOutputNeeded) {
|
| 681 |
+
frag_Aux_compute = aux_scaler(aux_iterator, frag_Aux_compute);
|
| 682 |
+
|
| 683 |
+
NumericArrayConverter<typename AuxOutputTileIterator::Fragment::Element, ElementCompute,
|
| 684 |
+
AuxOutputTileIterator::Fragment::kElements> cvt_to_aux;
|
| 685 |
+
typename AuxOutputTileIterator::Fragment frag_Aux = cvt_to_aux(frag_Aux_compute);
|
| 686 |
+
aux_iterator.store(frag_Aux);
|
| 687 |
+
++aux_iterator;
|
| 688 |
+
}
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
if (Base::kFragmentsPerIteration > 1) {
|
| 692 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
| 693 |
+
}
|
| 694 |
+
}
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
template<class Seq>
|
| 699 |
+
struct acc2smem_source_needed;
|
| 700 |
+
|
| 701 |
+
template <size_t... Seq>
|
| 702 |
+
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
| 703 |
+
template<int Advance>
|
| 704 |
+
CUTLASS_DEVICE
|
| 705 |
+
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
| 706 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 707 |
+
CUTLASS_PRAGMA_UNROLL
|
| 708 |
+
for (int i = 0; i < Advance; i++) {
|
| 709 |
+
++accum_fragment_iterator;
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 713 |
+
accum_fragment_iterator.load(accum_fragment);
|
| 714 |
+
warp_tile_iterator.store(accum_fragment);
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
CUTLASS_DEVICE
|
| 718 |
+
static void push(size_t pos,
|
| 719 |
+
AccumulatorFragmentIterator const &iterator_begin,
|
| 720 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 721 |
+
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
| 722 |
+
}
|
| 723 |
+
};
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
/// Streams the result to global memory
|
| 727 |
+
CUTLASS_DEVICE
|
| 728 |
+
void compute_source_needed_(
|
| 729 |
+
OutputOp &output_op, ///< Output operator
|
| 730 |
+
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 731 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 732 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 733 |
+
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
| 734 |
+
AuxOutputTileIterator aux_iterator, ///< Tile iterator for destination auxiliary output
|
| 735 |
+
OutputScaler& output_scaler, ///< Helper for (optionally) computing the absolute maximum and scaling output
|
| 736 |
+
AuxOutputScaler& aux_scaler ///< Helper for (optionally) computing the absolute maximum and scaling the auxiliary output
|
| 737 |
+
) {
|
| 738 |
+
|
| 739 |
+
typename OutputTileIterator::Fragment source_fragment;
|
| 740 |
+
source_fragment.clear();
|
| 741 |
+
|
| 742 |
+
//
|
| 743 |
+
// Iterator over warp-level accumulator fragment
|
| 744 |
+
//
|
| 745 |
+
|
| 746 |
+
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 747 |
+
|
| 748 |
+
//
|
| 749 |
+
// Iterate over accumulator tile
|
| 750 |
+
//
|
| 751 |
+
|
| 752 |
+
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
| 753 |
+
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
| 754 |
+
|
| 755 |
+
//
|
| 756 |
+
// Load the source
|
| 757 |
+
//
|
| 758 |
+
|
| 759 |
+
source_iterator.load(source_fragment);
|
| 760 |
+
++source_iterator;
|
| 761 |
+
|
| 762 |
+
//
|
| 763 |
+
// Convert and store fragment
|
| 764 |
+
//
|
| 765 |
+
|
| 766 |
+
__syncthreads();
|
| 767 |
+
|
| 768 |
+
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
| 769 |
+
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
| 770 |
+
|
| 771 |
+
__syncthreads();
|
| 772 |
+
|
| 773 |
+
//
|
| 774 |
+
// Load fragments from shared memory
|
| 775 |
+
//
|
| 776 |
+
|
| 777 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 778 |
+
|
| 779 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 780 |
+
|
| 781 |
+
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
| 782 |
+
if (kPartitionsK > 1)
|
| 783 |
+
{
|
| 784 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 785 |
+
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
|
| 786 |
+
|
| 787 |
+
CUTLASS_PRAGMA_UNROLL
|
| 788 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 789 |
+
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
|
| 790 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 791 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
//
|
| 798 |
+
// Apply output operation
|
| 799 |
+
//
|
| 800 |
+
|
| 801 |
+
FragmentCompute frag_Z_compute;
|
| 802 |
+
FragmentCompute frag_Aux_compute;
|
| 803 |
+
|
| 804 |
+
apply_output_operator_(
|
| 805 |
+
frag_Z_compute,
|
| 806 |
+
frag_Aux_compute,
|
| 807 |
+
output_op,
|
| 808 |
+
aligned_accum_fragment[0],
|
| 809 |
+
source_fragment,
|
| 810 |
+
broadcast_fragment);
|
| 811 |
+
|
| 812 |
+
//
|
| 813 |
+
// Conditionally store fragments
|
| 814 |
+
//
|
| 815 |
+
|
| 816 |
+
// (Optionally) compute the absolute maximum of frag_Z and scale frag_Z
|
| 817 |
+
frag_Z_compute = output_scaler(destination_iterator, frag_Z_compute);
|
| 818 |
+
NumericArrayConverter<typename OutputTileIterator::Fragment::Element, ElementCompute,
|
| 819 |
+
OutputTileIterator::Fragment::kElements> cvt_to_dst;
|
| 820 |
+
typename OutputTileIterator::Fragment frag_Z = cvt_to_dst(frag_Z_compute);
|
| 821 |
+
|
| 822 |
+
// Always store the output
|
| 823 |
+
destination_iterator.store(frag_Z);
|
| 824 |
+
++destination_iterator;
|
| 825 |
+
|
| 826 |
+
// Only store the auxiliary output if scaling and absolute-maximum calculation were needed
|
| 827 |
+
if (OutputOp::kIsScalingAndAmaxAuxOutputNeeded) {
|
| 828 |
+
frag_Aux_compute = aux_scaler(aux_iterator, frag_Aux_compute);
|
| 829 |
+
|
| 830 |
+
NumericArrayConverter<typename AuxOutputTileIterator::Fragment::Element, ElementCompute,
|
| 831 |
+
AuxOutputTileIterator::Fragment::kElements> cvt_to_aux;
|
| 832 |
+
typename AuxOutputTileIterator::Fragment frag_Aux = cvt_to_aux(frag_Aux_compute);
|
| 833 |
+
aux_iterator.store(frag_Aux);
|
| 834 |
+
++aux_iterator;
|
| 835 |
+
}
|
| 836 |
+
}
|
| 837 |
+
}
|
| 838 |
+
|
| 839 |
+
/// Helper to invoke the output functor over each vector of output
|
| 840 |
+
CUTLASS_DEVICE
|
| 841 |
+
void apply_output_operator_(
|
| 842 |
+
FragmentCompute &frag_Z,
|
| 843 |
+
FragmentCompute &frag_Aux,
|
| 844 |
+
OutputOp &output_op,
|
| 845 |
+
typename SharedLoadIterator::Fragment const &frag_AB,
|
| 846 |
+
typename OutputTileIterator::Fragment const &frag_C,
|
| 847 |
+
BroadcastFragment const &frag_Broadcast) {
|
| 848 |
+
|
| 849 |
+
using AccessTypeZ = Array<ElementCompute, kElementsPerAccess>;
|
| 850 |
+
using AccessTypeAux = Array<ElementCompute, kElementsPerAccess>;
|
| 851 |
+
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
| 852 |
+
|
| 853 |
+
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
| 854 |
+
AccessTypeAux *frag_Aux_ptr = reinterpret_cast<AccessTypeAux *>(&frag_Aux);
|
| 855 |
+
|
| 856 |
+
AccumulatorAccessType const *frag_AB_ptr =
|
| 857 |
+
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
| 858 |
+
|
| 859 |
+
OutputAccessType const *frag_C_ptr =
|
| 860 |
+
reinterpret_cast<OutputAccessType const *>(&frag_C);
|
| 861 |
+
|
| 862 |
+
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
| 863 |
+
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
| 864 |
+
|
| 865 |
+
int const kOutputOpIterations =
|
| 866 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 867 |
+
|
| 868 |
+
CUTLASS_PRAGMA_UNROLL
|
| 869 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 870 |
+
output_op(
|
| 871 |
+
frag_Z_ptr[i],
|
| 872 |
+
frag_Aux_ptr[i],
|
| 873 |
+
frag_AB_ptr[i],
|
| 874 |
+
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn],
|
| 875 |
+
frag_C_ptr[i]);
|
| 876 |
+
}
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
/// Helper to invoke the output functor over each vector of output
|
| 880 |
+
CUTLASS_DEVICE
|
| 881 |
+
void apply_output_operator_source_not_needed_(
|
| 882 |
+
FragmentCompute &frag_Z,
|
| 883 |
+
FragmentCompute &frag_Aux,
|
| 884 |
+
OutputOp &output_op,
|
| 885 |
+
typename SharedLoadIterator::Fragment const &frag_AB,
|
| 886 |
+
BroadcastFragment const &frag_Broadcast) {
|
| 887 |
+
|
| 888 |
+
using AccessTypeZ = Array<ElementCompute, kElementsPerAccess>;
|
| 889 |
+
using AccessTypeAux = Array<ElementCompute, kElementsPerAccess>;
|
| 890 |
+
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
| 891 |
+
|
| 892 |
+
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
| 893 |
+
AccessTypeAux *frag_Aux_ptr = reinterpret_cast<AccessTypeAux *>(&frag_Aux);
|
| 894 |
+
|
| 895 |
+
AccumulatorAccessType const *frag_AB_ptr =
|
| 896 |
+
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
| 897 |
+
|
| 898 |
+
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
| 899 |
+
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
| 900 |
+
|
| 901 |
+
int const kOutputOpIterations =
|
| 902 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 903 |
+
|
| 904 |
+
CUTLASS_PRAGMA_UNROLL
|
| 905 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 906 |
+
|
| 907 |
+
output_op(
|
| 908 |
+
frag_Z_ptr[i],
|
| 909 |
+
frag_Aux_ptr[i],
|
| 910 |
+
frag_AB_ptr[i],
|
| 911 |
+
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
| 912 |
+
}
|
| 913 |
+
}
|
| 914 |
+
};
|
| 915 |
+
|
| 916 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 917 |
+
|
| 918 |
+
} // namespace threadblock
|
| 919 |
+
} // namespace epilogue
|
| 920 |
+
} // namespace cutlass
|
| 921 |
+
|
| 922 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h
ADDED
|
@@ -0,0 +1,1717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
|
| 33 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 34 |
+
|
| 35 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 36 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 37 |
+
|
| 38 |
+
*/
|
| 39 |
+
|
| 40 |
+
#pragma once
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include CUDA_STD_HEADER(cassert)
|
| 43 |
+
|
| 44 |
+
#if defined(__CUDACC_RTC__)
|
| 45 |
+
#include CUDA_STD_HEADER(utility)
|
| 46 |
+
#else
|
| 47 |
+
#include <utility>
|
| 48 |
+
#endif
|
| 49 |
+
|
| 50 |
+
#include "cutlass/array.h"
|
| 51 |
+
#include "cutlass/numeric_types.h"
|
| 52 |
+
#include "cutlass/numeric_conversion.h"
|
| 53 |
+
#include "cutlass/tensor_coord.h"
|
| 54 |
+
#include "cutlass/aligned_buffer.h"
|
| 55 |
+
#include "cutlass/functional.h"
|
| 56 |
+
#include "cutlass/fast_math.h"
|
| 57 |
+
#include "cutlass/layout/vector.h"
|
| 58 |
+
#include "cutlass/layout/tensor.h"
|
| 59 |
+
|
| 60 |
+
#include "cutlass/gemm/gemm.h"
|
| 61 |
+
|
| 62 |
+
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 63 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 64 |
+
|
| 65 |
+
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
| 66 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 67 |
+
|
| 68 |
+
#include "cutlass/numeric_types.h"
|
| 69 |
+
|
| 70 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 71 |
+
|
| 72 |
+
namespace cutlass {
|
| 73 |
+
namespace epilogue {
|
| 74 |
+
namespace threadblock {
|
| 75 |
+
|
| 76 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 77 |
+
|
| 78 |
+
/// This base class is meant to define the concept required of the
|
| 79 |
+
/// EpilogueWithBroadcast::OutputOp
|
| 80 |
+
template <
|
| 81 |
+
typename ElementC_,
|
| 82 |
+
typename ElementAccumulator_,
|
| 83 |
+
typename ElementCompute_,
|
| 84 |
+
typename ElementZ_,
|
| 85 |
+
typename ElementT_,
|
| 86 |
+
int ElementsPerAccess,
|
| 87 |
+
bool StoreZ = true,
|
| 88 |
+
bool StoreT = true
|
| 89 |
+
>
|
| 90 |
+
struct EpilogueWithBroadcastOpBase {
|
| 91 |
+
|
| 92 |
+
using ElementOutput = ElementC_;
|
| 93 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 94 |
+
using ElementCompute = ElementCompute_;
|
| 95 |
+
using ElementZ = ElementZ_;
|
| 96 |
+
using ElementT = ElementT_;
|
| 97 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 98 |
+
|
| 99 |
+
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
| 100 |
+
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
| 101 |
+
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
| 102 |
+
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
| 103 |
+
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
| 104 |
+
|
| 105 |
+
/// If true, the 'Z' tensor is stored
|
| 106 |
+
static bool const kStoreZ = StoreZ;
|
| 107 |
+
|
| 108 |
+
/// If true, the 'T' tensor is stored
|
| 109 |
+
static bool const kStoreT = StoreT;
|
| 110 |
+
|
| 111 |
+
/// Parameters structure - required
|
| 112 |
+
struct Params { };
|
| 113 |
+
|
| 114 |
+
//
|
| 115 |
+
// Methods
|
| 116 |
+
//
|
| 117 |
+
|
| 118 |
+
/// Constructor from Params
|
| 119 |
+
EpilogueWithBroadcastOpBase(Params const ¶ms_) { }
|
| 120 |
+
|
| 121 |
+
/// Determine if the source is needed. May return false if
|
| 122 |
+
bool is_source_needed() const {
|
| 123 |
+
return true;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
CUTLASS_HOST_DEVICE
|
| 127 |
+
void set_k_partition(int k_partition, int k_partition_count) { }
|
| 128 |
+
|
| 129 |
+
/// Applies the operation when is_source_needed() is true
|
| 130 |
+
CUTLASS_HOST_DEVICE
|
| 131 |
+
void operator()(
|
| 132 |
+
FragmentZ &frag_Z,
|
| 133 |
+
FragmentT &frag_T,
|
| 134 |
+
FragmentAccumulator const &AB,
|
| 135 |
+
FragmentC const &frag_C1,
|
| 136 |
+
FragmentC const &frag_C2,
|
| 137 |
+
FragmentCompute const &V) const {
|
| 138 |
+
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/// Applies the operation when is_source_needed() is false
|
| 142 |
+
CUTLASS_HOST_DEVICE
|
| 143 |
+
void operator()(
|
| 144 |
+
FragmentZ &frag_Z,
|
| 145 |
+
FragmentT &frag_T,
|
| 146 |
+
FragmentAccumulator const &AB,
|
| 147 |
+
FragmentCompute const &V) const {
|
| 148 |
+
|
| 149 |
+
}
|
| 150 |
+
};
|
| 151 |
+
|
| 152 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 153 |
+
|
| 154 |
+
/// Epilogue operator with bias vector broadcast over columns.
|
| 155 |
+
///
|
| 156 |
+
/// Computes the following:
|
| 157 |
+
///
|
| 158 |
+
///
|
| 159 |
+
/// Z, T = OutputOp(AB, C, Broadcast)
|
| 160 |
+
///
|
| 161 |
+
/// if (ElementwiseOp::kStoreZ) {
|
| 162 |
+
/// store(converted_u);
|
| 163 |
+
/// }
|
| 164 |
+
///
|
| 165 |
+
/// if (ElementwiseOp::kStoreT) {
|
| 166 |
+
/// store(v);
|
| 167 |
+
/// }
|
| 168 |
+
///
|
| 169 |
+
template <
|
| 170 |
+
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 171 |
+
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 172 |
+
int PartitionsK, ///< Number of partitions of the K dimension
|
| 173 |
+
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
|
| 174 |
+
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
|
| 175 |
+
typename ElementVector_, ///< Pointer to broadcast vector
|
| 176 |
+
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 177 |
+
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 178 |
+
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
| 179 |
+
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
|
| 180 |
+
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
| 181 |
+
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
| 182 |
+
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
| 183 |
+
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
|
| 184 |
+
bool IsSingleSource = OutputOp_::kIsSingleSource
|
| 185 |
+
>
|
| 186 |
+
class EpilogueWithBroadcast;
|
| 187 |
+
|
| 188 |
+
template <
|
| 189 |
+
typename Shape_,
|
| 190 |
+
typename WarpMmaOperator_,
|
| 191 |
+
int PartitionsK,
|
| 192 |
+
typename OutputTileIterator_,
|
| 193 |
+
typename TensorTileIterator_,
|
| 194 |
+
typename ElementVector_,
|
| 195 |
+
typename AccumulatorFragmentIterator_,
|
| 196 |
+
typename WarpTileIterator_,
|
| 197 |
+
typename SharedLoadIterator_,
|
| 198 |
+
typename OutputOp_,
|
| 199 |
+
typename Padding_,
|
| 200 |
+
int FragmentsPerPartition,
|
| 201 |
+
int IterationsUnroll
|
| 202 |
+
>
|
| 203 |
+
class EpilogueWithBroadcast<
|
| 204 |
+
Shape_,
|
| 205 |
+
WarpMmaOperator_,
|
| 206 |
+
PartitionsK,
|
| 207 |
+
OutputTileIterator_,
|
| 208 |
+
TensorTileIterator_,
|
| 209 |
+
ElementVector_,
|
| 210 |
+
AccumulatorFragmentIterator_,
|
| 211 |
+
WarpTileIterator_,
|
| 212 |
+
SharedLoadIterator_,
|
| 213 |
+
OutputOp_,
|
| 214 |
+
Padding_,
|
| 215 |
+
FragmentsPerPartition,
|
| 216 |
+
IterationsUnroll,
|
| 217 |
+
false
|
| 218 |
+
> :
|
| 219 |
+
public EpilogueBase<
|
| 220 |
+
Shape_,
|
| 221 |
+
typename WarpMmaOperator_::Shape,
|
| 222 |
+
PartitionsK,
|
| 223 |
+
AccumulatorFragmentIterator_,
|
| 224 |
+
WarpTileIterator_,
|
| 225 |
+
Padding_,
|
| 226 |
+
FragmentsPerPartition> {
|
| 227 |
+
|
| 228 |
+
public:
|
| 229 |
+
|
| 230 |
+
using Base = EpilogueBase<
|
| 231 |
+
Shape_,
|
| 232 |
+
typename WarpMmaOperator_::Shape,
|
| 233 |
+
PartitionsK,
|
| 234 |
+
AccumulatorFragmentIterator_,
|
| 235 |
+
WarpTileIterator_,
|
| 236 |
+
Padding_,
|
| 237 |
+
FragmentsPerPartition>;
|
| 238 |
+
|
| 239 |
+
static bool const kIsSingleSource = false;
|
| 240 |
+
using Shape = Shape_;
|
| 241 |
+
using WarpMmaOperator = WarpMmaOperator_;
|
| 242 |
+
static int const kPartitionsK = PartitionsK;
|
| 243 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 244 |
+
using TensorTileIterator = TensorTileIterator_;
|
| 245 |
+
using ElementVector = ElementVector_;
|
| 246 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 247 |
+
using WarpTileIterator = WarpTileIterator_;
|
| 248 |
+
using SharedLoadIterator = SharedLoadIterator_;
|
| 249 |
+
using OutputOp = OutputOp_;
|
| 250 |
+
using Padding = Padding_;
|
| 251 |
+
|
| 252 |
+
using Layout = layout::RowMajor;
|
| 253 |
+
using LongIndex = typename Layout::LongIndex;
|
| 254 |
+
|
| 255 |
+
/// The complete warp-level accumulator tile
|
| 256 |
+
using AccumulatorTile = typename Base::AccumulatorTile;
|
| 257 |
+
|
| 258 |
+
/// Accumulator element
|
| 259 |
+
using ElementAccumulator = typename WarpTileIterator::Element;
|
| 260 |
+
|
| 261 |
+
/// Compute data type produced by the output op
|
| 262 |
+
using ElementCompute = typename OutputOp::ElementCompute;
|
| 263 |
+
|
| 264 |
+
/// Compute fragment
|
| 265 |
+
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
|
| 266 |
+
|
| 267 |
+
/// Thread map used by output tile iterators
|
| 268 |
+
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
| 269 |
+
|
| 270 |
+
/// Fragment object used to store the broadcast values
|
| 271 |
+
using BroadcastFragment = Array<
|
| 272 |
+
ElementCompute,
|
| 273 |
+
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
| 274 |
+
|
| 275 |
+
/// Output element
|
| 276 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 277 |
+
|
| 278 |
+
/// Data type of additional tensor
|
| 279 |
+
using ElementTensor = typename TensorTileIterator::Element;
|
| 280 |
+
|
| 281 |
+
/// Output access size
|
| 282 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 283 |
+
|
| 284 |
+
/// Tensor reference to destination tensor
|
| 285 |
+
using TensorRef = typename OutputTileIterator::TensorRef;
|
| 286 |
+
|
| 287 |
+
/// Tensor reference to sync tensor
|
| 288 |
+
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
| 289 |
+
|
| 290 |
+
/// Const tensor reference to source tensor
|
| 291 |
+
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
| 292 |
+
|
| 293 |
+
/// Array type used to output
|
| 294 |
+
using OutputAccessType = Array<
|
| 295 |
+
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 296 |
+
|
| 297 |
+
/// Array type used by output functor
|
| 298 |
+
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 299 |
+
|
| 300 |
+
/// Array type used by output functor
|
| 301 |
+
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
|
| 302 |
+
|
| 303 |
+
/// Tensor access type
|
| 304 |
+
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
|
| 305 |
+
|
| 306 |
+
/// Number of warps
|
| 307 |
+
using WarpCount = typename Base::WarpCount;
|
| 308 |
+
|
| 309 |
+
/// Shared memory allocation from epilogue base class
|
| 310 |
+
using BaseSharedStorage = typename Base::SharedStorage;
|
| 311 |
+
|
| 312 |
+
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
| 313 |
+
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
| 314 |
+
|
| 315 |
+
/// Used for the broadcast
|
| 316 |
+
struct BroadcastDetail {
|
| 317 |
+
|
| 318 |
+
/// Number of threads per warp
|
| 319 |
+
static int const kWarpSize = 32;
|
| 320 |
+
|
| 321 |
+
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
| 322 |
+
|
| 323 |
+
/// Number of distinct scalar column indices handled by each thread
|
| 324 |
+
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
| 325 |
+
|
| 326 |
+
/// Number of distinct scalar row indices handled by each thread
|
| 327 |
+
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
| 328 |
+
|
| 329 |
+
/// Number of threads per threadblock
|
| 330 |
+
static int const kThreadCount = kWarpSize * WarpCount::kCount;
|
| 331 |
+
|
| 332 |
+
/// Number of distinct threads per row of output tile
|
| 333 |
+
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
|
| 334 |
+
|
| 335 |
+
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
| 336 |
+
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
| 337 |
+
|
| 338 |
+
/// I'm not sure what I meant here.
|
| 339 |
+
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
| 340 |
+
|
| 341 |
+
/// Shape of the shared memory allocation for the epilogue
|
| 342 |
+
using StorageShape = MatrixShape<
|
| 343 |
+
kThreadRows,
|
| 344 |
+
Shape::kN
|
| 345 |
+
>;
|
| 346 |
+
|
| 347 |
+
/// Debug printing
|
| 348 |
+
CUTLASS_DEVICE
|
| 349 |
+
static void print() {
|
| 350 |
+
#if 0
|
| 351 |
+
printf("BroadcastDetail {\n");
|
| 352 |
+
printf(
|
| 353 |
+
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
| 354 |
+
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
|
| 355 |
+
kColumnsPerThread,
|
| 356 |
+
kRowsPerThread,
|
| 357 |
+
kThreadCount,
|
| 358 |
+
kThreadsPerRow,
|
| 359 |
+
kThreadRows,
|
| 360 |
+
kThreadAccessesPerRow,
|
| 361 |
+
StorageShape::kRow,
|
| 362 |
+
StorageShape::kColumn,
|
| 363 |
+
StorageShape::kCount
|
| 364 |
+
);
|
| 365 |
+
printf("};\n");
|
| 366 |
+
#endif
|
| 367 |
+
}
|
| 368 |
+
};
|
| 369 |
+
|
| 370 |
+
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
| 371 |
+
struct SharedStorage {
|
| 372 |
+
union {
|
| 373 |
+
BaseSharedStorage base;
|
| 374 |
+
};
|
| 375 |
+
|
| 376 |
+
CUTLASS_HOST_DEVICE
|
| 377 |
+
SharedStorage() { }
|
| 378 |
+
};
|
| 379 |
+
|
| 380 |
+
public:
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
| 384 |
+
"Mismatch between shared load iterator and output tile iterator.");
|
| 385 |
+
|
| 386 |
+
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
| 387 |
+
|
| 388 |
+
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
| 389 |
+
"Divisibility");
|
| 390 |
+
|
| 391 |
+
private:
|
| 392 |
+
|
| 393 |
+
/// Loads fragment from shared memory aligned with output tensor
|
| 394 |
+
SharedLoadIterator shared_load_iterator_;
|
| 395 |
+
|
| 396 |
+
/// Thread index within the threadblock
|
| 397 |
+
int thread_idx_;
|
| 398 |
+
|
| 399 |
+
public:
|
| 400 |
+
|
| 401 |
+
/// Constructor
|
| 402 |
+
CUTLASS_DEVICE
|
| 403 |
+
EpilogueWithBroadcast(
|
| 404 |
+
SharedStorage &shared_storage, ///< Shared storage object
|
| 405 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 406 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 407 |
+
int lane_idx ///< Id of thread within warp
|
| 408 |
+
):
|
| 409 |
+
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
|
| 410 |
+
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
|
| 411 |
+
thread_idx_(thread_idx)
|
| 412 |
+
{
|
| 413 |
+
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
/// Streams the result to global memory
|
| 417 |
+
CUTLASS_DEVICE
|
| 418 |
+
void operator()(
|
| 419 |
+
OutputOp const &output_op, ///< Output operator
|
| 420 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 421 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 422 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 423 |
+
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
|
| 424 |
+
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
|
| 425 |
+
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
| 426 |
+
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
| 427 |
+
MatrixCoord(Shape::kM, Shape::kN),
|
| 428 |
+
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
| 429 |
+
MatrixCoord()) {
|
| 430 |
+
|
| 431 |
+
BroadcastFragment broadcast_fragment;
|
| 432 |
+
|
| 433 |
+
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
| 434 |
+
|
| 435 |
+
if (!output_op.is_source_needed()) {
|
| 436 |
+
compute_source_not_needed_(
|
| 437 |
+
output_op,
|
| 438 |
+
broadcast_fragment,
|
| 439 |
+
destination_iterator,
|
| 440 |
+
accumulators,
|
| 441 |
+
tensor_iterator);
|
| 442 |
+
}
|
| 443 |
+
else {
|
| 444 |
+
compute_source_needed_(
|
| 445 |
+
output_op,
|
| 446 |
+
broadcast_fragment,
|
| 447 |
+
destination_iterator,
|
| 448 |
+
accumulators,
|
| 449 |
+
source_iterator1,
|
| 450 |
+
source_iterator2,
|
| 451 |
+
tensor_iterator);
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
private:
|
| 456 |
+
|
| 457 |
+
CUTLASS_DEVICE
|
| 458 |
+
void load_broadcast_fragment_(
|
| 459 |
+
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 460 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 461 |
+
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
|
| 462 |
+
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
|
| 463 |
+
) {
|
| 464 |
+
|
| 465 |
+
broadcast_fragment.clear();
|
| 466 |
+
|
| 467 |
+
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
| 468 |
+
if (!broadcast_ptr) {
|
| 469 |
+
return;
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
| 473 |
+
|
| 474 |
+
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
|
| 475 |
+
broadcast_ptr += thread_initial_column;
|
| 476 |
+
|
| 477 |
+
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
| 478 |
+
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
| 479 |
+
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
|
| 480 |
+
|
| 481 |
+
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
|
| 482 |
+
|
| 483 |
+
CUTLASS_PRAGMA_UNROLL
|
| 484 |
+
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
| 485 |
+
|
| 486 |
+
AccessType loaded;
|
| 487 |
+
|
| 488 |
+
loaded.clear();
|
| 489 |
+
|
| 490 |
+
if (thread_column_idx < problem_size.column()) {
|
| 491 |
+
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
ComputeFragmentType cvt = converter(loaded);
|
| 495 |
+
frag_ptr[j] = cvt;
|
| 496 |
+
|
| 497 |
+
thread_column_idx += ThreadMap::Delta::kColumn;
|
| 498 |
+
broadcast_ptr += ThreadMap::Delta::kColumn;
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
template <class Seq>
|
| 503 |
+
struct acc2smem_source_not_needed;
|
| 504 |
+
|
| 505 |
+
template <size_t... Seq>
|
| 506 |
+
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
| 507 |
+
template <int Advance>
|
| 508 |
+
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
| 509 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 510 |
+
CUTLASS_PRAGMA_UNROLL
|
| 511 |
+
for (int i = 0; i < Advance; i++) {
|
| 512 |
+
++accum_fragment_iterator;
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
CUTLASS_PRAGMA_UNROLL
|
| 516 |
+
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
| 517 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 518 |
+
|
| 519 |
+
accum_fragment_iterator.load(accum_fragment);
|
| 520 |
+
++accum_fragment_iterator;
|
| 521 |
+
|
| 522 |
+
warp_tile_iterator.store(accum_fragment);
|
| 523 |
+
if (p < Base::kFragmentsPerIteration - 1) {
|
| 524 |
+
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
| 525 |
+
}
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
if (Base::kFragmentsPerIteration > 1) {
|
| 529 |
+
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
| 530 |
+
(1 - Base::kFragmentsPerIteration));
|
| 531 |
+
}
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
CUTLASS_DEVICE
|
| 535 |
+
static void push(size_t pos,
|
| 536 |
+
AccumulatorFragmentIterator const &iterator_begin,
|
| 537 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 538 |
+
int dummy[] = {
|
| 539 |
+
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
| 540 |
+
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
| 541 |
+
|
| 542 |
+
CUTLASS_UNUSED(dummy[0]);
|
| 543 |
+
}
|
| 544 |
+
};
|
| 545 |
+
|
| 546 |
+
/// Streams the result to global memory
|
| 547 |
+
CUTLASS_DEVICE
|
| 548 |
+
void compute_source_not_needed_(
|
| 549 |
+
OutputOp const &output_op, ///< Output operator
|
| 550 |
+
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 551 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 552 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 553 |
+
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
| 554 |
+
) {
|
| 555 |
+
|
| 556 |
+
//
|
| 557 |
+
// Iterator over warp-level accumulator fragment
|
| 558 |
+
//
|
| 559 |
+
|
| 560 |
+
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 561 |
+
|
| 562 |
+
//
|
| 563 |
+
// Iterate over accumulator tile
|
| 564 |
+
//
|
| 565 |
+
|
| 566 |
+
// CUTLASS_PRAGMA_UNROLL
|
| 567 |
+
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
| 568 |
+
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
| 569 |
+
|
| 570 |
+
//
|
| 571 |
+
// Convert and store fragment
|
| 572 |
+
//
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
__syncthreads();
|
| 576 |
+
|
| 577 |
+
acc2smem_source_not_needed<
|
| 578 |
+
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
| 579 |
+
Base::kFragmentsPerIteration>>::push(iter,
|
| 580 |
+
accum_fragment_iterator,
|
| 581 |
+
this->warp_tile_iterator_);
|
| 582 |
+
|
| 583 |
+
__syncthreads();
|
| 584 |
+
|
| 585 |
+
//
|
| 586 |
+
// Load fragments from shared memory
|
| 587 |
+
//
|
| 588 |
+
|
| 589 |
+
CUTLASS_PRAGMA_UNROLL
|
| 590 |
+
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 594 |
+
|
| 595 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 596 |
+
|
| 597 |
+
if (p < Base::kFragmentsPerIteration - 1) {
|
| 598 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 599 |
+
}
|
| 600 |
+
else if (kPartitionsK > 1) {
|
| 601 |
+
|
| 602 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 603 |
+
|
| 604 |
+
CUTLASS_PRAGMA_UNROLL
|
| 605 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 606 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 607 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 608 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
//
|
| 615 |
+
// Apply output operation
|
| 616 |
+
//
|
| 617 |
+
|
| 618 |
+
typename OutputTileIterator::Fragment frag_Z;
|
| 619 |
+
typename TensorTileIterator::Fragment frag_T;
|
| 620 |
+
|
| 621 |
+
apply_output_operator_source_not_needed_(
|
| 622 |
+
frag_Z,
|
| 623 |
+
frag_T,
|
| 624 |
+
output_op,
|
| 625 |
+
aligned_accum_fragment[0],
|
| 626 |
+
broadcast_fragment);
|
| 627 |
+
|
| 628 |
+
//
|
| 629 |
+
// Conditionally store fragments
|
| 630 |
+
//
|
| 631 |
+
|
| 632 |
+
if (OutputOp::kStoreZ) {
|
| 633 |
+
destination_iterator.store(frag_Z);
|
| 634 |
+
++destination_iterator;
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
if (OutputOp::kStoreT) {
|
| 638 |
+
tensor_iterator.store(frag_T);
|
| 639 |
+
++tensor_iterator;
|
| 640 |
+
}
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
if (Base::kFragmentsPerIteration > 1) {
|
| 644 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
template<class Seq>
|
| 651 |
+
struct acc2smem_source_needed;
|
| 652 |
+
|
| 653 |
+
template <size_t... Seq>
|
| 654 |
+
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
| 655 |
+
template<int Advance>
|
| 656 |
+
CUTLASS_DEVICE
|
| 657 |
+
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
| 658 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 659 |
+
CUTLASS_PRAGMA_UNROLL
|
| 660 |
+
for (int i = 0; i < Advance; i++) {
|
| 661 |
+
++accum_fragment_iterator;
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 665 |
+
accum_fragment_iterator.load(accum_fragment);
|
| 666 |
+
warp_tile_iterator.store(accum_fragment);
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
CUTLASS_DEVICE
|
| 670 |
+
static void push(size_t pos,
|
| 671 |
+
AccumulatorFragmentIterator const &iterator_begin,
|
| 672 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 673 |
+
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
| 674 |
+
}
|
| 675 |
+
};
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
/// Streams the result to global memory
|
| 679 |
+
CUTLASS_DEVICE
|
| 680 |
+
void compute_source_needed_(
|
| 681 |
+
OutputOp const &output_op, ///< Output operator
|
| 682 |
+
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 683 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 684 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 685 |
+
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
|
| 686 |
+
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
|
| 687 |
+
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
| 688 |
+
) {
|
| 689 |
+
|
| 690 |
+
typename OutputTileIterator::Fragment source_fragment1;
|
| 691 |
+
source_fragment1.clear();
|
| 692 |
+
typename OutputTileIterator::Fragment source_fragment2;
|
| 693 |
+
source_fragment2.clear();
|
| 694 |
+
|
| 695 |
+
//
|
| 696 |
+
// Iterator over warp-level accumulator fragment
|
| 697 |
+
//
|
| 698 |
+
|
| 699 |
+
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 700 |
+
|
| 701 |
+
//
|
| 702 |
+
// Iterate over accumulator tile
|
| 703 |
+
//
|
| 704 |
+
|
| 705 |
+
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
| 706 |
+
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
| 707 |
+
|
| 708 |
+
//
|
| 709 |
+
// Load the source
|
| 710 |
+
//
|
| 711 |
+
|
| 712 |
+
source_iterator1.load(source_fragment1);
|
| 713 |
+
++source_iterator1;
|
| 714 |
+
|
| 715 |
+
source_iterator2.load(source_fragment2);
|
| 716 |
+
++source_iterator2;
|
| 717 |
+
|
| 718 |
+
//
|
| 719 |
+
// Convert and store fragment
|
| 720 |
+
//
|
| 721 |
+
|
| 722 |
+
__syncthreads();
|
| 723 |
+
|
| 724 |
+
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
| 725 |
+
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
| 726 |
+
|
| 727 |
+
__syncthreads();
|
| 728 |
+
|
| 729 |
+
//
|
| 730 |
+
// Load fragments from shared memory
|
| 731 |
+
//
|
| 732 |
+
|
| 733 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 734 |
+
|
| 735 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 736 |
+
|
| 737 |
+
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
| 738 |
+
if (kPartitionsK > 1)
|
| 739 |
+
{
|
| 740 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 741 |
+
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
|
| 742 |
+
|
| 743 |
+
CUTLASS_PRAGMA_UNROLL
|
| 744 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 745 |
+
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
|
| 746 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 747 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 748 |
+
}
|
| 749 |
+
|
| 750 |
+
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
//
|
| 754 |
+
// Apply output operation
|
| 755 |
+
//
|
| 756 |
+
|
| 757 |
+
typename OutputTileIterator::Fragment frag_Z;
|
| 758 |
+
typename TensorTileIterator::Fragment frag_T;
|
| 759 |
+
|
| 760 |
+
apply_output_operator_(
|
| 761 |
+
frag_Z,
|
| 762 |
+
frag_T,
|
| 763 |
+
output_op,
|
| 764 |
+
aligned_accum_fragment[0],
|
| 765 |
+
source_fragment1,
|
| 766 |
+
source_fragment2,
|
| 767 |
+
broadcast_fragment);
|
| 768 |
+
|
| 769 |
+
//
|
| 770 |
+
// Conditionally store fragments
|
| 771 |
+
//
|
| 772 |
+
|
| 773 |
+
if (OutputOp::kStoreZ) {
|
| 774 |
+
destination_iterator.store(frag_Z);
|
| 775 |
+
++destination_iterator;
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
if (OutputOp::kStoreT) {
|
| 779 |
+
tensor_iterator.store(frag_T);
|
| 780 |
+
++tensor_iterator;
|
| 781 |
+
}
|
| 782 |
+
}
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
/// Helper to invoke the output functor over each vector of output
|
| 786 |
+
CUTLASS_DEVICE
|
| 787 |
+
void apply_output_operator_(
|
| 788 |
+
typename OutputTileIterator::Fragment &frag_Z,
|
| 789 |
+
typename TensorTileIterator::Fragment &frag_T,
|
| 790 |
+
OutputOp const &output_op,
|
| 791 |
+
typename SharedLoadIterator::Fragment const &frag_AB,
|
| 792 |
+
typename OutputTileIterator::Fragment const &frag_C1,
|
| 793 |
+
typename OutputTileIterator::Fragment const &frag_C2,
|
| 794 |
+
BroadcastFragment const &frag_Broadcast) {
|
| 795 |
+
|
| 796 |
+
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
| 797 |
+
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
| 798 |
+
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
| 799 |
+
|
| 800 |
+
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
| 801 |
+
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
| 802 |
+
|
| 803 |
+
AccumulatorAccessType const *frag_AB_ptr =
|
| 804 |
+
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
| 805 |
+
|
| 806 |
+
OutputAccessType const *frag_C1_ptr =
|
| 807 |
+
reinterpret_cast<OutputAccessType const *>(&frag_C1);
|
| 808 |
+
|
| 809 |
+
OutputAccessType const *frag_C2_ptr =
|
| 810 |
+
reinterpret_cast<OutputAccessType const *>(&frag_C2);
|
| 811 |
+
|
| 812 |
+
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
| 813 |
+
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
| 814 |
+
|
| 815 |
+
int const kOutputOpIterations =
|
| 816 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 817 |
+
|
| 818 |
+
CUTLASS_PRAGMA_UNROLL
|
| 819 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 820 |
+
output_op(
|
| 821 |
+
frag_Z_ptr[i],
|
| 822 |
+
frag_T_ptr[i],
|
| 823 |
+
frag_AB_ptr[i],
|
| 824 |
+
frag_C1_ptr[i],
|
| 825 |
+
frag_C2_ptr[i],
|
| 826 |
+
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
| 827 |
+
}
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
/// Helper to invoke the output functor over each vector of output
|
| 831 |
+
CUTLASS_DEVICE
|
| 832 |
+
void apply_output_operator_source_not_needed_(
|
| 833 |
+
typename OutputTileIterator::Fragment &frag_Z,
|
| 834 |
+
typename TensorTileIterator::Fragment &frag_T,
|
| 835 |
+
OutputOp const &output_op,
|
| 836 |
+
typename SharedLoadIterator::Fragment const &frag_AB,
|
| 837 |
+
BroadcastFragment const &frag_Broadcast) {
|
| 838 |
+
|
| 839 |
+
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
| 840 |
+
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
| 841 |
+
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
| 842 |
+
|
| 843 |
+
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
| 844 |
+
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
| 845 |
+
|
| 846 |
+
AccumulatorAccessType const *frag_AB_ptr =
|
| 847 |
+
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
| 848 |
+
|
| 849 |
+
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
| 850 |
+
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
| 851 |
+
|
| 852 |
+
int const kOutputOpIterations =
|
| 853 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 854 |
+
|
| 855 |
+
CUTLASS_PRAGMA_UNROLL
|
| 856 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 857 |
+
|
| 858 |
+
output_op(
|
| 859 |
+
frag_Z_ptr[i],
|
| 860 |
+
frag_T_ptr[i],
|
| 861 |
+
frag_AB_ptr[i],
|
| 862 |
+
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
| 863 |
+
}
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
public:
|
| 867 |
+
/// Stream-K reduce helper
|
| 868 |
+
CUTLASS_DEVICE
|
| 869 |
+
void reduce(
|
| 870 |
+
int reduce_fragment_idx, ///< Reduce fragment index
|
| 871 |
+
OutputOp const &output_op, ///< Output operator
|
| 872 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 873 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 874 |
+
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
|
| 875 |
+
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
|
| 876 |
+
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
| 877 |
+
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
| 878 |
+
MatrixCoord(Shape::kM, Shape::kN),
|
| 879 |
+
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
| 880 |
+
MatrixCoord())
|
| 881 |
+
{
|
| 882 |
+
|
| 883 |
+
BroadcastFragment broadcast_fragment;
|
| 884 |
+
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
| 885 |
+
|
| 886 |
+
// Initialize/load source-fragment data
|
| 887 |
+
typename OutputTileIterator::Fragment source_fragment1;
|
| 888 |
+
source_fragment1.clear();
|
| 889 |
+
typename OutputTileIterator::Fragment source_fragment2;
|
| 890 |
+
source_fragment2.clear();
|
| 891 |
+
|
| 892 |
+
if (output_op.is_source_needed())
|
| 893 |
+
{
|
| 894 |
+
source_iterator1 += reduce_fragment_idx;
|
| 895 |
+
source_iterator1.load(source_fragment1);
|
| 896 |
+
|
| 897 |
+
source_iterator2 += reduce_fragment_idx;
|
| 898 |
+
source_iterator2.load(source_fragment2);
|
| 899 |
+
}
|
| 900 |
+
|
| 901 |
+
// Load fragment from shared memory
|
| 902 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 903 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 904 |
+
|
| 905 |
+
// Add fragments shared by other k partitions
|
| 906 |
+
if (kPartitionsK > 1)
|
| 907 |
+
{
|
| 908 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 909 |
+
|
| 910 |
+
CUTLASS_PRAGMA_UNROLL
|
| 911 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 912 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 913 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 914 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 915 |
+
}
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
//
|
| 919 |
+
// Apply output operation
|
| 920 |
+
//
|
| 921 |
+
|
| 922 |
+
typename OutputTileIterator::Fragment frag_Z;
|
| 923 |
+
typename TensorTileIterator::Fragment frag_T;
|
| 924 |
+
|
| 925 |
+
if (!output_op.is_source_needed()) {
|
| 926 |
+
apply_output_operator_source_not_needed_(
|
| 927 |
+
frag_Z,
|
| 928 |
+
frag_T,
|
| 929 |
+
output_op,
|
| 930 |
+
aligned_accum_fragment[0],
|
| 931 |
+
broadcast_fragment);
|
| 932 |
+
} else {
|
| 933 |
+
apply_output_operator_(
|
| 934 |
+
frag_Z,
|
| 935 |
+
frag_T,
|
| 936 |
+
output_op,
|
| 937 |
+
aligned_accum_fragment[0],
|
| 938 |
+
source_fragment1,
|
| 939 |
+
source_fragment2,
|
| 940 |
+
broadcast_fragment);
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
//
|
| 944 |
+
// Conditionally store fragments
|
| 945 |
+
//
|
| 946 |
+
|
| 947 |
+
if (OutputOp::kStoreZ) {
|
| 948 |
+
destination_iterator += reduce_fragment_idx;
|
| 949 |
+
destination_iterator.store(frag_Z);
|
| 950 |
+
}
|
| 951 |
+
|
| 952 |
+
if (OutputOp::kStoreT) {
|
| 953 |
+
tensor_iterator += reduce_fragment_idx;
|
| 954 |
+
tensor_iterator.store(frag_T);
|
| 955 |
+
}
|
| 956 |
+
}
|
| 957 |
+
};
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
template <
|
| 961 |
+
typename Shape_,
|
| 962 |
+
typename WarpMmaOperator_,
|
| 963 |
+
int PartitionsK,
|
| 964 |
+
typename OutputTileIterator_,
|
| 965 |
+
typename TensorTileIterator_,
|
| 966 |
+
typename ElementVector_,
|
| 967 |
+
typename AccumulatorFragmentIterator_,
|
| 968 |
+
typename WarpTileIterator_,
|
| 969 |
+
typename SharedLoadIterator_,
|
| 970 |
+
typename OutputOp_,
|
| 971 |
+
typename Padding_,
|
| 972 |
+
int FragmentsPerPartition,
|
| 973 |
+
int IterationsUnroll
|
| 974 |
+
>
|
| 975 |
+
class EpilogueWithBroadcast<
|
| 976 |
+
Shape_,
|
| 977 |
+
WarpMmaOperator_,
|
| 978 |
+
PartitionsK,
|
| 979 |
+
OutputTileIterator_,
|
| 980 |
+
TensorTileIterator_,
|
| 981 |
+
ElementVector_,
|
| 982 |
+
AccumulatorFragmentIterator_,
|
| 983 |
+
WarpTileIterator_,
|
| 984 |
+
SharedLoadIterator_,
|
| 985 |
+
OutputOp_,
|
| 986 |
+
Padding_,
|
| 987 |
+
FragmentsPerPartition,
|
| 988 |
+
IterationsUnroll,
|
| 989 |
+
true
|
| 990 |
+
> :
|
| 991 |
+
public EpilogueBase<
|
| 992 |
+
Shape_,
|
| 993 |
+
typename WarpMmaOperator_::Shape,
|
| 994 |
+
PartitionsK,
|
| 995 |
+
AccumulatorFragmentIterator_,
|
| 996 |
+
WarpTileIterator_,
|
| 997 |
+
Padding_,
|
| 998 |
+
FragmentsPerPartition> {
|
| 999 |
+
|
| 1000 |
+
public:
|
| 1001 |
+
|
| 1002 |
+
using Base = EpilogueBase<
|
| 1003 |
+
Shape_,
|
| 1004 |
+
typename WarpMmaOperator_::Shape,
|
| 1005 |
+
PartitionsK,
|
| 1006 |
+
AccumulatorFragmentIterator_,
|
| 1007 |
+
WarpTileIterator_,
|
| 1008 |
+
Padding_,
|
| 1009 |
+
FragmentsPerPartition>;
|
| 1010 |
+
|
| 1011 |
+
static bool const kIsSingleSource = true;
|
| 1012 |
+
using Shape = Shape_;
|
| 1013 |
+
using WarpMmaOperator = WarpMmaOperator_;
|
| 1014 |
+
static int const kPartitionsK = PartitionsK;
|
| 1015 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 1016 |
+
using TensorTileIterator = TensorTileIterator_;
|
| 1017 |
+
using ElementVector = ElementVector_;
|
| 1018 |
+
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 1019 |
+
using WarpTileIterator = WarpTileIterator_;
|
| 1020 |
+
using SharedLoadIterator = SharedLoadIterator_;
|
| 1021 |
+
using OutputOp = OutputOp_;
|
| 1022 |
+
using Padding = Padding_;
|
| 1023 |
+
|
| 1024 |
+
using Layout = layout::RowMajor;
|
| 1025 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1026 |
+
|
| 1027 |
+
/// The complete warp-level accumulator tile
|
| 1028 |
+
using AccumulatorTile = typename Base::AccumulatorTile;
|
| 1029 |
+
|
| 1030 |
+
/// Accumulator element
|
| 1031 |
+
using ElementAccumulator = typename WarpTileIterator::Element;
|
| 1032 |
+
|
| 1033 |
+
/// Compute data type produced by the output op
|
| 1034 |
+
using ElementCompute = typename OutputOp::ElementCompute;
|
| 1035 |
+
|
| 1036 |
+
/// Compute fragment
|
| 1037 |
+
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
|
| 1038 |
+
|
| 1039 |
+
/// Thread map used by output tile iterators
|
| 1040 |
+
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
| 1041 |
+
|
| 1042 |
+
/// Fragment object used to store the broadcast values
|
| 1043 |
+
using BroadcastFragment = Array<
|
| 1044 |
+
ElementCompute,
|
| 1045 |
+
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
| 1046 |
+
|
| 1047 |
+
/// Output element
|
| 1048 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 1049 |
+
|
| 1050 |
+
/// Data type of additional tensor
|
| 1051 |
+
using ElementTensor = typename TensorTileIterator::Element;
|
| 1052 |
+
|
| 1053 |
+
/// Output access size
|
| 1054 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 1055 |
+
|
| 1056 |
+
/// Tensor reference to destination tensor
|
| 1057 |
+
using TensorRef = typename OutputTileIterator::TensorRef;
|
| 1058 |
+
|
| 1059 |
+
/// Tensor reference to sync tensor
|
| 1060 |
+
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
| 1061 |
+
|
| 1062 |
+
/// Const tensor reference to source tensor
|
| 1063 |
+
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
| 1064 |
+
|
| 1065 |
+
/// Array type used to output
|
| 1066 |
+
using OutputAccessType = Array<
|
| 1067 |
+
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 1068 |
+
|
| 1069 |
+
/// Array type used by output functor
|
| 1070 |
+
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 1071 |
+
|
| 1072 |
+
/// Array type used by output functor
|
| 1073 |
+
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
|
| 1074 |
+
|
| 1075 |
+
/// Tensor access type
|
| 1076 |
+
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
|
| 1077 |
+
|
| 1078 |
+
/// Number of warps
|
| 1079 |
+
using WarpCount = typename Base::WarpCount;
|
| 1080 |
+
|
| 1081 |
+
/// Shared memory allocation from epilogue base class
|
| 1082 |
+
using BaseSharedStorage = typename Base::SharedStorage;
|
| 1083 |
+
|
| 1084 |
+
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
| 1085 |
+
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
| 1086 |
+
|
| 1087 |
+
/// Used for the broadcast
|
| 1088 |
+
struct BroadcastDetail {
|
| 1089 |
+
|
| 1090 |
+
/// Number of threads per warp
|
| 1091 |
+
static int const kWarpSize = 32;
|
| 1092 |
+
|
| 1093 |
+
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
| 1094 |
+
|
| 1095 |
+
/// Number of distinct scalar column indices handled by each thread
|
| 1096 |
+
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
| 1097 |
+
|
| 1098 |
+
/// Number of distinct scalar row indices handled by each thread
|
| 1099 |
+
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
| 1100 |
+
|
| 1101 |
+
/// Number of threads per threadblock
|
| 1102 |
+
static int const kThreadCount = kWarpSize * WarpCount::kCount;
|
| 1103 |
+
|
| 1104 |
+
/// Number of distinct threads per row of output tile
|
| 1105 |
+
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
|
| 1106 |
+
|
| 1107 |
+
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
| 1108 |
+
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
| 1109 |
+
|
| 1110 |
+
/// I'm not sure what I meant here.
|
| 1111 |
+
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
| 1112 |
+
|
| 1113 |
+
/// Shape of the shared memory allocation for the epilogue
|
| 1114 |
+
using StorageShape = MatrixShape<
|
| 1115 |
+
kThreadRows,
|
| 1116 |
+
Shape::kN
|
| 1117 |
+
>;
|
| 1118 |
+
|
| 1119 |
+
/// Debug printing
|
| 1120 |
+
CUTLASS_DEVICE
|
| 1121 |
+
static void print() {
|
| 1122 |
+
#if 0
|
| 1123 |
+
printf("BroadcastDetail {\n");
|
| 1124 |
+
printf(
|
| 1125 |
+
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
| 1126 |
+
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
|
| 1127 |
+
kColumnsPerThread,
|
| 1128 |
+
kRowsPerThread,
|
| 1129 |
+
kThreadCount,
|
| 1130 |
+
kThreadsPerRow,
|
| 1131 |
+
kThreadRows,
|
| 1132 |
+
kThreadAccessesPerRow,
|
| 1133 |
+
StorageShape::kRow,
|
| 1134 |
+
StorageShape::kColumn,
|
| 1135 |
+
StorageShape::kCount
|
| 1136 |
+
);
|
| 1137 |
+
printf("};\n");
|
| 1138 |
+
#endif
|
| 1139 |
+
}
|
| 1140 |
+
};
|
| 1141 |
+
|
| 1142 |
+
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
| 1143 |
+
struct SharedStorage {
|
| 1144 |
+
union {
|
| 1145 |
+
BaseSharedStorage base;
|
| 1146 |
+
};
|
| 1147 |
+
|
| 1148 |
+
CUTLASS_HOST_DEVICE
|
| 1149 |
+
SharedStorage() { }
|
| 1150 |
+
};
|
| 1151 |
+
|
| 1152 |
+
public:
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
| 1156 |
+
"Mismatch between shared load iterator and output tile iterator.");
|
| 1157 |
+
|
| 1158 |
+
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
| 1159 |
+
|
| 1160 |
+
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
| 1161 |
+
"Divisibility");
|
| 1162 |
+
|
| 1163 |
+
private:
|
| 1164 |
+
|
| 1165 |
+
/// Loads fragment from shared memory aligned with output tensor
|
| 1166 |
+
SharedLoadIterator shared_load_iterator_;
|
| 1167 |
+
|
| 1168 |
+
/// Thread index within the threadblock
|
| 1169 |
+
int thread_idx_;
|
| 1170 |
+
|
| 1171 |
+
public:
|
| 1172 |
+
|
| 1173 |
+
/// Constructor
|
| 1174 |
+
CUTLASS_DEVICE
|
| 1175 |
+
EpilogueWithBroadcast(
|
| 1176 |
+
SharedStorage &shared_storage, ///< Shared storage object
|
| 1177 |
+
int thread_idx, ///< ID of a thread within the threadblock
|
| 1178 |
+
int warp_idx, ///< ID of warp within threadblock
|
| 1179 |
+
int lane_idx ///< Id of thread within warp
|
| 1180 |
+
):
|
| 1181 |
+
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
|
| 1182 |
+
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
|
| 1183 |
+
thread_idx_(thread_idx)
|
| 1184 |
+
{
|
| 1185 |
+
|
| 1186 |
+
}
|
| 1187 |
+
|
| 1188 |
+
/// Streams the result to global memory
|
| 1189 |
+
CUTLASS_DEVICE
|
| 1190 |
+
void operator()(
|
| 1191 |
+
OutputOp const &output_op, ///< Output operator
|
| 1192 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 1193 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 1194 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 1195 |
+
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
| 1196 |
+
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
| 1197 |
+
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
| 1198 |
+
MatrixCoord(Shape::kM, Shape::kN),
|
| 1199 |
+
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
| 1200 |
+
MatrixCoord()) {
|
| 1201 |
+
|
| 1202 |
+
BroadcastFragment broadcast_fragment;
|
| 1203 |
+
|
| 1204 |
+
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
| 1205 |
+
|
| 1206 |
+
if (!output_op.is_source_needed()) {
|
| 1207 |
+
compute_source_not_needed_(
|
| 1208 |
+
output_op,
|
| 1209 |
+
broadcast_fragment,
|
| 1210 |
+
destination_iterator,
|
| 1211 |
+
accumulators,
|
| 1212 |
+
tensor_iterator);
|
| 1213 |
+
}
|
| 1214 |
+
else {
|
| 1215 |
+
compute_source_needed_(
|
| 1216 |
+
output_op,
|
| 1217 |
+
broadcast_fragment,
|
| 1218 |
+
destination_iterator,
|
| 1219 |
+
accumulators,
|
| 1220 |
+
source_iterator,
|
| 1221 |
+
tensor_iterator);
|
| 1222 |
+
}
|
| 1223 |
+
}
|
| 1224 |
+
|
| 1225 |
+
private:
|
| 1226 |
+
|
| 1227 |
+
CUTLASS_DEVICE
|
| 1228 |
+
void load_broadcast_fragment_(
|
| 1229 |
+
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 1230 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 1231 |
+
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
|
| 1232 |
+
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
|
| 1233 |
+
) {
|
| 1234 |
+
|
| 1235 |
+
broadcast_fragment.clear();
|
| 1236 |
+
|
| 1237 |
+
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
| 1238 |
+
if (!broadcast_ptr) {
|
| 1239 |
+
return;
|
| 1240 |
+
}
|
| 1241 |
+
|
| 1242 |
+
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
| 1243 |
+
|
| 1244 |
+
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
|
| 1245 |
+
broadcast_ptr += thread_initial_column;
|
| 1246 |
+
|
| 1247 |
+
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
| 1248 |
+
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
| 1249 |
+
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
|
| 1250 |
+
|
| 1251 |
+
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
|
| 1252 |
+
|
| 1253 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1254 |
+
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
| 1255 |
+
|
| 1256 |
+
AccessType loaded;
|
| 1257 |
+
|
| 1258 |
+
loaded.clear();
|
| 1259 |
+
|
| 1260 |
+
if (thread_column_idx < problem_size.column()) {
|
| 1261 |
+
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
ComputeFragmentType cvt = converter(loaded);
|
| 1265 |
+
frag_ptr[j] = cvt;
|
| 1266 |
+
|
| 1267 |
+
thread_column_idx += ThreadMap::Delta::kColumn;
|
| 1268 |
+
broadcast_ptr += ThreadMap::Delta::kColumn;
|
| 1269 |
+
}
|
| 1270 |
+
}
|
| 1271 |
+
|
| 1272 |
+
template <class Seq>
|
| 1273 |
+
struct acc2smem_source_not_needed;
|
| 1274 |
+
|
| 1275 |
+
template <size_t... Seq>
|
| 1276 |
+
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
| 1277 |
+
template <int Advance>
|
| 1278 |
+
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
| 1279 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 1280 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1281 |
+
for (int i = 0; i < Advance; i++) {
|
| 1282 |
+
++accum_fragment_iterator;
|
| 1283 |
+
}
|
| 1284 |
+
|
| 1285 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1286 |
+
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
| 1287 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 1288 |
+
|
| 1289 |
+
accum_fragment_iterator.load(accum_fragment);
|
| 1290 |
+
++accum_fragment_iterator;
|
| 1291 |
+
|
| 1292 |
+
warp_tile_iterator.store(accum_fragment);
|
| 1293 |
+
if (p < Base::kFragmentsPerIteration - 1) {
|
| 1294 |
+
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
| 1295 |
+
}
|
| 1296 |
+
}
|
| 1297 |
+
|
| 1298 |
+
if (Base::kFragmentsPerIteration > 1) {
|
| 1299 |
+
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
| 1300 |
+
(1 - Base::kFragmentsPerIteration));
|
| 1301 |
+
}
|
| 1302 |
+
}
|
| 1303 |
+
|
| 1304 |
+
CUTLASS_DEVICE
|
| 1305 |
+
static void push(size_t pos,
|
| 1306 |
+
AccumulatorFragmentIterator const &iterator_begin,
|
| 1307 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 1308 |
+
int dummy[] = {
|
| 1309 |
+
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
| 1310 |
+
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
| 1311 |
+
|
| 1312 |
+
CUTLASS_UNUSED(dummy[0]);
|
| 1313 |
+
}
|
| 1314 |
+
};
|
| 1315 |
+
|
| 1316 |
+
/// Streams the result to global memory
|
| 1317 |
+
CUTLASS_DEVICE
|
| 1318 |
+
void compute_source_not_needed_(
|
| 1319 |
+
OutputOp const &output_op, ///< Output operator
|
| 1320 |
+
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 1321 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 1322 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 1323 |
+
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
| 1324 |
+
) {
|
| 1325 |
+
|
| 1326 |
+
//
|
| 1327 |
+
// Iterator over warp-level accumulator fragment
|
| 1328 |
+
//
|
| 1329 |
+
|
| 1330 |
+
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 1331 |
+
|
| 1332 |
+
//
|
| 1333 |
+
// Iterate over accumulator tile
|
| 1334 |
+
//
|
| 1335 |
+
|
| 1336 |
+
// CUTLASS_PRAGMA_UNROLL
|
| 1337 |
+
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
| 1338 |
+
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
| 1339 |
+
|
| 1340 |
+
//
|
| 1341 |
+
// Convert and store fragment
|
| 1342 |
+
//
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
__syncthreads();
|
| 1346 |
+
|
| 1347 |
+
acc2smem_source_not_needed<
|
| 1348 |
+
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
| 1349 |
+
Base::kFragmentsPerIteration>>::push(iter,
|
| 1350 |
+
accum_fragment_iterator,
|
| 1351 |
+
this->warp_tile_iterator_);
|
| 1352 |
+
|
| 1353 |
+
__syncthreads();
|
| 1354 |
+
|
| 1355 |
+
//
|
| 1356 |
+
// Load fragments from shared memory
|
| 1357 |
+
//
|
| 1358 |
+
|
| 1359 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1360 |
+
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 1364 |
+
|
| 1365 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 1366 |
+
|
| 1367 |
+
if (p < Base::kFragmentsPerIteration - 1) {
|
| 1368 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 1369 |
+
}
|
| 1370 |
+
else if (kPartitionsK > 1) {
|
| 1371 |
+
|
| 1372 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 1373 |
+
|
| 1374 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1375 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 1376 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 1377 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 1378 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 1379 |
+
}
|
| 1380 |
+
|
| 1381 |
+
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
| 1382 |
+
}
|
| 1383 |
+
|
| 1384 |
+
//
|
| 1385 |
+
// Apply output operation
|
| 1386 |
+
//
|
| 1387 |
+
|
| 1388 |
+
typename OutputTileIterator::Fragment frag_Z;
|
| 1389 |
+
typename TensorTileIterator::Fragment frag_T;
|
| 1390 |
+
|
| 1391 |
+
apply_output_operator_source_not_needed_(
|
| 1392 |
+
frag_Z,
|
| 1393 |
+
frag_T,
|
| 1394 |
+
output_op,
|
| 1395 |
+
aligned_accum_fragment[0],
|
| 1396 |
+
broadcast_fragment);
|
| 1397 |
+
|
| 1398 |
+
//
|
| 1399 |
+
// Conditionally store fragments
|
| 1400 |
+
//
|
| 1401 |
+
|
| 1402 |
+
if (OutputOp::kStoreZ) {
|
| 1403 |
+
destination_iterator.store(frag_Z);
|
| 1404 |
+
++destination_iterator;
|
| 1405 |
+
}
|
| 1406 |
+
|
| 1407 |
+
if (OutputOp::kStoreT) {
|
| 1408 |
+
tensor_iterator.store(frag_T);
|
| 1409 |
+
++tensor_iterator;
|
| 1410 |
+
}
|
| 1411 |
+
}
|
| 1412 |
+
|
| 1413 |
+
if (Base::kFragmentsPerIteration > 1) {
|
| 1414 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
| 1415 |
+
}
|
| 1416 |
+
}
|
| 1417 |
+
}
|
| 1418 |
+
|
| 1419 |
+
|
| 1420 |
+
template<class Seq>
|
| 1421 |
+
struct acc2smem_source_needed;
|
| 1422 |
+
|
| 1423 |
+
template <size_t... Seq>
|
| 1424 |
+
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
| 1425 |
+
template<int Advance>
|
| 1426 |
+
CUTLASS_DEVICE
|
| 1427 |
+
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
| 1428 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 1429 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1430 |
+
for (int i = 0; i < Advance; i++) {
|
| 1431 |
+
++accum_fragment_iterator;
|
| 1432 |
+
}
|
| 1433 |
+
|
| 1434 |
+
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 1435 |
+
accum_fragment_iterator.load(accum_fragment);
|
| 1436 |
+
warp_tile_iterator.store(accum_fragment);
|
| 1437 |
+
}
|
| 1438 |
+
|
| 1439 |
+
CUTLASS_DEVICE
|
| 1440 |
+
static void push(size_t pos,
|
| 1441 |
+
AccumulatorFragmentIterator const &iterator_begin,
|
| 1442 |
+
WarpTileIterator &warp_tile_iterator) {
|
| 1443 |
+
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
| 1444 |
+
}
|
| 1445 |
+
};
|
| 1446 |
+
|
| 1447 |
+
|
| 1448 |
+
/// Streams the result to global memory
|
| 1449 |
+
CUTLASS_DEVICE
|
| 1450 |
+
void compute_source_needed_(
|
| 1451 |
+
OutputOp const &output_op, ///< Output operator
|
| 1452 |
+
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
| 1453 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 1454 |
+
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
| 1455 |
+
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
| 1456 |
+
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
| 1457 |
+
) {
|
| 1458 |
+
|
| 1459 |
+
typename OutputTileIterator::Fragment source_fragment;
|
| 1460 |
+
source_fragment.clear();
|
| 1461 |
+
|
| 1462 |
+
//
|
| 1463 |
+
// Iterator over warp-level accumulator fragment
|
| 1464 |
+
//
|
| 1465 |
+
|
| 1466 |
+
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 1467 |
+
|
| 1468 |
+
//
|
| 1469 |
+
// Iterate over accumulator tile
|
| 1470 |
+
//
|
| 1471 |
+
|
| 1472 |
+
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
| 1473 |
+
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
| 1474 |
+
|
| 1475 |
+
//
|
| 1476 |
+
// Load the source
|
| 1477 |
+
//
|
| 1478 |
+
|
| 1479 |
+
source_iterator.load(source_fragment);
|
| 1480 |
+
++source_iterator;
|
| 1481 |
+
|
| 1482 |
+
//
|
| 1483 |
+
// Convert and store fragment
|
| 1484 |
+
//
|
| 1485 |
+
|
| 1486 |
+
__syncthreads();
|
| 1487 |
+
|
| 1488 |
+
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
| 1489 |
+
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
| 1490 |
+
|
| 1491 |
+
__syncthreads();
|
| 1492 |
+
|
| 1493 |
+
//
|
| 1494 |
+
// Load fragments from shared memory
|
| 1495 |
+
//
|
| 1496 |
+
|
| 1497 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 1498 |
+
|
| 1499 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 1500 |
+
|
| 1501 |
+
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
| 1502 |
+
if (kPartitionsK > 1)
|
| 1503 |
+
{
|
| 1504 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 1505 |
+
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
|
| 1506 |
+
|
| 1507 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1508 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 1509 |
+
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
|
| 1510 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 1511 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 1512 |
+
}
|
| 1513 |
+
|
| 1514 |
+
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
|
| 1515 |
+
}
|
| 1516 |
+
|
| 1517 |
+
//
|
| 1518 |
+
// Apply output operation
|
| 1519 |
+
//
|
| 1520 |
+
|
| 1521 |
+
typename OutputTileIterator::Fragment frag_Z;
|
| 1522 |
+
typename TensorTileIterator::Fragment frag_T;
|
| 1523 |
+
|
| 1524 |
+
apply_output_operator_(
|
| 1525 |
+
frag_Z,
|
| 1526 |
+
frag_T,
|
| 1527 |
+
output_op,
|
| 1528 |
+
aligned_accum_fragment[0],
|
| 1529 |
+
source_fragment,
|
| 1530 |
+
broadcast_fragment);
|
| 1531 |
+
|
| 1532 |
+
//
|
| 1533 |
+
// Conditionally store fragments
|
| 1534 |
+
//
|
| 1535 |
+
|
| 1536 |
+
if (OutputOp::kStoreZ) {
|
| 1537 |
+
destination_iterator.store(frag_Z);
|
| 1538 |
+
++destination_iterator;
|
| 1539 |
+
}
|
| 1540 |
+
|
| 1541 |
+
if (OutputOp::kStoreT) {
|
| 1542 |
+
tensor_iterator.store(frag_T);
|
| 1543 |
+
++tensor_iterator;
|
| 1544 |
+
}
|
| 1545 |
+
}
|
| 1546 |
+
}
|
| 1547 |
+
|
| 1548 |
+
/// Helper to invoke the output functor over each vector of output
|
| 1549 |
+
CUTLASS_DEVICE
|
| 1550 |
+
void apply_output_operator_(
|
| 1551 |
+
typename OutputTileIterator::Fragment &frag_Z,
|
| 1552 |
+
typename TensorTileIterator::Fragment &frag_T,
|
| 1553 |
+
OutputOp const &output_op,
|
| 1554 |
+
typename SharedLoadIterator::Fragment const &frag_AB,
|
| 1555 |
+
typename OutputTileIterator::Fragment const &frag_C,
|
| 1556 |
+
BroadcastFragment const &frag_Broadcast) {
|
| 1557 |
+
|
| 1558 |
+
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
| 1559 |
+
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
| 1560 |
+
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
| 1561 |
+
|
| 1562 |
+
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
| 1563 |
+
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
| 1564 |
+
|
| 1565 |
+
AccumulatorAccessType const *frag_AB_ptr =
|
| 1566 |
+
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
| 1567 |
+
|
| 1568 |
+
OutputAccessType const *frag_C_ptr =
|
| 1569 |
+
reinterpret_cast<OutputAccessType const *>(&frag_C);
|
| 1570 |
+
|
| 1571 |
+
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
| 1572 |
+
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
| 1573 |
+
|
| 1574 |
+
int const kOutputOpIterations =
|
| 1575 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 1576 |
+
|
| 1577 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1578 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 1579 |
+
output_op(
|
| 1580 |
+
frag_Z_ptr[i],
|
| 1581 |
+
frag_T_ptr[i],
|
| 1582 |
+
frag_AB_ptr[i],
|
| 1583 |
+
frag_C_ptr[i],
|
| 1584 |
+
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
| 1585 |
+
}
|
| 1586 |
+
}
|
| 1587 |
+
|
| 1588 |
+
/// Helper to invoke the output functor over each vector of output
|
| 1589 |
+
CUTLASS_DEVICE
|
| 1590 |
+
void apply_output_operator_source_not_needed_(
|
| 1591 |
+
typename OutputTileIterator::Fragment &frag_Z,
|
| 1592 |
+
typename TensorTileIterator::Fragment &frag_T,
|
| 1593 |
+
OutputOp const &output_op,
|
| 1594 |
+
typename SharedLoadIterator::Fragment const &frag_AB,
|
| 1595 |
+
BroadcastFragment const &frag_Broadcast) {
|
| 1596 |
+
|
| 1597 |
+
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
| 1598 |
+
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
| 1599 |
+
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
| 1600 |
+
|
| 1601 |
+
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
| 1602 |
+
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
| 1603 |
+
|
| 1604 |
+
AccumulatorAccessType const *frag_AB_ptr =
|
| 1605 |
+
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
| 1606 |
+
|
| 1607 |
+
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
| 1608 |
+
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
| 1609 |
+
|
| 1610 |
+
int const kOutputOpIterations =
|
| 1611 |
+
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 1612 |
+
|
| 1613 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1614 |
+
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 1615 |
+
|
| 1616 |
+
output_op(
|
| 1617 |
+
frag_Z_ptr[i],
|
| 1618 |
+
frag_T_ptr[i],
|
| 1619 |
+
frag_AB_ptr[i],
|
| 1620 |
+
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
| 1621 |
+
}
|
| 1622 |
+
}
|
| 1623 |
+
|
| 1624 |
+
|
| 1625 |
+
public:
|
| 1626 |
+
/// Stream-K reduce helper
|
| 1627 |
+
CUTLASS_DEVICE
|
| 1628 |
+
void reduce(
|
| 1629 |
+
int reduce_fragment_idx, ///< Reduce fragment index
|
| 1630 |
+
OutputOp const &output_op, ///< Output operator
|
| 1631 |
+
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
| 1632 |
+
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
| 1633 |
+
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 1634 |
+
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
| 1635 |
+
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
| 1636 |
+
MatrixCoord(Shape::kM, Shape::kN),
|
| 1637 |
+
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
| 1638 |
+
MatrixCoord())
|
| 1639 |
+
{
|
| 1640 |
+
|
| 1641 |
+
BroadcastFragment broadcast_fragment;
|
| 1642 |
+
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
| 1643 |
+
|
| 1644 |
+
// Initialize/load source-fragment data
|
| 1645 |
+
typename OutputTileIterator::Fragment source_fragment;
|
| 1646 |
+
source_fragment.clear();
|
| 1647 |
+
|
| 1648 |
+
if (output_op.is_source_needed())
|
| 1649 |
+
{
|
| 1650 |
+
source_iterator += reduce_fragment_idx;
|
| 1651 |
+
source_iterator.load(source_fragment);
|
| 1652 |
+
}
|
| 1653 |
+
|
| 1654 |
+
// Load fragment from shared memory
|
| 1655 |
+
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
| 1656 |
+
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
| 1657 |
+
|
| 1658 |
+
// Add fragments shared by other k partitions
|
| 1659 |
+
if (kPartitionsK > 1)
|
| 1660 |
+
{
|
| 1661 |
+
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 1662 |
+
|
| 1663 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1664 |
+
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 1665 |
+
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
| 1666 |
+
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
| 1667 |
+
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
| 1668 |
+
}
|
| 1669 |
+
}
|
| 1670 |
+
|
| 1671 |
+
//
|
| 1672 |
+
// Apply output operation
|
| 1673 |
+
//
|
| 1674 |
+
|
| 1675 |
+
typename OutputTileIterator::Fragment frag_Z;
|
| 1676 |
+
typename TensorTileIterator::Fragment frag_T;
|
| 1677 |
+
|
| 1678 |
+
if (!output_op.is_source_needed()) {
|
| 1679 |
+
apply_output_operator_source_not_needed_(
|
| 1680 |
+
frag_Z,
|
| 1681 |
+
frag_T,
|
| 1682 |
+
output_op,
|
| 1683 |
+
aligned_accum_fragment[0],
|
| 1684 |
+
broadcast_fragment);
|
| 1685 |
+
} else {
|
| 1686 |
+
apply_output_operator_(
|
| 1687 |
+
frag_Z,
|
| 1688 |
+
frag_T,
|
| 1689 |
+
output_op,
|
| 1690 |
+
aligned_accum_fragment[0],
|
| 1691 |
+
source_fragment,
|
| 1692 |
+
broadcast_fragment);
|
| 1693 |
+
}
|
| 1694 |
+
|
| 1695 |
+
//
|
| 1696 |
+
// Conditionally store fragments
|
| 1697 |
+
//
|
| 1698 |
+
|
| 1699 |
+
if (OutputOp::kStoreZ) {
|
| 1700 |
+
destination_iterator.store(frag_Z);
|
| 1701 |
+
++destination_iterator;
|
| 1702 |
+
}
|
| 1703 |
+
|
| 1704 |
+
if (OutputOp::kStoreT) {
|
| 1705 |
+
tensor_iterator.store(frag_T);
|
| 1706 |
+
++tensor_iterator;
|
| 1707 |
+
}
|
| 1708 |
+
}
|
| 1709 |
+
};
|
| 1710 |
+
|
| 1711 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1712 |
+
|
| 1713 |
+
} // namespace threadblock
|
| 1714 |
+
} // namespace epilogue
|
| 1715 |
+
} // namespace cutlass
|
| 1716 |
+
|
| 1717 |
+
////////////////////////////////////////////////////////////////////////////////
|