Kernels
danieldk HF Staff commited on
Commit
639dc8e
·
verified ·
1 Parent(s): 0d682db

Build uploaded using `kernels` (batch 10/10).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +985 -0
  2. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h +610 -0
  3. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h +684 -0
  4. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h +250 -0
  5. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h +452 -0
  6. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h +70 -0
  7. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h +265 -0
  8. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_generic_with_scaling.h +325 -0
  9. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h +69 -0
  10. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h +231 -0
  11. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h +75 -0
  12. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h +236 -0
  13. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h +572 -0
  14. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h +543 -0
  15. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h +301 -0
  16. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h +70 -0
  17. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h +69 -0
  18. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp +253 -0
  19. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h +234 -0
  20. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/reduction_op.h +97 -0
  21. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/scale_type.h +66 -0
  22. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h +255 -0
  23. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h +264 -0
  24. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h +74 -0
  25. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h +241 -0
  26. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +443 -0
  27. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +904 -0
  28. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h +175 -0
  29. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h +337 -0
  30. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_absmax.h +126 -0
  31. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h +376 -0
  32. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h +177 -0
  33. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h +165 -0
  34. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h +127 -0
  35. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h +208 -0
  36. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h +228 -0
  37. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h +113 -0
  38. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h +142 -0
  39. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue.h +548 -0
  40. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h +234 -0
  41. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h +197 -0
  42. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h +335 -0
  43. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h +347 -0
  44. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h +206 -0
  45. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h +401 -0
  46. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h +224 -0
  47. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h +443 -0
  48. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h +513 -0
  49. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h +922 -0
  50. build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +1717 -0
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Functor performing linear combination operations used by epilogues.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/platform/platform.h"
44
+
45
+ #include "cutlass/epilogue/thread/activation.h"
46
+ #include "cutlass/epilogue/thread/scale_type.h"
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace epilogue {
52
+ namespace thread {
53
+
54
+ /////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ namespace detail {
57
+
58
+ struct EmptyArguments {};
59
+
60
+ template<class T, class = void>
61
+ struct ElementwiseOpDispatcher {
62
+ using Arguments = EmptyArguments;
63
+
64
+ T op;
65
+
66
+ CUTLASS_HOST_DEVICE
67
+ ElementwiseOpDispatcher(Arguments) {}
68
+
69
+ template <typename ValueType>
70
+ CUTLASS_HOST_DEVICE
71
+ ValueType operator()(ValueType value) {
72
+ return op(value);
73
+ }
74
+ };
75
+
76
+ template<class T>
77
+ struct ElementwiseOpDispatcher<T, std::void_t<typename T::Arguments>> {
78
+ using Arguments = typename T::Arguments;
79
+
80
+ Arguments args;
81
+ T op;
82
+
83
+ CUTLASS_HOST_DEVICE
84
+ ElementwiseOpDispatcher(Arguments args_):args(args_) {}
85
+
86
+ template <typename ValueType>
87
+ CUTLASS_HOST_DEVICE
88
+ ValueType operator()(ValueType value) {
89
+ return op(value, args);
90
+ }
91
+ };
92
+
93
+ }
94
+
95
+ /////////////////////////////////////////////////////////////////////////////////////////////////
96
+
97
+ /// This base class is meant to define the concept required of the
98
+ /// EpilogueWithBroadcast::OutputOp
99
+ template <
100
+ typename ElementC_,
101
+ typename ElementAccumulator_,
102
+ typename ElementCompute_,
103
+ typename ElementZ_,
104
+ typename ElementT_,
105
+ int ElementsPerAccess,
106
+ typename ElementwiseOp_ = Identity<ElementCompute_>,
107
+ typename BinaryOp_ = plus<ElementCompute_>,
108
+ bool StoreT_ = true,
109
+ typename ElementVector_ = ElementC_
110
+ >
111
+ class LinearCombinationBiasElementwise {
112
+ public:
113
+
114
+ using ElementOutput = ElementC_;
115
+ using ElementD = ElementOutput;
116
+ using ElementC = ElementC_;
117
+ using ElementAccumulator = ElementAccumulator_;
118
+ using ElementCompute = ElementCompute_;
119
+ using ElementScalar = ElementCompute;
120
+ using ElementZ = ElementZ_;
121
+ using ElementT = ElementT_;
122
+ using ElementVector = ElementVector_;
123
+ static int const kElementsPerAccess = ElementsPerAccess;
124
+ static int const kCount = kElementsPerAccess;
125
+
126
+ /// Follow cutlass3x EVT aliases
127
+ static bool const IsEltActSupported = true;
128
+
129
+ using ElementwiseOp = ElementwiseOp_;
130
+ using BinaryOp = BinaryOp_;
131
+
132
+ using ElementwiseOpDispatcher = detail::ElementwiseOpDispatcher<ElementwiseOp>;
133
+ using ElementwiseArguments = typename ElementwiseOpDispatcher::Arguments;
134
+
135
+ // Indicates that this epilogue applies only one binary operation
136
+ static bool const kIsSingleSource = true;
137
+
138
+
139
+ using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
140
+ using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
141
+ using FragmentC = Array<ElementC, kElementsPerAccess>;
142
+ using FragmentZ = Array<ElementZ, kElementsPerAccess>;
143
+ using FragmentT = Array<ElementT, kElementsPerAccess>;
144
+
145
+ // Definitions needed for collective epilogue
146
+ using FragmentSource = FragmentC;
147
+ using FragmentOutput = FragmentZ;
148
+ using ElementBias = ElementVector;
149
+ using FragmentBias = Array<ElementBias, kElementsPerAccess>;
150
+ using ActivationFn = ElementwiseOp;
151
+ static const ScaleType::Kind kScale = ScaleType::Default;
152
+
153
+ static bool const kIsHeavy = kIsHeavy_member_or_false<ElementwiseOp>::value;
154
+
155
+ /// If true, the 'Z' tensor is stored
156
+ static bool const kStoreZ = true;
157
+
158
+ /// If true, the 'T' tensor is stored
159
+ static bool const kStoreT = StoreT_;
160
+
161
+ /// Host-constructable parameters structure
162
+ struct Params {
163
+
164
+ ElementCompute alpha; ///< scales accumulators
165
+ ElementCompute beta; ///< scales source tensor
166
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
167
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
168
+ ElementwiseArguments elementwise; ///< Arguments for elementwise operation
169
+
170
+ //
171
+ // Methods
172
+ //
173
+
174
+ CUTLASS_HOST_DEVICE
175
+ Params():
176
+ alpha(ElementCompute(1)),
177
+ beta(ElementCompute(0)),
178
+ alpha_ptr(nullptr),
179
+ beta_ptr(nullptr) { }
180
+
181
+ CUTLASS_HOST_DEVICE
182
+ Params(
183
+ ElementCompute alpha,
184
+ ElementCompute beta,
185
+ ElementwiseArguments elementwise_ = ElementwiseArguments{}
186
+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr), elementwise(elementwise_) {
187
+
188
+ }
189
+
190
+ CUTLASS_HOST_DEVICE
191
+ Params(
192
+ ElementCompute alpha
193
+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
194
+
195
+ }
196
+
197
+ CUTLASS_HOST_DEVICE
198
+ Params(
199
+ ElementCompute const *alpha_ptr,
200
+ ElementCompute const *beta_ptr,
201
+ ElementwiseArguments elementwise_ = ElementwiseArguments{}
202
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), elementwise(elementwise_) {
203
+
204
+ }
205
+
206
+ CUTLASS_HOST_DEVICE
207
+ Params(
208
+ ElementCompute const *alpha_ptr
209
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
210
+
211
+ }
212
+ };
213
+
214
+ private:
215
+
216
+ //
217
+ // Data members
218
+ //
219
+
220
+ ElementCompute alpha_;
221
+ ElementCompute beta_;
222
+ ElementwiseArguments const &elementwise_;
223
+ bool skip_elementwise_;
224
+
225
+ public:
226
+
227
+ //
228
+ // Methods
229
+ //
230
+
231
+ /// Constructor from Params
232
+ CUTLASS_HOST_DEVICE
233
+ LinearCombinationBiasElementwise(Params const &params): elementwise_(params.elementwise) {
234
+
235
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
236
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
237
+ skip_elementwise_ = false;
238
+ }
239
+
240
+ /// Returns true if source is needed
241
+ CUTLASS_HOST_DEVICE
242
+ bool is_source_needed() const {
243
+ return beta_ != ElementCompute(0);
244
+ }
245
+
246
+ /// Functionally required for serial reduction in the epilogue
247
+ CUTLASS_HOST_DEVICE
248
+ void set_k_partition(int k_partition, int k_partition_count) {
249
+ if (k_partition) {
250
+ beta_ = ElementCompute(1);
251
+ }
252
+
253
+ if (k_partition != k_partition_count - 1) {
254
+ skip_elementwise_ = true;
255
+ }
256
+ }
257
+
258
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is true
259
+ template <typename ElementwiseArgs>
260
+ CUTLASS_HOST_DEVICE
261
+ void operator()(
262
+ FragmentZ &frag_Z,
263
+ FragmentT &frag_T,
264
+ FragmentAccumulator const &AB,
265
+ FragmentC const &frag_C,
266
+ FragmentCompute const &V,
267
+ ElementwiseArgs const &elementwise_args) const {
268
+
269
+ ElementwiseOp elementwise_op;
270
+ BinaryOp binary_op;
271
+
272
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
273
+ FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
274
+ FragmentCompute result_Z;
275
+ FragmentCompute result_T;
276
+
277
+ CUTLASS_PRAGMA_UNROLL
278
+ for (int i = 0; i < kElementsPerAccess; ++i) {
279
+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
280
+ result_T[i] = z;
281
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
282
+ }
283
+
284
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
285
+ frag_Z = convert_z(result_Z);
286
+
287
+ if constexpr (kStoreT) {
288
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
289
+ frag_T = convert_t(result_T);
290
+ }
291
+ }
292
+
293
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is false
294
+ template <typename ElementwiseArgs>
295
+ CUTLASS_HOST_DEVICE
296
+ void operator()(
297
+ FragmentZ &frag_Z,
298
+ FragmentT &frag_T,
299
+ FragmentAccumulator const &AB,
300
+ FragmentCompute const &V,
301
+ ElementwiseArgs const &elementwise_args) const {
302
+
303
+ ElementwiseOp elementwise_op;
304
+ BinaryOp binary_op;
305
+
306
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
307
+ FragmentCompute result_Z;
308
+ FragmentCompute result_T;
309
+
310
+ CUTLASS_PRAGMA_UNROLL
311
+ for (int i = 0; i < kElementsPerAccess; ++i) {
312
+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
313
+ result_T[i] = z;
314
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
315
+ }
316
+
317
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
318
+ frag_Z = convert_z(result_Z);
319
+
320
+ if constexpr (kStoreT) {
321
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
322
+ frag_T = convert_t(result_T);
323
+ }
324
+ }
325
+
326
+ /// Applies the operation when is_source_needed() is true
327
+ CUTLASS_HOST_DEVICE
328
+ void operator()(
329
+ FragmentZ &frag_Z,
330
+ FragmentT &frag_T,
331
+ FragmentAccumulator const &AB,
332
+ FragmentC const &frag_C,
333
+ FragmentCompute const &V) const {
334
+
335
+ ElementwiseOpDispatcher elementwise_op(elementwise_);
336
+ BinaryOp binary_op;
337
+
338
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
339
+ FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
340
+ FragmentCompute result_Z;
341
+ FragmentCompute result_T;
342
+
343
+ CUTLASS_PRAGMA_UNROLL
344
+ for (int i = 0; i < kElementsPerAccess; ++i) {
345
+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
346
+ result_T[i] = z;
347
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
348
+ }
349
+
350
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
351
+ frag_Z = convert_z(result_Z);
352
+
353
+ if constexpr (kStoreT) {
354
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
355
+ frag_T = convert_t(result_T);
356
+ }
357
+ }
358
+
359
+ /// Applies the operation when is_source_needed() is false
360
+ CUTLASS_HOST_DEVICE
361
+ void operator()(
362
+ FragmentZ &frag_Z,
363
+ FragmentT &frag_T,
364
+ FragmentAccumulator const &AB,
365
+ FragmentCompute const &V) const {
366
+
367
+ ElementwiseOpDispatcher elementwise_op(elementwise_);
368
+ BinaryOp binary_op;
369
+
370
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
371
+ FragmentCompute result_Z;
372
+ FragmentCompute result_T;
373
+
374
+ CUTLASS_PRAGMA_UNROLL
375
+ for (int i = 0; i < kElementsPerAccess; ++i) {
376
+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
377
+ result_T[i] = z;
378
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
379
+ }
380
+
381
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
382
+ frag_Z = convert_z(result_Z);
383
+
384
+ if constexpr (kStoreT) {
385
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
386
+ frag_T = convert_t(result_T);
387
+ }
388
+ }
389
+
390
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is true
391
+ template <typename ElementwiseArgs>
392
+ CUTLASS_HOST_DEVICE
393
+ void operator()(
394
+ ElementZ &Z,
395
+ ElementT &T,
396
+ ElementAccumulator const &AB,
397
+ ElementC const &C,
398
+ ElementCompute const &V,
399
+ ElementwiseArgs const &elementwise_args) const {
400
+
401
+ ElementwiseOp elementwise_op;
402
+ BinaryOp binary_op;
403
+
404
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
405
+ ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
406
+
407
+ ElementCompute z = binary_op(alpha_ * tmp_Accum + beta_ * tmp_C, V);
408
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
409
+
410
+ NumericConverter<ElementZ, ElementCompute> convert_z;
411
+ Z = convert_z(result_Z);
412
+
413
+ if constexpr (kStoreT) {
414
+ ElementCompute result_T = z;
415
+ NumericConverter<ElementT, ElementCompute> convert_t;
416
+ T = convert_t(result_T);
417
+ }
418
+ }
419
+
420
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is false
421
+ template <typename ElementwiseArgs>
422
+ CUTLASS_HOST_DEVICE
423
+ void operator()(
424
+ ElementZ &Z,
425
+ ElementT &T,
426
+ ElementAccumulator const &AB,
427
+ ElementCompute const &V,
428
+ ElementwiseArgs const &elementwise_args) const {
429
+
430
+ ElementwiseOp elementwise_op;
431
+ BinaryOp binary_op;
432
+
433
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
434
+
435
+ ElementCompute z = binary_op(alpha_ * tmp_Accum, V);
436
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
437
+
438
+ NumericConverter<ElementZ, ElementCompute> convert_z;
439
+ Z = convert_z(result_Z);
440
+
441
+ if constexpr (kStoreT) {
442
+ ElementCompute result_T = z;
443
+ NumericConverter<ElementT, ElementCompute> convert_t;
444
+ T = convert_t(result_T);
445
+ }
446
+ }
447
+
448
+ /// Applies the operation when is_source_needed() is true
449
+ CUTLASS_HOST_DEVICE
450
+ void operator()(
451
+ ElementZ &Z,
452
+ ElementT &T,
453
+ ElementAccumulator const &AB,
454
+ ElementC const &C,
455
+ ElementCompute const &V) const {
456
+
457
+ ElementwiseOpDispatcher elementwise_op(elementwise_);
458
+ BinaryOp binary_op;
459
+
460
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
461
+ ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
462
+
463
+ ElementCompute z = binary_op(alpha_ * tmp_Accum + beta_ * tmp_C, V);
464
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z);
465
+
466
+ NumericConverter<ElementZ, ElementCompute> convert_z;
467
+ Z = convert_z(result_Z);
468
+
469
+ if constexpr (kStoreT) {
470
+ ElementCompute result_T = z;
471
+ NumericConverter<ElementT, ElementCompute> convert_t;
472
+ T = convert_t(result_T);
473
+ }
474
+ }
475
+
476
+ /// Applies the operation when is_source_needed() is false
477
+ CUTLASS_HOST_DEVICE
478
+ void operator()(
479
+ ElementZ &Z,
480
+ ElementT &T,
481
+ ElementAccumulator const &AB,
482
+ ElementCompute const &V) const {
483
+
484
+ ElementwiseOpDispatcher elementwise_op(elementwise_);
485
+ BinaryOp binary_op;
486
+
487
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
488
+
489
+ ElementCompute z = binary_op(alpha_ * tmp_Accum, V);
490
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z);
491
+
492
+ NumericConverter<ElementZ, ElementCompute> convert_z;
493
+ Z = convert_z(result_Z);
494
+
495
+ if constexpr (kStoreT) {
496
+ ElementCompute result_T = z;
497
+ NumericConverter<ElementT, ElementCompute> convert_t;
498
+ T = convert_t(result_T);
499
+ }
500
+ }
501
+ };
502
+
503
+
504
+ /// This base class is meant to define the concept required of the
505
+ /// EpilogueWithBroadcast::OutputOp
506
+ template <
507
+ typename ElementC_,
508
+ typename ElementAccumulator_,
509
+ typename ElementCompute_,
510
+ typename ElementZ_,
511
+ typename ElementT_,
512
+ int ElementsPerAccess,
513
+ typename ElementwiseOp_ = Identity<ElementCompute_>,
514
+ typename BinaryOp_ = plus<ElementCompute_>,
515
+ bool StoreT_ = true,
516
+ typename ElementVector_ = ElementC_
517
+ >
518
+ class LinearCombinationPerChannelScalingBiasElementwise {
519
+ public:
520
+
521
+ using ElementOutput = ElementC_;
522
+ using ElementD = ElementOutput;
523
+ using ElementC = ElementC_;
524
+ using ElementAccumulator = ElementAccumulator_;
525
+ using ElementCompute = ElementCompute_;
526
+ using ElementScalar = ElementCompute;
527
+ using ElementZ = ElementZ_;
528
+ using ElementT = ElementT_;
529
+ using ElementVector = ElementVector_;
530
+ static int const kElementsPerAccess = ElementsPerAccess;
531
+ static int const kCount = kElementsPerAccess;
532
+
533
+ /// Follow cutlass3x EVT aliases
534
+ static bool const IsEltActSupported = true;
535
+ static bool const IsPerChannelScalingSupported = true;
536
+
537
+ using ElementwiseOp = ElementwiseOp_;
538
+ using BinaryOp = BinaryOp_;
539
+
540
+ using ElementwiseOpDispatcher = detail::ElementwiseOpDispatcher<ElementwiseOp>;
541
+ using ElementwiseArguments = typename ElementwiseOpDispatcher::Arguments;
542
+
543
+ // Indicates that this epilogue applies only one binary operation
544
+ static bool const kIsSingleSource = true;
545
+
546
+
547
+ using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
548
+ using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
549
+ using FragmentC = Array<ElementC, kElementsPerAccess>;
550
+ using FragmentZ = Array<ElementZ, kElementsPerAccess>;
551
+ using FragmentT = Array<ElementT, kElementsPerAccess>;
552
+
553
+ // Definitions needed for collective epilogue
554
+ using FragmentSource = FragmentC;
555
+ using FragmentOutput = FragmentZ;
556
+ using ElementBias = ElementVector;
557
+ using FragmentBias = Array<ElementBias, kElementsPerAccess>;
558
+ using ActivationFn = ElementwiseOp;
559
+ static const ScaleType::Kind kScale = ScaleType::PerChannelScaling;
560
+
561
+ static bool const kIsHeavy = kIsHeavy_member_or_false<ElementwiseOp>::value;
562
+
563
+ /// If true, the 'Z' tensor is stored
564
+ static bool const kStoreZ = true;
565
+
566
+ /// If true, the 'T' tensor is stored
567
+ static bool const kStoreT = StoreT_;
568
+
569
+ /// Host-constructable parameters structure
570
+ struct Params {
571
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
572
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
573
+ ElementCompute beta; ///< scales source tensor
574
+ ElementwiseArguments elementwise; ///< Arguments for elementwise operation
575
+
576
+ //
577
+ // Methods
578
+ //
579
+
580
+ CUTLASS_HOST_DEVICE
581
+ Params():
582
+ alpha_ptr(nullptr),
583
+ beta_ptr(nullptr),
584
+ beta(ElementCompute(0)) { }
585
+
586
+ CUTLASS_HOST_DEVICE
587
+ Params(
588
+ ElementCompute const *alpha_ptr,
589
+ ElementCompute const *beta_ptr,
590
+ ElementwiseArguments elementwise_ = ElementwiseArguments{}
591
+ ): beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), elementwise(elementwise_) {
592
+
593
+ }
594
+
595
+ CUTLASS_HOST_DEVICE
596
+ Params(
597
+ ElementCompute const *alpha_ptr
598
+ ): beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
599
+
600
+ }
601
+ };
602
+
603
+ private:
604
+
605
+ //
606
+ // Data members
607
+ //
608
+
609
+ ElementCompute const* beta_ptr_ = nullptr;
610
+ ElementCompute beta_ = 0;
611
+ ElementwiseArguments const &elementwise_;
612
+ bool skip_elementwise_;
613
+
614
+ public:
615
+
616
+ //
617
+ // Methods
618
+ //
619
+
620
+ /// Constructor from Params
621
+ CUTLASS_HOST_DEVICE
622
+ LinearCombinationPerChannelScalingBiasElementwise(Params const &params): elementwise_(params.elementwise) {
623
+ if (params.beta_ptr) {
624
+ beta_ptr_ = params.beta_ptr;
625
+ }
626
+ else {
627
+ beta_ = params.beta;
628
+ }
629
+ skip_elementwise_ = false;
630
+ }
631
+
632
+ /// Returns true if source is needed
633
+ CUTLASS_HOST_DEVICE
634
+ bool is_source_needed() const {
635
+ return beta_ptr_ != nullptr || beta_ != ElementCompute(0);
636
+ }
637
+
638
+ CUTLASS_HOST_DEVICE
639
+ bool is_beta_vector() const {
640
+ return beta_ptr_ != nullptr;
641
+ }
642
+
643
+ /// Functionally required for serial reduction in the epilogue
644
+ CUTLASS_HOST_DEVICE
645
+ void set_k_partition(int k_partition, int k_partition_count) {
646
+ if (k_partition) {
647
+ beta_ = ElementCompute(1);
648
+ }
649
+
650
+ if (k_partition != k_partition_count - 1) {
651
+ skip_elementwise_ = true;
652
+ }
653
+ }
654
+
655
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is true
656
+ template <typename ElementwiseArgs>
657
+ CUTLASS_HOST_DEVICE
658
+ void operator()(
659
+ FragmentZ &frag_Z,
660
+ FragmentT &frag_T,
661
+ FragmentAccumulator const &AB,
662
+ FragmentC const &frag_C,
663
+ FragmentCompute const & valpha,
664
+ FragmentCompute const & vbias,
665
+ ElementwiseArgs const &elementwise_args) const {
666
+
667
+ ElementwiseOp elementwise_op;
668
+ BinaryOp binary_op;
669
+
670
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
671
+ FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
672
+ FragmentCompute result_Z;
673
+ FragmentCompute result_T;
674
+
675
+ CUTLASS_PRAGMA_UNROLL
676
+ for (int i = 0; i < kElementsPerAccess; ++i) {
677
+ ElementCompute z = binary_op(valpha[i] * tmp_Accum[i] + beta_ * tmp_C[i], vbias[i]);
678
+ result_T[i] = z;
679
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
680
+ }
681
+
682
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
683
+ frag_Z = convert_z(result_Z);
684
+
685
+ if constexpr (kStoreT) {
686
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
687
+ frag_T = convert_t(result_T);
688
+ }
689
+ }
690
+
691
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is true
692
+ /// D = elementwise_op(vector_alpha * accumulator + vector_beta * source + bias)
693
+ template <typename ElementwiseArgs>
694
+ CUTLASS_HOST_DEVICE
695
+ void operator()(
696
+ FragmentZ &frag_Z,
697
+ FragmentT &frag_T,
698
+ FragmentAccumulator const &AB,
699
+ FragmentC const &frag_C,
700
+ FragmentCompute const & valpha,
701
+ FragmentCompute const & vbeta,
702
+ FragmentCompute const & vbias,
703
+ ElementwiseArgs const &elementwise_args) const {
704
+
705
+ ElementwiseOp elementwise_op;
706
+ BinaryOp binary_op;
707
+
708
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
709
+ FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
710
+ FragmentCompute result_Z;
711
+ FragmentCompute result_T;
712
+
713
+ CUTLASS_PRAGMA_UNROLL
714
+ for (int i = 0; i < kElementsPerAccess; ++i) {
715
+ ElementCompute z = binary_op(valpha[i] * tmp_Accum[i] + vbeta[i] * tmp_C[i], vbias[i]);
716
+ result_T[i] = z;
717
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
718
+ }
719
+
720
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
721
+ frag_Z = convert_z(result_Z);
722
+
723
+ if constexpr (kStoreT) {
724
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
725
+ frag_T = convert_t(result_T);
726
+ }
727
+ }
728
+
729
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is false
730
+ template <typename ElementwiseArgs>
731
+ CUTLASS_HOST_DEVICE
732
+ void operator()(
733
+ FragmentZ &frag_Z,
734
+ FragmentT &frag_T,
735
+ FragmentAccumulator const &AB,
736
+ FragmentCompute const & valpha,
737
+ FragmentCompute const & vbias,
738
+ ElementwiseArgs const &elementwise_args) const {
739
+
740
+ ElementwiseOp elementwise_op;
741
+ BinaryOp binary_op;
742
+
743
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
744
+ FragmentCompute result_Z;
745
+ FragmentCompute result_T;
746
+
747
+ CUTLASS_PRAGMA_UNROLL
748
+ for (int i = 0; i < kElementsPerAccess; ++i) {
749
+ ElementCompute z = binary_op(valpha[i] * tmp_Accum[i], vbias[i]);
750
+ result_T[i] = z;
751
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
752
+ }
753
+
754
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
755
+ frag_Z = convert_z(result_Z);
756
+
757
+ if constexpr (kStoreT) {
758
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
759
+ frag_T = convert_t(result_T);
760
+ }
761
+ }
762
+
763
+ /// Applies the operation when is_source_needed() is true
764
+ CUTLASS_HOST_DEVICE
765
+ void operator()(
766
+ FragmentZ &frag_Z,
767
+ FragmentT &frag_T,
768
+ FragmentAccumulator const &AB,
769
+ FragmentC const &frag_C,
770
+ FragmentCompute const & valpha,
771
+ FragmentCompute const & vbias) const {
772
+
773
+ ElementwiseOpDispatcher elementwise_op(elementwise_);
774
+ BinaryOp binary_op;
775
+
776
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
777
+ FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
778
+ FragmentCompute result_Z;
779
+ FragmentCompute result_T;
780
+
781
+ CUTLASS_PRAGMA_UNROLL
782
+ for (int i = 0; i < kElementsPerAccess; ++i) {
783
+ ElementCompute z = binary_op(valpha[i] * tmp_Accum[i] + beta_ * tmp_C[i], vbias[i]);
784
+ result_T[i] = z;
785
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
786
+ }
787
+
788
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
789
+ frag_Z = convert_z(result_Z);
790
+
791
+ if constexpr (kStoreT) {
792
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
793
+ frag_T = convert_t(result_T);
794
+ }
795
+ }
796
+
797
+ /// Applies the operation when is_source_needed() is false
798
+ CUTLASS_HOST_DEVICE
799
+ void operator()(
800
+ FragmentZ &frag_Z,
801
+ FragmentT &frag_T,
802
+ FragmentAccumulator const &AB,
803
+ FragmentCompute const & valpha,
804
+ FragmentCompute const & vbias) const {
805
+
806
+ ElementwiseOpDispatcher elementwise_op(elementwise_);
807
+ BinaryOp binary_op;
808
+
809
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
810
+ FragmentCompute result_Z;
811
+ FragmentCompute result_T;
812
+
813
+ CUTLASS_PRAGMA_UNROLL
814
+ for (int i = 0; i < kElementsPerAccess; ++i) {
815
+ ElementCompute z = binary_op(valpha[i] * tmp_Accum[i], vbias[i]);
816
+ result_T[i] = z;
817
+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
818
+ }
819
+
820
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
821
+ frag_Z = convert_z(result_Z);
822
+
823
+ if constexpr (kStoreT) {
824
+ NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
825
+ frag_T = convert_t(result_T);
826
+ }
827
+ }
828
+
829
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is true
830
+ template <typename ElementwiseArgs>
831
+ CUTLASS_HOST_DEVICE
832
+ void operator()(
833
+ ElementZ &Z,
834
+ ElementT &T,
835
+ ElementAccumulator const &AB,
836
+ ElementC const &C,
837
+ ElementCompute const & valpha,
838
+ ElementCompute const & vbias,
839
+ ElementwiseArgs const &elementwise_args) const {
840
+
841
+ ElementwiseOp elementwise_op;
842
+ BinaryOp binary_op;
843
+
844
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
845
+ ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
846
+
847
+ ElementCompute z = binary_op(valpha * tmp_Accum + beta_ * tmp_C, vbias);
848
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
849
+
850
+ NumericConverter<ElementZ, ElementCompute> convert_z;
851
+ Z = convert_z(result_Z);
852
+
853
+ if constexpr (kStoreT) {
854
+ ElementCompute result_T = z;
855
+ NumericConverter<ElementT, ElementCompute> convert_t;
856
+ T = convert_t(result_T);
857
+ }
858
+ }
859
+
860
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is true
861
+ /// D = elementwise_op(vector_alpha * accumulator + vector_beta * source + bias)
862
+ template <typename ElementwiseArgs>
863
+ CUTLASS_HOST_DEVICE
864
+ void operator()(
865
+ ElementZ &Z,
866
+ ElementT &T,
867
+ ElementAccumulator const &AB,
868
+ ElementC const &C,
869
+ ElementCompute const & valpha,
870
+ ElementCompute const & vbeta,
871
+ ElementCompute const & vbias,
872
+ ElementwiseArgs const &elementwise_args) const {
873
+
874
+ ElementwiseOp elementwise_op;
875
+ BinaryOp binary_op;
876
+
877
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
878
+ ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
879
+
880
+ ElementCompute z = binary_op(valpha * tmp_Accum + vbeta * tmp_C, vbias);
881
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
882
+
883
+ NumericConverter<ElementZ, ElementCompute> convert_z;
884
+ Z = convert_z(result_Z);
885
+
886
+ if constexpr (kStoreT) {
887
+ ElementCompute result_T = z;
888
+ NumericConverter<ElementT, ElementCompute> convert_t;
889
+ T = convert_t(result_T);
890
+ }
891
+ }
892
+
893
+ /// Applies the operation when elementwise_op require arguments and is_source_needed() is false
894
+ template <typename ElementwiseArgs>
895
+ CUTLASS_HOST_DEVICE
896
+ void operator()(
897
+ ElementZ &Z,
898
+ ElementT &T,
899
+ ElementAccumulator const &AB,
900
+ ElementCompute const & valpha,
901
+ ElementCompute const & vbias,
902
+ ElementwiseArgs const &elementwise_args) const {
903
+
904
+ ElementwiseOp elementwise_op;
905
+ BinaryOp binary_op;
906
+
907
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
908
+
909
+ ElementCompute z = binary_op(valpha * tmp_Accum, vbias);
910
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args);
911
+
912
+ NumericConverter<ElementZ, ElementCompute> convert_z;
913
+ Z = convert_z(result_Z);
914
+
915
+ if constexpr (kStoreT) {
916
+ ElementCompute result_T = z;
917
+ NumericConverter<ElementT, ElementCompute> convert_t;
918
+ T = convert_t(result_T);
919
+ }
920
+ }
921
+
922
+ /// Applies the operation when is_source_needed() is true
923
+ CUTLASS_HOST_DEVICE
924
+ void operator()(
925
+ ElementZ &Z,
926
+ ElementT &T,
927
+ ElementAccumulator const &AB,
928
+ ElementC const &C,
929
+ ElementCompute const & valpha,
930
+ ElementCompute const & vbias) const {
931
+
932
+ ElementwiseOpDispatcher elementwise_op(elementwise_);
933
+ BinaryOp binary_op;
934
+
935
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
936
+ ElementCompute tmp_C = NumericConverter<ElementCompute, ElementC>()(C);
937
+
938
+ ElementCompute z = binary_op(valpha * tmp_Accum + beta_ * tmp_C, vbias);
939
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z);
940
+
941
+ NumericConverter<ElementZ, ElementCompute> convert_z;
942
+ Z = convert_z(result_Z);
943
+
944
+ if constexpr (kStoreT) {
945
+ ElementCompute result_T = z;
946
+ NumericConverter<ElementT, ElementCompute> convert_t;
947
+ T = convert_t(result_T);
948
+ }
949
+ }
950
+
951
+ /// Applies the operation when is_source_needed() is false
952
+ CUTLASS_HOST_DEVICE
953
+ void operator()(
954
+ ElementZ &Z,
955
+ ElementT &T,
956
+ ElementAccumulator const &AB,
957
+ ElementCompute const & valpha,
958
+ ElementCompute const & vbias) const {
959
+
960
+ ElementwiseOpDispatcher elementwise_op(elementwise_);
961
+ BinaryOp binary_op;
962
+
963
+ ElementCompute tmp_Accum = NumericConverter<ElementCompute, ElementAccumulator>()(AB);
964
+
965
+ ElementCompute z = binary_op(valpha * tmp_Accum, vbias);
966
+ ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z);
967
+
968
+ NumericConverter<ElementZ, ElementCompute> convert_z;
969
+ Z = convert_z(result_Z);
970
+
971
+ if constexpr (kStoreT) {
972
+ ElementCompute result_T = z;
973
+ NumericConverter<ElementT, ElementCompute> convert_t;
974
+ T = convert_t(result_T);
975
+ }
976
+ }
977
+ };
978
+
979
+ /////////////////////////////////////////////////////////////////////////////////////////////////
980
+
981
+ } // namespace thread
982
+ } // namespace epilogue
983
+ } // namespace cutlass
984
+
985
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination operations used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include <cuda_fp16.h>
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/array.h"
42
+ #include "cutlass/functional.h"
43
+ #include "cutlass/numeric_conversion.h"
44
+ #include "cutlass/epilogue/thread/activation.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace epilogue {
50
+ namespace thread {
51
+
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ namespace detail {
55
+
56
+ template <typename Element, int ElementsPerAccess>
57
+ struct ArrayMaximum {
58
+
59
+ CUTLASS_HOST_DEVICE
60
+ Array<Element, ElementsPerAccess> operator()(
61
+ Array<Element, ElementsPerAccess> const &lhs,
62
+ Array<Element, ElementsPerAccess> const &rhs) const {
63
+
64
+ Array<Element, ElementsPerAccess> result;
65
+
66
+ CUTLASS_PRAGMA_UNROLL
67
+ for (int i = 0; i < ElementsPerAccess; ++i) {
68
+ result[i] = platform::max(lhs[i].get(), rhs[i]);
69
+ }
70
+
71
+ return result;
72
+ }
73
+
74
+ CUTLASS_HOST_DEVICE
75
+ Array<Element, ElementsPerAccess> operator()(
76
+ Array<Element, ElementsPerAccess> const &lhs,
77
+ Element rhs) const {
78
+
79
+ Array<Element, ElementsPerAccess> result;
80
+
81
+ CUTLASS_PRAGMA_UNROLL
82
+ for (int i = 0; i < ElementsPerAccess; ++i) {
83
+ result[i] = platform::max(lhs[i].get(), rhs);
84
+ }
85
+
86
+ return result;
87
+ }
88
+ };
89
+
90
+
91
+ /// Partial specialization: Element=float
92
+ template <int ElementsPerAccess>
93
+ struct ArrayMaximum<float, ElementsPerAccess> {
94
+
95
+ CUTLASS_HOST_DEVICE
96
+ Array<float, ElementsPerAccess> operator()(
97
+ Array<float, ElementsPerAccess> const &lhs,
98
+ Array<float, ElementsPerAccess> const &rhs) const {
99
+
100
+ Array<float, ElementsPerAccess> result;
101
+
102
+ CUTLASS_PRAGMA_UNROLL
103
+ for (int i = 0; i < ElementsPerAccess; ++i) {
104
+ result[i] = fmax(lhs[i], rhs[i]);
105
+ }
106
+
107
+ return result;
108
+ }
109
+
110
+ CUTLASS_HOST_DEVICE
111
+ Array<float, ElementsPerAccess> operator()(
112
+ Array<float, ElementsPerAccess> const &lhs,
113
+ float rhs) const {
114
+
115
+ Array<float, ElementsPerAccess> result;
116
+
117
+ CUTLASS_PRAGMA_UNROLL
118
+ for (int i = 0; i < ElementsPerAccess; ++i) {
119
+ result[i] = fmax(lhs[i], rhs);
120
+ }
121
+
122
+ return result;
123
+ }
124
+ };
125
+
126
+ /// Partial specialization: Element=half
127
+ template <int ElementsPerAccess>
128
+ struct ArrayMaximum<half_t, ElementsPerAccess> {
129
+
130
+ CUTLASS_DEVICE
131
+ Array<half_t, ElementsPerAccess> operator()(
132
+ Array<half_t, ElementsPerAccess> const &lhs,
133
+ Array<half_t, ElementsPerAccess> const &rhs) const {
134
+
135
+ Array<half_t, ElementsPerAccess> result;
136
+
137
+ #if __CUDA_ARCH__ >= 800
138
+ int const kVectorCount = ElementsPerAccess / 2;
139
+
140
+
141
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data());
142
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(rhs.raw_data());
143
+ __half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data());
144
+
145
+ CUTLASS_PRAGMA_UNROLL
146
+ for (int i = 0; i < kVectorCount; ++i) {
147
+ res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
148
+ }
149
+
150
+ static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
151
+
152
+ #else
153
+ __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
154
+ __half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data());
155
+ __half *res_ptr = reinterpret_cast<__half *>(result.raw_data());
156
+
157
+ CUTLASS_PRAGMA_UNROLL
158
+ for (int i = 0; i < ElementsPerAccess; ++i) {
159
+ res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]);
160
+ }
161
+
162
+ #endif
163
+
164
+ return result;
165
+ }
166
+
167
+ CUTLASS_DEVICE
168
+ Array<half_t, ElementsPerAccess> operator()(
169
+ Array<half_t, ElementsPerAccess> const &lhs,
170
+ half_t const &rhs) const {
171
+
172
+ Array<half_t, ElementsPerAccess> result;
173
+
174
+ #if __CUDA_ARCH__ >= 800
175
+ int const kVectorCount = ElementsPerAccess / 2;
176
+
177
+
178
+ __half rhs_raw = reinterpret_cast<__half const &>(rhs);
179
+ __half2 rhs_pair = __half2half2(rhs_raw);
180
+
181
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data());
182
+ __half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data());
183
+
184
+ CUTLASS_PRAGMA_UNROLL
185
+ for (int i = 0; i < kVectorCount; ++i) {
186
+ res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
187
+ }
188
+
189
+ static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
190
+
191
+ #else
192
+
193
+ __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
194
+ __half const rhs_raw = reinterpret_cast<__half const &>(rhs);
195
+ __half *res_ptr = reinterpret_cast<__half *>(result.raw_data());
196
+
197
+ CUTLASS_PRAGMA_UNROLL
198
+ for (int i = 0; i < ElementsPerAccess; ++i) {
199
+ res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]);
200
+ }
201
+
202
+ #endif
203
+
204
+ return result;
205
+ }
206
+ };
207
+
208
+ /// Partial specialization: Element=bfloat16_t
209
+ template <int ElementsPerAccess>
210
+ struct ArrayMaximum<bfloat16_t, ElementsPerAccess> {
211
+
212
+ using NvType = __nv_bfloat16;
213
+ using NvTypeV2 = __nv_bfloat162;
214
+
215
+ CUTLASS_DEVICE
216
+ Array<bfloat16_t, ElementsPerAccess> operator()(
217
+ Array<bfloat16_t, ElementsPerAccess> const &lhs,
218
+ Array<bfloat16_t, ElementsPerAccess> const &rhs) const {
219
+
220
+ Array<bfloat16_t, ElementsPerAccess> result;
221
+
222
+ #if __CUDA_ARCH__ >= 800
223
+ int const kVectorCount = ElementsPerAccess / 2;
224
+
225
+
226
+ NvTypeV2 const *lhs_ptr = reinterpret_cast<NvTypeV2 const *>(lhs.raw_data());
227
+ NvTypeV2 const *rhs_ptr = reinterpret_cast<NvTypeV2 const *>(rhs.raw_data());
228
+ NvTypeV2 *res_ptr = reinterpret_cast<NvTypeV2 *>(result.raw_data());
229
+
230
+ CUTLASS_PRAGMA_UNROLL
231
+ for (int i = 0; i < kVectorCount; ++i) {
232
+ res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
233
+ }
234
+
235
+ #else
236
+ NvType const *lhs_ptr = reinterpret_cast<NvType const *>(lhs.raw_data());
237
+ NvType const *rhs_ptr = reinterpret_cast<NvType const *>(rhs.raw_data());
238
+ NvType *res_ptr = reinterpret_cast<NvType *>(result.raw_data());
239
+
240
+ CUTLASS_PRAGMA_UNROLL
241
+ for (int i = 0; i < ElementsPerAccess; ++i) {
242
+ res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]);
243
+ }
244
+
245
+ #endif
246
+
247
+ return result;
248
+ }
249
+
250
+ CUTLASS_DEVICE
251
+ Array<bfloat16_t, ElementsPerAccess> operator()(
252
+ Array<bfloat16_t, ElementsPerAccess> const &lhs,
253
+ bfloat16_t rhs) const {
254
+
255
+ Array<bfloat16_t, ElementsPerAccess> result;
256
+
257
+ #if __CUDA_ARCH__ >= 800
258
+ int const kVectorCount = ElementsPerAccess / 2;
259
+
260
+
261
+ NvType rhs_raw = reinterpret_cast<NvType const &>(rhs);
262
+ NvTypeV2 rhs_pair = __bfloat162bfloat162(rhs_raw);
263
+
264
+ NvTypeV2 const *lhs_ptr = reinterpret_cast<NvTypeV2 const *>(lhs.raw_data());
265
+ NvTypeV2 *res_ptr = reinterpret_cast<NvTypeV2 *>(result.raw_data());
266
+
267
+ CUTLASS_PRAGMA_UNROLL
268
+ for (int i = 0; i < kVectorCount; ++i) {
269
+ res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
270
+ }
271
+
272
+ static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length.");
273
+
274
+ #else
275
+
276
+ NvType const *lhs_ptr = reinterpret_cast<NvType const *>(lhs.raw_data());
277
+ NvType const rhs_raw = reinterpret_cast<NvType const &>(rhs);
278
+ NvType *res_ptr = reinterpret_cast<NvType *>(result.raw_data());
279
+
280
+ CUTLASS_PRAGMA_UNROLL
281
+ for (int i = 0; i < ElementsPerAccess; ++i) {
282
+ res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]);
283
+ }
284
+
285
+ #endif
286
+
287
+ return result;
288
+ }
289
+ };
290
+
291
+
292
+ /////////////////////////////////////////////////////////////////////////////////////////////////
293
+
294
+ template <typename Element, int ElementsPerAccess>
295
+ struct ReluConditional {
296
+
297
+ CUTLASS_HOST_DEVICE
298
+ void operator()(
299
+ bool conditional[],
300
+ Array<Element, ElementsPerAccess> const &fragment,
301
+ Element threshold) const {
302
+
303
+ CUTLASS_PRAGMA_UNROLL
304
+ for (int i = 0; i < ElementsPerAccess; ++i) {
305
+ conditional[i] = !(fragment[i] < threshold);
306
+ }
307
+ }
308
+ };
309
+
310
+ template <int ElementsPerAccess>
311
+ struct ReluConditional<half_t, ElementsPerAccess> {
312
+
313
+ CUTLASS_DEVICE
314
+ void operator()(
315
+ bool conditional[],
316
+ Array<half_t, ElementsPerAccess> const &fragment,
317
+ half_t threshold) const {
318
+
319
+ __half y = reinterpret_cast<__half const &>(threshold);
320
+ __half const *x = reinterpret_cast<__half const *>(fragment.raw_data());
321
+
322
+ CUTLASS_PRAGMA_UNROLL
323
+ for (int i = 0; i < ElementsPerAccess; ++i) {
324
+ conditional[i] = !__hlt(x[i], y);
325
+ }
326
+ }
327
+ };
328
+
329
+ template <int ElementsPerAccess>
330
+ struct ReluConditional<bfloat16_t, ElementsPerAccess> {
331
+
332
+ CUTLASS_DEVICE
333
+ void operator()(
334
+ bool conditional[],
335
+ Array<bfloat16_t, ElementsPerAccess> const &fragment,
336
+ bfloat16_t threshold) const {
337
+
338
+ __nv_bfloat16 y = reinterpret_cast<__nv_bfloat16 const &>(threshold);
339
+ __nv_bfloat16 const *x = reinterpret_cast<__nv_bfloat16 const *>(fragment.raw_data());
340
+
341
+ CUTLASS_PRAGMA_UNROLL
342
+ for (int i = 0; i < ElementsPerAccess; ++i) {
343
+ conditional[i] = !__hlt(x[i], y);
344
+ }
345
+ }
346
+ };
347
+
348
+ } // namespace detail
349
+
350
+ /////////////////////////////////////////////////////////////////////////////////////////////////
351
+
352
+ /// This is a partial specialization for fused Bias and ReLU. It supports the option of packing
353
+ /// ReLU conditionals in a bit vector that may be used by backwards passes as an optimization.
354
+ ///
355
+ /// This class can only be used with cutlass::epilogue::threadblock::EpilogueWithBroadcast<>.
356
+ ///
357
+ /// This base class is meant to define the concept required of the
358
+ /// EpilogueWithBroadcast::OutputOp
359
+ template <
360
+ typename ElementC_,
361
+ typename ElementAccumulator_,
362
+ typename ElementCompute_,
363
+ typename ElementZ_,
364
+ int ElementsPerAccess,
365
+ bool StoreT_ = true,
366
+ typename ElementVector_ = ElementC_
367
+ >
368
+ class LinearCombinationBiasRelu {
369
+ public:
370
+
371
+ using ElementOutput = ElementC_;
372
+ using ElementC = ElementC_;
373
+ using ElementAccumulator = ElementAccumulator_;
374
+ using ElementCompute = ElementCompute_;
375
+ using ElementZ = ElementZ_;
376
+ using ElementVector = ElementVector_;
377
+
378
+ using ElementT = uint1b_t;
379
+
380
+ static int const kElementsPerAccess = ElementsPerAccess;
381
+ static int const kCount = kElementsPerAccess;
382
+
383
+ using ElementwiseOp = ReLu<ElementCompute>;
384
+ using BinaryOp = plus<ElementCompute>;
385
+
386
+ // Indicates that this epilogue applies only one binary operation
387
+ static bool const kIsSingleSource = true;
388
+
389
+ using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
390
+ using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
391
+ using FragmentC = Array<ElementOutput, kElementsPerAccess>;
392
+ using FragmentZ = Array<ElementZ, kElementsPerAccess>;
393
+ using FragmentT = Array<ElementT, kElementsPerAccess>;
394
+
395
+ /// If true, the 'Z' tensor is stored
396
+ static bool const kStoreZ = true;
397
+
398
+ /// If true, the 'T' tensor is stored
399
+ static bool const kStoreT = StoreT_;
400
+
401
+ /// Host-constructable parameters structure
402
+ struct Params {
403
+
404
+ ElementCompute alpha; ///< scales accumulators
405
+ ElementCompute beta; ///< scales source tensor
406
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
407
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
408
+ ElementZ threshold; ///< ReLu threshold
409
+
410
+ //
411
+ // Methods
412
+ //
413
+ //
414
+ // Methods
415
+ //
416
+
417
+ CUTLASS_HOST_DEVICE
418
+ Params():
419
+ alpha(ElementCompute(1)),
420
+ beta(ElementCompute()),
421
+ alpha_ptr(nullptr),
422
+ beta_ptr(nullptr),
423
+ threshold(ElementCompute()) { }
424
+
425
+ CUTLASS_HOST_DEVICE
426
+ Params(
427
+ ElementCompute alpha,
428
+ ElementCompute beta,
429
+ ElementCompute threshold_ = ElementCompute()
430
+ ):
431
+ alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
432
+
433
+ NumericConverter<ElementZ, ElementCompute> convert_threshold;
434
+
435
+ threshold = convert_threshold(threshold_);
436
+ }
437
+
438
+ CUTLASS_HOST_DEVICE
439
+ Params(
440
+ ElementCompute alpha
441
+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(ElementZ()) {
442
+
443
+ }
444
+
445
+ CUTLASS_HOST_DEVICE
446
+ Params(
447
+ ElementCompute const *alpha_ptr,
448
+ ElementCompute const *beta_ptr,
449
+ ElementCompute threshold_ = ElementCompute()
450
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
451
+
452
+ NumericConverter<ElementZ, ElementCompute> convert_threshold;
453
+
454
+ threshold = convert_threshold(threshold_);
455
+ }
456
+
457
+ CUTLASS_HOST_DEVICE
458
+ Params(
459
+ ElementCompute const *alpha_ptr
460
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr), threshold(ElementZ()) {
461
+ }
462
+
463
+ };
464
+
465
+ private:
466
+
467
+ //
468
+ // Data members
469
+ //
470
+
471
+ ElementCompute alpha_;
472
+ ElementCompute beta_;
473
+ ElementZ threshold_;
474
+
475
+ public:
476
+
477
+ //
478
+ // Methods
479
+ //
480
+
481
+ /// Constructor from Params
482
+ CUTLASS_HOST_DEVICE
483
+ LinearCombinationBiasRelu(Params const &params) {
484
+
485
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
486
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
487
+ threshold_ = params.threshold;
488
+ }
489
+
490
+ /// Returns true if source is needed
491
+ CUTLASS_HOST_DEVICE
492
+ bool is_source_needed() const {
493
+ return beta_ != ElementCompute(0);
494
+ }
495
+
496
+ /// Functionally required for serial reduction in the epilogue
497
+ CUTLASS_HOST_DEVICE
498
+ void set_k_partition(int k_partition, int k_partition_count) {
499
+ if (k_partition) {
500
+ beta_ = ElementCompute(1);
501
+ }
502
+
503
+ if (k_partition != k_partition_count - 1) {
504
+ // set to NaN to make ReLU no-op for all except last k partitions
505
+ int64_t allones = -1;
506
+ threshold_ = reinterpret_cast<ElementZ const &>(allones);
507
+ }
508
+ }
509
+
510
+ /// Applies the operation when is_source_needed() is true
511
+ CUTLASS_HOST_DEVICE
512
+ void operator()(
513
+ FragmentZ &frag_Z,
514
+ FragmentT &frag_T,
515
+ FragmentAccumulator const &AB,
516
+ FragmentC const &frag_C,
517
+ FragmentCompute const &V) const {
518
+
519
+ BinaryOp binary_op;
520
+
521
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
522
+ FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
523
+ FragmentCompute result_Z;
524
+
525
+ bool conditions[kElementsPerAccess];
526
+
527
+ CUTLASS_PRAGMA_UNROLL
528
+ for (int i = 0; i < kElementsPerAccess; ++i) {
529
+
530
+ ElementCompute z = alpha_ * tmp_Accum[i];
531
+ z += beta_ * tmp_C[i];
532
+
533
+ z = binary_op(z, V[i]);
534
+ result_Z[i] = z;
535
+ }
536
+
537
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
538
+ frag_Z = convert_z(result_Z);
539
+
540
+ //
541
+ // Compute condition
542
+ //
543
+
544
+ detail::ReluConditional<ElementZ, kElementsPerAccess> relu_conditional;
545
+ relu_conditional(conditions, frag_Z, threshold_);
546
+
547
+ detail::ArrayMaximum<ElementZ, kElementsPerAccess> maximum_op;
548
+ frag_Z = maximum_op(frag_Z, threshold_);
549
+
550
+ if (kStoreT) {
551
+ PackPredicates<kElementsPerAccess> pack_predicates;
552
+ frag_T = pack_predicates(conditions);
553
+ }
554
+ }
555
+
556
+ /// Applies the operation when is_source_needed() is false
557
+ CUTLASS_HOST_DEVICE
558
+ void operator()(
559
+ FragmentZ &frag_Z,
560
+ FragmentT &frag_T,
561
+ FragmentAccumulator const &AB,
562
+ FragmentCompute const &V) const {
563
+
564
+ BinaryOp binary_op;
565
+
566
+ FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
567
+ FragmentCompute result_Z;
568
+
569
+ bool conditions[kElementsPerAccess];
570
+
571
+ CUTLASS_PRAGMA_UNROLL
572
+ for (int i = 0; i < kElementsPerAccess; ++i) {
573
+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
574
+ result_Z[i] = z;
575
+ }
576
+
577
+ NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
578
+ frag_Z = convert_z(result_Z);
579
+
580
+ //
581
+ // Compute condition
582
+ //
583
+
584
+ detail::ReluConditional<ElementZ, kElementsPerAccess> relu_conditional;
585
+ relu_conditional(conditions, frag_Z, threshold_);
586
+
587
+ detail::ArrayMaximum<ElementZ, kElementsPerAccess> maximum_op;
588
+ frag_Z = maximum_op(frag_Z, threshold_);
589
+
590
+ //
591
+ // Compute conditions
592
+ //
593
+
594
+ //
595
+ // Store
596
+ //
597
+ if (kStoreT) {
598
+ PackPredicates<kElementsPerAccess> pack_predicates;
599
+ frag_T = pack_predicates(conditions);
600
+ }
601
+ }
602
+ };
603
+
604
+ /////////////////////////////////////////////////////////////////////////////////////////////////
605
+
606
+ } // namespace thread
607
+ } // namespace epilogue
608
+ } // namespace cutlass
609
+
610
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear scaling operations used by epilogues. Values are clamped before
33
+ converting to the output element type.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/epilogue/thread/scale_type.h"
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ namespace cutlass {
48
+ namespace epilogue {
49
+ namespace thread {
50
+
51
+ /////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ namespace detail {
54
+
55
+ /// Single source of truth for whether to unroll for `LinearCombinationClamp()`
56
+ constexpr bool LinearCombinationClampIsHeavy() {
57
+ return false;
58
+ }
59
+
60
+ }
61
+
62
+ /////////////////////////////////////////////////////////////////////////////////////////////////
63
+
64
+ /// Applies a linear combination operator to an array of elements then clamps the output before
65
+ /// converting to the output element type.
66
+ ///
67
+ /// D = alpha * accumulator + beta * source + uniform
68
+ ///
69
+ template <
70
+ typename ElementOutput_, ///< Data type used to load and store tensors
71
+ int Count, ///< Number of elements computed per operation
72
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
73
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
74
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
75
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
76
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
77
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
78
+ >
79
+ class LinearCombinationClamp {
80
+ public:
81
+
82
+ using ElementOutput = ElementOutput_;
83
+ using ElementAccumulator = ElementAccumulator_;
84
+ using ElementCompute = ElementCompute_;
85
+
86
+ static int const kCount = Count;
87
+
88
+ using FragmentOutput = Array<ElementOutput, kCount>;
89
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
90
+ using ComputeFragment = Array<ElementCompute, kCount>;
91
+ using FragmentSource = Array<ElementOutput, kCount>;
92
+
93
+ static FloatRoundStyle const kRound = Round;
94
+
95
+ static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy();
96
+
97
+ /// Host-constructable parameters structure
98
+ struct Params {
99
+
100
+ ElementCompute alpha; ///< scales accumulators
101
+ ElementCompute beta; ///< scales source tensor
102
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
103
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
104
+
105
+ //
106
+ // Methods
107
+ //
108
+
109
+ CUTLASS_HOST_DEVICE
110
+ Params():
111
+ alpha(ElementCompute(1)),
112
+ beta(ElementCompute(0)),
113
+ alpha_ptr(nullptr),
114
+ beta_ptr(nullptr) { }
115
+
116
+ CUTLASS_HOST_DEVICE
117
+ Params(
118
+ ElementCompute alpha,
119
+ ElementCompute beta
120
+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
121
+
122
+ }
123
+
124
+ CUTLASS_HOST_DEVICE
125
+ Params(
126
+ ElementCompute alpha
127
+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
128
+
129
+ }
130
+
131
+ CUTLASS_HOST_DEVICE
132
+ Params(
133
+ ElementCompute const *alpha_ptr,
134
+ ElementCompute const *beta_ptr
135
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
136
+
137
+ }
138
+
139
+ CUTLASS_HOST_DEVICE
140
+ Params(
141
+ ElementCompute const *alpha_ptr
142
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
143
+
144
+ }
145
+ };
146
+
147
+ private:
148
+
149
+ //
150
+ // Data members
151
+ //
152
+
153
+ ElementCompute alpha_;
154
+ ElementCompute beta_;
155
+
156
+ public:
157
+
158
+ /// Constructs the function object, possibly loading from pointers in host memory
159
+ CUTLASS_HOST_DEVICE
160
+ LinearCombinationClamp(Params const &params) {
161
+
162
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
163
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
164
+ }
165
+
166
+ /// Returns true if source is needed
167
+ CUTLASS_HOST_DEVICE
168
+ bool is_source_needed() const {
169
+ if (Scale == ScaleType::NoBetaScaling) return true;
170
+
171
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
172
+
173
+ if (Scale == ScaleType::Nothing) return false;
174
+
175
+ return beta_ != ElementCompute(0);
176
+ }
177
+
178
+ /// Functionally required for serial reduction in the epilogue
179
+ CUTLASS_HOST_DEVICE
180
+ void set_k_partition(int k_partition, int k_partition_count) {
181
+ if (k_partition) {
182
+ beta_ = ElementCompute(1);
183
+ }
184
+ }
185
+
186
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
187
+ CUTLASS_HOST_DEVICE
188
+ FragmentOutput operator()(
189
+ FragmentAccumulator const &accumulator,
190
+ FragmentOutput const &source,
191
+ ElementCompute uniform = ElementCompute(0)) const {
192
+
193
+ // Convert source to interal compute numeric type
194
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
195
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
196
+
197
+ ComputeFragment converted_source = source_converter(source);
198
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
199
+
200
+ // Perform binary operations
201
+
202
+ ComputeFragment intermediate;
203
+
204
+ multiplies<ComputeFragment> mul_add_source;
205
+ multiply_add<ComputeFragment> mul_add_accumulator;
206
+
207
+ minimum<ComputeFragment> min_accumulator;
208
+ maximum<ComputeFragment> max_accumulator;
209
+
210
+ if (Scale == ScaleType::NoBetaScaling) {
211
+ intermediate = converted_source;
212
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
213
+ } else if (Scale == ScaleType::Nothing) {
214
+ intermediate = converted_accumulator;
215
+ } else {
216
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
217
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
218
+ }
219
+
220
+ /// Clamping constant value
221
+ ElementCompute const kClampMax =
222
+ ElementCompute(cutlass::platform::numeric_limits<ElementOutput>::max());
223
+
224
+ ElementCompute const kClampMin =
225
+ ElementCompute(cutlass::platform::numeric_limits<ElementOutput>::lowest());
226
+
227
+ intermediate = max_accumulator(intermediate, kClampMin);
228
+ intermediate = min_accumulator(intermediate, kClampMax);
229
+
230
+ // Convert to destination numeric type
231
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
232
+
233
+ return destination_converter(intermediate);
234
+ }
235
+
236
+ /// Computes linear scaling: D = alpha * accumulator
237
+ CUTLASS_HOST_DEVICE
238
+ FragmentOutput operator()(
239
+ FragmentAccumulator const &accumulator) const {
240
+
241
+ // Convert source to interal compute numeric type
242
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
243
+
244
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
245
+
246
+ // Perform binary operations
247
+
248
+ ComputeFragment intermediate;
249
+
250
+ multiplies<ComputeFragment> mul_accumulator;
251
+
252
+ minimum<ComputeFragment> min_accumulator;
253
+ maximum<ComputeFragment> max_accumulator;
254
+
255
+ if (Scale == ScaleType::Nothing) {
256
+ intermediate = converted_accumulator;
257
+ } else {
258
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
259
+ }
260
+
261
+ /// Clamping constant value
262
+ ElementCompute const kClampMax =
263
+ ElementCompute(cutlass::platform::numeric_limits<ElementOutput>::max());
264
+
265
+ ElementCompute const kClampMin =
266
+ ElementCompute(cutlass::platform::numeric_limits<ElementOutput>::lowest());
267
+
268
+ intermediate = max_accumulator(intermediate, kClampMin);
269
+ intermediate = min_accumulator(intermediate, kClampMax);
270
+
271
+ // Convert to destination numeric type
272
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
273
+
274
+ return destination_converter(intermediate);
275
+ }
276
+ };
277
+
278
+ /////////////////////////////////////////////////////////////////////////////////////////////////
279
+
280
+ // Conditional guards to enable partial specialization for packed integers
281
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
282
+
283
+ /// Applies a linear combination operator to an array of elements then clamps the output before
284
+ /// converting to the output element type.
285
+ ///
286
+ /// D = alpha * accumulator + beta * source + uniform
287
+ ///
288
+ template <
289
+ typename ElementOutput_, ///< Data type used to load and store tensors
290
+ int Count, ///< Number of elements computed per operation
291
+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
292
+ FloatRoundStyle Round
293
+ >
294
+ class LinearCombinationClamp<ElementOutput_, Count, int, float, Scale, Round> {
295
+ public:
296
+
297
+ using ElementOutput = ElementOutput_;
298
+ using ElementAccumulator = int;
299
+ using ElementCompute = float;
300
+
301
+ static_assert(
302
+ cutlass::platform::numeric_limits<ElementOutput>::is_integer,
303
+ "This elementwise op expects the output to be int.");
304
+
305
+ static int const kCount = Count;
306
+
307
+ using FragmentOutput = Array<ElementOutput, kCount>;
308
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
309
+ using ComputeFragment = Array<ElementCompute, kCount>;
310
+
311
+ static FloatRoundStyle const kRound = Round;
312
+
313
+ static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy();
314
+
315
+ /// Host-constructable parameters structure
316
+ struct Params {
317
+
318
+ ElementCompute alpha; ///< scales accumulators
319
+ ElementCompute beta; ///< scales source tensor
320
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
321
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
322
+
323
+ //
324
+ // Methods
325
+ //
326
+
327
+ CUTLASS_HOST_DEVICE
328
+ Params():
329
+ alpha(ElementCompute(1)),
330
+ beta(ElementCompute(0)),
331
+ alpha_ptr(nullptr),
332
+ beta_ptr(nullptr) { }
333
+
334
+ CUTLASS_HOST_DEVICE
335
+ Params(
336
+ ElementCompute alpha,
337
+ ElementCompute beta
338
+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
339
+
340
+ }
341
+
342
+ CUTLASS_HOST_DEVICE
343
+ Params(
344
+ ElementCompute alpha
345
+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
346
+
347
+ }
348
+
349
+ CUTLASS_HOST_DEVICE
350
+ Params(
351
+ ElementCompute const *alpha_ptr,
352
+ ElementCompute const *beta_ptr
353
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
354
+
355
+ }
356
+
357
+ CUTLASS_HOST_DEVICE
358
+ Params(
359
+ ElementCompute const *alpha_ptr
360
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
361
+
362
+ }
363
+ };
364
+
365
+ private:
366
+
367
+ //
368
+ // Data members
369
+ //
370
+
371
+ ElementCompute alpha_;
372
+ ElementCompute beta_;
373
+
374
+ public:
375
+
376
+ /// Constructs the function object, possibly loading from pointers in host memory
377
+ CUTLASS_HOST_DEVICE
378
+ LinearCombinationClamp(Params const &params) {
379
+
380
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
381
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
382
+ }
383
+
384
+ /// Returns true if source is needed
385
+ CUTLASS_HOST_DEVICE
386
+ bool is_source_needed() const {
387
+ if (Scale == ScaleType::NoBetaScaling) return true;
388
+
389
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
390
+
391
+ if (Scale == ScaleType::Nothing) return false;
392
+
393
+ return beta_ != ElementCompute(0);
394
+ }
395
+
396
+ /// Functionally required for serial reduction in the epilogue
397
+ CUTLASS_HOST_DEVICE
398
+ void set_k_partition(int k_partition, int k_partition_count) {
399
+ if (k_partition) {
400
+ beta_ = ElementCompute(1);
401
+ }
402
+ }
403
+
404
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
405
+ CUTLASS_HOST_DEVICE
406
+ FragmentOutput operator()(
407
+ FragmentAccumulator const &accumulator,
408
+ FragmentOutput const &source,
409
+ ElementCompute uniform = ElementCompute(0)) const {
410
+
411
+ // Convert source to interal compute numeric type
412
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
413
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
414
+
415
+ ComputeFragment converted_source = source_converter(source);
416
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
417
+
418
+ // Compute linear scaling in floating point
419
+ ComputeFragment intermediate;
420
+
421
+ multiplies<ComputeFragment> mul_add_source;
422
+ multiply_add<ComputeFragment> mul_add_accumulator;
423
+
424
+ // Float min-max
425
+ if (Scale == ScaleType::NoBetaScaling) {
426
+ intermediate = converted_source;
427
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
428
+ } else if (Scale == ScaleType::Nothing) {
429
+ intermediate = converted_accumulator;
430
+ } else {
431
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
432
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
433
+ }
434
+
435
+ //
436
+ // Convert float => ElementOutput_ with clamping
437
+ //
438
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
439
+
440
+ return destination_converter(intermediate);
441
+ }
442
+
443
+ /// Computes linear scaling: D = alpha * accumulator
444
+ CUTLASS_HOST_DEVICE
445
+ FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
446
+
447
+ // Convert source to interal compute numeric type
448
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
449
+
450
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
451
+
452
+ // Compute linear scaling in floating point
453
+ ComputeFragment intermediate;
454
+
455
+ multiplies<ComputeFragment> mul_add_accumulator;
456
+
457
+ // Float min-max
458
+ if (Scale == ScaleType::Nothing) {
459
+ intermediate = converted_accumulator;
460
+ } else {
461
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
462
+ }
463
+
464
+ //
465
+ // Convert float => ElementOutput_ with clamping
466
+ //
467
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
468
+
469
+ return destination_converter(intermediate);
470
+ }
471
+ };
472
+
473
+ #endif // Conditional guards to enable partial specialization for packed integers
474
+
475
+ ////////////////////////////////////////////////////////////////////////////////
476
+
477
+ /// Applies a linear combination operator to an array of elements then clamps
478
+ /// the output before converting to the output element type.
479
+ ///
480
+ /// D = alpha * accumulator + beta * source + uniform
481
+ ///
482
+ /// Note: The below method only when problem_size_K <= 256 for signed int8 gemm
483
+ /// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
484
+ /// above.
485
+ template <
486
+ /// Data type used to load and store< tensors
487
+ typename ElementOutput_,
488
+ /// Number of elements computed per operation
489
+ int Count,
490
+ ///< Control Alpha and Beta scaling
491
+ ScaleType::Kind Scale = ScaleType::Default,
492
+ /// Rounding mode
493
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
494
+ class FastLinearCombinationClamp {
495
+ public:
496
+ using ElementOutput = ElementOutput_;
497
+ using ElementAccumulator = int;
498
+ using ElementCompute = float;
499
+
500
+ static_assert(
501
+ cutlass::platform::numeric_limits<ElementOutput>::is_integer,
502
+ "This elementwise op expects the output to be int.");
503
+
504
+ static int const kCount = Count;
505
+
506
+ using FragmentOutput = Array<ElementOutput, kCount>;
507
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
508
+ using ComputeFragment = Array<ElementCompute, kCount>;
509
+
510
+ static FloatRoundStyle const kRound = Round;
511
+
512
+ static bool const kIsHeavy = false;
513
+
514
+ /// Host-constructable parameters structure
515
+ struct Params {
516
+ /// scales accumulators
517
+ ElementCompute alpha;
518
+ /// scales source tensor
519
+ ElementCompute beta;
520
+ /// pointer to accumulator scalar - if not null, loads it from memory
521
+ ElementCompute const *alpha_ptr;
522
+ /// pointer to source scalar - if not null, loads it from memory
523
+ ElementCompute const *beta_ptr;
524
+
525
+ //
526
+ // Methods
527
+ //
528
+
529
+ CUTLASS_HOST_DEVICE
530
+ Params()
531
+ : alpha(ElementCompute(1)),
532
+ beta(ElementCompute(0)),
533
+ alpha_ptr(nullptr),
534
+ beta_ptr(nullptr) {}
535
+
536
+ CUTLASS_HOST_DEVICE
537
+ Params(ElementCompute alpha, ElementCompute beta)
538
+ : alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {}
539
+
540
+ CUTLASS_HOST_DEVICE
541
+ Params(ElementCompute alpha)
542
+ : alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {}
543
+
544
+ CUTLASS_HOST_DEVICE
545
+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
546
+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
547
+
548
+ CUTLASS_HOST_DEVICE
549
+ Params(ElementCompute const *alpha_ptr)
550
+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {}
551
+ };
552
+
553
+ private:
554
+ //
555
+ // Data members
556
+ //
557
+
558
+ ElementCompute alpha_;
559
+ ElementCompute beta_;
560
+
561
+ public:
562
+ /// Constructs the function object, possibly loading from pointers in host
563
+ /// memory
564
+ CUTLASS_HOST_DEVICE
565
+ FastLinearCombinationClamp(Params const &params) {
566
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
567
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
568
+ }
569
+
570
+ /// Returns true if source is needed
571
+ CUTLASS_HOST_DEVICE
572
+ bool is_source_needed() const {
573
+ if (Scale == ScaleType::NoBetaScaling) return true;
574
+
575
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
576
+
577
+ if (Scale == ScaleType::Nothing) return false;
578
+
579
+ return beta_ != ElementCompute(0);
580
+ }
581
+
582
+ /// Functionally required for serial reduction in the epilogue
583
+ CUTLASS_HOST_DEVICE
584
+ void set_k_partition(int k_partition, int k_partition_count) {
585
+ if (k_partition) {
586
+ beta_ = ElementCompute(1);
587
+ }
588
+ }
589
+
590
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
591
+ CUTLASS_HOST_DEVICE
592
+ FragmentOutput operator()(FragmentAccumulator const &accumulator,
593
+ FragmentOutput const &source,
594
+ ElementCompute uniform = ElementCompute(0)) const {
595
+ // Convert source to interal compute numeric type
596
+ FastNumericArrayConverter<ElementCompute, ElementOutput, kCount, Round>
597
+ source_converter;
598
+ FastNumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
599
+ accumulator_converter;
600
+
601
+ ComputeFragment converted_source = source_converter(source);
602
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
603
+
604
+ // Compute linear scaling in floating point
605
+ ComputeFragment intermediate;
606
+
607
+ multiplies<ComputeFragment> mul_add_source;
608
+ multiply_add<ComputeFragment> mul_add_accumulator;
609
+
610
+ minimum<ComputeFragment> min_accumulator;
611
+ maximum<ComputeFragment> max_accumulator;
612
+
613
+ // Float min-max
614
+ if (Scale == ScaleType::NoBetaScaling) {
615
+ intermediate = converted_source;
616
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
617
+ } else if (Scale == ScaleType::Nothing) {
618
+ intermediate = converted_accumulator;
619
+ } else {
620
+ intermediate =
621
+ mul_add_source(beta_, converted_source); // X = beta * C + uniform
622
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator,
623
+ intermediate); // D = alpha * Accum + X
624
+ }
625
+
626
+ /// Clamping constant value
627
+ ElementCompute const kClamp =
628
+ ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
629
+
630
+ intermediate = max_accumulator(intermediate, -kClamp);
631
+ intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
632
+
633
+ // Convert to destination numeric type
634
+ FastNumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
635
+ destination_converter;
636
+
637
+ return destination_converter(intermediate);
638
+ }
639
+
640
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
641
+ CUTLASS_HOST_DEVICE
642
+ FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
643
+
644
+ // Convert source to interal compute numeric type
645
+ FastNumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
646
+ accumulator_converter;
647
+
648
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
649
+
650
+ // Compute linear scaling in floating point
651
+ ComputeFragment intermediate;
652
+
653
+ multiplies<ComputeFragment> mul_accumulator;
654
+
655
+ minimum<ComputeFragment> min_accumulator;
656
+ maximum<ComputeFragment> max_accumulator;
657
+
658
+ // Float min-max
659
+ if (Scale == ScaleType::Nothing) {
660
+ intermediate = converted_accumulator;
661
+ } else {
662
+ intermediate = mul_accumulator(alpha_, converted_accumulator);
663
+ }
664
+
665
+ /// Clamping constant value
666
+ ElementCompute const kClamp =
667
+ ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
668
+
669
+ intermediate = max_accumulator(intermediate, -kClamp);
670
+ intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
671
+
672
+ // Convert to destination numeric type
673
+ FastNumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
674
+ destination_converter;
675
+
676
+ return destination_converter(intermediate);
677
+ }
678
+ };
679
+
680
+ ////////////////////////////////////////////////////////////////////////////////
681
+
682
+ } // namespace thread
683
+ } // namespace epilogue
684
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+
33
+ \brief Functor performing linear combination followed by dGelu operation
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/half.h"
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/array.h"
42
+ #include "cutlass/constants.h"
43
+ #include "cutlass/fast_math.h"
44
+ #include "cutlass/functional.h"
45
+ #include "cutlass/numeric_conversion.h"
46
+ #include "cutlass/epilogue/thread/activation.h"
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace epilogue {
52
+ namespace thread {
53
+
54
+ /////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ /// Applies a linear combination operator to an array of elements.
57
+ ///
58
+ /// D = alpha * accumulator + beta * source + uniform
59
+ ///
60
+ template <
61
+ typename ElementCompute_, ///< Data type returned by this functor
62
+ typename ElementAccumulator_, ///< Data type of accumulators
63
+ typename ElementSource_, ///< Data type of source tensor
64
+ typename ElementTensor_, ///< Data type of additional tensor
65
+ int Count, ///< Number of elements computed per operation
66
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
67
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
68
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
69
+ >
70
+ class LinearCombinationDGelu {
71
+ public:
72
+
73
+ using ElementOutput = ElementSource_;
74
+ using ElementCompute = ElementCompute_;
75
+ using ElementAccumulator = ElementAccumulator_;
76
+ using ElementSource = ElementSource_;
77
+ using ElementTensor = ElementTensor_;
78
+
79
+ static bool const kIsHeavy = true;
80
+
81
+ static int const kCount = Count;
82
+
83
+ using FragmentCompute = Array<ElementCompute, kCount>;
84
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
85
+ using FragmentSource = Array<ElementSource, kCount>;
86
+ using FragmentTensor = Array<ElementTensor, kCount>;
87
+
88
+ static FloatRoundStyle const kRound = Round;
89
+
90
+ /// Host-constructable parameters structure
91
+ struct Params {
92
+
93
+ ElementCompute alpha; ///< scales accumulators
94
+ ElementCompute beta; ///< scales source tensor
95
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
96
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
97
+ ElementCompute threshold; ///< minimum value that is output
98
+ //
99
+ // Methods
100
+ //
101
+
102
+ CUTLASS_HOST_DEVICE
103
+ Params():
104
+ alpha(ElementCompute(1)),
105
+ beta(ElementCompute(0)),
106
+ threshold(ElementCompute(0)),
107
+ alpha_ptr(nullptr),
108
+ beta_ptr(nullptr) { }
109
+
110
+ CUTLASS_HOST_DEVICE
111
+ Params(
112
+ ElementCompute alpha,
113
+ ElementCompute beta,
114
+ ElementCompute threshold = ElementCompute(0)
115
+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
116
+
117
+ }
118
+
119
+ CUTLASS_HOST_DEVICE
120
+ Params(
121
+ ElementCompute const *alpha_ptr,
122
+ ElementCompute const *beta_ptr,
123
+ ElementCompute threshold = ElementCompute(0)
124
+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
125
+
126
+ }
127
+ };
128
+
129
+ private:
130
+
131
+ //
132
+ // Data members
133
+ //
134
+
135
+ ElementCompute alpha_;
136
+ ElementCompute beta_;
137
+ ElementCompute threshold_;
138
+ bool participates_in_reduction_;
139
+
140
+ public:
141
+
142
+ /// Constructs the function object, possibly loading from pointers in host memory
143
+ CUTLASS_HOST_DEVICE
144
+ LinearCombinationDGelu(Params const &params) {
145
+
146
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
147
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
148
+ threshold_ = params.threshold;
149
+ participates_in_reduction_ = true;
150
+ }
151
+
152
+ /// Returns true if source is needed
153
+ CUTLASS_HOST_DEVICE
154
+ bool is_source_needed() const {
155
+ return beta_ != ElementCompute(0);
156
+ }
157
+
158
+ /// Returns true if the threadblock computes the reduction
159
+ CUTLASS_HOST_DEVICE
160
+ bool participates_in_reduction() const {
161
+ return participates_in_reduction_;
162
+ }
163
+
164
+ /// Functionally required for serial reduction in the epilogue
165
+ CUTLASS_HOST_DEVICE
166
+ void set_k_partition(int k_partition, int k_partition_count) {
167
+ if (k_partition) {
168
+ beta_ = ElementCompute(1);
169
+ }
170
+
171
+ if (k_partition != k_partition_count - 1) {
172
+ // set to NaN to make ReLU no-op for all except last k partitions
173
+ int64_t allones = -1;
174
+ threshold_ = reinterpret_cast<ElementCompute const &>(allones);
175
+ // Avoid computing the reduction if this isn't the final Split-K slice
176
+ participates_in_reduction_ = false;
177
+ }
178
+ }
179
+
180
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
181
+ CUTLASS_HOST_DEVICE
182
+ FragmentCompute operator()(
183
+ FragmentAccumulator const &accumulator,
184
+ FragmentSource const &source,
185
+ FragmentTensor const &tensor) const {
186
+
187
+ // Convert source to interal compute numeric type
188
+ NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
189
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
190
+
191
+ FragmentCompute converted_source = source_converter(source);
192
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
193
+
194
+ // Perform binary operations
195
+ FragmentCompute intermediate;
196
+
197
+ multiplies<FragmentCompute> mul_add_source;
198
+ multiply_add<FragmentCompute> mul_add_accumulator;
199
+
200
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
201
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
202
+
203
+ dGELU<ElementCompute> gelu_op;
204
+
205
+ // dGelu
206
+ CUTLASS_PRAGMA_UNROLL
207
+ for (int i = 0; i < kCount; ++i) {
208
+ intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i]));
209
+ }
210
+
211
+ return intermediate;
212
+ }
213
+
214
+ /// Computes linear scaling: D = alpha * accumulator
215
+ CUTLASS_HOST_DEVICE
216
+ FragmentCompute operator()(
217
+ FragmentAccumulator const &accumulator,
218
+ FragmentTensor const &tensor) const {
219
+
220
+ // Convert source to interal compute numeric type
221
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
222
+
223
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
224
+
225
+ // Perform binary operations
226
+ FragmentCompute intermediate;
227
+
228
+ multiplies<FragmentCompute> mul_accumulator;
229
+
230
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
231
+
232
+ dGELU<ElementCompute> gelu_op;
233
+
234
+ // dGelu with conversion
235
+ CUTLASS_PRAGMA_UNROLL
236
+ for (int i = 0; i < kCount; ++i) {
237
+ intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i]));
238
+ }
239
+
240
+ return intermediate;
241
+ }
242
+ };
243
+
244
+ /////////////////////////////////////////////////////////////////////////////////////////////////
245
+
246
+ } // namespace thread
247
+ } // namespace epilogue
248
+ } // namespace cutlass
249
+
250
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination with a maximum operation used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/half.h"
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/epilogue/thread/activation.h"
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ namespace cutlass {
48
+ namespace epilogue {
49
+ namespace thread {
50
+
51
+ /////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ /// Applies a linear combination operator to an array of elements.
54
+ ///
55
+ /// D = alpha * accumulator + beta * source + uniform
56
+ ///
57
+ template <
58
+ typename ElementCompute_, ///< Data type returned by this functor
59
+ typename ElementAccumulator_, ///< Data type of accumulators
60
+ typename ElementSource_, ///< Data type of source tensor
61
+ typename ElementTensor_, ///< Data type of additional tensor
62
+ int Count, ///< Number of elements computed per operation
63
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
64
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
65
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
66
+ >
67
+ class LinearCombinationDRelu {
68
+ public:
69
+
70
+ using ElementOutput = ElementSource_;
71
+ using ElementCompute = ElementCompute_;
72
+ using ElementAccumulator = ElementAccumulator_;
73
+ using ElementSource = ElementSource_;
74
+ using ElementTensor = ElementTensor_;
75
+
76
+ static int const kCount = Count;
77
+
78
+ using FragmentCompute = Array<ElementCompute, kCount>;
79
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
80
+ using FragmentSource = Array<ElementSource, kCount>;
81
+ using FragmentTensor = Array<ElementTensor, kCount>;
82
+
83
+ static FloatRoundStyle const kRound = Round;
84
+
85
+ /// Host-constructable parameters structure
86
+ struct Params {
87
+
88
+ ElementCompute alpha; ///< scales accumulators
89
+ ElementCompute beta; ///< scales source tensor
90
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
91
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
92
+ ElementCompute threshold; ///< minimum value that is output
93
+ //
94
+ // Methods
95
+ //
96
+
97
+ CUTLASS_HOST_DEVICE
98
+ Params():
99
+ alpha(ElementCompute(1)),
100
+ beta(ElementCompute(0)),
101
+ threshold(ElementCompute(0)),
102
+ alpha_ptr(nullptr),
103
+ beta_ptr(nullptr) { }
104
+
105
+ CUTLASS_HOST_DEVICE
106
+ Params(
107
+ ElementCompute alpha,
108
+ ElementCompute beta,
109
+ ElementCompute threshold = ElementCompute(0)
110
+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
111
+
112
+ }
113
+
114
+ CUTLASS_HOST_DEVICE
115
+ Params(
116
+ ElementCompute const *alpha_ptr,
117
+ ElementCompute const *beta_ptr,
118
+ ElementCompute threshold = ElementCompute(0)
119
+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
120
+
121
+ }
122
+ };
123
+
124
+ private:
125
+
126
+ //
127
+ // Data members
128
+ //
129
+
130
+ ElementCompute alpha_;
131
+ ElementCompute beta_;
132
+ ElementTensor threshold_;
133
+ bool participates_in_reduction_;
134
+
135
+ public:
136
+
137
+ /// Constructs the function object, possibly loading from pointers in host memory
138
+ CUTLASS_HOST_DEVICE
139
+ LinearCombinationDRelu(Params const &params) {
140
+
141
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
142
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
143
+ threshold_ = ElementTensor(params.threshold);
144
+ participates_in_reduction_ = true;
145
+ }
146
+
147
+ /// Returns true if source is needed
148
+ CUTLASS_HOST_DEVICE
149
+ bool is_source_needed() const {
150
+ return beta_ != ElementCompute(0);
151
+ }
152
+
153
+ /// Returns true if the threadblock computes the reduction
154
+ CUTLASS_HOST_DEVICE
155
+ bool participates_in_reduction() const {
156
+ return participates_in_reduction_;
157
+ }
158
+
159
+ /// Functionally required for serial reduction in the epilogue
160
+ CUTLASS_DEVICE
161
+ void set_k_partition(int k_partition, int k_partition_count) {
162
+ if (k_partition) {
163
+ beta_ = ElementCompute(1);
164
+ }
165
+
166
+ if (k_partition != k_partition_count - 1) {
167
+ // set to NaN to make ReLU no-op for all except last k partitions
168
+ int64_t allones = -1;
169
+ threshold_ = reinterpret_cast<ElementTensor const &>(allones);
170
+ participates_in_reduction_ = false;
171
+ }
172
+ }
173
+
174
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
175
+ CUTLASS_HOST_DEVICE
176
+ FragmentCompute operator()(
177
+ FragmentAccumulator const &accumulator,
178
+ FragmentSource const &source,
179
+ FragmentTensor const &tensor) const {
180
+
181
+ // Convert source to interal compute numeric type
182
+ NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
183
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
184
+
185
+ FragmentCompute converted_source = source_converter(source);
186
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
187
+
188
+ // Perform binary operations
189
+ FragmentCompute intermediate;
190
+
191
+ multiplies<FragmentCompute> mul_add_source;
192
+ multiply_add<FragmentCompute> mul_add_accumulator;
193
+
194
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C
195
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
196
+
197
+ // dReLU = (cond ? dy : 0)
198
+ CUTLASS_PRAGMA_UNROLL
199
+ for (int i = 0; i < kCount; ++i) {
200
+ ElementTensor cond = tensor[i];
201
+ if (cond <= threshold_) {
202
+ intermediate[i] = ElementCompute();
203
+ }
204
+ }
205
+
206
+ return intermediate;
207
+ }
208
+
209
+ /// Computes linear scaling: D = alpha * accumulator
210
+ CUTLASS_HOST_DEVICE
211
+ FragmentCompute operator()(
212
+ FragmentAccumulator const &accumulator,
213
+ FragmentTensor const &tensor) const {
214
+
215
+ // Convert source to interal compute numeric type
216
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
217
+
218
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
219
+
220
+ // Perform binary operations
221
+ FragmentCompute intermediate;
222
+
223
+ multiplies<FragmentCompute> mul_accumulator;
224
+
225
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
226
+
227
+ // dReLU = (cond ? dy : 0)
228
+ CUTLASS_PRAGMA_UNROLL
229
+ for (int i = 0; i < kCount; ++i) {
230
+ ElementTensor cond = tensor[i];
231
+ if (cond <= threshold_) {
232
+ intermediate[i] = ElementCompute();
233
+ }
234
+ }
235
+
236
+ return intermediate;
237
+ }
238
+ };
239
+
240
+
241
+ /////////////////////////////////////////////////////////////////////////////////////////////////
242
+
243
+ /// Applies a linear combination operator to an array of elements.
244
+ ///
245
+ /// D = alpha * accumulator + beta * source + uniform
246
+ ///
247
+ template <
248
+ typename ElementCompute_, ///< Data type returned by this functor
249
+ typename ElementAccumulator_, ///< Data type of accumulators
250
+ typename ElementSource_, ///< Data type of source tensor
251
+ int Count, ///< Number of elements computed per operation
252
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
253
+ >
254
+ class LinearCombinationDReluConditionalBits {
255
+ public:
256
+
257
+ using ElementOutput = ElementSource_;
258
+ using ElementCompute = ElementCompute_;
259
+ using ElementAccumulator = ElementAccumulator_;
260
+ using ElementSource = ElementSource_;
261
+ using ElementTensor = uint1b_t;
262
+
263
+ static bool const kIsHeavy = false;
264
+
265
+ static int const kCount = Count;
266
+
267
+ using FragmentCompute = Array<ElementCompute, kCount>;
268
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
269
+ using FragmentSource = Array<ElementSource, kCount>;
270
+ using FragmentTensor = Array<ElementTensor, kCount>;
271
+
272
+ static FloatRoundStyle const kRound = Round;
273
+
274
+ /// Host-constructable parameters structure
275
+ struct Params {
276
+
277
+ ElementCompute alpha; ///< scales accumulators
278
+ ElementCompute beta; ///< scales source tensor
279
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
280
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
281
+ //
282
+ // Methods
283
+ //
284
+
285
+ CUTLASS_HOST_DEVICE
286
+ Params():
287
+ alpha(ElementCompute(1)),
288
+ beta(ElementCompute(0)),
289
+ alpha_ptr(nullptr),
290
+ beta_ptr(nullptr) { }
291
+
292
+ CUTLASS_HOST_DEVICE
293
+ Params(
294
+ ElementCompute alpha,
295
+ ElementCompute beta
296
+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
297
+
298
+ }
299
+
300
+ CUTLASS_HOST_DEVICE
301
+ Params(
302
+ ElementCompute const *alpha_ptr,
303
+ ElementCompute const *beta_ptr
304
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
305
+
306
+ }
307
+ };
308
+
309
+ private:
310
+
311
+ //
312
+ // Data members
313
+ //
314
+
315
+ ElementCompute alpha_;
316
+ ElementCompute beta_;
317
+ FragmentTensor predicate_mask_;
318
+ bool participates_in_reduction_;
319
+
320
+ public:
321
+
322
+ /// Constructs the function object, possibly loading from pointers in host memory
323
+ CUTLASS_HOST_DEVICE
324
+ LinearCombinationDReluConditionalBits(Params const &params) {
325
+
326
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
327
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
328
+ participates_in_reduction_ = true;
329
+ predicate_mask_.clear();
330
+ }
331
+
332
+ /// Returns true if source is needed
333
+ CUTLASS_HOST_DEVICE
334
+ bool is_source_needed() const {
335
+ return beta_ != ElementCompute(0);
336
+ }
337
+
338
+ /// Returns true if the threadblock computes the reduction
339
+ CUTLASS_HOST_DEVICE
340
+ bool participates_in_reduction() const {
341
+ return participates_in_reduction_;
342
+ }
343
+
344
+ /// Functionally required for serial reduction in the epilogue
345
+ CUTLASS_HOST_DEVICE
346
+ void set_k_partition(int k_partition, int k_partition_count) {
347
+ predicate_mask_.clear();
348
+
349
+ if (k_partition) {
350
+ beta_ = ElementCompute(1);
351
+ }
352
+
353
+ if (k_partition != k_partition_count - 1) {
354
+ // Avoid computing the reduction if this isn't the final Split-K slice
355
+ participates_in_reduction_ = false;
356
+
357
+ bit_not<FragmentTensor> not_op;
358
+ predicate_mask_ = not_op(predicate_mask_);
359
+ }
360
+ }
361
+
362
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
363
+ CUTLASS_DEVICE
364
+ FragmentCompute operator()(
365
+ FragmentAccumulator const &accumulator,
366
+ FragmentSource const &source,
367
+ FragmentTensor const &tensor) const {
368
+
369
+ // Convert source to interal compute numeric type
370
+ NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
371
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
372
+
373
+ FragmentCompute converted_source = source_converter(source);
374
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
375
+
376
+ // Perform binary operations
377
+ FragmentCompute intermediate;
378
+
379
+ multiplies<FragmentCompute> mul_add_source;
380
+ multiply_add<FragmentCompute> mul_add_accumulator;
381
+
382
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
383
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
384
+
385
+ bit_or<FragmentTensor> or_op;
386
+
387
+ FragmentTensor predicates = or_op(tensor, predicate_mask_);
388
+
389
+ // Obtain from packed bits
390
+ bool conditions[kCount];
391
+ UnpackPredicates<kCount> unpack_predicates;
392
+
393
+ unpack_predicates(conditions, predicates);
394
+
395
+ // dReLU = (cond ? dy : 0)
396
+ CUTLASS_PRAGMA_UNROLL
397
+ for (int i = 0; i < kCount; ++i) {
398
+ if (!conditions[i]) {
399
+ intermediate[i] = ElementCompute();
400
+ }
401
+ }
402
+
403
+ return intermediate;
404
+ }
405
+
406
+ /// Computes linear scaling: D = alpha * accumulator
407
+ CUTLASS_HOST_DEVICE
408
+ FragmentCompute operator()(
409
+ FragmentAccumulator const &accumulator,
410
+ FragmentTensor const &tensor) const {
411
+
412
+ // Convert source to interal compute numeric type
413
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
414
+
415
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
416
+
417
+ // Perform binary operations
418
+ FragmentCompute intermediate;
419
+
420
+ multiplies<FragmentCompute> mul_accumulator;
421
+
422
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
423
+
424
+ bit_or<FragmentTensor> or_op;
425
+
426
+ FragmentTensor predicates = or_op(tensor, predicate_mask_);
427
+
428
+ // Obtain from packed bits
429
+ bool conditions[kCount];
430
+ UnpackPredicates<kCount> unpack_predicates;
431
+
432
+ unpack_predicates(conditions, predicates);
433
+
434
+ // dReLU = (cond ? dy : 0)
435
+ CUTLASS_PRAGMA_UNROLL
436
+ for (int i = 0; i < kCount; ++i) {
437
+ if (!conditions[i]) {
438
+ intermediate[i] = ElementCompute();
439
+ }
440
+ }
441
+
442
+ return intermediate;
443
+ }
444
+ };
445
+
446
+ /////////////////////////////////////////////////////////////////////////////////////////////////
447
+
448
+ } // namespace thread
449
+ } // namespace epilogue
450
+ } // namespace cutlass
451
+
452
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination with GELU operations used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/epilogue/thread/activation.h"
39
+ #include "cutlass/epilogue/thread/linear_combination_generic.h"
40
+
41
+ /////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ namespace cutlass {
44
+ namespace epilogue {
45
+ namespace thread {
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ /// Applies a linear combination operator followed by the GELU activation to an array of elements.
50
+ ///
51
+ /// D = gelu(alpha * accumulator + beta * source + uniform)
52
+ ///
53
+ template <
54
+ typename ElementOutput_, ///< Data type used to load and store tensors
55
+ int Count, ///< Number of elements computed per operation
56
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
57
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
58
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
59
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
60
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
61
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
62
+ >
63
+ using LinearCombinationGELU = LinearCombinationGeneric<GELU, ElementOutput_, Count, ElementAccumulator_,
64
+ ElementCompute_, Scale, Round, true>;
65
+
66
+ /////////////////////////////////////////////////////////////////////////////////////////////////
67
+
68
+ } // namespace thread
69
+ } // namespace epilogue
70
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination operations used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/numeric_types.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/functional.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/epilogue/thread/scale_type.h"
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ namespace cutlass {
47
+ namespace epilogue {
48
+ namespace thread {
49
+
50
+ /////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ template <class Activation, class = void>
53
+ struct GenericActivationTraits {
54
+ static constexpr bool IsArgumentsNeeded = false;
55
+ struct Arguments {};
56
+ };
57
+
58
+ template <class Activation>
59
+ struct GenericActivationTraits<Activation, decltype(typename Activation::Arguments(), void())> {
60
+ static constexpr bool IsArgumentsNeeded = true;
61
+ using Arguments = typename Activation::Arguments;
62
+ };
63
+
64
+ template <typename T>
65
+ struct LinearCombinationGenericParams {
66
+ T alpha; ///< scales accumulators
67
+ T beta; ///< scales source tensor
68
+ T const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
69
+ T const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
70
+
71
+ //
72
+ // Methods
73
+ //
74
+
75
+ CUTLASS_HOST_DEVICE
76
+ LinearCombinationGenericParams():
77
+ alpha(T(1)),
78
+ beta(T(0)),
79
+ alpha_ptr(nullptr),
80
+ beta_ptr(nullptr) { }
81
+
82
+ CUTLASS_HOST_DEVICE
83
+ LinearCombinationGenericParams(
84
+ T alpha,
85
+ T beta = T(0)
86
+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { }
87
+
88
+ CUTLASS_HOST_DEVICE
89
+ LinearCombinationGenericParams(
90
+ T const *alpha_ptr,
91
+ T const *beta_ptr = nullptr
92
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { }
93
+ };
94
+
95
+ /////////////////////////////////////////////////////////////////////////////////////////////////
96
+
97
+ /// Applies a linear combination operator followed by an activation function to an array of elements.
98
+ ///
99
+ /// D = activation(alpha * accumulator + beta * source + uniform)
100
+ ///
101
+ template <
102
+ template<typename T> class ActivationFunctor,
103
+ typename ElementOutput_, ///< Data type used to load and store tensors
104
+ int Count, ///< Number of elements computed per operation
105
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
106
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
107
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
108
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
109
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
110
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
111
+ bool IsHeavy = false
112
+ >
113
+ class LinearCombinationGeneric {
114
+ public:
115
+
116
+ using ElementOutput = ElementOutput_;
117
+ using ElementAccumulator = ElementAccumulator_;
118
+ using ElementCompute = ElementCompute_;
119
+
120
+ static bool const kIsHeavy = IsHeavy;
121
+ static int const kCount = Count;
122
+ static const ScaleType::Kind kScale = Scale;
123
+
124
+ using FragmentOutput = Array<ElementOutput, kCount>;
125
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
126
+ using FragmentSource = Array<ElementOutput, kCount>;
127
+ using FragmentCompute = Array<ElementCompute, kCount>;
128
+
129
+ static FloatRoundStyle const kRound = Round;
130
+
131
+ /// Host-constructable parameters structure
132
+ struct Params
133
+ : LinearCombinationGenericParams<ElementCompute>,
134
+ GenericActivationTraits<ActivationFunctor<ElementCompute>>::Arguments {
135
+ using LinearCombinationGenericParams<ElementCompute>::LinearCombinationGenericParams;
136
+ };
137
+
138
+ private:
139
+
140
+ //
141
+ // Data members
142
+ //
143
+
144
+ Params params_;
145
+ bool skip_elementwise_;
146
+
147
+ public:
148
+
149
+ /// Constructs the function object, possibly loading from pointers in host memory
150
+ CUTLASS_HOST_DEVICE
151
+ LinearCombinationGeneric(Params const &params) {
152
+ params_ = params;
153
+ params_.alpha = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
154
+ params_.beta = (params.beta_ptr ? *params.beta_ptr : params.beta);
155
+ skip_elementwise_ = false;
156
+ }
157
+
158
+ /// Returns true if source is needed
159
+ CUTLASS_HOST_DEVICE
160
+ bool is_source_needed() const {
161
+ if (Scale == ScaleType::NoBetaScaling) return true;
162
+
163
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
164
+
165
+ if (Scale == ScaleType::Nothing) return false;
166
+
167
+ return params_.beta != ElementCompute(0);
168
+ }
169
+
170
+ /// Functionally required for serial reduction in the epilogue
171
+ CUTLASS_HOST_DEVICE
172
+ void set_k_partition(int k_partition, int k_partition_count) {
173
+ if (k_partition) {
174
+ params_.beta = ElementCompute(1);
175
+ }
176
+
177
+ if (k_partition != k_partition_count - 1) {
178
+ skip_elementwise_ = true;
179
+ }
180
+ }
181
+
182
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
183
+ CUTLASS_HOST_DEVICE
184
+ FragmentOutput operator()(
185
+ FragmentAccumulator const &accumulator,
186
+ FragmentOutput const &source) const {
187
+
188
+ // Convert source to interal compute numeric type
189
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
190
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
191
+
192
+ FragmentCompute converted_source = source_converter(source);
193
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
194
+
195
+ // Perform binary operations
196
+
197
+ FragmentCompute intermediate;
198
+
199
+ multiplies<FragmentCompute> mul_add_source;
200
+ multiply_add<FragmentCompute> mul_add_accumulator;
201
+ ActivationFunctor<FragmentCompute> activation;
202
+
203
+ if (Scale == ScaleType::NoBetaScaling) {
204
+ intermediate = converted_source;
205
+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
206
+ } else if (Scale == ScaleType::Nothing) {
207
+ intermediate = converted_accumulator;
208
+ } else {
209
+ intermediate = mul_add_source(params_.beta, converted_source); // X = beta * C + uniform
210
+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
211
+ }
212
+
213
+ if constexpr (GenericActivationTraits<ActivationFunctor<ElementCompute>>::IsArgumentsNeeded) {
214
+ intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_);
215
+ } else {
216
+ intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
217
+ }
218
+
219
+ // Convert to destination numeric type
220
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
221
+
222
+ return destination_converter(intermediate);
223
+ }
224
+
225
+ /// Computes linear scaling: D = alpha * accumulator
226
+ CUTLASS_HOST_DEVICE
227
+ FragmentOutput operator()(
228
+ FragmentAccumulator const &accumulator) const {
229
+
230
+ // Convert source to interal compute numeric type
231
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
232
+
233
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
234
+
235
+ // Perform binary operations
236
+
237
+ FragmentCompute intermediate;
238
+
239
+ multiplies<FragmentCompute> mul_add_accumulator;
240
+ ActivationFunctor<FragmentCompute> activation;
241
+
242
+ if (Scale == ScaleType::Nothing) {
243
+ intermediate = converted_accumulator;
244
+ } else {
245
+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator); // D = alpha * Accum
246
+ }
247
+
248
+ if constexpr (GenericActivationTraits<ActivationFunctor<FragmentCompute>>::IsArgumentsNeeded) {
249
+ intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_);
250
+ } else {
251
+ intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
252
+ }
253
+
254
+ // Convert to destination numeric type
255
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
256
+
257
+ return destination_converter(intermediate);
258
+ }
259
+ };
260
+
261
+ /////////////////////////////////////////////////////////////////////////////////////////////////
262
+
263
+ } // namespace thread
264
+ } // namespace epilogue
265
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_generic_with_scaling.h ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Functor performing linear combination operations with a generic element-wise activation
34
+ function. Scaling factors are applied to operands A, B, and C. The pre-activation auxiliary
35
+ output is also returned.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/numeric_types.h"
42
+ #include "cutlass/array.h"
43
+ #include "cutlass/functional.h"
44
+ #include "cutlass/numeric_conversion.h"
45
+ #include "cutlass/epilogue/thread/scale_type.h"
46
+ #include "cutlass/epilogue/thread/linear_combination_generic.h"
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace epilogue {
52
+ namespace thread {
53
+
54
+ /////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ /// Applies a linear combination operator to an array of elements.
57
+ ///
58
+ /// Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias
59
+ /// D = activation(Aux)
60
+ ///
61
+ template <
62
+ template<typename T> class ActivationFunctor,
63
+ typename ElementOutput_, ///< Data type used to load and store tensors
64
+ typename ElementAuxOutput_, ///< Data type used to store auxiliary output
65
+ int Count, ///< Number of elements computed per operation
66
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
67
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
68
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
69
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
70
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
71
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
72
+ bool IsHeavy = false
73
+ >
74
+ class LinearCombinationGenericWithScalingAndAbsMax {
75
+ public:
76
+
77
+ using ElementOutput = ElementOutput_;
78
+ using ElementAuxOutput = ElementAuxOutput_;
79
+ using ElementAccumulator = ElementAccumulator_;
80
+ using ElementCompute = ElementCompute_;
81
+ using ElementScalingFactor = ElementAccumulator_;
82
+
83
+ /// Data type used for absolute maximum value
84
+ using ElementAbsmax = float;
85
+
86
+ static bool const kIsScalingAndAmaxAuxOutputNeeded = (platform::is_same<ElementAuxOutput, cutlass::float_e4m3_t>::value ||
87
+ platform::is_same<ElementAuxOutput, cutlass::float_e5m2_t>::value);
88
+ static bool const kIsScalingAndAmaxOutputNeeded = (platform::is_same<ElementOutput, cutlass::float_e4m3_t>::value ||
89
+ platform::is_same<ElementOutput, cutlass::float_e5m2_t>::value);
90
+
91
+ static bool const kIsHeavy = IsHeavy;
92
+ static int const kCount = Count;
93
+ static const ScaleType::Kind kScale = Scale;
94
+
95
+ using FragmentOutput = Array<ElementOutput, kCount>;
96
+ using FragmentAuxOutput = Array<ElementAuxOutput, kCount>;
97
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
98
+ using FragmentCompute = Array<ElementCompute, kCount>;
99
+
100
+ static FloatRoundStyle const kRound = Round;
101
+
102
+ /// Host-constructable parameters structure
103
+ struct Params {
104
+ struct ActivationParams
105
+ : LinearCombinationGenericParams<ElementCompute>,
106
+ GenericActivationTraits<ActivationFunctor<ElementCompute>>::Arguments {
107
+ using LinearCombinationGenericParams<ElementCompute>::LinearCombinationGenericParams;
108
+ };
109
+
110
+ ActivationParams activation;
111
+ ElementScalingFactor const* scale_a_ptr = nullptr; ///< pointer to a scalar - if not null, loads it from memory
112
+ ElementScalingFactor const* scale_b_ptr = nullptr; ///< pointer to b scalar - if not null, loads it from memory
113
+ ElementScalingFactor const* scale_c_ptr = nullptr; ///< pointer to c scalar - if not null, loads it from memory
114
+ ElementScalingFactor const* scale_d_ptr = nullptr; ///< pointer to d scalar - if not null, loads it from memory
115
+ ElementScalingFactor const* scale_aux_ptr = nullptr; ///< pointer to aux scalar - if not null, loads it from memory
116
+
117
+ ElementAbsmax * abs_max_aux_ptr = nullptr; ///< pointer to location to store amax of Aux
118
+ ElementAbsmax * abs_max_D_ptr = nullptr; ///< pointer to location to store amax of D
119
+
120
+ CUTLASS_HOST_DEVICE
121
+ Params() :
122
+ scale_a_ptr(nullptr),
123
+ scale_b_ptr(nullptr),
124
+ scale_c_ptr(nullptr),
125
+ scale_d_ptr(nullptr),
126
+ scale_aux_ptr(nullptr),
127
+ abs_max_aux_ptr(nullptr),
128
+ abs_max_D_ptr(nullptr) {}
129
+
130
+ CUTLASS_HOST_DEVICE
131
+ Params(ActivationParams activation_params,
132
+ ElementScalingFactor const* scale_a_ptr,
133
+ ElementScalingFactor const* scale_b_ptr,
134
+ ElementScalingFactor const* scale_c_ptr,
135
+ ElementScalingFactor const* scale_d_ptr,
136
+ ElementScalingFactor const* scale_aux_ptr,
137
+ ElementAbsmax * abs_max_aux_ptr,
138
+ ElementAbsmax * abs_max_D_ptr) :
139
+ activation(activation_params),
140
+ scale_a_ptr(scale_a_ptr),
141
+ scale_b_ptr(scale_b_ptr),
142
+ scale_c_ptr(scale_c_ptr),
143
+ scale_d_ptr(scale_d_ptr),
144
+ scale_aux_ptr(scale_aux_ptr),
145
+ abs_max_aux_ptr(abs_max_aux_ptr),
146
+ abs_max_D_ptr(abs_max_D_ptr) {}
147
+ };
148
+
149
+ private:
150
+
151
+ //
152
+ // Data members
153
+ //
154
+
155
+ Params params_;
156
+ bool skip_elementwise_;
157
+
158
+ // Scaling factors for output and auxiliary output
159
+ ElementCompute scale_d_;
160
+ ElementCompute scale_aux_;
161
+
162
+ public:
163
+
164
+ /// Constructs the function object, possibly loading from pointers in host memory
165
+ CUTLASS_HOST_DEVICE
166
+ LinearCombinationGenericWithScalingAndAbsMax(Params const &params) :
167
+ params_(params),
168
+ skip_elementwise_(false),
169
+ scale_d_(ElementCompute(params.scale_d_ptr ? *(params.scale_d_ptr) : ElementScalingFactor(1))),
170
+ scale_aux_(ElementCompute(params.scale_aux_ptr ? *(params.scale_aux_ptr) : ElementScalingFactor(1)))
171
+ {
172
+ params_.activation.alpha = (params.activation.alpha_ptr ? *params.activation.alpha_ptr : params.activation.alpha);
173
+ params_.activation.beta = (params.activation.beta_ptr ? *params.activation.beta_ptr : params.activation.beta);
174
+ auto scale_a =
175
+ ElementCompute(params.scale_a_ptr ? *(params.scale_a_ptr) : ElementScalingFactor(1));
176
+ auto scale_b =
177
+ ElementCompute(params.scale_b_ptr ? *(params.scale_b_ptr) : ElementScalingFactor(1));
178
+ auto scale_c =
179
+ ElementCompute(params.scale_c_ptr ? *(params.scale_c_ptr) : ElementScalingFactor(1));
180
+
181
+ multiplies<ElementCompute> multiply;
182
+ params_.activation.alpha = multiply(params.activation.alpha, multiply(scale_a, scale_b));
183
+ params_.activation.beta = multiply(params.activation.beta, scale_c);
184
+ }
185
+
186
+ /// Returns true if source is needed
187
+ CUTLASS_HOST_DEVICE
188
+ bool is_source_needed() const {
189
+ if (Scale == ScaleType::NoBetaScaling) return true;
190
+
191
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
192
+
193
+ if (Scale == ScaleType::Nothing) return false;
194
+
195
+ return params_.activation.beta != ElementCompute(0);
196
+ }
197
+
198
+ /// Functionally required for serial reduction in the epilogue
199
+ CUTLASS_HOST_DEVICE
200
+ void set_k_partition(int k_partition, int k_partition_count) {
201
+ if (k_partition) {
202
+ params_.activation.beta = ElementCompute(1);
203
+ }
204
+
205
+ // Only the final partition should perform the activation function
206
+ // and scale the output and auxiliary output values.
207
+ if (k_partition != k_partition_count - 1) {
208
+ skip_elementwise_ = true;
209
+ scale_d_ = ElementCompute(1.);
210
+ scale_aux_ = ElementCompute(1.);
211
+ }
212
+ }
213
+
214
+ /// Computes linear scaling:
215
+ /// Aux = (alpha * scale_a * scale_b * accumulator) + (beta * scale_c * source) + bias
216
+ /// D = activation(Aux)
217
+ CUTLASS_HOST_DEVICE
218
+ void operator()(
219
+ FragmentCompute& output,
220
+ FragmentCompute& aux_output,
221
+ FragmentAccumulator const &accumulator,
222
+ FragmentCompute const& bias,
223
+ FragmentOutput const &source) {
224
+
225
+ // Convert source to interal compute numeric type
226
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
227
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
228
+
229
+ FragmentCompute converted_source = source_converter(source);
230
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
231
+
232
+ // Perform binary operations
233
+
234
+ FragmentCompute intermediate;
235
+
236
+ multiplies<FragmentCompute> multiply;
237
+ plus<FragmentCompute> add;
238
+ multiply_add<FragmentCompute> mul_add_accumulator;
239
+ ActivationFunctor<FragmentCompute> activation;
240
+
241
+ if (Scale == ScaleType::NoBetaScaling) {
242
+ intermediate = converted_source;
243
+ intermediate = mul_add_accumulator(params_.activation.alpha, converted_accumulator, intermediate);
244
+ } else if (Scale == ScaleType::Nothing) {
245
+ intermediate = converted_accumulator;
246
+ } else {
247
+ intermediate = multiply(params_.activation.beta, converted_source);
248
+ intermediate = mul_add_accumulator(params_.activation.alpha, converted_accumulator, intermediate);
249
+ }
250
+
251
+ intermediate = add(intermediate, bias);
252
+
253
+ aux_output = intermediate;
254
+ if constexpr (GenericActivationTraits<ActivationFunctor<ElementCompute>>::IsArgumentsNeeded) {
255
+ output = skip_elementwise_ ? intermediate : activation(intermediate, params_.activation);
256
+ } else {
257
+ output = skip_elementwise_ ? intermediate : activation(intermediate);
258
+ }
259
+ }
260
+
261
+ /// Computes linear scaling:
262
+ /// Aux = (alpha * scale_a * scale_b * accumulator) + bias
263
+ /// D = activation(Aux)
264
+ CUTLASS_DEVICE
265
+ void operator()(
266
+ FragmentCompute& output,
267
+ FragmentCompute& aux_output,
268
+ FragmentAccumulator const &accumulator,
269
+ FragmentCompute const& bias) {
270
+
271
+ // Convert source to interal compute numeric type
272
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
273
+
274
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
275
+
276
+ // Perform binary operations
277
+
278
+ FragmentCompute intermediate;
279
+
280
+ multiplies<FragmentCompute> multiply;
281
+ plus<FragmentCompute> add;
282
+ ActivationFunctor<FragmentCompute> activation;
283
+
284
+ if (Scale == ScaleType::Nothing) {
285
+ intermediate = converted_accumulator;
286
+ } else {
287
+ intermediate = multiply(params_.activation.alpha, converted_accumulator);
288
+ }
289
+
290
+ intermediate = add(intermediate, bias);
291
+
292
+ aux_output = intermediate;
293
+ if constexpr (GenericActivationTraits<ActivationFunctor<FragmentCompute>>::IsArgumentsNeeded) {
294
+ output = skip_elementwise_ ? intermediate : activation(intermediate, params_.activation);
295
+ } else {
296
+ output = skip_elementwise_ ? intermediate : activation(intermediate);
297
+ }
298
+ }
299
+
300
+ CUTLASS_HOST_DEVICE
301
+ ElementAbsmax* get_ptr_output_abs_max() const {
302
+ return params_.abs_max_D_ptr;
303
+ }
304
+
305
+ CUTLASS_HOST_DEVICE
306
+ ElementAbsmax* get_ptr_aux_output_abs_max() const {
307
+ return params_.abs_max_aux_ptr;
308
+ }
309
+
310
+ CUTLASS_HOST_DEVICE
311
+ ElementCompute get_scale_d() const {
312
+ return scale_d_;
313
+ }
314
+
315
+ CUTLASS_HOST_DEVICE
316
+ ElementCompute get_scale_aux() const {
317
+ return scale_aux_;
318
+ }
319
+ };
320
+
321
+ /////////////////////////////////////////////////////////////////////////////////////////////////
322
+
323
+ } // namespace thread
324
+ } // namespace epilogue
325
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination with HardSwish operations used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/epilogue/thread/activation.h"
39
+ #include "cutlass/epilogue/thread/linear_combination_generic.h"
40
+
41
+ /////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ namespace cutlass {
44
+ namespace epilogue {
45
+ namespace thread {
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ /// Applies a linear combination operator followed by the HardSwish activation to an array of elements.
50
+ ///
51
+ /// D = hardswish(alpha * accumulator + beta * source + uniform)
52
+ ///
53
+ template <
54
+ typename ElementOutput_, ///< Data type used to load and store tensors
55
+ int Count, ///< Number of elements computed per operation
56
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
57
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
58
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
59
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
60
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
61
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
62
+ >
63
+ using LinearCombinationHardSwish = LinearCombinationGeneric<HardSwish, ElementOutput_, Count, ElementAccumulator_,
64
+ ElementCompute_, Scale, Round>;
65
+ /////////////////////////////////////////////////////////////////////////////////////////////////
66
+
67
+ } // namespace thread
68
+ } // namespace epilogue
69
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include "cutlass/cutlass.h"
35
+ #include "cutlass/numeric_types.h"
36
+ #include "cutlass/array.h"
37
+ #include "cutlass/functional.h"
38
+ #include "cutlass/numeric_conversion.h"
39
+ #include "cutlass/epilogue/thread/activation.h"
40
+ #include "cutlass/epilogue/thread/scale_type.h"
41
+
42
+ /////////////////////////////////////////////////////////////////////////////////////////////////
43
+
44
+ namespace cutlass {
45
+ namespace epilogue {
46
+ namespace thread {
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ /// Applies a linear combination operator to an array of elements.
51
+ ///
52
+ /// D = alpha * accumulator + beta * source + uniform
53
+ ///
54
+ template <
55
+ typename ElementOutput_, ///< Data type used to load and store tensors
56
+ int Count, ///< Number of elements computed per operation
57
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
58
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
59
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
60
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
61
+ >
62
+ class LinearCombinationLeakyRelu {
63
+ public:
64
+
65
+ using ElementOutput = ElementOutput_;
66
+ using ElementAccumulator = ElementAccumulator_;
67
+ using ElementCompute = ElementCompute_;
68
+
69
+ static int const kCount = Count;
70
+ static const ScaleType::Kind kScale = Scale;
71
+
72
+ using FragmentOutput = Array<ElementOutput, kCount>;
73
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
74
+ using ComputeFragment = Array<ElementCompute, kCount>;
75
+ using FragmentSource = Array<ElementOutput, kCount>;
76
+
77
+ static FloatRoundStyle const kRound = Round;
78
+
79
+ /// Host-constructable parameters structure
80
+ struct Params {
81
+
82
+ ElementCompute alpha; ///< scales accumulators
83
+ ElementCompute beta_bias; ///< scales bias tensor
84
+ ElementCompute leaky_alpha; ///< leaky_alpha
85
+ //
86
+ // Methods
87
+ //
88
+
89
+ CUTLASS_HOST_DEVICE
90
+ Params():
91
+ alpha(ElementCompute(1)),
92
+ beta_bias(ElementCompute(0)),
93
+ leaky_alpha(ElementCompute(1))
94
+ { }
95
+
96
+ CUTLASS_HOST_DEVICE
97
+ Params(
98
+ ElementCompute alpha,
99
+ ElementCompute beta_bias,
100
+ ElementCompute leaky_alpha = ElementCompute(1)
101
+ ): alpha(alpha), beta_bias(beta_bias), leaky_alpha(leaky_alpha) {
102
+
103
+ }
104
+
105
+ };
106
+
107
+ private:
108
+
109
+ //
110
+ // Data members
111
+ //
112
+
113
+ ElementCompute alpha_;
114
+ ElementCompute beta_bias_;
115
+ ElementCompute leaky_alpha_recip_;
116
+
117
+ public:
118
+
119
+ /// Constructs the function object, possibly loading from pointers in host memory
120
+ CUTLASS_HOST_DEVICE
121
+ LinearCombinationLeakyRelu(Params const &params) {
122
+ alpha_ = (params.alpha);
123
+ beta_bias_ = (params.beta_bias);
124
+ leaky_alpha_recip_ = (ElementCompute(params.leaky_alpha));
125
+ }
126
+
127
+ /// Returns true if source is needed
128
+ CUTLASS_HOST_DEVICE
129
+ bool is_source_needed() const {
130
+ if (Scale == ScaleType::NoBetaScaling) return true;
131
+
132
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
133
+
134
+ if (Scale == ScaleType::Nothing) return false;
135
+
136
+ return beta_bias_ != ElementCompute(0);
137
+ }
138
+
139
+ /// Functionally required for serial reduction in the epilogue
140
+ CUTLASS_HOST_DEVICE
141
+ void set_k_partition(int k_partition) {
142
+ if (k_partition) {
143
+ beta_bias_ = ElementCompute(1);
144
+ }
145
+ }
146
+ CUTLASS_HOST_DEVICE
147
+ void set_k_partition(int k_partition, int k_partition_count) {
148
+ if (k_partition) {
149
+ beta_bias_ = ElementCompute(1);
150
+ }
151
+ }
152
+
153
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
154
+ CUTLASS_HOST_DEVICE
155
+ FragmentOutput operator()(
156
+ FragmentAccumulator const &accumulator,
157
+ FragmentOutput const &source) const {
158
+
159
+ // Convert source to interal compute numeric type
160
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
161
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
162
+
163
+ ComputeFragment converted_source = source_converter(source);
164
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
165
+
166
+ // Perform binary operations
167
+ ComputeFragment intermediate;
168
+
169
+ multiplies<ComputeFragment> mul_add_source;
170
+ multiply_add<ComputeFragment> mul_add_accumulator;
171
+
172
+ LeakyReLU<ComputeFragment> leakyrelu;
173
+
174
+ if (Scale == ScaleType::NoBetaScaling) {
175
+ intermediate = converted_source;
176
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
177
+ } else if (Scale == ScaleType::Nothing) {
178
+ intermediate = converted_accumulator;
179
+ } else {
180
+ intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform
181
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
182
+ }
183
+ // Compute threshold optionally
184
+ intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
185
+
186
+ // Convert to destination numeric type
187
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
188
+
189
+ return destination_converter(intermediate);
190
+ }
191
+
192
+ /// Computes linear scaling: D = alpha * accumulator
193
+ CUTLASS_HOST_DEVICE
194
+ FragmentOutput operator()(
195
+ FragmentAccumulator const &accumulator) const {
196
+
197
+ // Convert source to interal compute numeric type
198
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
199
+
200
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
201
+
202
+ // Perform binary operations
203
+ ComputeFragment intermediate;
204
+
205
+ multiplies<ComputeFragment> mul_accumulator;
206
+ LeakyReLU<ComputeFragment> leakyrelu;
207
+ //printf("in doing with bias");
208
+ if (Scale == ScaleType::Nothing) {
209
+ intermediate = converted_accumulator;
210
+ } else {
211
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
212
+ }
213
+
214
+ // Compute threshold optionally
215
+ intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
216
+
217
+
218
+ // Convert to destination numeric type
219
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
220
+
221
+ return destination_converter(intermediate);
222
+ }
223
+ };
224
+
225
+ /////////////////////////////////////////////////////////////////////////////////////////////////
226
+
227
+ } // namespace thread
228
+ } // namespace epilogue
229
+ } // namespace cutlass
230
+
231
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief
33
+ */
34
+
35
+ #pragma once
36
+
37
+ /////////////////////////////////////////////////////////////////////////////////////////////////
38
+
39
+ namespace cutlass {
40
+ namespace epilogue {
41
+ namespace thread {
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ struct LinearCombinationParams {
46
+ uint64_t alpha_data[2];
47
+ uint64_t beta_data[2];
48
+
49
+ CUTLASS_HOST_DEVICE
50
+ LinearCombinationParams()
51
+ : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu}
52
+ { }
53
+
54
+ template <typename ElementCompute>
55
+ CUTLASS_HOST_DEVICE
56
+ LinearCombinationParams(ElementCompute alpha, ElementCompute beta)
57
+ : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu}
58
+ {
59
+ #if defined(__CUDA_ARCH__)
60
+ reinterpret_cast<ElementCompute&>(alpha_data) = alpha;
61
+ reinterpret_cast<ElementCompute&>(beta_data) = beta;
62
+ #else
63
+ memcpy( alpha_data, &alpha, sizeof(ElementCompute) );
64
+ memcpy( beta_data, &beta, sizeof(ElementCompute) );
65
+ #endif
66
+ }
67
+ };
68
+
69
+ /////////////////////////////////////////////////////////////////////////////////////////////////
70
+
71
+ } // namespace thread
72
+ } // namespace epilogue
73
+ } // namespace cutlass
74
+
75
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination operations on planar-complex arrays
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/numeric_types.h"
39
+ #include "cutlass/complex.h"
40
+ #include "cutlass/array_planar_complex.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/epilogue/thread/scale_type.h"
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ namespace cutlass {
48
+ namespace epilogue {
49
+ namespace thread {
50
+
51
+ /////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ /// Applies a linear combination operator to arrays of planar-complex elements.
54
+ ///
55
+ /// D = alpha * accumulator + beta * source + uniform
56
+ ///
57
+ /// Note, as with most CUTLASS components for planar complex, the template arguments describe
58
+ /// the underlying real data type.
59
+ template <
60
+ typename ElementOutput_, ///< Data type used to load and store tensors
61
+ int Count, ///< Number of elements computed per operation
62
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
63
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
64
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
65
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
66
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
67
+ ScaleType::Kind Scale = ScaleType::Default ///< Control Alpha and Beta scaling
68
+ >
69
+ class LinearCombinationPlanarComplex {
70
+ public:
71
+
72
+ using ElementOutput = ElementOutput_;
73
+ using ElementAccumulator = ElementAccumulator_;
74
+ using ElementCompute = ElementCompute_;
75
+ using ElementScalar = complex<ElementCompute>;
76
+
77
+ static int const kCount = Count;
78
+ static const ScaleType::Kind kScale = Scale;
79
+
80
+ using FragmentOutput = ArrayPlanarComplex<ElementOutput, kCount>;
81
+ using FragmentAccumulator = ArrayPlanarComplex<ElementAccumulator, kCount>;
82
+ using ComputeFragment = ArrayPlanarComplex<ElementCompute, kCount>;
83
+
84
+ static FloatRoundStyle const kRound = Round;
85
+
86
+ /// Host-constructable parameters structure
87
+ struct Params {
88
+
89
+ ElementScalar alpha{ElementCompute(1)}; ///< scales accumulators
90
+ ElementScalar beta{ElementCompute(0)}; ///< scales source tensor
91
+ ElementScalar const* alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
92
+ ElementScalar const* beta_ptr{nullptr}; ///< pointer to source scalar - if not null, loads it from memory
93
+
94
+ //
95
+ // Methods
96
+ //
97
+
98
+ Params() = default;
99
+
100
+ CUTLASS_HOST_DEVICE
101
+ Params(
102
+ ElementScalar alpha,
103
+ ElementScalar beta
104
+ ): alpha(alpha), beta(beta)
105
+ {}
106
+
107
+ CUTLASS_HOST_DEVICE
108
+ Params(
109
+ ElementScalar const *alpha_ptr,
110
+ ElementScalar const *beta_ptr
111
+ ): alpha_ptr(alpha_ptr), beta_ptr(beta_ptr)
112
+ {}
113
+ };
114
+
115
+ private:
116
+
117
+ //
118
+ // Data members
119
+ //
120
+
121
+ ElementScalar alpha_;
122
+ ElementScalar beta_;
123
+
124
+ public:
125
+
126
+ /// Constructs the function object, possibly loading from pointers in host memory
127
+ CUTLASS_HOST_DEVICE
128
+ LinearCombinationPlanarComplex(Params const &params) {
129
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
130
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
131
+ }
132
+
133
+ /// Returns true if source is needed
134
+ CUTLASS_HOST_DEVICE
135
+ bool is_source_needed() const {
136
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
137
+
138
+ return beta_.real() != ElementCompute(0) || beta_.imag() != ElementCompute(0);
139
+ }
140
+
141
+ /// Functionally required for serial reduction in the epilogue
142
+ CUTLASS_HOST_DEVICE
143
+ void set_k_partition(int k_partition, int k_partition_count) {
144
+ if (k_partition) {
145
+ beta_ = ElementCompute(1);
146
+ }
147
+ }
148
+
149
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
150
+ CUTLASS_HOST_DEVICE
151
+ FragmentOutput operator()(
152
+ FragmentAccumulator const &accumulator,
153
+ FragmentOutput const &source) const {
154
+
155
+ // Convert source to interal compute numeric type
156
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
157
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
158
+
159
+ ComputeFragment converted_source{
160
+ source_converter(source.real),
161
+ source_converter(source.imag)};
162
+
163
+ ComputeFragment converted_accumulator{
164
+ accumulator_converter(accumulator.real),
165
+ accumulator_converter(accumulator.imag)};
166
+
167
+ multiplies<Array<ElementCompute, kCount> > mul_op;
168
+ multiply_add<Array<ElementCompute, kCount> > mul_add_op;
169
+
170
+ // Perform binary operations
171
+
172
+ // complex multiply: I = beta * C
173
+ ComputeFragment intermediate {
174
+ mul_op(beta_.real(), converted_source.real),
175
+ mul_op(beta_.real(), converted_source.imag)
176
+ };
177
+
178
+ intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real);
179
+ intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag);
180
+
181
+ // complex multiply-add: I = alpha * AB + I
182
+ intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real, intermediate.real);
183
+ intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag, intermediate.imag);
184
+
185
+ intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real);
186
+ intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag);
187
+
188
+ // Convert to destination numeric type
189
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
190
+
191
+ return FragmentOutput{
192
+ destination_converter(intermediate.real),
193
+ destination_converter(intermediate.imag)};
194
+ }
195
+
196
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
197
+ CUTLASS_HOST_DEVICE
198
+ FragmentOutput operator()(
199
+ FragmentAccumulator const &accumulator) const {
200
+
201
+ // Convert source to interal compute numeric type
202
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
203
+
204
+ ComputeFragment converted_accumulator{
205
+ accumulator_converter(accumulator.real),
206
+ accumulator_converter(accumulator.imag)};
207
+
208
+ // Perform binary operations
209
+ multiplies<Array<ElementCompute, kCount> > mul_op;
210
+ multiply_add<Array<ElementCompute, kCount> > mul_add_op;
211
+
212
+ // complex multiply-add: I = alpha * AB + I
213
+ ComputeFragment intermediate {
214
+ mul_op(alpha_.real(), converted_accumulator.real),
215
+ mul_op(alpha_.real(), converted_accumulator.imag)
216
+ };
217
+
218
+ intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real);
219
+ intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag);
220
+
221
+ // Convert to destination numeric type
222
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
223
+
224
+ return FragmentOutput{
225
+ destination_converter(intermediate.real),
226
+ destination_converter(intermediate.imag)};
227
+ }
228
+ };
229
+
230
+ /////////////////////////////////////////////////////////////////////////////////////////////////
231
+
232
+ } // namespace thread
233
+ } // namespace epilogue
234
+ } // namespace cutlass
235
+
236
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination with a maximum operation used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/half.h"
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/epilogue/thread/activation.h"
44
+ #include "cutlass/epilogue/thread/scale_type.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace epilogue {
50
+ namespace thread {
51
+
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ namespace detail {
55
+
56
+ /// Single source of truth for whether to unroll for `LinearCombinationClamp()`
57
+ constexpr bool LinearCombinationReluIsHeavy() {
58
+ return false;
59
+ }
60
+
61
+ }
62
+
63
+ /////////////////////////////////////////////////////////////////////////////////////////////////
64
+
65
+ /// Applies a linear combination operator to an array of elements.
66
+ ///
67
+ /// D = alpha * accumulator + beta * source + uniform
68
+ ///
69
+ template <
70
+ typename ElementOutput_, ///< Data type used to load and store tensors
71
+ int Count, ///< Number of elements computed per operation
72
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
73
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
74
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
75
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
76
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
77
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
78
+ >
79
+ class LinearCombinationRelu {
80
+ public:
81
+
82
+ using ElementOutput = ElementOutput_;
83
+ using ElementAccumulator = ElementAccumulator_;
84
+ using ElementCompute = ElementCompute_;
85
+
86
+ static int const kCount = Count;
87
+ static const ScaleType::Kind kScale = Scale;
88
+
89
+ using FragmentOutput = Array<ElementOutput, kCount>;
90
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
91
+ using FragmentCompute = Array<ElementCompute, kCount>;
92
+ using FragmentScaleBias = Array<ElementCompute, kCount>;
93
+ using FragmentSource = Array<ElementOutput, kCount>;
94
+
95
+ static FloatRoundStyle const kRound = Round;
96
+
97
+ static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy();
98
+
99
+ /// Host-constructable parameters structure
100
+ struct Params {
101
+
102
+ ElementCompute alpha; ///< scales accumulators
103
+ ElementCompute beta; ///< scales source tensor
104
+ ElementCompute threshold; ///< minimum value that is output
105
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
106
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
107
+ //
108
+ // Methods
109
+ //
110
+
111
+ CUTLASS_HOST_DEVICE
112
+ Params():
113
+ alpha(ElementCompute(1)),
114
+ beta(ElementCompute(0)),
115
+ threshold(ElementCompute(0)),
116
+ alpha_ptr(nullptr),
117
+ beta_ptr(nullptr) { }
118
+
119
+ CUTLASS_HOST_DEVICE
120
+ Params(
121
+ ElementCompute alpha,
122
+ ElementCompute beta = ElementCompute(0),
123
+ ElementCompute threshold = ElementCompute(0)
124
+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
125
+
126
+ }
127
+
128
+ CUTLASS_HOST_DEVICE
129
+ Params(
130
+ ElementCompute const *alpha_ptr,
131
+ ElementCompute const *beta_ptr = nullptr,
132
+ ElementCompute threshold = ElementCompute(0)
133
+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
134
+
135
+ }
136
+ };
137
+
138
+ private:
139
+
140
+ //
141
+ // Data members
142
+ //
143
+
144
+ ElementCompute alpha_;
145
+ ElementCompute beta_;
146
+ ElementCompute threshold_;
147
+
148
+ public:
149
+
150
+ /// Constructs the function object, possibly loading from pointers in host memory
151
+ CUTLASS_HOST_DEVICE
152
+ LinearCombinationRelu(Params const &params) {
153
+
154
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
155
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
156
+ threshold_ = params.threshold;
157
+ }
158
+
159
+ /// Returns true if source is needed
160
+ CUTLASS_HOST_DEVICE
161
+ bool is_source_needed() const {
162
+ if (Scale == ScaleType::NoBetaScaling) return true;
163
+
164
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
165
+
166
+ if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false;
167
+
168
+ if (Scale == ScaleType::Nothing) return false;
169
+
170
+ return beta_ != ElementCompute(0);
171
+ }
172
+
173
+ /// Functionally required for serial reduction in the epilogue
174
+ CUTLASS_HOST_DEVICE
175
+ void set_k_partition(int k_partition, int k_partition_count) {
176
+ if (k_partition) {
177
+ beta_ = ElementCompute(1);
178
+ }
179
+
180
+ if (k_partition != k_partition_count - 1) {
181
+ // set to NaN to make ReLU no-op for all except last k partitions
182
+ int64_t allones = -1;
183
+ threshold_ = reinterpret_cast<ElementCompute const &>(allones);
184
+ }
185
+ }
186
+
187
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
188
+ CUTLASS_HOST_DEVICE
189
+ FragmentOutput operator()(
190
+ FragmentAccumulator const &accumulator,
191
+ FragmentOutput const &source) const {
192
+
193
+ // Convert source to interal compute numeric type
194
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
195
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
196
+
197
+ FragmentCompute converted_source = source_converter(source);
198
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
199
+
200
+ // Perform binary operations
201
+ FragmentCompute intermediate;
202
+
203
+ multiplies<FragmentCompute> mul_add_source;
204
+ multiply_add<FragmentCompute> mul_add_accumulator;
205
+ ReLu<FragmentCompute> relu;
206
+
207
+ if (Scale == ScaleType::NoBetaScaling) {
208
+ intermediate = converted_source;
209
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
210
+ } else if (Scale == ScaleType::Nothing) {
211
+ intermediate = converted_accumulator;
212
+ } else {
213
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
214
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
215
+ }
216
+
217
+ // Compute threshold optionally
218
+ intermediate = relu(threshold_, intermediate);
219
+
220
+ // Convert to destination numeric type
221
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
222
+
223
+ return destination_converter(intermediate);
224
+ }
225
+
226
+ /// Computes linear scaling: D = alpha * accumulator
227
+ CUTLASS_HOST_DEVICE
228
+ FragmentOutput operator()(
229
+ FragmentAccumulator const &accumulator) const {
230
+
231
+ // Convert source to interal compute numeric type
232
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
233
+
234
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
235
+
236
+ // Perform binary operations
237
+ FragmentCompute intermediate;
238
+
239
+ multiplies<FragmentCompute> mul_accumulator;
240
+ ReLu<FragmentCompute> relu;
241
+
242
+ if (Scale == ScaleType::Nothing) {
243
+ intermediate = converted_accumulator;
244
+ } else {
245
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
246
+ }
247
+
248
+ // Compute threshold optionally
249
+ intermediate = relu(threshold_, intermediate);
250
+
251
+ // Convert to destination numeric type
252
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
253
+
254
+ return destination_converter(intermediate);
255
+ }
256
+
257
+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
258
+ /// Scale and Bias are from input Fragment
259
+ CUTLASS_HOST_DEVICE
260
+ FragmentOutput operator()(
261
+ FragmentAccumulator const &accumulator,
262
+ FragmentScaleBias const &scale,
263
+ FragmentScaleBias const &bias) const {
264
+
265
+ // Convert source to interal compute numeric type
266
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
267
+
268
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
269
+
270
+ // Perform per-channel scale and bias
271
+ FragmentCompute intermediate;
272
+
273
+ multiply_add<FragmentCompute> mul_add_accumulator;
274
+
275
+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
276
+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
277
+ else
278
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
279
+
280
+ ReLu<FragmentCompute> relu;
281
+
282
+ // Compute threshold optionally
283
+ intermediate = relu(threshold_, intermediate);
284
+
285
+ // Convert to destination numeric type
286
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
287
+
288
+ return destination_converter(intermediate);
289
+ }
290
+ };
291
+
292
+ /////////////////////////////////////////////////////////////////////////////////////////////////
293
+
294
+ // Conditional guards to enable partial specialization for packed integers
295
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
296
+
297
+ /// Applies a linear combination operator to an array of elements.
298
+ ///
299
+ /// D = alpha * accumulator + beta * source + uniform
300
+ ///
301
+ /// Special handling for int types
302
+
303
+ template <
304
+ typename ElementOutput_, ///< Data type used to load and store tensors
305
+ int Count, ///< Number of elements computed per operation
306
+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
307
+ FloatRoundStyle Round
308
+ >
309
+ class LinearCombinationRelu <ElementOutput_, Count, int, float, Scale, Round> {
310
+ public:
311
+
312
+ using ElementOutput = ElementOutput_;
313
+ using ElementAccumulator = int;
314
+ using ElementCompute = float;
315
+
316
+ static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy();
317
+
318
+ static int const kCount = Count;
319
+ static const ScaleType::Kind kScale = Scale;
320
+
321
+ using FragmentOutput = Array<ElementOutput, kCount>;
322
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
323
+ using FragmentCompute = Array<ElementCompute, kCount>;
324
+ using FragmentScaleBias = Array<ElementCompute, kCount>;
325
+ using FragmentSource = Array<ElementOutput, kCount>;
326
+
327
+ static FloatRoundStyle const kRound = Round;
328
+
329
+ /// Host-constructable parameters structure
330
+ struct Params {
331
+
332
+ ElementCompute alpha; ///< scales accumulators
333
+ ElementCompute beta; ///< scales source tensor
334
+ ElementCompute threshold; ///< minimum value that is output
335
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
336
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
337
+ //
338
+ // Methods
339
+ //
340
+
341
+ CUTLASS_HOST_DEVICE
342
+ Params():
343
+ alpha(ElementCompute(1)),
344
+ beta(ElementCompute(0)),
345
+ threshold(ElementCompute(0)),
346
+ alpha_ptr(nullptr),
347
+ beta_ptr(nullptr) { }
348
+
349
+ CUTLASS_HOST_DEVICE
350
+ Params(
351
+ ElementCompute alpha,
352
+ ElementCompute beta = ElementCompute(0),
353
+ ElementCompute threshold = ElementCompute(0)
354
+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
355
+
356
+ }
357
+
358
+ CUTLASS_HOST_DEVICE
359
+ Params(
360
+ ElementCompute const *alpha_ptr,
361
+ ElementCompute const *beta_ptr = nullptr,
362
+ ElementCompute threshold = ElementCompute(0)
363
+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
364
+
365
+ }
366
+ };
367
+
368
+ private:
369
+
370
+ //
371
+ // Data members
372
+ //
373
+
374
+ ElementCompute alpha_;
375
+ ElementCompute beta_;
376
+ ElementCompute threshold_;
377
+
378
+ public:
379
+
380
+ /// Constructs the function object, possibly loading from pointers in host memory
381
+ CUTLASS_HOST_DEVICE
382
+ LinearCombinationRelu(Params const &params) {
383
+
384
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
385
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
386
+ threshold_ = params.threshold;
387
+ }
388
+
389
+ /// Returns true if source is needed
390
+ CUTLASS_HOST_DEVICE
391
+ bool is_source_needed() const {
392
+ if (Scale == ScaleType::NoBetaScaling) return true;
393
+
394
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
395
+
396
+ if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false;
397
+
398
+ if (Scale == ScaleType::Nothing) return false;
399
+
400
+ return beta_ != ElementCompute(0);
401
+ }
402
+
403
+ /// Functionally required for serial reduction in the epilogue
404
+ CUTLASS_HOST_DEVICE
405
+ void set_k_partition(int k_partition, int k_partition_count) {
406
+ if (k_partition) {
407
+ beta_ = ElementCompute(1);
408
+ }
409
+
410
+ if (k_partition != k_partition_count - 1) {
411
+ // set to NaN to make ReLU no-op for all except last k partitions
412
+ int64_t allones = -1;
413
+ threshold_ = reinterpret_cast<ElementCompute const &>(allones);
414
+ }
415
+ }
416
+
417
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
418
+ CUTLASS_HOST_DEVICE
419
+ FragmentOutput operator()(
420
+ FragmentAccumulator const &accumulator,
421
+ FragmentOutput const &source) const {
422
+
423
+ // Convert source to interal compute numeric type
424
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
425
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
426
+
427
+ FragmentCompute converted_source = source_converter(source);
428
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
429
+
430
+ // Perform binary operations
431
+ FragmentCompute intermediate;
432
+
433
+ multiplies<FragmentCompute> mul_add_source;
434
+ multiply_add<FragmentCompute> mul_add_accumulator;
435
+ ReLu<FragmentCompute> relu;
436
+
437
+ if (Scale == ScaleType::NoBetaScaling) {
438
+ intermediate = converted_source;
439
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
440
+ } else if (Scale == ScaleType::Nothing) {
441
+ intermediate = converted_accumulator;
442
+ } else {
443
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
444
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
445
+ }
446
+
447
+ // Compute threshold optionally
448
+ intermediate = relu(threshold_, intermediate);
449
+
450
+ if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
451
+ // Convert floats back to INT
452
+ FragmentAccumulator scaled_accumulator;
453
+
454
+ NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
455
+
456
+ scaled_accumulator = compute_converter(intermediate);
457
+
458
+ // Convert to destination numeric type
459
+ NumericArrayConverter<ElementOutput, int, kCount, Round>
460
+ destination_converter;
461
+
462
+ return destination_converter(scaled_accumulator);
463
+ } else {
464
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
465
+ destination_converter;
466
+ return destination_converter(intermediate);
467
+ }
468
+ }
469
+
470
+ /// Computes linear scaling: D = alpha * accumulator
471
+ CUTLASS_HOST_DEVICE
472
+ FragmentOutput operator()(
473
+ FragmentAccumulator const &accumulator) const {
474
+
475
+ // Convert source to interal compute numeric type
476
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
477
+
478
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
479
+
480
+ // Perform binary operations
481
+ FragmentCompute intermediate;
482
+
483
+ multiplies<FragmentCompute> mul_accumulator;
484
+ ReLu<FragmentCompute> relu;
485
+
486
+ if (Scale == ScaleType::Nothing) {
487
+ intermediate = converted_accumulator;
488
+ } else {
489
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
490
+ }
491
+
492
+ // Compute threshold optionally
493
+ intermediate = relu(threshold_, intermediate);
494
+
495
+ if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
496
+ // Convert floats back to INT
497
+ FragmentAccumulator scaled_accumulator;
498
+
499
+ NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
500
+
501
+ scaled_accumulator = compute_converter(intermediate);
502
+
503
+ // Convert to destination numeric type
504
+ NumericArrayConverter<ElementOutput, int, kCount, Round>
505
+ destination_converter;
506
+
507
+ return destination_converter(scaled_accumulator);
508
+ } else {
509
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
510
+ destination_converter;
511
+ return destination_converter(intermediate);
512
+ }
513
+ }
514
+
515
+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
516
+ /// Scale and Bias are from input Fragment
517
+ CUTLASS_HOST_DEVICE
518
+ FragmentOutput operator()(
519
+ FragmentAccumulator const &accumulator,
520
+ FragmentScaleBias const &scale,
521
+ FragmentScaleBias const &bias) const {
522
+
523
+ // Convert source to interal compute numeric type
524
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
525
+
526
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
527
+
528
+ // Perform per-channel scale and bias
529
+ FragmentCompute intermediate;
530
+
531
+ multiply_add<FragmentCompute> mul_add_accumulator;
532
+
533
+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
534
+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
535
+ else
536
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
537
+
538
+ ReLu<FragmentCompute> relu;
539
+
540
+ // Compute threshold optionally
541
+ intermediate = relu(threshold_, intermediate);
542
+
543
+ if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
544
+ // Convert floats back to INT
545
+ FragmentAccumulator scaled_accumulator;
546
+
547
+ NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
548
+
549
+ scaled_accumulator = compute_converter(intermediate);
550
+
551
+ // Convert to destination numeric type
552
+ NumericArrayConverter<ElementOutput, int, kCount, Round>
553
+ destination_converter;
554
+
555
+ return destination_converter(scaled_accumulator);
556
+ } else {
557
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
558
+ destination_converter;
559
+ return destination_converter(intermediate);
560
+ }
561
+ }
562
+ };
563
+
564
+ #endif // Conditional guards to enable partial specialization for packed integers
565
+
566
+ /////////////////////////////////////////////////////////////////////////////////////////////////
567
+
568
+ } // namespace thread
569
+ } // namespace epilogue
570
+ } // namespace cutlass
571
+
572
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination with a relu operation used by epilogues.
33
+ This one only supports relu0 and tries to folding relu into other instructions. Thus,
34
+ serial splitk is not supported by this one. For example, relu can be folded into
35
+ hfma2/hmul2 for sm80+
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/half.h"
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+ #include "cutlass/functional.h"
45
+ #include "cutlass/numeric_conversion.h"
46
+ #include "cutlass/epilogue/thread/activation.h"
47
+ #include "cutlass/epilogue/thread/scale_type.h"
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ namespace cutlass {
52
+ namespace epilogue {
53
+ namespace thread {
54
+
55
+ /////////////////////////////////////////////////////////////////////////////////////////////////
56
+
57
+ namespace detail {
58
+
59
+ /// Single source of truth for whether to unroll for `LinearCombinationClamp()`
60
+ constexpr bool LinearCombinationRelu0IsHeavy() {
61
+ return false;
62
+ }
63
+
64
+ }
65
+
66
+ /////////////////////////////////////////////////////////////////////////////////////////////////
67
+
68
+ /// Applies a linear combination operator to an array of elements.
69
+ ///
70
+ /// D = alpha * accumulator + beta * source + uniform
71
+ ///
72
+ template <
73
+ typename ElementOutput_, ///< Data type used to load and store tensors
74
+ int Count, ///< Number of elements computed per operation
75
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
76
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
77
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
78
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
79
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
80
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
81
+ >
82
+ class LinearCombinationRelu0 {
83
+ public:
84
+
85
+ using ElementOutput = ElementOutput_;
86
+ using ElementAccumulator = ElementAccumulator_;
87
+ using ElementCompute = ElementCompute_;
88
+
89
+ static int const kCount = Count;
90
+ static const ScaleType::Kind kScale = Scale;
91
+
92
+ using FragmentOutput = Array<ElementOutput, kCount>;
93
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
94
+ using FragmentCompute = Array<ElementCompute, kCount>;
95
+ using FragmentScaleBias = Array<ElementCompute, kCount>;
96
+ using FragmentSource = Array<ElementOutput, kCount>;
97
+
98
+ static FloatRoundStyle const kRound = Round;
99
+
100
+ static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy();
101
+
102
+ /// Host-constructable parameters structure
103
+ struct Params {
104
+
105
+ ElementCompute alpha; ///< scales accumulators
106
+ ElementCompute beta; ///< scales source tensor
107
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
108
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
109
+ //
110
+ // Methods
111
+ //
112
+
113
+ CUTLASS_HOST_DEVICE
114
+ Params():
115
+ alpha(ElementCompute(1)),
116
+ beta(ElementCompute(0)),
117
+ alpha_ptr(nullptr),
118
+ beta_ptr(nullptr) { }
119
+
120
+ CUTLASS_HOST_DEVICE
121
+ Params(
122
+ ElementCompute alpha,
123
+ ElementCompute beta = ElementCompute(0)
124
+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
125
+
126
+ }
127
+
128
+ CUTLASS_HOST_DEVICE
129
+ Params(
130
+ ElementCompute const *alpha_ptr,
131
+ ElementCompute const *beta_ptr = nullptr
132
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
133
+
134
+ }
135
+ };
136
+
137
+ private:
138
+
139
+ //
140
+ // Data members
141
+ //
142
+
143
+ ElementCompute alpha_;
144
+ ElementCompute beta_;
145
+
146
+ public:
147
+
148
+ /// Constructs the function object, possibly loading from pointers in host memory
149
+ CUTLASS_HOST_DEVICE
150
+ LinearCombinationRelu0(Params const &params) {
151
+
152
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
153
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
154
+ }
155
+
156
+ /// Returns true if source is needed
157
+ CUTLASS_HOST_DEVICE
158
+ bool is_source_needed() const {
159
+ if (Scale == ScaleType::NoBetaScaling) return true;
160
+
161
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
162
+
163
+ if (Scale == ScaleType::Nothing) return false;
164
+
165
+ return beta_ != ElementCompute(0);
166
+ }
167
+
168
+ /// This is used for serial reduction which is not supported by Relu0
169
+ CUTLASS_HOST_DEVICE
170
+ void set_k_partition(int k_partition, int k_partition_count) {
171
+ assert(k_partition == 0);
172
+ }
173
+
174
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
175
+ CUTLASS_HOST_DEVICE
176
+ FragmentOutput operator()(
177
+ FragmentAccumulator const &accumulator,
178
+ FragmentOutput const &source) const {
179
+
180
+ // Convert source to interal compute numeric type
181
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
182
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
183
+
184
+ FragmentCompute converted_source = source_converter(source);
185
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
186
+
187
+ // Perform binary operations
188
+ FragmentCompute intermediate;
189
+
190
+ multiplies<FragmentCompute> mul_add_source;
191
+ multiply_add_relu0<FragmentCompute> mul_add_relu0_accumulator;
192
+ ReLu<FragmentCompute> relu;
193
+
194
+ if (Scale == ScaleType::NoBetaScaling) {
195
+ intermediate = converted_source;
196
+ intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
197
+ } else if (Scale == ScaleType::Nothing) {
198
+ intermediate = converted_accumulator;
199
+
200
+ // Compute threshold optionally
201
+ intermediate = relu(intermediate);
202
+ } else {
203
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
204
+ intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
205
+ }
206
+
207
+ // Convert to destination numeric type
208
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
209
+
210
+ return destination_converter(intermediate);
211
+ }
212
+
213
+ /// Computes linear scaling: D = alpha * accumulator
214
+ CUTLASS_HOST_DEVICE
215
+ FragmentOutput operator()(
216
+ FragmentAccumulator const &accumulator) const {
217
+
218
+ // Convert source to interal compute numeric type
219
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
220
+
221
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
222
+
223
+ // Perform binary operations
224
+ FragmentCompute intermediate;
225
+
226
+ multiplies<FragmentCompute> mul_accumulator;
227
+ ReLu<FragmentCompute> relu;
228
+
229
+ if (Scale == ScaleType::Nothing) {
230
+ intermediate = converted_accumulator;
231
+ } else {
232
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
233
+ }
234
+
235
+ // Compute threshold optionally
236
+ intermediate = relu(intermediate);
237
+
238
+ // Convert to destination numeric type
239
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
240
+
241
+ return destination_converter(intermediate);
242
+ }
243
+
244
+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
245
+ /// Scale and Bias are from input Fragment
246
+ CUTLASS_HOST_DEVICE
247
+ FragmentOutput operator()(
248
+ FragmentAccumulator const &accumulator,
249
+ FragmentScaleBias const &scale,
250
+ FragmentScaleBias const &bias) const {
251
+
252
+ // Convert source to interal compute numeric type
253
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
254
+
255
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
256
+
257
+ // Perform per-channel scale and bias
258
+ FragmentCompute intermediate;
259
+
260
+ multiply_add<FragmentCompute> mul_add_accumulator;
261
+
262
+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
263
+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
264
+ else
265
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
266
+
267
+ ReLu<FragmentCompute> relu;
268
+
269
+ // Compute threshold optionally
270
+ intermediate = relu(intermediate);
271
+
272
+ // Convert to destination numeric type
273
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
274
+
275
+ return destination_converter(intermediate);
276
+ }
277
+ };
278
+
279
+ /////////////////////////////////////////////////////////////////////////////////////////////////
280
+
281
+ // Conditional guards to enable partial specialization for packed integers
282
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
283
+
284
+ /// Applies a linear combination operator to an array of elements.
285
+ ///
286
+ /// D = alpha * accumulator + beta * source + uniform
287
+ ///
288
+ /// Special handling for int types
289
+
290
+ template <
291
+ typename ElementOutput_, ///< Data type used to load and store tensors
292
+ int Count, ///< Number of elements computed per operation
293
+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling
294
+ FloatRoundStyle Round
295
+ >
296
+ class LinearCombinationRelu0 <ElementOutput_, Count, int, float, Scale, Round> {
297
+ public:
298
+
299
+ using ElementOutput = ElementOutput_;
300
+ using ElementAccumulator = int;
301
+ using ElementCompute = float;
302
+
303
+ static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy();
304
+
305
+ static int const kCount = Count;
306
+ static const ScaleType::Kind kScale = Scale;
307
+
308
+ using FragmentOutput = Array<ElementOutput, kCount>;
309
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
310
+ using FragmentCompute = Array<ElementCompute, kCount>;
311
+ using FragmentScaleBias = Array<ElementCompute, kCount>;
312
+ using FragmentSource = Array<ElementOutput, kCount>;
313
+
314
+ static FloatRoundStyle const kRound = Round;
315
+
316
+ /// Host-constructable parameters structure
317
+ struct Params {
318
+
319
+ ElementCompute alpha; ///< scales accumulators
320
+ ElementCompute beta; ///< scales source tensor
321
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
322
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
323
+ //
324
+ // Methods
325
+ //
326
+
327
+ CUTLASS_HOST_DEVICE
328
+ Params():
329
+ alpha(ElementCompute(1)),
330
+ beta(ElementCompute(0)),
331
+ alpha_ptr(nullptr),
332
+ beta_ptr(nullptr) { }
333
+
334
+ CUTLASS_HOST_DEVICE
335
+ Params(
336
+ ElementCompute alpha,
337
+ ElementCompute beta = ElementCompute(0)
338
+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
339
+
340
+ }
341
+
342
+ CUTLASS_HOST_DEVICE
343
+ Params(
344
+ ElementCompute const *alpha_ptr,
345
+ ElementCompute const *beta_ptr = nullptr
346
+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
347
+
348
+ }
349
+ };
350
+
351
+ private:
352
+
353
+ //
354
+ // Data members
355
+ //
356
+
357
+ ElementCompute alpha_;
358
+ ElementCompute beta_;
359
+
360
+ public:
361
+
362
+ /// Constructs the function object, possibly loading from pointers in host memory
363
+ CUTLASS_HOST_DEVICE
364
+ LinearCombinationRelu0(Params const &params) {
365
+
366
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
367
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
368
+ }
369
+
370
+ /// Returns true if source is needed
371
+ CUTLASS_HOST_DEVICE
372
+ bool is_source_needed() const {
373
+ if (Scale == ScaleType::NoBetaScaling) return true;
374
+
375
+ if (Scale == ScaleType::OnlyAlphaScaling) return false;
376
+
377
+ if (Scale == ScaleType::Nothing) return false;
378
+
379
+ return beta_ != ElementCompute(0);
380
+ }
381
+
382
+ /// This is used for serial reduction which is not supported by Relu0
383
+ CUTLASS_HOST_DEVICE
384
+ void set_k_partition(int k_partition, int k_partition_count) {
385
+ assert(k_partition == 0);
386
+ }
387
+
388
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
389
+ CUTLASS_HOST_DEVICE
390
+ FragmentOutput operator()(
391
+ FragmentAccumulator const &accumulator,
392
+ FragmentOutput const &source) const {
393
+
394
+ // Convert source to interal compute numeric type
395
+ NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
396
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
397
+
398
+ FragmentCompute converted_source = source_converter(source);
399
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
400
+
401
+ // Perform binary operations
402
+ FragmentCompute intermediate;
403
+
404
+ multiplies<FragmentCompute> mul_add_source;
405
+ multiply_add<FragmentCompute> mul_add_accumulator;
406
+ ReLu<FragmentCompute> relu;
407
+
408
+ if (Scale == ScaleType::NoBetaScaling) {
409
+ intermediate = converted_source;
410
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
411
+ } else if (Scale == ScaleType::Nothing) {
412
+ intermediate = converted_accumulator;
413
+ } else {
414
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
415
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
416
+ }
417
+
418
+ // Compute threshold optionally
419
+ intermediate = relu(intermediate);
420
+
421
+ if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
422
+ // Convert floats back to INT
423
+ FragmentAccumulator scaled_accumulator;
424
+
425
+ NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
426
+
427
+ scaled_accumulator = compute_converter(intermediate);
428
+
429
+ // Convert to destination numeric type
430
+ NumericArrayConverter<ElementOutput, int, kCount, Round>
431
+ destination_converter;
432
+
433
+ return destination_converter(scaled_accumulator);
434
+ } else {
435
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
436
+ destination_converter;
437
+ return destination_converter(intermediate);
438
+ }
439
+ }
440
+
441
+ /// Computes linear scaling: D = alpha * accumulator
442
+ CUTLASS_HOST_DEVICE
443
+ FragmentOutput operator()(
444
+ FragmentAccumulator const &accumulator) const {
445
+
446
+ // Convert source to interal compute numeric type
447
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
448
+
449
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
450
+
451
+ // Perform binary operations
452
+ FragmentCompute intermediate;
453
+
454
+ multiplies<FragmentCompute> mul_accumulator;
455
+ ReLu<FragmentCompute> relu;
456
+
457
+ if (Scale == ScaleType::Nothing) {
458
+ intermediate = converted_accumulator;
459
+ } else {
460
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
461
+ }
462
+
463
+ // Compute threshold optionally
464
+ intermediate = relu(intermediate);
465
+
466
+ if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
467
+ // Convert floats back to INT
468
+ FragmentAccumulator scaled_accumulator;
469
+
470
+ NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
471
+
472
+ scaled_accumulator = compute_converter(intermediate);
473
+
474
+ // Convert to destination numeric type
475
+ NumericArrayConverter<ElementOutput, int, kCount, Round>
476
+ destination_converter;
477
+
478
+ return destination_converter(scaled_accumulator);
479
+ } else {
480
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
481
+ destination_converter;
482
+ return destination_converter(intermediate);
483
+ }
484
+ }
485
+
486
+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias
487
+ /// Scale and Bias are from input Fragment
488
+ CUTLASS_HOST_DEVICE
489
+ FragmentOutput operator()(
490
+ FragmentAccumulator const &accumulator,
491
+ FragmentScaleBias const &scale,
492
+ FragmentScaleBias const &bias) const {
493
+
494
+ // Convert source to interal compute numeric type
495
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
496
+
497
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
498
+
499
+ // Perform per-channel scale and bias
500
+ FragmentCompute intermediate;
501
+
502
+ multiply_add<FragmentCompute> mul_add_accumulator;
503
+
504
+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
505
+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
506
+ else
507
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias
508
+
509
+ ReLu<FragmentCompute> relu;
510
+
511
+ // Compute threshold optionally
512
+ intermediate = relu(intermediate);
513
+
514
+ if (cutlass::platform::numeric_limits<ElementOutput>::is_integer) {
515
+ // Convert floats back to INT
516
+ FragmentAccumulator scaled_accumulator;
517
+
518
+ NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
519
+
520
+ scaled_accumulator = compute_converter(intermediate);
521
+
522
+ // Convert to destination numeric type
523
+ NumericArrayConverter<ElementOutput, int, kCount, Round>
524
+ destination_converter;
525
+
526
+ return destination_converter(scaled_accumulator);
527
+ } else {
528
+ NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
529
+ destination_converter;
530
+ return destination_converter(intermediate);
531
+ }
532
+ }
533
+ };
534
+
535
+ #endif // Conditional guards to enable partial specialization for packed integers
536
+
537
+ /////////////////////////////////////////////////////////////////////////////////////////////////
538
+
539
+ } // namespace thread
540
+ } // namespace epilogue
541
+ } // namespace cutlass
542
+
543
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Epilogue functor specialized for residual blocks in deep neural networks.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/functional.h"
40
+ #include "cutlass/numeric_conversion.h"
41
+ #include "cutlass/epilogue/thread/detail.hpp"
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace epilogue {
47
+ namespace thread {
48
+
49
+ /// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
50
+ template <typename ElementOutput_, typename ElementAccumulator_,
51
+ typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
52
+ template <typename T> class ActivationOp_,
53
+ template <typename T> class BinaryOp1_,
54
+ template <typename T> class UnaryOp_,
55
+ template <typename T> class BinaryOp2_ = detail::NoOp,
56
+ bool StoreT_ = false,
57
+ typename ElementVector_ = ElementC_>
58
+ class LinearCombinationResidualBlock {
59
+ public:
60
+ static bool const kIsSingleSource = false;
61
+
62
+ using ElementOutput = ElementC_;
63
+ using ElementC = ElementC_;
64
+ using ElementAccumulator = ElementAccumulator_;
65
+ using ElementCompute = ElementCompute_;
66
+ using ElementVector = ElementVector_;
67
+ static int const kElementsPerAccess = ElementsPerAccess;
68
+ static int const kCount = kElementsPerAccess;
69
+
70
+ using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
71
+ using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
72
+ using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
73
+ using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
74
+
75
+ using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
76
+ using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
77
+ using FragmentC = Array<ElementC, kElementsPerAccess>;
78
+ using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
79
+
80
+ using ElementZ = ElementOutput_;
81
+ using ElementT = ElementZ;
82
+ using FragmentZ = Array<ElementZ, kElementsPerAccess>;
83
+ using FragmentT = Array<ElementT, kElementsPerAccess>;
84
+
85
+ static bool const kIsHeavy = true;
86
+ static bool const kStoreZ = true;
87
+ static bool const kStoreT = StoreT_;
88
+
89
+ /// Host-constructable parameters structure
90
+ struct Params {
91
+
92
+ ElementCompute alpha; ///< scales accumulators
93
+ ElementCompute beta; ///< scales residual input
94
+ ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
95
+ ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
96
+
97
+ CUTLASS_HOST_DEVICE
98
+ Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
99
+
100
+ CUTLASS_HOST_DEVICE
101
+ Params(ElementCompute alpha, ElementCompute beta)
102
+ : alpha(alpha), beta(beta) {}
103
+
104
+ CUTLASS_HOST_DEVICE
105
+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
106
+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
107
+ };
108
+
109
+ private:
110
+
111
+ ElementCompute alpha_;
112
+ ElementCompute beta_;
113
+ bool skip_elementwise_;
114
+
115
+ public:
116
+
117
+ /// Constructor from Params
118
+ CUTLASS_HOST_DEVICE
119
+ LinearCombinationResidualBlock(Params const &params) {
120
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
121
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
122
+ skip_elementwise_ = false;
123
+ }
124
+
125
+ /// The "source" tensor corresponds to the residual input
126
+ CUTLASS_HOST_DEVICE
127
+ bool is_source_needed() const { return true; }
128
+
129
+ /// Functionally required for serial reduction in the epilogue
130
+ /// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
131
+ CUTLASS_HOST_DEVICE
132
+ void set_k_partition(int k_partition, int k_partition_count) {
133
+ if (k_partition) {
134
+ beta_ = ElementCompute(1);
135
+ }
136
+
137
+ if (k_partition != k_partition_count - 1) {
138
+ skip_elementwise_ = true;
139
+ }
140
+ }
141
+
142
+ /// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
143
+ CUTLASS_HOST_DEVICE
144
+ void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
145
+ FragmentC const &residual1, FragmentC const &residual2,
146
+ FragmentCompute const &bias) const {
147
+ UnaryOp unary_op;
148
+ BinaryOp1 binary_op1;
149
+ BinaryOp2 binary_op2;
150
+ ActivationOp activation;
151
+
152
+ FragmentCompute tmp_Accum =
153
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
154
+ FragmentCompute tmp_residual1 =
155
+ NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
156
+ FragmentCompute tmp_residual2 =
157
+ NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
158
+
159
+ FragmentCompute z =
160
+ binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
161
+ FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
162
+
163
+ NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
164
+ frag_Z = convert_z(result_Z);
165
+ }
166
+
167
+ /// Should never be called
168
+ CUTLASS_HOST_DEVICE
169
+ void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
170
+ FragmentCompute const &) const {}
171
+ };
172
+
173
+ /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
174
+ template <typename ElementOutput_, typename ElementAccumulator_,
175
+ typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
176
+ template <typename T> class ActivationOp_,
177
+ template <typename T> class BinaryOp1_,
178
+ template <typename T> class UnaryOp_,
179
+ bool StoreT_,
180
+ typename ElementVector_>
181
+ class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
182
+ ElementCompute_, ElementC_, ElementsPerAccess,
183
+ ActivationOp_, BinaryOp1_, UnaryOp_,
184
+ detail::NoOp, StoreT_, ElementVector_> {
185
+ public:
186
+ static bool const kIsSingleSource = true;
187
+
188
+ using ElementOutput = ElementC_;
189
+ using ElementC = ElementC_;
190
+ using ElementAccumulator = ElementAccumulator_;
191
+ using ElementCompute = ElementCompute_;
192
+ using ElementVector = ElementVector_;
193
+ static int const kElementsPerAccess = ElementsPerAccess;
194
+ static int const kCount = kElementsPerAccess;
195
+
196
+ using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
197
+ using BinaryOp = BinaryOp1_<Array<ElementCompute, kCount>>;
198
+ using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
199
+
200
+ using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
201
+ using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
202
+ using FragmentC = Array<ElementC, kElementsPerAccess>;
203
+ using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
204
+
205
+ using ElementZ = ElementOutput_;
206
+ using ElementT = ElementZ;
207
+ using FragmentZ = Array<ElementZ, kElementsPerAccess>;
208
+ using FragmentT = Array<ElementT, kElementsPerAccess>;
209
+
210
+ static bool const kIsHeavy = true;
211
+ static bool const kStoreZ = true;
212
+ static bool const kStoreT = StoreT_;
213
+
214
+ /// Host-constructable parameters structure
215
+ struct Params {
216
+
217
+ ElementCompute alpha; ///< scales accumulators
218
+ ElementCompute beta; ///< scales residual input
219
+ ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
220
+ ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
221
+
222
+ CUTLASS_HOST_DEVICE
223
+ Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
224
+
225
+ CUTLASS_HOST_DEVICE
226
+ Params(ElementCompute alpha, ElementCompute beta)
227
+ : alpha(alpha), beta(beta) {}
228
+
229
+ CUTLASS_HOST_DEVICE
230
+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
231
+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
232
+ };
233
+
234
+ private:
235
+
236
+ ElementCompute alpha_;
237
+ ElementCompute beta_;
238
+ bool skip_elementwise_;
239
+
240
+ public:
241
+
242
+ /// Constructor from Params
243
+ CUTLASS_HOST_DEVICE
244
+ LinearCombinationResidualBlock(Params const &params) {
245
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
246
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
247
+ skip_elementwise_ = false;
248
+ }
249
+
250
+ /// The "source" tensor corresponds to the residual input
251
+ CUTLASS_HOST_DEVICE
252
+ bool is_source_needed() const { return true; }
253
+
254
+ /// Functionally required for serial reduction in the epilogue
255
+ /// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
256
+ CUTLASS_HOST_DEVICE
257
+ void set_k_partition(int k_partition, int k_partition_count) {
258
+ if (k_partition) {
259
+ beta_ = ElementCompute(1);
260
+ }
261
+
262
+ if (k_partition != k_partition_count - 1) {
263
+ skip_elementwise_ = true;
264
+ }
265
+ }
266
+
267
+ /// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual))
268
+ CUTLASS_HOST_DEVICE
269
+ void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
270
+ FragmentC const &residual,
271
+ FragmentCompute const &bias) const {
272
+ UnaryOp unary_op;
273
+ BinaryOp binary_op;
274
+ ActivationOp activation;
275
+
276
+ FragmentCompute tmp_Accum =
277
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
278
+ FragmentCompute tmp_residual =
279
+ NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual);
280
+
281
+ FragmentCompute z =
282
+ binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual);
283
+ FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
284
+
285
+ NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
286
+ frag_Z = convert_z(result_Z);
287
+ }
288
+
289
+ /// Should never be called
290
+ CUTLASS_HOST_DEVICE
291
+ void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
292
+ FragmentCompute const &) const {}
293
+ };
294
+
295
+ /////////////////////////////////////////////////////////////////////////////////////////////////
296
+
297
+ } // namespace thread
298
+ } // namespace epilogue
299
+ } // namespace cutlass
300
+
301
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination with Sigmoid operations used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/epilogue/thread/activation.h"
39
+ #include "cutlass/epilogue/thread/linear_combination_generic.h"
40
+
41
+ /////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ namespace cutlass {
44
+ namespace epilogue {
45
+ namespace thread {
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ /// Applies a linear combination operator followed by the Sigmoid activation, to an array of elements.
50
+ ///
51
+ /// D = sigmoid(alpha * accumulator + beta * source + uniform)
52
+ ///
53
+ template <
54
+ typename ElementOutput_, ///< Data type used to load and store tensors
55
+ int Count, ///< Number of elements computed per operation
56
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
57
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
58
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
59
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
60
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
61
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
62
+ >
63
+ using LinearCombinationSigmoid = LinearCombinationGeneric<Sigmoid, ElementOutput_, Count, ElementAccumulator_,
64
+ ElementCompute_, Scale, Round, true>;
65
+
66
+ /////////////////////////////////////////////////////////////////////////////////////////////////
67
+
68
+ } // namespace thread
69
+ } // namespace epilogue
70
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing linear combination with SiLU operations used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/epilogue/thread/activation.h"
39
+ #include "cutlass/epilogue/thread/linear_combination_generic.h"
40
+
41
+ /////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ namespace cutlass {
44
+ namespace epilogue {
45
+ namespace thread {
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ /// Applies a linear combination operator folllowed by the SiLU activation to an array of elements.
50
+ ///
51
+ /// D = silu(alpha * accumulator + beta * source + uniform)
52
+ ///
53
+ template <
54
+ typename ElementOutput_, ///< Data type used to load and store tensors
55
+ int Count, ///< Number of elements computed per operation
56
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
57
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
58
+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
59
+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
60
+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
61
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
62
+ >
63
+ using LinearCombinationSilu = LinearCombinationGeneric<SiLu, ElementOutput_, Count, ElementAccumulator_,
64
+ ElementCompute_, Scale, Round, true>;
65
+ /////////////////////////////////////////////////////////////////////////////////////////////////
66
+
67
+ } // namespace thread
68
+ } // namespace epilogue
69
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Functor performing linear combination operation, bias addition, and tensor-tensor
34
+ elementwise operations
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/numeric_types.h"
44
+ #include "cutlass/epilogue/thread/activation.h"
45
+ #include "cutlass/epilogue/thread/detail.hpp"
46
+ #include "cutlass/epilogue/thread/scale_type.h"
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace epilogue {
52
+ namespace thread {
53
+
54
+ namespace detail {
55
+
56
+ /// Returns whether a source operand is needed for a combination of binary operation and scale
57
+ /// type. Simple specialized checks are made for cases in which 0 is an identity element of
58
+ /// the binary operation.
59
+ template <class BinaryOp, class ElementCompute, ScaleType::Kind Scale>
60
+ CUTLASS_HOST_DEVICE
61
+ bool is_binary_op_source_needed(ElementCompute scale) {
62
+ if constexpr (cute::is_same_v<BinaryOp, NoOp<ElementCompute>>) {
63
+ return false;
64
+ }
65
+ else if constexpr (cute::is_same_v<BinaryOp, plus<ElementCompute>> || cute::is_same_v<BinaryOp, minus<ElementCompute>>) {
66
+ // Cases for binary operators for which 0 is an identity element
67
+ if constexpr (Scale == ScaleType::NoBetaScaling) return true;
68
+
69
+ if constexpr (Scale == ScaleType::OnlyAlphaScaling) return false;
70
+
71
+ if constexpr (Scale == ScaleType::Nothing) return false;
72
+
73
+ return scale != ElementCompute(0);
74
+ }
75
+
76
+ return true;
77
+ }
78
+
79
+ } // namespace detail
80
+
81
+ /////////////////////////////////////////////////////////////////////////////////////////////////
82
+
83
+ /** Compute a tensor-tensor broadcast epilogue.
84
+ *
85
+ * @param ElementOutput_ Data type used to load and store tensors
86
+ * @param ElementAccumulator_ Accumulator data type
87
+ * @param ElementCompute_ Data type used to compute linear combination
88
+ * @param ElementBias_ Data type of Bias elements
89
+ * @param ActivationFunctor_ Fused Activation
90
+ * @param BinaryOp0_ Binary operation to perform on O0 and C0. detail::NoOp means no operation
91
+ * @param BinaryOp1_ Binary operation to perform on O1 and C1. detail::NoOp means no operation
92
+ * @param UnaryOp_ Unary operation to perform on final result
93
+ * @param Scale Controls the type of Alpha and Beta scaling to perform
94
+ * @param Round How values should be rounded in conversions
95
+ * @param ElementSource_ Data type used for source operands
96
+ *
97
+ * Computes the following:
98
+ * O0 = alpha * accumulator + bias
99
+ * O1 = BinaryOp0(O0, beta * C0)
100
+ * O2 = BinaryOp1(O1, beta * C1)
101
+ * D = UnaryOp(O2)
102
+ */
103
+ template <
104
+ class ElementOutput_,
105
+ class ElementAccumulator_ = ElementOutput_,
106
+ class ElementCompute_ = ElementOutput_,
107
+ class ElementBias_ = ElementCompute_,
108
+ template <class T> class ActivationFunctor_ = Identity,
109
+ template <class T> class BinaryOp0_ = plus,
110
+ template <class T> class BinaryOp1_ = detail::NoOp,
111
+ template <class T> class UnaryOp_ = Identity,
112
+ ScaleType::Kind Scale = ScaleType::Default,
113
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
114
+ class ElementSource_ = ElementOutput_
115
+ >
116
+ class LinearCombinationTensorBroadcast {
117
+ public:
118
+
119
+ using ElementOutput = ElementOutput_;
120
+ using ElementAccumulator = ElementAccumulator_;
121
+ using ElementCompute = ElementCompute_;
122
+ using ElementScalar = ElementCompute;
123
+ using ElementBias = ElementBias_;
124
+ using ElementC = ElementSource_;
125
+ using ElementD = ElementOutput_;
126
+ using ElementScalingFactor = ElementAccumulator_;
127
+
128
+ using UnaryOp = UnaryOp_<ElementCompute>;
129
+ using BinaryOp0 = BinaryOp0_<ElementCompute>;
130
+ using BinaryOp1 = BinaryOp1_<ElementCompute>;
131
+ using ActivationFunctor = ActivationFunctor_<ElementCompute>;
132
+
133
+ static constexpr int kCount = 1;
134
+ static constexpr ScaleType::Kind kScale = Scale;
135
+
136
+ using FragmentOutput = Array<ElementOutput, kCount>;
137
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
138
+ using ComputeFragment = Array<ElementCompute, kCount>;
139
+ using FragmentBias = Array<ElementBias, kCount>;
140
+
141
+ static constexpr FloatRoundStyle kRound = Round;
142
+ using NoOpType = detail::NoOp<ElementCompute>;
143
+ static constexpr bool IsBinaryOp0Enabled = !cute::is_same_v<BinaryOp0, NoOpType>;
144
+ static constexpr bool IsBinaryOp1Enabled = !cute::is_same_v<BinaryOp1, NoOpType>;
145
+ static constexpr bool IsUnaryOpEnabled = !cute::is_same_v<UnaryOp, NoOpType> && !cute::is_same_v<UnaryOp, Identity<ElementCompute>>;
146
+
147
+ /// Host-constructable parameters structure
148
+ struct Params {
149
+
150
+ ElementCompute alpha{}; ///< scales accumulators
151
+ ElementCompute beta{}; ///< scales source tensor
152
+ ElementCompute const* alpha_ptr = nullptr; ///< pointer to accumulator scalar - if not null, loads it from memory
153
+ ElementCompute const* beta_ptr = nullptr; ///< pointer to source scalar - if not null, loads it from memory
154
+
155
+ //
156
+ // Methods
157
+ //
158
+ Params() = default;
159
+
160
+ CUTLASS_HOST_DEVICE
161
+ Params(ElementCompute const* alpha_ptr, ElementCompute const* beta_ptr)
162
+ : alpha_ptr(alpha_ptr),
163
+ beta_ptr(beta_ptr) {}
164
+
165
+ CUTLASS_HOST_DEVICE
166
+ Params(ElementCompute const* alpha_ptr)
167
+ : alpha_ptr(alpha_ptr) {}
168
+
169
+ CUTLASS_HOST_DEVICE
170
+ Params(ElementCompute alpha,
171
+ ElementCompute beta)
172
+ : alpha(alpha),
173
+ beta(beta) {}
174
+ };
175
+
176
+ private:
177
+ //
178
+ // Data members
179
+ //
180
+
181
+ ElementCompute alpha_;
182
+ ElementCompute beta_;
183
+
184
+ public:
185
+
186
+ /// Constructs the function object, possibly loading from pointers in host memory
187
+ CUTLASS_HOST_DEVICE
188
+ LinearCombinationTensorBroadcast(Params const& params)
189
+ : alpha_(params.alpha_ptr ? *params.alpha_ptr : params.alpha),
190
+ beta_(params.beta_ptr ? *params.beta_ptr : params.beta) {}
191
+
192
+ /// Returns true if source 0 is needed
193
+ CUTLASS_HOST_DEVICE
194
+ bool is_source0_needed() const {
195
+ return detail::is_binary_op_source_needed<BinaryOp0, ElementCompute, Scale>(beta_);
196
+ }
197
+
198
+ /// Returns true if source 1 is needed
199
+ CUTLASS_HOST_DEVICE
200
+ bool is_source1_needed() const {
201
+ return detail::is_binary_op_source_needed<BinaryOp1, ElementCompute, Scale>(beta_);
202
+ }
203
+
204
+ //
205
+ // Specialization for scalar
206
+ //
207
+ CUTLASS_HOST_DEVICE
208
+ ElementD operator()(ElementAccumulator const accumulator, ElementC const source0, ElementC source1, ElementBias const bias) {
209
+ // Convert everything to Compute type, do compute, and then store to output type
210
+ NumericConverter<ElementCompute, ElementAccumulator, Round> accumulator_converter;
211
+ NumericConverter<ElementCompute, ElementBias, Round> bias_converter;
212
+ NumericConverter<ElementCompute, ElementC, Round> source_converter;
213
+ NumericConverter<ElementD, ElementCompute, Round> destination_converter;
214
+
215
+ ActivationFunctor act;
216
+ multiplies<ElementCompute> mul;
217
+ multiply_add<ElementCompute> madd;
218
+
219
+ ElementCompute intermediate = accumulator_converter(accumulator);
220
+ intermediate = madd(alpha_, intermediate, bias_converter(bias));
221
+ intermediate = act(intermediate);
222
+
223
+ // Apply BinaryOp0, if needed
224
+ if constexpr (IsBinaryOp0Enabled) {
225
+ BinaryOp0 bin0;
226
+ ElementCompute converted_source = source_converter(source0);
227
+ intermediate = bin0(intermediate, mul(beta_, converted_source));
228
+ }
229
+
230
+ // Apply BinaryOp1, if needed
231
+ if constexpr (IsBinaryOp1Enabled) {
232
+ BinaryOp1 bin1;
233
+ ElementCompute converted_source = source_converter(source1);
234
+ intermediate = bin1(intermediate, mul(beta_, converted_source));
235
+ }
236
+
237
+ // Apply UnaryOp, if needed
238
+ if constexpr (IsUnaryOpEnabled) {
239
+ UnaryOp unary;
240
+ intermediate = unary(intermediate);
241
+ }
242
+
243
+ return destination_converter(intermediate);
244
+ }
245
+ };
246
+
247
+ /////////////////////////////////////////////////////////////////////////////////////////////////
248
+
249
+ } // namespace thread
250
+ } // namespace epilogue
251
+ } // namespace cutlass
252
+
253
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+
33
+ \brief Functor performing linear combination with elementwise
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/half.h"
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/array.h"
42
+ #include "cutlass/constants.h"
43
+ #include "cutlass/fast_math.h"
44
+ #include "cutlass/functional.h"
45
+ #include "cutlass/numeric_conversion.h"
46
+ #include "cutlass/epilogue/thread/activation.h"
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace epilogue {
52
+ namespace thread {
53
+
54
+ /////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ /// Applies a linear combination operator to an array of elements.
57
+ ///
58
+ /// D = alpha * accumulator + beta * source + uniform
59
+ ///
60
+ template <
61
+ typename ElementCompute_, ///< Data type returned by this functor
62
+ typename ElementAccumulator_, ///< Data type of accumulators
63
+ typename ElementSource_, ///< Data type of source tensor
64
+ typename ElementTensor_, ///< Data type of additional tensor
65
+ int Count, ///< Number of elements computed per operation
66
+ ///< Usually it is 128/sizeof_bits<ElementOutput_>,
67
+ ///< but we use 64 or 32 sometimes when there are not enough data to store
68
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
69
+ >
70
+ class LinearCombinationWithElementwise {
71
+ public:
72
+
73
+ using ElementOutput = ElementSource_;
74
+ using ElementCompute = ElementCompute_;
75
+ using ElementAccumulator = ElementAccumulator_;
76
+ using ElementSource = ElementSource_;
77
+ using ElementTensor = ElementTensor_;
78
+
79
+ static bool const kIsHeavy = true;
80
+
81
+ static int const kCount = Count;
82
+
83
+ using FragmentCompute = Array<ElementCompute, kCount>;
84
+ using FragmentAccumulator = Array<ElementAccumulator, kCount>;
85
+ using FragmentSource = Array<ElementSource, kCount>;
86
+ using FragmentTensor = Array<ElementTensor, kCount>;
87
+
88
+ static FloatRoundStyle const kRound = Round;
89
+
90
+ /// Host-constructable parameters structure
91
+ struct Params {
92
+
93
+ ElementCompute alpha; ///< scales accumulators
94
+ ElementCompute beta; ///< scales source tensor
95
+ ElementCompute threshold; ///< minimum value that is output
96
+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
97
+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
98
+ //
99
+ // Methods
100
+ //
101
+
102
+ CUTLASS_HOST_DEVICE
103
+ Params():
104
+ alpha(ElementCompute(1)),
105
+ beta(ElementCompute(0)),
106
+ threshold(ElementCompute(0)),
107
+ alpha_ptr(nullptr),
108
+ beta_ptr(nullptr) { }
109
+
110
+ CUTLASS_HOST_DEVICE
111
+ Params(
112
+ ElementCompute alpha,
113
+ ElementCompute beta,
114
+ ElementCompute threshold = ElementCompute(0)
115
+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
116
+
117
+ }
118
+
119
+ CUTLASS_HOST_DEVICE
120
+ Params(
121
+ ElementCompute const *alpha_ptr,
122
+ ElementCompute const *beta_ptr,
123
+ ElementCompute threshold = ElementCompute(0)
124
+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
125
+
126
+ }
127
+ };
128
+
129
+ private:
130
+
131
+ //
132
+ // Data members
133
+ //
134
+
135
+ ElementCompute alpha_;
136
+ ElementCompute beta_;
137
+ ElementCompute threshold_;
138
+ bool participates_in_reduction_;
139
+
140
+ public:
141
+
142
+ /// Constructs the function object, possibly loading from pointers in host memory
143
+ CUTLASS_HOST_DEVICE
144
+ LinearCombinationWithElementwise(Params const &params) {
145
+
146
+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
147
+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
148
+ threshold_ = params.threshold;
149
+ participates_in_reduction_ = true;
150
+ }
151
+
152
+ /// Returns true if source is needed
153
+ CUTLASS_HOST_DEVICE
154
+ bool is_source_needed() const {
155
+ return beta_ != ElementCompute(0);
156
+ }
157
+
158
+ /// Returns true if the threadblock computes the reduction
159
+ CUTLASS_HOST_DEVICE
160
+ bool participates_in_reduction() const {
161
+ return participates_in_reduction_;
162
+ }
163
+
164
+ /// Functionally required for serial reduction in the epilogue
165
+ CUTLASS_HOST_DEVICE
166
+ void set_k_partition(int k_partition, int k_partition_count) {
167
+ if (k_partition) {
168
+ beta_ = ElementCompute(1);
169
+ }
170
+
171
+ if (k_partition != k_partition_count - 1) {
172
+ // set to NaN to make ReLU no-op for all except last k partitions
173
+ int64_t allones = -1;
174
+ threshold_ = reinterpret_cast<ElementCompute const &>(allones);
175
+ // Avoid computing the reduction if this isn't the final Split-K slice
176
+ participates_in_reduction_ = false;
177
+ }
178
+ }
179
+
180
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
181
+ CUTLASS_HOST_DEVICE
182
+ FragmentCompute operator()(
183
+ FragmentAccumulator const &accumulator,
184
+ FragmentSource const &source,
185
+ FragmentTensor const &tensor) const {
186
+
187
+ // Convert source to interal compute numeric type
188
+ NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
189
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
190
+
191
+ FragmentCompute converted_source = source_converter(source);
192
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
193
+
194
+ // Perform binary operations
195
+ FragmentCompute intermediate;
196
+
197
+ multiplies<FragmentCompute> mul_add_source;
198
+ multiply_add<FragmentCompute> mul_add_accumulator;
199
+
200
+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
201
+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
202
+
203
+ return intermediate;
204
+ }
205
+
206
+ /// Computes linear scaling: D = alpha * accumulator
207
+ CUTLASS_HOST_DEVICE
208
+ FragmentCompute operator()(
209
+ FragmentAccumulator const &accumulator,
210
+ FragmentTensor const &tensor) const {
211
+
212
+ // Convert source to interal compute numeric type
213
+ NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
214
+
215
+ FragmentCompute converted_accumulator = accumulator_converter(accumulator);
216
+
217
+ // Perform binary operations
218
+ FragmentCompute intermediate;
219
+
220
+ multiplies<FragmentCompute> mul_accumulator;
221
+
222
+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
223
+
224
+ return intermediate;
225
+ }
226
+ };
227
+
228
+ /////////////////////////////////////////////////////////////////////////////////////////////////
229
+
230
+ } // namespace thread
231
+ } // namespace epilogue
232
+ } // namespace cutlass
233
+
234
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/reduction_op.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Functor performing reduction operations used by epilogues.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/numeric_types.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/functional.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace epilogue {
47
+ namespace thread {
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ /// Applies a reduction sum to an array of elements.
52
+ ///
53
+ ///
54
+ template <
55
+ typename Element_, ///< Data type used to load and store tensors
56
+ int Count ///< Number of elements computed per operation
57
+ >
58
+ class ReductionOpPlus {
59
+ public:
60
+
61
+ using Element = Element_;
62
+ static int const kCount = Count;
63
+
64
+ using Fragment = Array<Element, kCount>;
65
+ using Operator = plus<Fragment>;
66
+
67
+ /// Host-constructable parameters structure
68
+ struct Params { };
69
+
70
+ private:
71
+
72
+ /// reduction operator
73
+ Operator operator_;
74
+
75
+ public:
76
+
77
+ /// Constructs the function object, possibly loading from pointers in host memory
78
+ CUTLASS_HOST_DEVICE
79
+ ReductionOpPlus(Params const &params) {
80
+
81
+ }
82
+
83
+ /// Computes Compute =>
84
+ CUTLASS_HOST_DEVICE
85
+ Fragment operator()(
86
+ Fragment const &lhs,
87
+ Fragment const &rhs) const {
88
+
89
+ return operator_(lhs, rhs);
90
+ }
91
+ };
92
+
93
+ /////////////////////////////////////////////////////////////////////////////////////////////////
94
+
95
+ } // namespace thread
96
+ } // namespace epilogue
97
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/scale_type.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Enum defines the behaviors of the epilogue.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+
39
+ /////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ namespace cutlass {
42
+ namespace epilogue {
43
+ namespace thread {
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ /// Specifies internal data type for computation
48
+ /// Note :
49
+ /// 1. Scalar means alpha/beta is a single value from host(constant param) or device memory.
50
+ /// 2. Vector means alpha/beta is a vector always from device memory.
51
+ struct ScaleType {
52
+ enum Kind {
53
+ Default, // D = scalar_alpha x Acc + scalar_beta x C
54
+ NoBetaScaling, // D = scalar_alpha x Acc + C
55
+ OnlyAlphaScaling, // D = scalar_alpha x Acc
56
+ PerChannelScaling, // D = vector_alpha x Acc + vector_beta x C
57
+ OnlyAlphaPerChannelScaling, // D = vector_alpha x Acc
58
+ Nothing // D = Acc
59
+ };
60
+ };
61
+
62
+ /////////////////////////////////////////////////////////////////////////////////////////////////
63
+
64
+ } // namespace thread
65
+ } // namespace epilogue
66
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped complex GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+
45
+ #include "cutlass/gemm/gemm.h"
46
+
47
+ #include "cutlass/epilogue/thread/linear_combination.h"
48
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
49
+ #include "cutlass/epilogue/thread/linear_combination_gelu.h"
50
+ #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
51
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
52
+
53
+ #include "cutlass/epilogue/thread/conversion_op.h"
54
+ #include "cutlass/epilogue/thread/reduction_op.h"
55
+
56
+ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
57
+
58
+ #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
59
+ #include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h"
60
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
61
+ #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
62
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
63
+ #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
64
+
65
+ #include "cutlass/epilogue/threadblock/epilogue.h"
66
+
67
+ ////////////////////////////////////////////////////////////////////////////////
68
+
69
+ namespace cutlass {
70
+ namespace epilogue {
71
+ namespace threadblock {
72
+
73
+ /////////////////////////////////////////////////////////////////////////////////////////////////
74
+ /// Specialization and defines sensible defaults for epilogues for complex*complex case
75
+ // 4 real-valued mma operations (Complex)
76
+ // A = (ar + j ai), B (br +j bi), D = AB
77
+ // D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br)
78
+ /////////////////////////////////////////////////////////////////////////////////////////////////
79
+ template <
80
+ /// Epilogue Shape
81
+ typename Shape_,
82
+ /// Warp-level mma operator
83
+ typename WarpMmaTensorOp_,
84
+ /// Number of k partitions
85
+ int PartitionsK,
86
+ /// Epilogue output operator
87
+ typename OutputOp_,
88
+ /// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load()
89
+ int ElementsPerAccess,
90
+ /// Multiply-add operator
91
+ /// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
92
+ typename Operator_ = arch::OpMultiplyAddComplex
93
+ >
94
+ struct DefaultEpilogueComplexTensorOp {
95
+
96
+ using Shape = Shape_;
97
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
98
+ static int const kPartitionsK = PartitionsK;
99
+ using OutputOp = OutputOp_;
100
+ static int const kElementsPerAccess = ElementsPerAccess;
101
+ using Operator = Operator_;
102
+
103
+ using ElementOutput = typename OutputOp::ElementOutput;
104
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
105
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
106
+
107
+ //
108
+ // Thread map
109
+ //
110
+
111
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
112
+ Shape,
113
+ typename WarpMmaTensorOp::Shape,
114
+ kPartitionsK,
115
+ ElementOutput,
116
+ kElementsPerAccess
117
+ >::Type;
118
+
119
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
120
+ OutputTileThreadMap,
121
+ ElementOutput
122
+ >;
123
+
124
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
125
+ typename WarpMmaTensorOp::Shape,
126
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
127
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
128
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
129
+ LayoutC
130
+ >;
131
+
132
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
133
+ typename WarpMmaTensorOp::Shape,
134
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
135
+ ElementAccumulator,
136
+ LayoutC
137
+ >;
138
+
139
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
140
+ typename OutputTileThreadMap::CompactedThreadMap,
141
+ ElementAccumulator
142
+ >;
143
+
144
+ /// Hard-coded padding elements added
145
+ using Padding = cutlass::MatrixShape<0, 0>;
146
+
147
+ //
148
+ // Define the epilogue
149
+ //
150
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
151
+ Shape,
152
+ WarpMmaTensorOp,
153
+ kPartitionsK,
154
+ OutputTileIterator,
155
+ AccumulatorFragmentIterator,
156
+ WarpTileIterator,
157
+ SharedLoadIterator,
158
+ OutputOp,
159
+ Padding
160
+ >;
161
+ };
162
+
163
+ /////////////////////////////////////////////////////////////////////////////////////////////////
164
+ /// Partial specialization and defines sensible defaults for epilogues for complex*complex case
165
+ // 3 real-valued mma operations (Gaussian Complex)
166
+ // A = (ar + j ai), B = (br +j bi), D = AB
167
+ // P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi)
168
+ // D = dr + j di = (P1 - P3) + j (P1 + P2)
169
+ /////////////////////////////////////////////////////////////////////////////////////////////////
170
+ template <
171
+ typename Shape_,
172
+ typename WarpMmaTensorOp_,
173
+ int PartitionsK,
174
+ typename OutputOp_,
175
+ int ElementsPerAccess
176
+ >
177
+ struct DefaultEpilogueComplexTensorOp <Shape_, WarpMmaTensorOp_, PartitionsK,
178
+ OutputOp_, ElementsPerAccess,
179
+ arch::OpMultiplyAddGaussianComplex
180
+ > {
181
+
182
+ using Shape = Shape_;
183
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
184
+ static int const kPartitionsK = PartitionsK;
185
+ using OutputOp = OutputOp_;
186
+ static int const kElementsPerAccess = ElementsPerAccess;
187
+ using Operator = arch::OpMultiplyAddGaussianComplex;
188
+
189
+ using ElementOutput = typename OutputOp::ElementOutput;
190
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
191
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
192
+
193
+ //
194
+ // Thread map
195
+ //
196
+
197
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
198
+ Shape,
199
+ typename WarpMmaTensorOp::Shape,
200
+ kPartitionsK,
201
+ ElementOutput,
202
+ kElementsPerAccess
203
+ >::Type;
204
+
205
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
206
+ OutputTileThreadMap,
207
+ ElementOutput
208
+ >;
209
+
210
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp<
211
+ typename WarpMmaTensorOp::Shape,
212
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
213
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
214
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
215
+ LayoutC
216
+ >;
217
+
218
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
219
+ typename WarpMmaTensorOp::Shape,
220
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
221
+ ElementAccumulator,
222
+ LayoutC
223
+ >;
224
+
225
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
226
+ typename OutputTileThreadMap::CompactedThreadMap,
227
+ ElementAccumulator
228
+ >;
229
+
230
+ /// Hard-coded padding elements added
231
+ using Padding = cutlass::MatrixShape<0, 0>;
232
+
233
+ //
234
+ // Define the epilogue
235
+ //
236
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
237
+ Shape,
238
+ WarpMmaTensorOp,
239
+ kPartitionsK,
240
+ OutputTileIterator,
241
+ AccumulatorFragmentIterator,
242
+ WarpTileIterator,
243
+ SharedLoadIterator,
244
+ OutputOp,
245
+ Padding
246
+ >;
247
+ };
248
+
249
+ ////////////////////////////////////////////////////////////////////////////////
250
+
251
+ } // namespace threadblock
252
+ } // namespace epilogue
253
+ } // namespace cutlass
254
+
255
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped complex GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+
38
+ */
39
+
40
+ #pragma once
41
+
42
+ #include "cutlass/cutlass.h"
43
+ #include "cutlass/numeric_types.h"
44
+ #include "cutlass/array.h"
45
+
46
+ #include "cutlass/gemm/gemm.h"
47
+
48
+ #include "cutlass/epilogue/thread/linear_combination.h"
49
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
50
+ #include "cutlass/epilogue/thread/linear_combination_gelu.h"
51
+ #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
52
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
53
+
54
+ #include "cutlass/epilogue/thread/conversion_op.h"
55
+ #include "cutlass/epilogue/thread/reduction_op.h"
56
+
57
+ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
58
+
59
+ #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
60
+ #include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h"
61
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
62
+ #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
63
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h"
64
+ #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
65
+
66
+ #include "cutlass/epilogue/threadblock/epilogue.h"
67
+
68
+ ////////////////////////////////////////////////////////////////////////////////
69
+
70
+ namespace cutlass {
71
+ namespace epilogue {
72
+ namespace threadblock {
73
+
74
+ /////////////////////////////////////////////////////////////////////////////////////////////////
75
+ /// Specialization and defines sensible defaults for epilogues for complex*complex case
76
+ // 4 real-valued mma operations (Complex)
77
+ // A = (ar + j ai), B (br +j bi), D = AB
78
+ // D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br)
79
+ /////////////////////////////////////////////////////////////////////////////////////////////////
80
+ template <
81
+ /// Epilogue Shape
82
+ typename Shape_,
83
+ /// Warp-level mma operator
84
+ typename WarpMmaTensorOp_,
85
+ /// Number of k partitions
86
+ int PartitionsK,
87
+ /// Epilogue output operator
88
+ typename OutputOp_,
89
+ /// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load()
90
+ int ElementsPerAccess,
91
+ /// Multiply-add operator
92
+ /// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
93
+ typename Operator_ = arch::OpMultiplyAddComplex,
94
+ /// Is for a symmetric kernel
95
+ BlasMode BlasMode_ = BlasMode::kGemm
96
+ >
97
+ struct DefaultEpilogueComplexTensorOpBlas3 {
98
+
99
+ using Shape = Shape_;
100
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
101
+ static int const kPartitionsK = PartitionsK;
102
+ using OutputOp = OutputOp_;
103
+ static int const kElementsPerAccess = ElementsPerAccess;
104
+ using Operator = Operator_;
105
+ static BlasMode const kBlasMode = BlasMode_;
106
+
107
+ using ElementOutput = typename OutputOp::ElementOutput;
108
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
109
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
110
+
111
+ //
112
+ // Thread map
113
+ //
114
+
115
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
116
+ Shape,
117
+ typename WarpMmaTensorOp::Shape,
118
+ kPartitionsK,
119
+ ElementOutput,
120
+ kElementsPerAccess
121
+ >::Type;
122
+
123
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3<
124
+ OutputTileThreadMap,
125
+ ElementOutput
126
+ , kBlasMode
127
+ >;
128
+
129
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
130
+ typename WarpMmaTensorOp::Shape,
131
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
132
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
133
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
134
+ LayoutC
135
+ >;
136
+
137
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
138
+ typename WarpMmaTensorOp::Shape,
139
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
140
+ ElementAccumulator,
141
+ LayoutC
142
+ >;
143
+
144
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
145
+ typename OutputTileThreadMap::CompactedThreadMap,
146
+ ElementAccumulator
147
+ >;
148
+
149
+ /// Hard-coded padding elements added
150
+ using Padding = cutlass::MatrixShape<0, 0>;
151
+
152
+ //
153
+ // Define the epilogue
154
+ //
155
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
156
+ Shape,
157
+ WarpMmaTensorOp,
158
+ kPartitionsK,
159
+ OutputTileIterator,
160
+ AccumulatorFragmentIterator,
161
+ WarpTileIterator,
162
+ SharedLoadIterator,
163
+ OutputOp,
164
+ Padding
165
+ >;
166
+ };
167
+
168
+ /////////////////////////////////////////////////////////////////////////////////////////////////
169
+ /// Partial specialization and defines sensible defaults for epilogues for complex*complex case
170
+ // 3 real-valued mma operations (Gaussian Complex)
171
+ // A = (ar + j ai), B = (br +j bi), D = AB
172
+ // P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi)
173
+ // D = dr + j di = (P1 - P3) + j (P1 + P2)
174
+ /////////////////////////////////////////////////////////////////////////////////////////////////
175
+ template <
176
+ typename Shape_,
177
+ typename WarpMmaTensorOp_,
178
+ int PartitionsK,
179
+ typename OutputOp_,
180
+ int ElementsPerAccess,
181
+ BlasMode BlasMode_
182
+ >
183
+ struct DefaultEpilogueComplexTensorOpBlas3 <Shape_, WarpMmaTensorOp_, PartitionsK,
184
+ OutputOp_, ElementsPerAccess,
185
+ arch::OpMultiplyAddGaussianComplex
186
+ , BlasMode_
187
+ > {
188
+
189
+ using Shape = Shape_;
190
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
191
+ static int const kPartitionsK = PartitionsK;
192
+ using OutputOp = OutputOp_;
193
+ static int const kElementsPerAccess = ElementsPerAccess;
194
+ using Operator = arch::OpMultiplyAddGaussianComplex;
195
+ static BlasMode const kBlasMode = BlasMode_;
196
+
197
+ using ElementOutput = typename OutputOp::ElementOutput;
198
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
199
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
200
+
201
+ //
202
+ // Thread map
203
+ //
204
+
205
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
206
+ Shape,
207
+ typename WarpMmaTensorOp::Shape,
208
+ kPartitionsK,
209
+ ElementOutput,
210
+ kElementsPerAccess
211
+ >::Type;
212
+
213
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3<
214
+ OutputTileThreadMap,
215
+ ElementOutput,
216
+ kBlasMode
217
+ >;
218
+
219
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp<
220
+ typename WarpMmaTensorOp::Shape,
221
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
222
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
223
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
224
+ LayoutC
225
+ >;
226
+
227
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
228
+ typename WarpMmaTensorOp::Shape,
229
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
230
+ ElementAccumulator,
231
+ LayoutC
232
+ >;
233
+
234
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
235
+ typename OutputTileThreadMap::CompactedThreadMap,
236
+ ElementAccumulator
237
+ >;
238
+
239
+ /// Hard-coded padding elements added
240
+ using Padding = cutlass::MatrixShape<0, 0>;
241
+
242
+ //
243
+ // Define the epilogue
244
+ //
245
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
246
+ Shape,
247
+ WarpMmaTensorOp,
248
+ kPartitionsK,
249
+ OutputTileIterator,
250
+ AccumulatorFragmentIterator,
251
+ WarpTileIterator,
252
+ SharedLoadIterator,
253
+ OutputOp,
254
+ Padding
255
+ >;
256
+ };
257
+
258
+ ////////////////////////////////////////////////////////////////////////////////
259
+
260
+ } // namespace threadblock
261
+ } // namespace epilogue
262
+ } // namespace cutlass
263
+
264
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Direct store epilogue
33
+ */
34
+
35
+ #pragma once
36
+
37
+ ////////////////////////////////////////////////////////////////////////////////
38
+
39
+ #include "cutlass/epilogue/threadblock/epilogue_direct_store.h"
40
+ #include "cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h"
41
+
42
+ ////////////////////////////////////////////////////////////////////////////////
43
+
44
+ namespace cutlass {
45
+ namespace epilogue {
46
+ namespace threadblock {
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////
49
+
50
+ /// Given a properly constructed epilogue, returns a direct store epilogue
51
+ template <typename EpilogueTensorOp>
52
+ struct DefaultEpilogueDirectStore {
53
+
54
+ using OutputTileIterator = DirectStoreEpilogueIterator<typename EpilogueTensorOp::OutputTileIterator::Element>;
55
+
56
+ using Epilogue = EpilogueDirectStore<
57
+ typename EpilogueTensorOp::Shape,
58
+ typename EpilogueTensorOp::WarpMmaOperator,
59
+ EpilogueTensorOp::kPartitionsK,
60
+ OutputTileIterator,
61
+ typename EpilogueTensorOp::AccumulatorFragmentIterator,
62
+ typename EpilogueTensorOp::WarpTileIterator,
63
+ typename EpilogueTensorOp::SharedLoadIterator,
64
+ typename EpilogueTensorOp::OutputOp
65
+ >;
66
+ };
67
+
68
+ ////////////////////////////////////////////////////////////////////////////////
69
+
70
+ } // namespace threadblock
71
+ } // namespace epilogue
72
+ } // namespace cutlass
73
+
74
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Constructs a default epilogue for planar complex outputs.
33
+
34
+ This template reuses components for real-valued epilogues and applies them to planar complex
35
+ output matrices.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+ #include "cutlass/array_planar_complex.h"
45
+
46
+ #include "cutlass/arch/arch.h"
47
+
48
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
49
+ #include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
50
+ #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
51
+ #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
52
+
53
+ #include "cutlass/epilogue/threadblock/epilogue_planar_complex.h"
54
+
55
+ /////////////////////////////////////////////////////////////////////////////////////////////////
56
+
57
+ namespace cutlass {
58
+ namespace epilogue {
59
+ namespace threadblock {
60
+
61
+ /////////////////////////////////////////////////////////////////////////////////////////////////
62
+
63
+ /// Defines sensible defaults for epilogues.
64
+ template <
65
+ typename ThreadblockShape_,
66
+ typename WarpMma_,
67
+ typename OpcodeClass_,
68
+ typename ArchTag_,
69
+ int PartitionsK,
70
+ typename OutputOp_,
71
+ int ElementsPerAccess
72
+ >
73
+ struct DefaultEpiloguePlanarComplex;
74
+
75
+ /////////////////////////////////////////////////////////////////////////////////////////////////
76
+
77
+ /// Defines sensible defaults for epilogues.
78
+ template <
79
+ typename ThreadblockShape_,
80
+ typename WarpMmaOperator_,
81
+ int PartitionsK,
82
+ typename OutputOp_,
83
+ int ElementsPerAccess
84
+ >
85
+ struct DefaultEpiloguePlanarComplex<
86
+ ThreadblockShape_,
87
+ WarpMmaOperator_,
88
+ arch::OpClassTensorOp,
89
+ arch::Sm70,
90
+ PartitionsK,
91
+ OutputOp_,
92
+ ElementsPerAccess> {
93
+
94
+ using RealEpilogue = DefaultEpilogueVoltaTensorOp<
95
+ ThreadblockShape_,
96
+ WarpMmaOperator_,
97
+ PartitionsK,
98
+ OutputOp_,
99
+ ElementsPerAccess
100
+ >;
101
+
102
+ using Epilogue = EpiloguePlanarComplex<
103
+ ThreadblockShape_,
104
+ WarpMmaOperator_,
105
+ PartitionsK,
106
+ typename RealEpilogue::OutputTileIterator,
107
+ typename RealEpilogue::AccumulatorFragmentIterator,
108
+ typename RealEpilogue::WarpTileIterator,
109
+ typename RealEpilogue::SharedLoadIterator,
110
+ OutputOp_,
111
+ typename RealEpilogue::Padding
112
+ >;
113
+ };
114
+
115
+ /////////////////////////////////////////////////////////////////////////////////////////////////
116
+
117
+ /// Defines sensible defaults for epilogues.
118
+ template <
119
+ typename ThreadblockShape_,
120
+ typename WarpMmaOperator_,
121
+ int PartitionsK,
122
+ typename OutputOp_,
123
+ int ElementsPerAccess
124
+ >
125
+ struct DefaultEpiloguePlanarComplex<
126
+ ThreadblockShape_,
127
+ WarpMmaOperator_,
128
+ arch::OpClassTensorOp,
129
+ arch::Sm75,
130
+ PartitionsK,
131
+ OutputOp_,
132
+ ElementsPerAccess> {
133
+
134
+ using RealEpilogue = DefaultEpilogueTensorOp<
135
+ ThreadblockShape_,
136
+ WarpMmaOperator_,
137
+ PartitionsK,
138
+ OutputOp_,
139
+ ElementsPerAccess
140
+ >;
141
+
142
+ using Epilogue = EpiloguePlanarComplex<
143
+ ThreadblockShape_,
144
+ WarpMmaOperator_,
145
+ PartitionsK,
146
+ typename RealEpilogue::OutputTileIterator,
147
+ typename RealEpilogue::AccumulatorFragmentIterator,
148
+ typename RealEpilogue::WarpTileIterator,
149
+ typename RealEpilogue::SharedLoadIterator,
150
+ OutputOp_,
151
+ typename RealEpilogue::Padding
152
+ >;
153
+ };
154
+
155
+ /////////////////////////////////////////////////////////////////////////////////////////////////
156
+
157
+ /// Defines sensible defaults for epilogues.
158
+ template <
159
+ typename ThreadblockShape_,
160
+ typename WarpMmaOperator_,
161
+ int PartitionsK,
162
+ typename OutputOp_,
163
+ int ElementsPerAccess
164
+ >
165
+ struct DefaultEpiloguePlanarComplex<
166
+ ThreadblockShape_,
167
+ WarpMmaOperator_,
168
+ arch::OpClassTensorOp,
169
+ arch::Sm80,
170
+ PartitionsK,
171
+ OutputOp_,
172
+ ElementsPerAccess> {
173
+
174
+ using RealEpilogue = DefaultEpilogueTensorOp<
175
+ ThreadblockShape_,
176
+ WarpMmaOperator_,
177
+ PartitionsK,
178
+ OutputOp_,
179
+ ElementsPerAccess
180
+ >;
181
+
182
+ using Epilogue = EpiloguePlanarComplex<
183
+ ThreadblockShape_,
184
+ WarpMmaOperator_,
185
+ PartitionsK,
186
+ typename RealEpilogue::OutputTileIterator,
187
+ typename RealEpilogue::AccumulatorFragmentIterator,
188
+ typename RealEpilogue::WarpTileIterator,
189
+ typename RealEpilogue::SharedLoadIterator,
190
+ OutputOp_,
191
+ typename RealEpilogue::Padding
192
+ >;
193
+ };
194
+
195
+ /////////////////////////////////////////////////////////////////////////////////////////////////
196
+
197
+ /// Defines sensible defaults for epilogues.
198
+ template <
199
+ typename ThreadblockShape_,
200
+ typename WarpMmaOperator_,
201
+ typename ArchTag_,
202
+ int PartitionsK,
203
+ typename OutputOp_,
204
+ int ElementsPerAccess
205
+ >
206
+ struct DefaultEpiloguePlanarComplex<
207
+ ThreadblockShape_,
208
+ WarpMmaOperator_,
209
+ arch::OpClassSimt,
210
+ ArchTag_,
211
+ PartitionsK,
212
+ OutputOp_,
213
+ ElementsPerAccess> {
214
+
215
+ using RealEpilogue = DefaultEpilogueSimt<
216
+ ThreadblockShape_,
217
+ WarpMmaOperator_,
218
+ OutputOp_,
219
+ ElementsPerAccess
220
+ >;
221
+
222
+ using Epilogue = EpiloguePlanarComplex<
223
+ ThreadblockShape_,
224
+ WarpMmaOperator_,
225
+ PartitionsK,
226
+ typename RealEpilogue::OutputTileIterator,
227
+ typename RealEpilogue::AccumulatorFragmentIterator,
228
+ typename RealEpilogue::WarpTileIterator,
229
+ typename RealEpilogue::SharedLoadIterator,
230
+ OutputOp_,
231
+ typename RealEpilogue::Padding
232
+ >;
233
+ };
234
+
235
+ /////////////////////////////////////////////////////////////////////////////////////////////////
236
+
237
+ } // namespace threadblock
238
+ } // namespace epilogue
239
+ } // namespace cutlass
240
+
241
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using SIMT.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+
45
+ #include "cutlass/arch/mma.h"
46
+
47
+ #include "cutlass/gemm/gemm.h"
48
+ #include "cutlass/gemm/warp/mma.h"
49
+
50
+ #include "cutlass/epilogue/thread/linear_combination.h"
51
+ #include "cutlass/epilogue/thread/linear_combination_clamp.h"
52
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
53
+ #include "cutlass/epilogue/thread/linear_combination_gelu.h"
54
+ #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
55
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
56
+ #include "cutlass/epilogue/thread/conversion_op.h"
57
+ #include "cutlass/epilogue/thread/reduction_op.h"
58
+
59
+ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
60
+
61
+ #include "cutlass/epilogue/warp/fragment_iterator_simt.h"
62
+ #include "cutlass/epilogue/warp/tile_iterator_simt.h"
63
+ #include "cutlass/epilogue/threadblock/default_thread_map_simt.h"
64
+ #include "cutlass/transform/pitch_linear_thread_map.h"
65
+
66
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
67
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h"
68
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
69
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
70
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h"
71
+ #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
72
+ #include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h"
73
+ #include "cutlass/epilogue/threadblock/epilogue.h"
74
+ #include "cutlass/epilogue/threadblock/epilogue_depthwise.h"
75
+
76
+ #include "cutlass/layout/permute.h"
77
+
78
+ /////////////////////////////////////////////////////////////////////////////////////////////////
79
+
80
+ namespace cutlass {
81
+ namespace epilogue {
82
+ namespace threadblock {
83
+
84
+ /////////////////////////////////////////////////////////////////////////////////////////////////
85
+
86
+ /// Defines sensible defaults for epilogues for SimtOps.
87
+ template <
88
+ typename Shape_,
89
+ typename WarpMmaSimt_,
90
+ typename OutputOp_,
91
+ int ElementsPerAccess,
92
+ bool ScatterD = false,
93
+ typename PermuteDLayout = layout::NoPermute,
94
+ conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity,
95
+ int Rank = 4
96
+ >
97
+ struct DefaultEpilogueSimt {
98
+
99
+ using Shape = Shape_;
100
+ using WarpMmaSimt = WarpMmaSimt_;
101
+ using OutputOp = OutputOp_;
102
+ static int const kElementsPerAccess = ElementsPerAccess;
103
+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
104
+
105
+ using ElementOutput = typename OutputOp::ElementOutput;
106
+ using LayoutC = typename WarpMmaSimt::LayoutC;
107
+ using ElementAccumulator = typename WarpMmaSimt::ElementC;
108
+ static conv::StrideSupport const kStrideSupport = StrideSupport;
109
+ static int const kRank = Rank;
110
+
111
+ //
112
+ // Thread map
113
+ //
114
+
115
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
116
+ Shape,
117
+ typename WarpMmaSimt::Shape,
118
+ typename WarpMmaSimt::Policy,
119
+ kPartitionsK,
120
+ ElementOutput,
121
+ kElementsPerAccess
122
+ >::Type;
123
+
124
+ static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
125
+
126
+ using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
127
+ OutputTileThreadMap,
128
+ ElementOutput,
129
+ ScatterD,
130
+ PermuteDLayout,
131
+ UseCUDAStore
132
+ >;
133
+
134
+ using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv<
135
+ OutputTileThreadMap,
136
+ ElementOutput,
137
+ ScatterD,
138
+ PermuteDLayout,
139
+ UseCUDAStore,
140
+ kRank
141
+ >;
142
+
143
+ using OutputTileIterator = typename platform::conditional<StrideSupport == cutlass::conv::StrideSupport::kUnity,
144
+ PackedOutputTileIterator,
145
+ StridedOutputTileIterator>::type;
146
+
147
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
148
+ typename WarpMmaSimt::Shape,
149
+ typename WarpMmaSimt::ThreadMma,
150
+ layout::RowMajor,
151
+ typename WarpMmaSimt::Policy
152
+ >;
153
+
154
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
155
+ typename WarpMmaSimt::Shape,
156
+ typename WarpMmaSimt::ThreadMma,
157
+ ElementAccumulator,
158
+ layout::RowMajor,
159
+ typename WarpMmaSimt::Policy
160
+ >;
161
+
162
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
163
+ typename OutputTileThreadMap::CompactedThreadMap,
164
+ ElementAccumulator
165
+ >;
166
+
167
+ /// Hard-coded padding elements added
168
+ using Padding = typename WarpTileIterator::Padding;
169
+
170
+ //
171
+ // Define the epilogue
172
+ //
173
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
174
+ Shape,
175
+ WarpMmaSimt,
176
+ kPartitionsK,
177
+ OutputTileIterator,
178
+ AccumulatorFragmentIterator,
179
+ WarpTileIterator,
180
+ SharedLoadIterator,
181
+ OutputOp,
182
+ Padding
183
+ >;
184
+ };
185
+
186
+ /////////////////////////////////////////////////////////////////////////////////////////////////
187
+
188
+ /// Defines sensible defaults for epilogues for SimtOps.
189
+ template <
190
+ typename Shape_,
191
+ typename WarpMmaSimt_,
192
+ typename OutputOp_,
193
+ int ElementsPerAccess
194
+ >
195
+ struct DefaultEpilogueSimtStridedDgrad {
196
+
197
+ using Shape = Shape_;
198
+ using WarpMmaSimt = WarpMmaSimt_;
199
+ using OutputOp = OutputOp_;
200
+ static int const kElementsPerAccess = ElementsPerAccess;
201
+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
202
+
203
+ using ElementOutput = typename OutputOp::ElementOutput;
204
+ using LayoutC = typename WarpMmaSimt::LayoutC;
205
+ using ElementAccumulator = typename WarpMmaSimt::ElementC;
206
+
207
+ //
208
+ // Thread map
209
+ //
210
+
211
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
212
+ Shape,
213
+ typename WarpMmaSimt::Shape,
214
+ typename WarpMmaSimt::Policy,
215
+ kPartitionsK,
216
+ ElementOutput,
217
+ kElementsPerAccess
218
+ >::Type;
219
+
220
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
221
+ OutputTileThreadMap,
222
+ ElementOutput
223
+ >;
224
+
225
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
226
+ typename WarpMmaSimt::Shape,
227
+ typename WarpMmaSimt::ThreadMma,
228
+ layout::RowMajor,
229
+ typename WarpMmaSimt::Policy
230
+ >;
231
+
232
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
233
+ typename WarpMmaSimt::Shape,
234
+ typename WarpMmaSimt::ThreadMma,
235
+ ElementAccumulator,
236
+ layout::RowMajor,
237
+ typename WarpMmaSimt::Policy
238
+ >;
239
+
240
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
241
+ typename OutputTileThreadMap::CompactedThreadMap,
242
+ ElementAccumulator
243
+ >;
244
+
245
+ /// Hard-coded padding elements added
246
+ using Padding = typename WarpTileIterator::Padding;
247
+
248
+ //
249
+ // Define the epilogue
250
+ //
251
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
252
+ Shape,
253
+ WarpMmaSimt,
254
+ kPartitionsK,
255
+ OutputTileIterator,
256
+ AccumulatorFragmentIterator,
257
+ WarpTileIterator,
258
+ SharedLoadIterator,
259
+ OutputOp,
260
+ Padding
261
+ >;
262
+ };
263
+
264
+ /////////////////////////////////////////////////////////////////////////////////////////////////
265
+
266
+ /// Defines sensible defaults for epilogues for SimtOps.
267
+ template <
268
+ int Rank,
269
+ typename Shape_,
270
+ typename WarpMmaSimt_,
271
+ typename OutputOp_,
272
+ int ElementsPerAccess
273
+ >
274
+ struct DefaultEpilogueSimtAffineRankN {
275
+
276
+ using Shape = Shape_;
277
+ using WarpMmaSimt = WarpMmaSimt_;
278
+ using OutputOp = OutputOp_;
279
+ static int const kElementsPerAccess = ElementsPerAccess;
280
+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
281
+
282
+ using ElementOutput = typename OutputOp::ElementOutput;
283
+ using LayoutC = typename WarpMmaSimt::LayoutC;
284
+ using ElementAccumulator = typename WarpMmaSimt::ElementC;
285
+
286
+ //
287
+ // Thread map
288
+ //
289
+
290
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
291
+ Shape,
292
+ typename WarpMmaSimt::Shape,
293
+ typename WarpMmaSimt::Policy,
294
+ kPartitionsK,
295
+ ElementOutput,
296
+ kElementsPerAccess
297
+ >::Type;
298
+
299
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
300
+ OutputTileThreadMap,
301
+ ElementOutput,
302
+ Rank
303
+ >;
304
+
305
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
306
+ typename WarpMmaSimt::Shape,
307
+ typename WarpMmaSimt::ThreadMma,
308
+ layout::RowMajor,
309
+ typename WarpMmaSimt::Policy
310
+ >;
311
+
312
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
313
+ typename WarpMmaSimt::Shape,
314
+ typename WarpMmaSimt::ThreadMma,
315
+ ElementAccumulator,
316
+ layout::RowMajor,
317
+ typename WarpMmaSimt::Policy
318
+ >;
319
+
320
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
321
+ typename OutputTileThreadMap::CompactedThreadMap,
322
+ ElementAccumulator
323
+ >;
324
+
325
+ /// Hard-coded padding elements added
326
+ using Padding = typename WarpTileIterator::Padding;
327
+
328
+ //
329
+ // Define the epilogue
330
+ //
331
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
332
+ Shape,
333
+ WarpMmaSimt,
334
+ kPartitionsK,
335
+ OutputTileIterator,
336
+ AccumulatorFragmentIterator,
337
+ WarpTileIterator,
338
+ SharedLoadIterator,
339
+ OutputOp,
340
+ Padding
341
+ >;
342
+ };
343
+
344
+ /////////////////////////////////////////////////////////////////////////////////////////////////
345
+ /////////////////////////////////////////////////////////////////////////////////////////////////
346
+
347
+ /// Defines sensible defaults for epilogues for SimtOps.
348
+ template <typename Shape_, // ThreadBlock Shape
349
+ typename WarpMmaSimt_, // mma_depthwise_simt
350
+ typename OutputOp_,
351
+ int ElementsPerAccess_,
352
+ typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>,
353
+ typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> >
354
+ struct DefaultDirectConvEpilogueSimt {
355
+ using Shape = Shape_;
356
+ using WarpMmaSimt = WarpMmaSimt_;
357
+ using WarpShape = typename WarpMmaSimt::Shape;
358
+ using OutputOp = OutputOp_;
359
+ using ThreadOutputShape = ThreadOutputShape_;
360
+ using ThreadBlockOutputShape = ThreadBlockOutputShape_;
361
+ static int const kElementsPerAccess = ElementsPerAccess_;
362
+
363
+
364
+ using ElementOutput = typename OutputOp::ElementOutput;
365
+ using LayoutC = typename WarpMmaSimt::LayoutC;
366
+ using ElementAccumulator = typename WarpMmaSimt::ElementC;
367
+
368
+ /// Number of threads total
369
+ using WarpCount = gemm::GemmShape<
370
+ Shape::kM / WarpShape::kM,
371
+ Shape::kN / WarpShape::kN
372
+ >;
373
+
374
+ static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
375
+
376
+ static int const kThreads = WarpCount::kCount * kWarpSize;
377
+
378
+ //
379
+ // Thread map
380
+ //
381
+
382
+ using OutputTileThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
383
+ layout::PitchLinearShape<ThreadBlockOutputShape::kC, ThreadBlockOutputShape::kNHW>,
384
+ kThreads,
385
+ kElementsPerAccess
386
+ >;
387
+
388
+
389
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<
390
+ OutputTileThreadMap,
391
+ ElementOutput,
392
+ ThreadOutputShape,
393
+ ThreadBlockOutputShape
394
+ >;
395
+
396
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
397
+ typename WarpMmaSimt::Shape,
398
+ typename WarpMmaSimt::ThreadMma,
399
+ layout::RowMajor,
400
+ typename WarpMmaSimt::Policy
401
+ >;
402
+
403
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<
404
+ typename WarpMmaSimt::Shape,
405
+ ThreadOutputShape,
406
+ ThreadBlockOutputShape,
407
+ typename WarpMmaSimt::ThreadMma,
408
+ ElementAccumulator,
409
+ layout::RowMajor,
410
+ typename WarpMmaSimt::Policy
411
+ >;
412
+
413
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<
414
+ OutputTileThreadMap,
415
+ ElementAccumulator
416
+ >;
417
+
418
+ /// Hard-coded padding elements added
419
+ using Padding = typename WarpTileIterator::Padding;
420
+ //
421
+ // Define the epilogue
422
+ //
423
+ using Epilogue = cutlass::epilogue::threadblock::EpilogueDepthwise<
424
+ Shape,
425
+ ThreadOutputShape,
426
+ ThreadBlockOutputShape,
427
+ WarpMmaSimt,
428
+ OutputTileIterator,
429
+ AccumulatorFragmentIterator,
430
+ WarpTileIterator,
431
+ SharedLoadIterator,
432
+ OutputOp,
433
+ Padding
434
+ >;
435
+ };
436
+
437
+ /////////////////////////////////////////////////////////////////////////////////////////////////
438
+
439
+ } // namespace threadblock
440
+ } // namespace epilogue
441
+ } // namespace cutlass
442
+
443
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+
45
+ #include "cutlass/platform/platform.h"
46
+
47
+ #include "cutlass/gemm/gemm.h"
48
+
49
+ #include "cutlass/epilogue/thread/linear_combination.h"
50
+ #include "cutlass/epilogue/thread/linear_combination_clamp.h"
51
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
52
+ #include "cutlass/epilogue/thread/linear_combination_relu0.h"
53
+ #include "cutlass/epilogue/thread/linear_combination_gelu.h"
54
+ #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
55
+ #include "cutlass/epilogue/thread/linear_combination_hardswish.h"
56
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
57
+
58
+ #include "cutlass/epilogue/thread/conversion_op.h"
59
+ #include "cutlass/epilogue/thread/reduction_op.h"
60
+
61
+ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
62
+
63
+ #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
64
+ #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
65
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
66
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
67
+ #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
68
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
69
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h"
70
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
71
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
72
+ #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
73
+ #include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
74
+
75
+ #include "cutlass/epilogue/threadblock/epilogue.h"
76
+ #include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
77
+
78
+ #include "cutlass/layout/permute.h"
79
+
80
+ ////////////////////////////////////////////////////////////////////////////////
81
+
82
+ namespace cutlass {
83
+ namespace epilogue {
84
+ namespace threadblock {
85
+
86
+ ////////////////////////////////////////////////////////////////////////////////
87
+
88
+ namespace detail {
89
+
90
+ template <
91
+ typename ElementOutput,
92
+ typename ElementAccumulator,
93
+ int ElementsPerAccess,
94
+ typename ThreadblockShape,
95
+ typename WarpShape,
96
+ typename InstructionShape,
97
+ typename ThreadMap
98
+ >
99
+ struct DefaultIteratorsTensorOp {
100
+
101
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
102
+ WarpShape,
103
+ InstructionShape,
104
+ ElementAccumulator,
105
+ layout::RowMajor
106
+ >;
107
+
108
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
109
+ ThreadMap,
110
+ ElementAccumulator
111
+ >;
112
+
113
+ static int const kFragmentsPerIteration = 1;
114
+ };
115
+
116
+ /// Partial specialization for float <= float x 4
117
+ template <
118
+ typename ThreadblockShape,
119
+ typename WarpShape,
120
+ typename InstructionShape,
121
+ typename ThreadMap
122
+ >
123
+ struct DefaultIteratorsTensorOp<float, float, 4, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
124
+
125
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
126
+ WarpShape,
127
+ InstructionShape,
128
+ float,
129
+ layout::RowMajor
130
+ >;
131
+
132
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
133
+ ThreadMap,
134
+ float
135
+ >;
136
+
137
+ static int const kFragmentsPerIteration = 2;
138
+ };
139
+
140
+ /// Partial specialization for int32_t <= int32_t
141
+ template <
142
+ int ElementsPerAccess,
143
+ typename ThreadblockShape,
144
+ typename WarpShape,
145
+ typename InstructionShape,
146
+ typename ThreadMap
147
+ >
148
+ struct DefaultIteratorsTensorOp<int32_t, int32_t, ElementsPerAccess, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
149
+
150
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
151
+ WarpShape,
152
+ InstructionShape,
153
+ int32_t,
154
+ layout::RowMajor
155
+ >;
156
+
157
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
158
+ ThreadMap,
159
+ int32_t
160
+ >;
161
+
162
+ static int const kFragmentsPerIteration = 1;
163
+ };
164
+
165
+ /// Partial specialization for float <= int32_t
166
+ template <
167
+ int ElementsPerAccess,
168
+ typename ThreadblockShape,
169
+ typename WarpShape,
170
+ typename InstructionShape,
171
+ typename ThreadMap
172
+ >
173
+ struct DefaultIteratorsTensorOp<float, int32_t, ElementsPerAccess, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
174
+
175
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
176
+ WarpShape,
177
+ InstructionShape,
178
+ int32_t,
179
+ layout::RowMajor
180
+ >;
181
+
182
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
183
+ ThreadMap,
184
+ int32_t
185
+ >;
186
+
187
+ static int const kFragmentsPerIteration = 1;
188
+ };
189
+
190
+ /// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts.
191
+ template <
192
+ typename ThreadblockShape,
193
+ typename WarpShape,
194
+ typename InstructionShape,
195
+ typename ThreadMap
196
+ >
197
+ struct DefaultIteratorsTensorOp<
198
+ half_t,
199
+ float,
200
+ 8,
201
+ ThreadblockShape,
202
+ WarpShape,
203
+ InstructionShape,
204
+ ThreadMap> {
205
+
206
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
207
+ WarpShape,
208
+ InstructionShape,
209
+ float,
210
+ 32,
211
+ 16,
212
+ 8,
213
+ 8
214
+ >;
215
+
216
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
217
+ ThreadMap,
218
+ float,
219
+ 32,
220
+ 16,
221
+ 8,
222
+ 8
223
+ >;
224
+
225
+ static int const kFragmentsPerIteration = 2;
226
+ };
227
+
228
+ /// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
229
+ template <
230
+ typename ThreadblockShape,
231
+ typename WarpShape,
232
+ typename InstructionShape,
233
+ typename ThreadMap
234
+ >
235
+ struct DefaultIteratorsTensorOp<
236
+ bfloat16_t,
237
+ int32_t,
238
+ 8,
239
+ ThreadblockShape,
240
+ WarpShape,
241
+ InstructionShape,
242
+ ThreadMap> {
243
+
244
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
245
+ WarpShape,
246
+ InstructionShape,
247
+ int32_t,
248
+ 32,
249
+ 16,
250
+ 8,
251
+ 8
252
+ >;
253
+
254
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
255
+ ThreadMap,
256
+ int32_t,
257
+ 32,
258
+ 16,
259
+ 8,
260
+ 8
261
+ >;
262
+
263
+ static int const kFragmentsPerIteration = 2;
264
+ };
265
+
266
+ /// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
267
+ template <
268
+ typename ThreadblockShape,
269
+ typename WarpShape,
270
+ typename InstructionShape,
271
+ typename ThreadMap
272
+ >
273
+ struct DefaultIteratorsTensorOp<
274
+ half_t,
275
+ int32_t,
276
+ 8,
277
+ ThreadblockShape,
278
+ WarpShape,
279
+ InstructionShape,
280
+ ThreadMap> {
281
+
282
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
283
+ WarpShape,
284
+ InstructionShape,
285
+ int32_t,
286
+ 32,
287
+ 16,
288
+ 8,
289
+ 8
290
+ >;
291
+
292
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
293
+ ThreadMap,
294
+ int32_t,
295
+ 32,
296
+ 16,
297
+ 8,
298
+ 8
299
+ >;
300
+
301
+ static int const kFragmentsPerIteration = 2;
302
+ };
303
+
304
+ /// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
305
+ /// Threadblock::kN = 256 still has bank conflicts.
306
+ template <
307
+ typename ElementOutput,
308
+ int ElementsPerAccess,
309
+ typename ThreadblockShape,
310
+ typename WarpShape,
311
+ typename InstructionShape,
312
+ typename ThreadMap
313
+ >
314
+ struct DefaultIteratorsTensorOp<
315
+ ElementOutput,
316
+ int32_t,
317
+ ElementsPerAccess,
318
+ ThreadblockShape,
319
+ WarpShape,
320
+ InstructionShape,
321
+ ThreadMap> {
322
+
323
+ static_assert(platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
324
+ platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
325
+ platform::is_same<ElementOutput, int8_t>::value ||
326
+ platform::is_same<ElementOutput, uint8_t>::value,
327
+ "ElementOutput needs to be 4 or 8 bit (unsigned) int.");
328
+
329
+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8 || ElementsPerAccess == 4),
330
+ "ElementsPerAccess needs to be 16 or 8.");
331
+
332
+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
333
+ WarpShape,
334
+ InstructionShape,
335
+ int32_t,
336
+ 32,
337
+ cutlass::sizeof_bits<ElementOutput>::value,
338
+ ElementsPerAccess,
339
+ 8
340
+ >;
341
+
342
+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
343
+ WarpShape,
344
+ InstructionShape,
345
+ int32_t,
346
+ layout::RowMajor
347
+ >;
348
+
349
+ using WarpTileIterator = typename platform::conditional<
350
+ (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
351
+ WarpTileIteratorNotMixed,
352
+ WarpTileIteratorMixed>::type;
353
+
354
+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
355
+ ThreadMap,
356
+ int32_t,
357
+ 32,
358
+ cutlass::sizeof_bits<ElementOutput>::value,
359
+ ElementsPerAccess,
360
+ 8
361
+ >;
362
+
363
+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
364
+ ThreadMap,
365
+ int32_t
366
+ >;
367
+
368
+ using SharedLoadIterator = typename platform::conditional<
369
+ (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
370
+ SharedLoadIteratorNotMixed,
371
+ SharedLoadIteratorMixed>::type;
372
+
373
+ static int const kFragmentsPerIteration = 1;
374
+ };
375
+
376
+ /// Partial specialization for float_e4m3_t <= float x 16/8 epilogues avoids shared memory bank conflicts.
377
+ /// Threadblock::kN = 256 still has bank conflicts.
378
+ template <
379
+ int ElementsPerAccess,
380
+ typename ThreadblockShape,
381
+ typename WarpShape,
382
+ typename InstructionShape,
383
+ typename ThreadMap
384
+ >
385
+ struct DefaultIteratorsTensorOp<
386
+ cutlass::float_e4m3_t,
387
+ float,
388
+ ElementsPerAccess,
389
+ ThreadblockShape,
390
+ WarpShape,
391
+ InstructionShape,
392
+ ThreadMap> {
393
+
394
+ using ElementOutput = cutlass::float_e4m3_t;
395
+
396
+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8 || ElementsPerAccess == 4),
397
+ "ElementsPerAccess needs to be 16 or 8.");
398
+
399
+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
400
+ WarpShape,
401
+ InstructionShape,
402
+ float,
403
+ 32,
404
+ cutlass::sizeof_bits<ElementOutput>::value,
405
+ ElementsPerAccess,
406
+ 8
407
+ >;
408
+
409
+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
410
+ WarpShape,
411
+ InstructionShape,
412
+ float,
413
+ layout::RowMajor
414
+ >;
415
+
416
+ using WarpTileIterator = typename platform::conditional<
417
+ (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
418
+ WarpTileIteratorNotMixed,
419
+ WarpTileIteratorMixed>::type;
420
+
421
+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
422
+ ThreadMap,
423
+ float,
424
+ 32,
425
+ cutlass::sizeof_bits<ElementOutput>::value,
426
+ ElementsPerAccess,
427
+ 8
428
+ >;
429
+
430
+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
431
+ ThreadMap,
432
+ float
433
+ >;
434
+
435
+ using SharedLoadIterator = typename platform::conditional<
436
+ (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
437
+ SharedLoadIteratorNotMixed,
438
+ SharedLoadIteratorMixed>::type;
439
+
440
+ static int const kFragmentsPerIteration = 1;
441
+ };
442
+
443
+ /// Partial specialization for float_e5m2_t <= float x 16/8 epilogues avoids shared memory bank conflicts.
444
+ /// Threadblock::kN = 256 still has bank conflicts.
445
+ template <
446
+ int ElementsPerAccess,
447
+ typename ThreadblockShape,
448
+ typename WarpShape,
449
+ typename InstructionShape,
450
+ typename ThreadMap
451
+ >
452
+ struct DefaultIteratorsTensorOp<
453
+ cutlass::float_e5m2_t,
454
+ float,
455
+ ElementsPerAccess,
456
+ ThreadblockShape,
457
+ WarpShape,
458
+ InstructionShape,
459
+ ThreadMap> {
460
+
461
+ using ElementOutput = cutlass::float_e5m2_t;
462
+
463
+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8 || ElementsPerAccess == 4),
464
+ "ElementsPerAccess needs to be 16 or 8.");
465
+
466
+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
467
+ WarpShape,
468
+ InstructionShape,
469
+ float,
470
+ 32,
471
+ cutlass::sizeof_bits<ElementOutput>::value,
472
+ ElementsPerAccess,
473
+ 8
474
+ >;
475
+
476
+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
477
+ WarpShape,
478
+ InstructionShape,
479
+ float,
480
+ layout::RowMajor
481
+ >;
482
+
483
+ using WarpTileIterator = typename platform::conditional<
484
+ (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
485
+ WarpTileIteratorNotMixed,
486
+ WarpTileIteratorMixed>::type;
487
+
488
+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
489
+ ThreadMap,
490
+ float,
491
+ 32,
492
+ cutlass::sizeof_bits<ElementOutput>::value,
493
+ ElementsPerAccess,
494
+ 8
495
+ >;
496
+
497
+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
498
+ ThreadMap,
499
+ float
500
+ >;
501
+
502
+ using SharedLoadIterator = typename platform::conditional<
503
+ (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8) || (ElementsPerAccess == 4),
504
+ SharedLoadIteratorNotMixed,
505
+ SharedLoadIteratorMixed>::type;
506
+
507
+ static int const kFragmentsPerIteration = 1;
508
+ };
509
+
510
+ } // namespace detail
511
+
512
+ ////////////////////////////////////////////////////////////////////////////////
513
+
514
+ /// Defines sensible defaults for epilogues for TensorOps.
515
+ template <
516
+ typename Shape_,
517
+ typename WarpMmaTensorOp_,
518
+ int PartitionsK,
519
+ typename OutputOp_,
520
+ int ElementsPerAccess,
521
+ bool ScatterD = false,
522
+ typename PermuteDLayout = layout::NoPermute,
523
+ conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity,
524
+ int Rank = 4
525
+ >
526
+ struct DefaultEpilogueTensorOp {
527
+
528
+ using Shape = Shape_;
529
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
530
+ static int const kPartitionsK = PartitionsK;
531
+ using OutputOp = OutputOp_;
532
+ static int const kElementsPerAccess = ElementsPerAccess;
533
+
534
+ using ElementOutput = typename OutputOp::ElementOutput;
535
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
536
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
537
+ static conv::StrideSupport const kStrideSupport = StrideSupport;
538
+ static int const kRank = Rank;
539
+
540
+ //
541
+ // Thread map
542
+ //
543
+
544
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
545
+ Shape,
546
+ typename WarpMmaTensorOp::Shape,
547
+ kPartitionsK,
548
+ ElementOutput,
549
+ kElementsPerAccess
550
+ >::Type;
551
+
552
+ static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
553
+
554
+ using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
555
+ OutputTileThreadMap,
556
+ ElementOutput,
557
+ ScatterD,
558
+ PermuteDLayout,
559
+ UseCUDAStore
560
+ >;
561
+
562
+ using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv<
563
+ OutputTileThreadMap,
564
+ ElementOutput,
565
+ ScatterD,
566
+ PermuteDLayout,
567
+ UseCUDAStore,
568
+ kRank
569
+ >;
570
+
571
+ using OutputTileIterator = typename platform::conditional<StrideSupport == cutlass::conv::StrideSupport::kUnity,
572
+ PackedOutputTileIterator,
573
+ StridedOutputTileIterator>::type;
574
+
575
+ using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
576
+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
577
+ typename WarpMmaTensorOp::Shape,
578
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
579
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
580
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
581
+ LayoutC>,
582
+ cutlass::epilogue::warp::FragmentIteratorTensorOp<
583
+ typename WarpMmaTensorOp::Shape,
584
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
585
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
586
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
587
+ LayoutC> >::type;
588
+
589
+ /// Support several implementations depending on structure of epilogue
590
+ using DefaultIterators = detail::DefaultIteratorsTensorOp<
591
+ ElementOutput,
592
+ ElementAccumulator,
593
+ kElementsPerAccess,
594
+ Shape,
595
+ typename WarpMmaTensorOp::Shape,
596
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
597
+ typename OutputTileThreadMap::CompactedThreadMap
598
+ >;
599
+
600
+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
601
+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
602
+
603
+ /// Hard-coded padding elements added
604
+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
605
+
606
+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
607
+
608
+ //
609
+ // Define the epilogue
610
+ //
611
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
612
+ Shape,
613
+ WarpMmaTensorOp,
614
+ kPartitionsK,
615
+ OutputTileIterator,
616
+ AccumulatorFragmentIterator,
617
+ WarpTileIterator,
618
+ SharedLoadIterator,
619
+ OutputOp,
620
+ Padding,
621
+ kFragmentsPerIteration
622
+ >;
623
+ };
624
+
625
+ ////////////////////////////////////////////////////////////////////////////////
626
+
627
+ /// Defines sensible defaults for epilogues for TensorOps.
628
+ template <
629
+ typename Shape_,
630
+ typename WarpMmaTensorOp_,
631
+ int PartitionsK,
632
+ typename OutputOp_,
633
+ int ElementsPerAccess
634
+ >
635
+ struct DefaultEpilogueTensorOpStridedDgrad {
636
+
637
+ using Shape = Shape_;
638
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
639
+ static int const kPartitionsK = PartitionsK;
640
+ using OutputOp = OutputOp_;
641
+ static int const kElementsPerAccess = ElementsPerAccess;
642
+
643
+ using ElementOutput = typename OutputOp::ElementOutput;
644
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
645
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
646
+
647
+ //
648
+ // Thread map
649
+ //
650
+
651
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
652
+ Shape,
653
+ typename WarpMmaTensorOp::Shape,
654
+ kPartitionsK,
655
+ ElementOutput,
656
+ kElementsPerAccess
657
+ >::Type;
658
+
659
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
660
+ OutputTileThreadMap,
661
+ ElementOutput
662
+ >;
663
+
664
+ using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
665
+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
666
+ typename WarpMmaTensorOp::Shape,
667
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
668
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
669
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
670
+ LayoutC>,
671
+ cutlass::epilogue::warp::FragmentIteratorTensorOp<
672
+ typename WarpMmaTensorOp::Shape,
673
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
674
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
675
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
676
+ LayoutC> >::type;
677
+
678
+ /// Support several implementations depending on structure of epilogue
679
+ using DefaultIterators = detail::DefaultIteratorsTensorOp<
680
+ ElementOutput,
681
+ ElementAccumulator,
682
+ kElementsPerAccess,
683
+ Shape,
684
+ typename WarpMmaTensorOp::Shape,
685
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
686
+ typename OutputTileThreadMap::CompactedThreadMap
687
+ >;
688
+
689
+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
690
+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
691
+
692
+ /// Hard-coded padding elements added
693
+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
694
+
695
+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
696
+
697
+ //
698
+ // Define the epilogue
699
+ //
700
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
701
+ Shape,
702
+ WarpMmaTensorOp,
703
+ kPartitionsK,
704
+ OutputTileIterator,
705
+ AccumulatorFragmentIterator,
706
+ WarpTileIterator,
707
+ SharedLoadIterator,
708
+ OutputOp,
709
+ Padding,
710
+ kFragmentsPerIteration
711
+ >;
712
+ };
713
+
714
+
715
+ ////////////////////////////////////////////////////////////////////////////////
716
+
717
+ /// Defines sensible defaults for epilogues for TensorOps.
718
+ template <
719
+ int Rank,
720
+ typename Shape_,
721
+ typename WarpMmaTensorOp_,
722
+ int PartitionsK,
723
+ typename OutputOp_,
724
+ int ElementsPerAccess
725
+ >
726
+ struct DefaultEpilogueTensorOpAffineRankN {
727
+
728
+ using Shape = Shape_;
729
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
730
+ static int const kPartitionsK = PartitionsK;
731
+ using OutputOp = OutputOp_;
732
+ static int const kElementsPerAccess = ElementsPerAccess;
733
+
734
+ using ElementOutput = typename OutputOp::ElementOutput;
735
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
736
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
737
+
738
+ //
739
+ // Thread map
740
+ //
741
+
742
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
743
+ Shape,
744
+ typename WarpMmaTensorOp::Shape,
745
+ kPartitionsK,
746
+ ElementOutput,
747
+ kElementsPerAccess
748
+ >::Type;
749
+
750
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
751
+ OutputTileThreadMap,
752
+ ElementOutput,
753
+ Rank
754
+ >;
755
+
756
+ // Map to the row major iterator since the iterator selection for affineN is the same.
757
+ using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
758
+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
759
+ typename WarpMmaTensorOp::Shape,
760
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
761
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
762
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
763
+ layout::RowMajor>,
764
+ cutlass::epilogue::warp::FragmentIteratorTensorOp<
765
+ typename WarpMmaTensorOp::Shape,
766
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
767
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
768
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
769
+ layout::RowMajor> >::type;
770
+
771
+ /// Support several implementations depending on structure of epilogue
772
+ using DefaultIterators = detail::DefaultIteratorsTensorOp<
773
+ ElementOutput,
774
+ ElementAccumulator,
775
+ kElementsPerAccess,
776
+ Shape,
777
+ typename WarpMmaTensorOp::Shape,
778
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
779
+ typename OutputTileThreadMap::CompactedThreadMap
780
+ >;
781
+
782
+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
783
+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
784
+
785
+ /// Hard-coded padding elements added
786
+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
787
+
788
+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
789
+
790
+ //
791
+ // Define the epilogue
792
+ //
793
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
794
+ Shape,
795
+ WarpMmaTensorOp,
796
+ kPartitionsK,
797
+ OutputTileIterator,
798
+ AccumulatorFragmentIterator,
799
+ WarpTileIterator,
800
+ SharedLoadIterator,
801
+ OutputOp,
802
+ Padding,
803
+ kFragmentsPerIteration
804
+ >;
805
+ };
806
+
807
+ ////////////////////////////////////////////////////////////////////////////////
808
+ /// Defines sensible defaults for epilogues for TensorOps which uses
809
+ /// intereleaved output layout. For this case, shared memory is not needed.
810
+ template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
811
+ typename OutputOp_, int ElementsPerAccess, int InterleavedK,
812
+ bool isSplitK = false>
813
+ struct DefaultInterleavedEpilogueTensorOp {
814
+ using Shape = Shape_;
815
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
816
+ static int const kPartitionsK = PartitionsK;
817
+ using OutputOp = OutputOp_;
818
+ static int const kElementsPerAccess = ElementsPerAccess;
819
+
820
+ using ElementOutput = typename OutputOp::ElementOutput;
821
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
822
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
823
+
824
+ //
825
+ // Thread map
826
+ //
827
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::
828
+ DefaultInterleavedThreadMapTensorOp<
829
+ Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput,
830
+ kElementsPerAccess, InterleavedK>::Type;
831
+
832
+ using OutputTileIterator =
833
+ cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator<
834
+ OutputTileThreadMap, ElementOutput, InterleavedK>;
835
+
836
+ using AccumulatorFragmentIterator =
837
+ cutlass::epilogue::warp::FragmentIteratorTensorOp<
838
+ typename WarpMmaTensorOp::Shape,
839
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
840
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
841
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
842
+ LayoutC>;
843
+
844
+ //
845
+ // Define the epilogue
846
+ //
847
+ using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue<
848
+ Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator,
849
+ AccumulatorFragmentIterator, OutputOp, InterleavedK>;
850
+ };
851
+
852
+ ////////////////////////////////////////////////////////////////////////////////
853
+
854
+ /// Defines sensible defaults for epilogues for TensorOps which uses
855
+ /// intereleaved output layout. For this case, shared memory is not needed.
856
+ template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
857
+ typename OutputOp_, int ElementsPerAccess, int InterleavedK,
858
+ bool isSplitK = false>
859
+ struct DefaultInterleavedConvEpilogue {
860
+ using Shape = Shape_;
861
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
862
+ static int const kPartitionsK = PartitionsK;
863
+ using OutputOp = OutputOp_;
864
+ static int const kElementsPerAccess = ElementsPerAccess;
865
+
866
+ using ElementOutput = typename OutputOp::ElementOutput;
867
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
868
+
869
+ //
870
+ // Thread map
871
+ //
872
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::
873
+ DefaultInterleavedConvThreadMapTensorOp<
874
+ Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput,
875
+ kElementsPerAccess, InterleavedK>::Type;
876
+
877
+ using OutputTileIterator =
878
+ cutlass::epilogue::threadblock::InterleavedConvPredicatedTileIterator<
879
+ OutputTileThreadMap, ElementOutput, InterleavedK>;
880
+
881
+ using AccumulatorFragmentIterator =
882
+ cutlass::epilogue::warp::FragmentIteratorTensorOp<
883
+ typename WarpMmaTensorOp::Shape,
884
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
885
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
886
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
887
+ // can reuse the gemm version here to do element selection
888
+ layout::ColumnMajorInterleaved<InterleavedK>>;
889
+
890
+ //
891
+ // Define the epilogue
892
+ //
893
+ using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue<
894
+ Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator,
895
+ AccumulatorFragmentIterator, OutputOp, InterleavedK>;
896
+ };
897
+
898
+ ////////////////////////////////////////////////////////////////////////////////
899
+
900
+ } // namespace threadblock
901
+ } // namespace epilogue
902
+ } // namespace cutlass
903
+
904
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+
38
+ */
39
+
40
+ #pragma once
41
+
42
+ #include "cutlass/cutlass.h"
43
+ #include "cutlass/numeric_types.h"
44
+ #include "cutlass/array.h"
45
+
46
+ #include "cutlass/gemm/gemm.h"
47
+
48
+ #include "cutlass/epilogue/thread/linear_combination.h"
49
+ #include "cutlass/epilogue/thread/linear_combination_clamp.h"
50
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
51
+ #include "cutlass/epilogue/thread/linear_combination_gelu.h"
52
+ #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
53
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
54
+
55
+ #include "cutlass/epilogue/thread/conversion_op.h"
56
+ #include "cutlass/epilogue/thread/reduction_op.h"
57
+
58
+ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
59
+
60
+ #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
61
+ #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
62
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
63
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
64
+ #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
65
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h"
66
+ #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
67
+ #include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
68
+
69
+ #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
70
+ #include "cutlass/epilogue/threadblock/epilogue.h"
71
+ #include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
72
+
73
+ ////////////////////////////////////////////////////////////////////////////////
74
+
75
+ namespace cutlass {
76
+ namespace epilogue {
77
+ namespace threadblock {
78
+
79
+ ////////////////////////////////////////////////////////////////////////////////
80
+
81
+ /// Defines sensible defaults for epilogues for TensorOps.
82
+ template <
83
+ typename Shape_,
84
+ typename WarpMmaTensorOp_,
85
+ int PartitionsK,
86
+ typename OutputOp_,
87
+ int ElementsPerAccess,
88
+ /// Is for a symmetric kernel
89
+ BlasMode BlasMode_ = BlasMode::kGemm
90
+ >
91
+ struct DefaultEpilogueTensorOpBlas3 {
92
+
93
+ using Shape = Shape_;
94
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
95
+ static int const kPartitionsK = PartitionsK;
96
+ using OutputOp = OutputOp_;
97
+ static int const kElementsPerAccess = ElementsPerAccess;
98
+ static BlasMode const kBlasMode = BlasMode_;
99
+
100
+ using ElementOutput = typename OutputOp::ElementOutput;
101
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
102
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
103
+
104
+ //
105
+ // Thread map
106
+ //
107
+
108
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
109
+ Shape,
110
+ typename WarpMmaTensorOp::Shape,
111
+ kPartitionsK,
112
+ ElementOutput,
113
+ kElementsPerAccess
114
+ >::Type;
115
+
116
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3<
117
+ OutputTileThreadMap,
118
+ ElementOutput,
119
+ kBlasMode
120
+ >;
121
+
122
+ using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
123
+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
124
+ typename WarpMmaTensorOp::Shape,
125
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
126
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
127
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
128
+ LayoutC>,
129
+ cutlass::epilogue::warp::FragmentIteratorTensorOp<
130
+ typename WarpMmaTensorOp::Shape,
131
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
132
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
133
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
134
+ LayoutC> >::type;
135
+
136
+ /// Support several implementations depending on structure of epilogue
137
+ using DefaultIterators = detail::DefaultIteratorsTensorOp<
138
+ ElementOutput,
139
+ ElementAccumulator,
140
+ kElementsPerAccess,
141
+ Shape,
142
+ typename WarpMmaTensorOp::Shape,
143
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
144
+ typename OutputTileThreadMap::CompactedThreadMap
145
+ >;
146
+
147
+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
148
+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
149
+
150
+ /// Hard-coded padding elements added
151
+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
152
+
153
+ //
154
+ // Define the epilogue
155
+ //
156
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
157
+ Shape,
158
+ WarpMmaTensorOp,
159
+ kPartitionsK,
160
+ OutputTileIterator,
161
+ AccumulatorFragmentIterator,
162
+ WarpTileIterator,
163
+ SharedLoadIterator,
164
+ OutputOp,
165
+ Padding
166
+ >;
167
+ };
168
+
169
+ ////////////////////////////////////////////////////////////////////////////////
170
+
171
+ } // namespace threadblock
172
+ } // namespace epilogue
173
+ } // namespace cutlass
174
+
175
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops on Volta.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+
45
+ #include "cutlass/gemm/gemm.h"
46
+
47
+ #include "cutlass/epilogue/thread/linear_combination.h"
48
+ #include "cutlass/epilogue/thread/linear_combination_clamp.h"
49
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
50
+ #include "cutlass/epilogue/thread/linear_combination_gelu.h"
51
+ #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
52
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
53
+
54
+ #include "cutlass/epilogue/thread/conversion_op.h"
55
+ #include "cutlass/epilogue/thread/reduction_op.h"
56
+
57
+ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
58
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
59
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
60
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
61
+ #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
62
+
63
+ #include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h"
64
+ #include "cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h"
65
+ #include "cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h"
66
+
67
+ #include "cutlass/epilogue/threadblock/epilogue.h"
68
+
69
+ #include "cutlass/layout/permute.h"
70
+
71
+ /////////////////////////////////////////////////////////////////////////////////////////////////
72
+
73
+ namespace cutlass {
74
+ namespace epilogue {
75
+ namespace threadblock {
76
+
77
+ /////////////////////////////////////////////////////////////////////////////////////////////////
78
+
79
+ /// Defines sensible defaults for epilogues for TensorOps.
80
+ template <
81
+ typename Shape_,
82
+ typename WarpMmaTensorOp_,
83
+ int PartitionsK,
84
+ typename OutputOp_,
85
+ int ElementsPerAccess,
86
+ bool ScatterD = false,
87
+ typename PermuteDLayout = layout::NoPermute
88
+ >
89
+ struct DefaultEpilogueVoltaTensorOp {
90
+
91
+ using Shape = Shape_;
92
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
93
+ static int const kPartitionsK = PartitionsK;
94
+ using OutputOp = OutputOp_;
95
+ static int const kElementsPerAccess = ElementsPerAccess;
96
+
97
+ using ElementOutput = typename OutputOp::ElementOutput;
98
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
99
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
100
+
101
+ //
102
+ // Thread map
103
+ //
104
+
105
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp<
106
+ Shape,
107
+ typename WarpMmaTensorOp::Shape,
108
+ kPartitionsK,
109
+ ElementOutput,
110
+ kElementsPerAccess,
111
+ ElementAccumulator
112
+ >::Type;
113
+
114
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
115
+ OutputTileThreadMap,
116
+ ElementOutput,
117
+ ScatterD,
118
+ PermuteDLayout
119
+ >;
120
+
121
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
122
+ typename WarpMmaTensorOp::Shape,
123
+ gemm::GemmShape<32, 32, 4>,
124
+ ElementAccumulator,
125
+ LayoutC
126
+ >;
127
+
128
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
129
+ typename WarpMmaTensorOp::Shape,
130
+ gemm::GemmShape<32, 32, 4>,
131
+ ElementAccumulator,
132
+ LayoutC
133
+ >;
134
+
135
+ static int const kSharedMemAlignment = sizeof_bits<ElementAccumulator>::value * WarpTileIterator::kElementsPerAccess / 8;
136
+
137
+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B");
138
+
139
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
140
+ typename OutputTileThreadMap::CompactedThreadMap,
141
+ ElementAccumulator,
142
+ kSharedMemAlignment
143
+ >;
144
+
145
+ /// Hard-coded padding elements added
146
+ using Padding = typename WarpTileIterator::Padding;
147
+
148
+ //
149
+ // Define the epilogue
150
+ //
151
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
152
+ Shape,
153
+ WarpMmaTensorOp,
154
+ kPartitionsK,
155
+ OutputTileIterator,
156
+ AccumulatorFragmentIterator,
157
+ WarpTileIterator,
158
+ SharedLoadIterator,
159
+ OutputOp,
160
+ Padding
161
+ >;
162
+ };
163
+
164
+ /////////////////////////////////////////////////////////////////////////////////////////////////
165
+
166
+ /// Defines sensible defaults for epilogues for TensorOps.
167
+ template <
168
+ typename Shape_,
169
+ typename WarpMmaTensorOp_,
170
+ int PartitionsK,
171
+ typename OutputOp_,
172
+ int ElementsPerAccess
173
+ >
174
+ struct DefaultEpilogueVoltaTensorOpStridedDgrad {
175
+
176
+ using Shape = Shape_;
177
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
178
+ static int const kPartitionsK = PartitionsK;
179
+ using OutputOp = OutputOp_;
180
+ static int const kElementsPerAccess = ElementsPerAccess;
181
+
182
+ using ElementOutput = typename OutputOp::ElementOutput;
183
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
184
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
185
+
186
+ //
187
+ // Thread map
188
+ //
189
+
190
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp<
191
+ Shape,
192
+ typename WarpMmaTensorOp::Shape,
193
+ kPartitionsK,
194
+ ElementOutput,
195
+ kElementsPerAccess,
196
+ ElementAccumulator
197
+ >::Type;
198
+
199
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
200
+ OutputTileThreadMap,
201
+ ElementOutput
202
+ >;
203
+
204
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
205
+ typename WarpMmaTensorOp::Shape,
206
+ gemm::GemmShape<32, 32, 4>,
207
+ ElementAccumulator,
208
+ LayoutC
209
+ >;
210
+
211
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
212
+ typename WarpMmaTensorOp::Shape,
213
+ gemm::GemmShape<32, 32, 4>,
214
+ ElementAccumulator,
215
+ LayoutC
216
+ >;
217
+
218
+ static int const kSharedMemAlignment = sizeof_bits<ElementAccumulator>::value * WarpTileIterator::kElementsPerAccess / 8;
219
+
220
+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B");
221
+
222
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
223
+ typename OutputTileThreadMap::CompactedThreadMap,
224
+ ElementAccumulator,
225
+ kSharedMemAlignment
226
+ >;
227
+
228
+ /// Hard-coded padding elements added
229
+ using Padding = typename WarpTileIterator::Padding;
230
+
231
+ //
232
+ // Define the epilogue
233
+ //
234
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
235
+ Shape,
236
+ WarpMmaTensorOp,
237
+ kPartitionsK,
238
+ OutputTileIterator,
239
+ AccumulatorFragmentIterator,
240
+ WarpTileIterator,
241
+ SharedLoadIterator,
242
+ OutputOp,
243
+ Padding
244
+ >;
245
+ };
246
+
247
+ /////////////////////////////////////////////////////////////////////////////////////////////////
248
+
249
+ /// Defines sensible defaults for epilogues for TensorOps.
250
+ template <
251
+ int Rank,
252
+ typename Shape_,
253
+ typename WarpMmaTensorOp_,
254
+ int PartitionsK,
255
+ typename OutputOp_,
256
+ int ElementsPerAccess
257
+ >
258
+ struct DefaultEpilogueVoltaTensorOpAffineRankN {
259
+
260
+ using Shape = Shape_;
261
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
262
+ static int const kPartitionsK = PartitionsK;
263
+ using OutputOp = OutputOp_;
264
+ static int const kElementsPerAccess = ElementsPerAccess;
265
+
266
+ using ElementOutput = typename OutputOp::ElementOutput;
267
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
268
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
269
+
270
+ //
271
+ // Thread map
272
+ //
273
+
274
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp<
275
+ Shape,
276
+ typename WarpMmaTensorOp::Shape,
277
+ kPartitionsK,
278
+ ElementOutput,
279
+ kElementsPerAccess,
280
+ ElementAccumulator
281
+ >::Type;
282
+
283
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
284
+ OutputTileThreadMap,
285
+ ElementOutput,
286
+ Rank
287
+ >;
288
+
289
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
290
+ typename WarpMmaTensorOp::Shape,
291
+ gemm::GemmShape<32, 32, 4>,
292
+ ElementAccumulator,
293
+ LayoutC
294
+ >;
295
+
296
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
297
+ typename WarpMmaTensorOp::Shape,
298
+ gemm::GemmShape<32, 32, 4>,
299
+ ElementAccumulator,
300
+ LayoutC
301
+ >;
302
+
303
+ static int const kSharedMemAlignment = sizeof_bits<ElementAccumulator>::value * WarpTileIterator::kElementsPerAccess / 8;
304
+
305
+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B");
306
+
307
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
308
+ typename OutputTileThreadMap::CompactedThreadMap,
309
+ ElementAccumulator,
310
+ kSharedMemAlignment
311
+ >;
312
+
313
+ /// Hard-coded padding elements added
314
+ using Padding = typename WarpTileIterator::Padding;
315
+
316
+ //
317
+ // Define the epilogue
318
+ //
319
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
320
+ Shape,
321
+ WarpMmaTensorOp,
322
+ kPartitionsK,
323
+ OutputTileIterator,
324
+ AccumulatorFragmentIterator,
325
+ WarpTileIterator,
326
+ SharedLoadIterator,
327
+ OutputOp,
328
+ Padding
329
+ >;
330
+ };
331
+
332
+ /////////////////////////////////////////////////////////////////////////////////////////////////
333
+ } // namespace threadblock
334
+ } // namespace epilogue
335
+ } // namespace cutlass
336
+
337
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_absmax.h ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Default configuration for epilogue computing absolute maximum of output and auxiliary outputs.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/array.h"
41
+
42
+ #include "cutlass/gemm/gemm.h"
43
+
44
+ #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
45
+ #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
46
+ #include "cutlass/epilogue/threadblock/epilogue.h"
47
+ #include "cutlass/epilogue/threadblock/epilogue_with_absmax.h"
48
+
49
+ #include "cutlass/layout/permute.h"
50
+
51
+ ////////////////////////////////////////////////////////////////////////////////
52
+
53
+ namespace cutlass {
54
+ namespace epilogue {
55
+ namespace threadblock {
56
+
57
+ ////////////////////////////////////////////////////////////////////////////////
58
+
59
+ /// Defines sensible defaults for absolute-maximum-computing epilogues with TensorOps
60
+ template <
61
+ typename Shape,
62
+ typename WarpMmaTensorOp,
63
+ int PartitionsK,
64
+ typename ElementOutput,
65
+ typename ElementAuxOutput,
66
+ typename ElementVector,
67
+ typename OutputOp,
68
+ int ElementsPerAccess,
69
+ bool ScatterD = false,
70
+ typename PermuteDLayout = layout::NoPermute
71
+ >
72
+ struct DefaultEpilogueWithAbsMax {
73
+
74
+ /// Use defaults related to the existing epilogue
75
+ using Base = DefaultEpilogueTensorOp<
76
+ Shape,
77
+ WarpMmaTensorOp,
78
+ PartitionsK,
79
+ OutputOp,
80
+ ElementsPerAccess
81
+ >;
82
+
83
+ //
84
+ // Stores the output
85
+ //
86
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
87
+ typename Base::OutputTileThreadMap,
88
+ ElementOutput,
89
+ ScatterD,
90
+ PermuteDLayout
91
+ >;
92
+
93
+ //
94
+ // Stores the auxiliary output
95
+ //
96
+ using AuxOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
97
+ typename Base::OutputTileThreadMap,
98
+ ElementAuxOutput,
99
+ ScatterD,
100
+ PermuteDLayout
101
+ >;
102
+
103
+ /// Define the epilogue
104
+ using Epilogue = EpilogueWithAbsMax<
105
+ Shape,
106
+ WarpMmaTensorOp,
107
+ PartitionsK,
108
+ OutputTileIterator,
109
+ AuxOutputTileIterator,
110
+ ElementVector,
111
+ typename Base::AccumulatorFragmentIterator,
112
+ typename Base::WarpTileIterator,
113
+ typename Base::SharedLoadIterator,
114
+ OutputOp,
115
+ typename Base::Padding,
116
+ Base::kFragmentsPerIteration
117
+ >;
118
+ };
119
+
120
+ ////////////////////////////////////////////////////////////////////////////////
121
+
122
+ } // namespace threadblock
123
+ } // namespace epilogue
124
+ } // namespace cutlass
125
+
126
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+
45
+ #include "cutlass/gemm/gemm.h"
46
+
47
+ #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
48
+ #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
49
+ #include "cutlass/epilogue/threadblock/epilogue.h"
50
+ #include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
51
+ #include "cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h"
52
+
53
+ #include "cutlass/layout/permute.h"
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////
56
+
57
+ namespace cutlass {
58
+ namespace epilogue {
59
+ namespace threadblock {
60
+ ////////////////////////////////////////////////////////////////////////////////
61
+
62
+ /// Defines sensible defaults for epilogues for SimtOps.
63
+ template <
64
+ typename Shape,
65
+ typename WarpMmaSimt,
66
+ typename ElementOutput,
67
+ typename ElementTensor,
68
+ typename ElementVector,
69
+ typename OutputOp,
70
+ int ElementsPerAccess,
71
+ bool ScatterD = false,
72
+ typename PermuteDLayout = layout::NoPermute,
73
+ conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity,
74
+ int Rank = 4
75
+ >
76
+ struct DefaultEpilogueWithBroadcastSimt {
77
+
78
+ static conv::StrideSupport const kStrideSupport = StrideSupport;
79
+ static int const kRank = Rank;
80
+
81
+ static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
82
+
83
+ /// Use defaults related to the existing epilogue
84
+ using Base = DefaultEpilogueSimt<
85
+ Shape,
86
+ WarpMmaSimt,
87
+ OutputOp,
88
+ ElementsPerAccess
89
+ >;
90
+
91
+ using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
92
+ typename Base::OutputTileThreadMap,
93
+ ElementOutput,
94
+ ScatterD,
95
+ PermuteDLayout,
96
+ UseCUDAStore
97
+ >;
98
+
99
+ using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv<
100
+ typename Base::OutputTileThreadMap,
101
+ ElementOutput,
102
+ ScatterD,
103
+ PermuteDLayout,
104
+ UseCUDAStore,
105
+ kRank
106
+ >;
107
+
108
+ //
109
+ // Stores the result z = (y = GEMM(A, B, C), broadcast)
110
+ //
111
+ using OutputTileIterator = typename platform::conditional<StrideSupport == cutlass::conv::StrideSupport::kUnity,
112
+ PackedOutputTileIterator,
113
+ StridedOutputTileIterator>::type;
114
+
115
+ //
116
+ // Additional tensor tile iterator - stores t = Elementwise(z)
117
+ //
118
+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
119
+ typename Base::OutputTileThreadMap,
120
+ ElementTensor
121
+ >;
122
+ /// Define the epilogue
123
+ using Epilogue = EpilogueWithBroadcast<
124
+ Shape,
125
+ WarpMmaSimt,
126
+ Base::kPartitionsK,
127
+ OutputTileIterator,
128
+ TensorTileIterator,
129
+ ElementVector,
130
+ typename Base::AccumulatorFragmentIterator,
131
+ typename Base::WarpTileIterator,
132
+ typename Base::SharedLoadIterator,
133
+ OutputOp,
134
+ typename Base::Padding
135
+ >;
136
+ };
137
+ ////////////////////////////////////////////////////////////////////////////////
138
+
139
+ /// Defines sensible defaults for strided dgrad epilogues for SimtOps.
140
+ template <
141
+ typename Shape,
142
+ typename WarpMmaSimt,
143
+ typename ElementOutput,
144
+ typename ElementTensor,
145
+ typename ElementVector,
146
+ typename OutputOp,
147
+ int ElementsPerAccess,
148
+ bool ScatterD = false,
149
+ typename PermuteDLayout = layout::NoPermute
150
+ >
151
+ struct DefaultEpilogueWithBroadcastSimtStridedDgrad {
152
+
153
+ /// Use defaults related to the existing epilogue
154
+ using Base = DefaultEpilogueSimtStridedDgrad<
155
+ Shape,
156
+ WarpMmaSimt,
157
+ OutputOp,
158
+ ElementsPerAccess
159
+ >;
160
+
161
+ //
162
+ // Stores the result z = (y = GEMM(A, B, C), broadcast)
163
+ //
164
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
165
+ typename Base::OutputTileThreadMap,
166
+ ElementOutput
167
+ >;
168
+
169
+ //
170
+ // Additional tensor tile iterator - stores t = Elementwise(z)
171
+ //
172
+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
173
+ typename Base::OutputTileThreadMap,
174
+ ElementTensor
175
+ >;
176
+
177
+ /// Define the epilogue
178
+ using Epilogue = EpilogueWithBroadcast<
179
+ Shape,
180
+ WarpMmaSimt,
181
+ Base::kPartitionsK,
182
+ OutputTileIterator,
183
+ TensorTileIterator,
184
+ ElementVector,
185
+ typename Base::AccumulatorFragmentIterator,
186
+ typename Base::WarpTileIterator,
187
+ typename Base::SharedLoadIterator,
188
+ OutputOp,
189
+ typename Base::Padding
190
+ >;
191
+ };
192
+ ////////////////////////////////////////////////////////////////////////////////
193
+
194
+ /// Defines sensible defaults for epilogues for TensorOps.
195
+ template <
196
+ typename Shape,
197
+ typename WarpMmaTensorOp,
198
+ int PartitionsK,
199
+ typename ElementOutput,
200
+ typename ElementTensor,
201
+ typename ElementVector,
202
+ typename OutputOp,
203
+ int ElementsPerAccess,
204
+ bool ScatterD = false,
205
+ typename PermuteDLayout = layout::NoPermute
206
+ >
207
+ struct DefaultEpilogueWithBroadcastTensorOp {
208
+
209
+ /// Use defaults related to the existing epilogue
210
+ using Base = DefaultEpilogueTensorOp<
211
+ Shape,
212
+ WarpMmaTensorOp,
213
+ PartitionsK,
214
+ OutputOp,
215
+ ElementsPerAccess
216
+ >;
217
+
218
+ //
219
+ // Stores the result z = (y = GEMM(A, B, C), broadcast)
220
+ //
221
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
222
+ typename Base::OutputTileThreadMap,
223
+ ElementOutput,
224
+ ScatterD,
225
+ PermuteDLayout
226
+ >;
227
+
228
+ //
229
+ // Additional tensor tile iterator - stores t = Elementwise(z)
230
+ //
231
+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
232
+ typename Base::OutputTileThreadMap,
233
+ ElementTensor
234
+ >;
235
+
236
+ /// Define the epilogue
237
+ using Epilogue = EpilogueWithBroadcast<
238
+ Shape,
239
+ WarpMmaTensorOp,
240
+ PartitionsK,
241
+ OutputTileIterator,
242
+ TensorTileIterator,
243
+ ElementVector,
244
+ typename Base::AccumulatorFragmentIterator,
245
+ typename Base::WarpTileIterator,
246
+ typename Base::SharedLoadIterator,
247
+ OutputOp,
248
+ typename Base::Padding,
249
+ Base::kFragmentsPerIteration
250
+ >;
251
+ };
252
+
253
+ ////////////////////////////////////////////////////////////////////////////////
254
+
255
+ /// Defines sensible defaults for streamk epilogues for TensorOps.
256
+ template <
257
+ typename Shape,
258
+ typename WarpMmaTensorOp,
259
+ int PartitionsK,
260
+ typename ElementOutput,
261
+ typename ElementTensor,
262
+ typename ElementVector,
263
+ typename OutputOp,
264
+ int ElementsPerAccess,
265
+ bool ScatterD = false,
266
+ typename PermuteDLayout = layout::NoPermute
267
+ >
268
+ struct DefaultStreamkEpilogueWithBroadcastTensorOp {
269
+
270
+ /// Use defaults related to the existing epilogue
271
+ using Base = DefaultEpilogueTensorOp<
272
+ Shape,
273
+ WarpMmaTensorOp,
274
+ PartitionsK,
275
+ OutputOp,
276
+ ElementsPerAccess
277
+ >;
278
+
279
+ //
280
+ // Stores the result z = (y = GEMM(A, B, C), broadcast)
281
+ //
282
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
283
+ typename Base::OutputTileThreadMap,
284
+ ElementOutput,
285
+ ScatterD,
286
+ PermuteDLayout
287
+ >;
288
+
289
+ //
290
+ // Additional tensor tile iterator - stores t = Elementwise(z)
291
+ //
292
+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
293
+ typename Base::OutputTileThreadMap,
294
+ ElementTensor
295
+ >;
296
+
297
+ /// Define the epilogue
298
+ using Epilogue = EpilogueStreamkWithBroadcast<
299
+ Shape,
300
+ WarpMmaTensorOp,
301
+ PartitionsK,
302
+ OutputTileIterator,
303
+ TensorTileIterator,
304
+ ElementVector,
305
+ typename Base::AccumulatorFragmentIterator,
306
+ typename Base::WarpTileIterator,
307
+ typename Base::SharedLoadIterator,
308
+ OutputOp,
309
+ typename Base::Padding,
310
+ Base::kFragmentsPerIteration
311
+ >;
312
+ };
313
+
314
+ ////////////////////////////////////////////////////////////////////////////////
315
+
316
+ /// Defines sensible defaults for epilogues for VoltaTensorOps.
317
+ template <
318
+ typename Shape,
319
+ typename WarpMmaTensorOp,
320
+ int PartitionsK,
321
+ typename ElementOutput,
322
+ typename ElementTensor,
323
+ typename ElementVector,
324
+ typename OutputOp,
325
+ int ElementsPerAccess
326
+ >
327
+ struct DefaultEpilogueWithBroadcastVoltaTensorOp {
328
+
329
+ /// Use defaults related to the existing epilogue
330
+ using Base = DefaultEpilogueVoltaTensorOp<
331
+ Shape,
332
+ WarpMmaTensorOp,
333
+ PartitionsK,
334
+ OutputOp,
335
+ ElementsPerAccess
336
+ >;
337
+
338
+ //
339
+ // Stores the result z = (y = GEMM(A, B, C), broadcast)
340
+ //
341
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
342
+ typename Base::OutputTileThreadMap,
343
+ ElementOutput
344
+ >;
345
+
346
+ //
347
+ // Additional tensor tile iterator - stores t = Elementwise(z)
348
+ //
349
+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
350
+ typename Base::OutputTileThreadMap,
351
+ ElementTensor
352
+ >;
353
+
354
+ /// Define the epilogue
355
+ using Epilogue = EpilogueWithBroadcast<
356
+ Shape,
357
+ WarpMmaTensorOp,
358
+ PartitionsK,
359
+ OutputTileIterator,
360
+ TensorTileIterator,
361
+ ElementVector,
362
+ typename Base::AccumulatorFragmentIterator,
363
+ typename Base::WarpTileIterator,
364
+ typename Base::SharedLoadIterator,
365
+ OutputOp,
366
+ typename Base::Padding
367
+ >;
368
+ };
369
+
370
+ ////////////////////////////////////////////////////////////////////////////////
371
+
372
+ } // namespace threadblock
373
+ } // namespace epilogue
374
+ } // namespace cutlass
375
+
376
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+
33
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
34
+
35
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
36
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
37
+
38
+ */
39
+
40
+ #pragma once
41
+
42
+ #include "cutlass/cutlass.h"
43
+ #include "cutlass/numeric_types.h"
44
+ #include "cutlass/array.h"
45
+
46
+ #include "cutlass/gemm/gemm.h"
47
+
48
+ #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
49
+ #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
50
+ #include "cutlass/epilogue/threadblock/epilogue.h"
51
+ #include "cutlass/epilogue/threadblock/epilogue_with_reduction.h"
52
+
53
+ #include "cutlass/layout/permute.h"
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////
56
+
57
+ namespace cutlass {
58
+ namespace epilogue {
59
+ namespace threadblock {
60
+
61
+ ////////////////////////////////////////////////////////////////////////////////
62
+
63
+ /// Defines sensible defaults for epilogues for TensorOps.
64
+ template <
65
+ typename Shape,
66
+ typename WarpMmaTensorOp,
67
+ int PartitionsK,
68
+ typename ElementOutput,
69
+ typename OutputOp,
70
+ typename ReductionOp,
71
+ int ElementsPerAccess,
72
+ bool ScatterD = false,
73
+ typename PermuteDLayout = layout::NoPermute
74
+ >
75
+ struct DefaultEpilogueWithReductionTensorOp {
76
+
77
+ /// Use defaults related to the existing epilogue
78
+ using Base = DefaultEpilogueTensorOp<
79
+ Shape,
80
+ WarpMmaTensorOp,
81
+ PartitionsK,
82
+ OutputOp,
83
+ ElementsPerAccess
84
+ >;
85
+
86
+ /// Additional tensor tile iterator
87
+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
88
+ typename Base::OutputTileThreadMap,
89
+ typename OutputOp::ElementTensor
90
+ >;
91
+
92
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
93
+ typename Base::OutputTileThreadMap,
94
+ ElementOutput,
95
+ ScatterD,
96
+ PermuteDLayout
97
+ >;
98
+
99
+ /// Define the epilogue
100
+ using Epilogue = EpilogueWithReduction<
101
+ Shape,
102
+ WarpMmaTensorOp,
103
+ PartitionsK,
104
+ OutputTileIterator,
105
+ TensorTileIterator,
106
+ typename WarpMmaTensorOp::ElementC,
107
+ typename Base::AccumulatorFragmentIterator,
108
+ typename Base::WarpTileIterator,
109
+ typename Base::SharedLoadIterator,
110
+ typename Base::OutputOp,
111
+ ReductionOp,
112
+ typename Base::Padding
113
+ >;
114
+ };
115
+
116
+ ////////////////////////////////////////////////////////////////////////////////
117
+
118
+ /// Defines sensible defaults for epilogues for TensorOps.
119
+ template <
120
+ typename Shape,
121
+ typename WarpMmaTensorOp,
122
+ int PartitionsK,
123
+ typename ElementOutput,
124
+ typename OutputOp,
125
+ typename ReductionOp,
126
+ int ElementsPerAccess,
127
+ bool ScatterD = false,
128
+ typename PermuteDLayout = layout::NoPermute
129
+ >
130
+ struct DefaultEpilogueWithReductionVoltaTensorOp {
131
+
132
+ /// Use defaults related to the existing epilogue
133
+ using Base = DefaultEpilogueVoltaTensorOp<
134
+ Shape,
135
+ WarpMmaTensorOp,
136
+ PartitionsK,
137
+ OutputOp,
138
+ ElementsPerAccess
139
+ >;
140
+
141
+ /// Additional tensor tile iterator
142
+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
143
+ typename Base::OutputTileThreadMap,
144
+ typename OutputOp::ElementTensor
145
+ >;
146
+
147
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
148
+ typename Base::OutputTileThreadMap,
149
+ ElementOutput,
150
+ ScatterD,
151
+ PermuteDLayout
152
+ >;
153
+
154
+ /// Define the epilogue
155
+ using Epilogue = EpilogueWithReduction<
156
+ Shape,
157
+ WarpMmaTensorOp,
158
+ PartitionsK,
159
+ OutputTileIterator,
160
+ TensorTileIterator,
161
+ typename WarpMmaTensorOp::ElementC,
162
+ typename Base::AccumulatorFragmentIterator,
163
+ typename Base::WarpTileIterator,
164
+ typename Base::SharedLoadIterator,
165
+ typename Base::OutputOp,
166
+ ReductionOp,
167
+ typename Base::Padding
168
+ >;
169
+ };
170
+
171
+ ////////////////////////////////////////////////////////////////////////////////
172
+
173
+ } // namespace threadblock
174
+ } // namespace epilogue
175
+ } // namespace cutlass
176
+
177
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using WMMA.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+
45
+ #include "cutlass/gemm/gemm.h"
46
+
47
+ #include "cutlass/epilogue/thread/linear_combination.h"
48
+ #include "cutlass/epilogue/thread/linear_combination_clamp.h"
49
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
50
+ #include "cutlass/epilogue/thread/linear_combination_gelu.h"
51
+ #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
52
+ #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
53
+
54
+ #include "cutlass/epilogue/thread/conversion_op.h"
55
+ #include "cutlass/epilogue/thread/reduction_op.h"
56
+
57
+ #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
58
+
59
+ #include "cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h"
60
+ #include "cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h"
61
+ #include "cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h"
62
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
63
+ #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
64
+
65
+ #include "cutlass/epilogue/threadblock/epilogue.h"
66
+
67
+ #include "cutlass/layout/permute.h"
68
+
69
+ ////////////////////////////////////////////////////////////////////////////////
70
+
71
+ namespace cutlass {
72
+ namespace epilogue {
73
+ namespace threadblock {
74
+
75
+ ////////////////////////////////////////////////////////////////////////////////
76
+
77
+ /// Defines sensible defaults for epilogues for WMMA TensorOps.
78
+ template <
79
+ typename Shape_,
80
+ typename WarpMmaTensorOp_,
81
+ int PartitionsK,
82
+ typename OutputOp_,
83
+ int ElementsPerAccess,
84
+ bool ScatterD = false,
85
+ typename PermuteDLayout = layout::NoPermute
86
+ >
87
+ struct DefaultEpilogueWmmaTensorOp {
88
+
89
+ using Shape = Shape_;
90
+ using WarpMmaTensorOp = WarpMmaTensorOp_;
91
+ static int const kPartitionsK = PartitionsK;
92
+ using OutputOp = OutputOp_;
93
+ static int const kElementsPerAccess = ElementsPerAccess;
94
+
95
+ using ElementOutput = typename OutputOp::ElementOutput;
96
+ using LayoutC = typename WarpMmaTensorOp::LayoutC;
97
+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
98
+
99
+ //
100
+ // Thread map
101
+ //
102
+
103
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapWmmaTensorOp<
104
+ Shape,
105
+ typename WarpMmaTensorOp::Shape,
106
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
107
+ kPartitionsK,
108
+ ElementOutput,
109
+ kElementsPerAccess
110
+ >::Type;
111
+
112
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
113
+ OutputTileThreadMap,
114
+ ElementOutput,
115
+ ScatterD,
116
+ PermuteDLayout
117
+ >;
118
+
119
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp<
120
+ typename WarpMmaTensorOp::Shape,
121
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
122
+ typename WarpMmaTensorOp::Policy::Operator::ElementC,
123
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
124
+ LayoutC
125
+ >;
126
+
127
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorWmmaTensorOp<
128
+ typename WarpMmaTensorOp::Shape,
129
+ typename WarpMmaTensorOp::Policy::Operator::Shape,
130
+ typename WarpMmaTensorOp::Policy::Operator::FragmentC,
131
+ LayoutC
132
+ >;
133
+
134
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
135
+ typename OutputTileThreadMap::CompactedThreadMap,
136
+ ElementAccumulator
137
+ >;
138
+
139
+ /// Hard-coded padding elements added
140
+ using Padding = typename WarpTileIterator::Padding;
141
+
142
+ //
143
+ // Define the epilogue
144
+ //
145
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
146
+ Shape,
147
+ WarpMmaTensorOp,
148
+ kPartitionsK,
149
+ OutputTileIterator,
150
+ AccumulatorFragmentIterator,
151
+ WarpTileIterator,
152
+ SharedLoadIterator,
153
+ OutputOp,
154
+ Padding
155
+ >;
156
+ };
157
+
158
+
159
+ ////////////////////////////////////////////////////////////////////////////////
160
+
161
+ } // namespace threadblock
162
+ } // namespace epilogue
163
+ } // namespace cutlass
164
+
165
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief
33
+
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
39
+ #include "cutlass/gemm/gemm.h"
40
+
41
+ /////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ namespace cutlass {
44
+ namespace epilogue {
45
+ namespace threadblock {
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ /// Defines the optimal thread map for SIMT accumulator layouts
50
+ template <
51
+ typename ThreadblockShape_,
52
+ typename WarpShape_,
53
+ typename MmaSimtPolicy_,
54
+ int PartitionsK,
55
+ typename Element_,
56
+ int ElementsPerAccess
57
+ >
58
+ struct DefaultThreadMapSimt {
59
+
60
+ using ThreadblockShape = ThreadblockShape_;
61
+ using WarpShape = WarpShape_;
62
+ using MmaSimtPolicy = MmaSimtPolicy_;
63
+ static int const kPartitionsK = PartitionsK;
64
+ using Element = Element_;
65
+ static int const kElementsPerAccess = ElementsPerAccess;
66
+
67
+ //
68
+ // Definitions
69
+ //
70
+
71
+ struct Detail {
72
+
73
+ static int const kWarpSize = 32;
74
+
75
+ static_assert(
76
+ !(ThreadblockShape::kM % WarpShape::kM) &&
77
+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
78
+
79
+ /// Number of warps
80
+ using WarpCount = gemm::GemmShape<
81
+ ThreadblockShape::kM / WarpShape::kM,
82
+ ThreadblockShape::kN / WarpShape::kN,
83
+ kPartitionsK
84
+ >;
85
+
86
+ /// Computes number of thread-level matrix multiplies are needed to span a warp
87
+ static int const kGroupCount =
88
+ WarpShape::kM / (MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM);
89
+
90
+ /// Number of participating threads
91
+ static int const kThreads = WarpCount::kCount * kWarpSize;
92
+
93
+ /// Number of iterations
94
+ static int const kIterations = MmaSimtPolicy::LaneMmaShape::kM * kGroupCount;
95
+ };
96
+
97
+ //
98
+ // ThreadMap
99
+ //
100
+
101
+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
102
+ using Type = OutputTileOptimalThreadMap<
103
+ OutputTileShape< // Shape
104
+ ThreadblockShape::kN,
105
+ 1,
106
+ MmaSimtPolicy::WarpShape::kRow,
107
+ Detail::WarpCount::kM,
108
+ 1>,
109
+ OutputTileShape< // Count
110
+ 1,
111
+ MmaSimtPolicy::LaneMmaShape::kM,
112
+ Detail::kGroupCount,
113
+ 1,
114
+ Detail::kIterations>,
115
+ Detail::kThreads,
116
+ kElementsPerAccess,
117
+ sizeof_bits<Element>::value
118
+ >;
119
+ };
120
+
121
+ /////////////////////////////////////////////////////////////////////////////////////////////////
122
+
123
+ } // namespace threadblock
124
+ } // namespace epilogue
125
+ } // namespace cutlass
126
+
127
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief
33
+
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
39
+ #include "cutlass/gemm/gemm.h"
40
+ #include "cutlass/layout/pitch_linear.h"
41
+
42
+ ////////////////////////////////////////////////////////////////////////////////
43
+
44
+ namespace cutlass {
45
+ namespace epilogue {
46
+ namespace threadblock {
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////
49
+
50
+ /// Defines the optimal thread map for TensorOp accumulator layouts
51
+ template <
52
+ typename ThreadblockShape_,
53
+ typename WarpShape_,
54
+ int PartitionsK,
55
+ typename Element_,
56
+ int ElementsPerAccess
57
+ >
58
+ struct DefaultThreadMapTensorOp {
59
+
60
+ using ThreadblockShape = ThreadblockShape_;
61
+ using WarpShape = WarpShape_;
62
+ static int const kPartitionsK = PartitionsK;
63
+ using Element = Element_;
64
+ static int const kElementsPerAccess = ElementsPerAccess;
65
+
66
+ //
67
+ // Definitions
68
+ //
69
+
70
+ struct Detail {
71
+
72
+ /// Tensor Operations fundamentally perform operations on 8 rows
73
+ static int const kTensorOpRows = 8;
74
+ static int const kWarpSize = 32;
75
+
76
+ static_assert(
77
+ !(ThreadblockShape::kM % WarpShape::kM) &&
78
+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
79
+
80
+ /// Number of warps
81
+ using WarpCount = gemm::GemmShape<
82
+ ThreadblockShape::kM / WarpShape::kM,
83
+ ThreadblockShape::kN / WarpShape::kN,
84
+ kPartitionsK
85
+ >;
86
+
87
+ /// Number of participating threads
88
+ static int const kThreads = WarpCount::kCount * kWarpSize;
89
+ };
90
+
91
+ //
92
+ // ThreadMap
93
+ //
94
+
95
+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
96
+ using Type = OutputTileOptimalThreadMap <
97
+ OutputTileShape<ThreadblockShape::kN, Detail::kTensorOpRows, Detail::WarpCount::kM, 1, 1>,
98
+ OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>,
99
+ Detail::kThreads,
100
+ kElementsPerAccess,
101
+ sizeof_bits<Element>::value
102
+ >;
103
+ };
104
+
105
+ ////////////////////////////////////////////////////////////////////////////////
106
+
107
+ /// Defines the optimal thread map for TensorOp accumulator layouts
108
+ template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
109
+ typename Element_, int ElementsPerAccess, int InterleavedK>
110
+ struct DefaultInterleavedThreadMapTensorOp {
111
+ using ThreadblockShape = ThreadblockShape_;
112
+ using WarpShape = WarpShape_;
113
+ static int const kPartitionsK = PartitionsK;
114
+ using Element = Element_;
115
+ static int const kElementsPerAccess = ElementsPerAccess;
116
+ static int const kInterleavedK = InterleavedK;
117
+
118
+ //
119
+ // Definitions
120
+ //
121
+
122
+ struct Detail {
123
+ /// Tensor Operations fundamentally perform operations on 8 rows
124
+ static int const kTensorOpRows = 8;
125
+ static int const kWarpSize = 32;
126
+
127
+ static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
128
+ !(ThreadblockShape::kN % WarpShape::kN),
129
+ "Divisibility");
130
+
131
+ /// Number of warps
132
+ using WarpCount =
133
+ gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
134
+ ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
135
+
136
+ /// Number of participating threads
137
+ static int const kThreads = WarpCount::kCount * kWarpSize;
138
+ };
139
+
140
+ //
141
+ // ThreadMap
142
+ //
143
+
144
+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept
145
+ /// InterleavedOutputTileThreadMap
146
+ using Type = InterleavedOutputTileThreadMap<
147
+ layout::PitchLinearShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
148
+ layout::PitchLinearShape<WarpShape::kM / Detail::kTensorOpRows,
149
+ WarpShape::kN / InterleavedK>,
150
+ Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
151
+ };
152
+
153
+
154
+ ////////////////////////////////////////////////////////////////////////////////
155
+
156
+ /// Defines the optimal thread map for TensorOp accumulator layouts
157
+ template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
158
+ typename Element_, int ElementsPerAccess, int InterleavedK>
159
+ struct DefaultInterleavedConvThreadMapTensorOp {
160
+ using ThreadblockShape = ThreadblockShape_;
161
+ using WarpShape = WarpShape_;
162
+ static int const kPartitionsK = PartitionsK;
163
+ using Element = Element_;
164
+ static int const kElementsPerAccess = ElementsPerAccess;
165
+ static int const kInterleavedK = InterleavedK;
166
+
167
+ //
168
+ // Definitions
169
+ //
170
+
171
+ struct Detail {
172
+ /// Tensor Operations fundamentally perform operations on 8 rows
173
+ static int const kTensorOpRows = 8;
174
+ static int const kWarpSize = 32;
175
+
176
+ static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
177
+ !(ThreadblockShape::kN % WarpShape::kN),
178
+ "Divisibility");
179
+
180
+ /// Number of warps
181
+ using WarpCount =
182
+ gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
183
+ ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
184
+
185
+ /// Number of participating threads
186
+ static int const kThreads = WarpCount::kCount * kWarpSize;
187
+ };
188
+
189
+ //
190
+ // ThreadMap
191
+ //
192
+
193
+ /// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept
194
+ /// InterleavedOutputTileThreadMap
195
+ using Type = InterleavedConvOutputTileThreadMap<
196
+ MatrixShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
197
+ MatrixShape<WarpShape::kM / Detail::kTensorOpRows,
198
+ WarpShape::kN / InterleavedK>,
199
+ Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
200
+ };
201
+
202
+ ////////////////////////////////////////////////////////////////////////////////
203
+
204
+ } // namespace threadblock
205
+ } // namespace epilogue
206
+ } // namespace cutlass
207
+
208
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief
33
+
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
39
+ #include "cutlass/gemm/gemm.h"
40
+
41
+ /////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ namespace cutlass {
44
+ namespace epilogue {
45
+ namespace threadblock {
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ /// Defines the optimal thread map for TensorOp accumulator layouts
50
+ template <
51
+ typename ThreadblockShape,
52
+ typename WarpShape,
53
+ int PartitionsK,
54
+ typename ElementOutput,
55
+ int ElementsPerAccess,
56
+ typename ElementAccumulator
57
+ >
58
+ struct DefaultThreadMapVoltaTensorOp;
59
+
60
+ /////////////////////////////////////////////////////////////////////////////////////////////////
61
+
62
+ /// Defines the optimal thread map for TensorOp accumulator layouts
63
+ template <
64
+ typename ThreadblockShape_,
65
+ typename WarpShape_,
66
+ int PartitionsK,
67
+ typename ElementOutput_,
68
+ int ElementsPerAccess
69
+ >
70
+ struct DefaultThreadMapVoltaTensorOp<
71
+ ThreadblockShape_,
72
+ WarpShape_,
73
+ PartitionsK,
74
+ ElementOutput_,
75
+ ElementsPerAccess,
76
+ half_t> {
77
+
78
+ using ThreadblockShape = ThreadblockShape_;
79
+ using WarpShape = WarpShape_;
80
+ static int const kPartitionsK = PartitionsK;
81
+ using ElementOutput = ElementOutput_;
82
+ static int const kElementsPerAccess = ElementsPerAccess;
83
+ using ElementAccumulator = half_t;
84
+
85
+ //
86
+ // Definitions
87
+ //
88
+
89
+ struct Detail {
90
+
91
+ static int const kTensorOpRows = 16;
92
+ static int const kWarpSize = 32;
93
+ static int const kInterleavedTilesM = WarpShape::kM / 32;
94
+
95
+ static_assert(
96
+ !(ThreadblockShape::kM % WarpShape::kM) &&
97
+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
98
+
99
+ /// Number of warps
100
+ using WarpCount = gemm::GemmShape<
101
+ ThreadblockShape::kM / WarpShape::kM,
102
+ ThreadblockShape::kN / WarpShape::kN,
103
+ kPartitionsK
104
+ >;
105
+
106
+ /// Number of participating threads
107
+ static int const kThreads = WarpCount::kCount * kWarpSize;
108
+
109
+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<
110
+ ThreadblockShape::kN, // column
111
+ 4, // row
112
+ 4, // group
113
+ WarpCount::kM, // cluster
114
+ 1 // tile
115
+ >;
116
+
117
+ /// Number of iterations per subspace
118
+ using Count = cutlass::epilogue::threadblock::OutputTileShape<
119
+ 1, // column
120
+ 2, // row
121
+ kInterleavedTilesM, // group
122
+ 1, // cluster
123
+ WarpShape::kM / kTensorOpRows // iterations
124
+ >;
125
+ };
126
+
127
+ //
128
+ // ThreadMap
129
+ //
130
+
131
+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
132
+ using Type = OutputTileOptimalThreadMap <
133
+ typename Detail::Shape,
134
+ typename Detail::Count,
135
+ Detail::kThreads,
136
+ kElementsPerAccess,
137
+ sizeof_bits<ElementOutput>::value
138
+ >;
139
+ };
140
+
141
+ /////////////////////////////////////////////////////////////////////////////////////////////////
142
+
143
+ /// Defines the optimal thread map for TensorOp accumulator layouts
144
+ template <
145
+ typename ThreadblockShape_,
146
+ typename WarpShape_,
147
+ int PartitionsK,
148
+ typename ElementOutput_,
149
+ int ElementsPerAccess
150
+ >
151
+ struct DefaultThreadMapVoltaTensorOp<
152
+ ThreadblockShape_,
153
+ WarpShape_,
154
+ PartitionsK,
155
+ ElementOutput_,
156
+ ElementsPerAccess,
157
+ float> {
158
+
159
+ using ThreadblockShape = ThreadblockShape_;
160
+ using WarpShape = WarpShape_;
161
+ static int const kPartitionsK = PartitionsK;
162
+ using ElementOutput = ElementOutput_;
163
+ static int const kElementsPerAccess = ElementsPerAccess;
164
+ using ElementAccumulator = float;
165
+
166
+ //
167
+ // Definitions
168
+ //
169
+
170
+ struct Detail {
171
+
172
+ static int const kTensorOpRows = 16;
173
+ static int const kWarpSize = 32;
174
+ static int const kInterleavedTilesM = WarpShape::kM / 32;
175
+
176
+ static_assert(
177
+ !(ThreadblockShape::kM % WarpShape::kM) &&
178
+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
179
+
180
+ /// Number of warps
181
+ using WarpCount = gemm::GemmShape<
182
+ ThreadblockShape::kM / WarpShape::kM,
183
+ ThreadblockShape::kN / WarpShape::kN,
184
+ kPartitionsK
185
+ >;
186
+
187
+ /// Number of participating threads
188
+ static int const kThreads = WarpCount::kCount * kWarpSize;
189
+
190
+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<
191
+ ThreadblockShape::kN, // column
192
+ 4, // row
193
+ 4, // group
194
+ WarpCount::kM, // cluster
195
+ 1 // tile
196
+ >;
197
+
198
+ /// Number of iterations per subspace
199
+ using Count = cutlass::epilogue::threadblock::OutputTileShape<
200
+ 1, // column
201
+ 2, // row
202
+ kInterleavedTilesM, // group
203
+ 1, // cluster
204
+ WarpShape::kM / kTensorOpRows // iterations
205
+ >;
206
+ };
207
+
208
+ //
209
+ // ThreadMap
210
+ //
211
+
212
+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
213
+ using Type = OutputTileOptimalThreadMap <
214
+ typename Detail::Shape,
215
+ typename Detail::Count,
216
+ Detail::kThreads,
217
+ kElementsPerAccess,
218
+ sizeof_bits<ElementOutput>::value
219
+ >;
220
+ };
221
+
222
+ /////////////////////////////////////////////////////////////////////////////////////////////////
223
+
224
+ } // namespace threadblock
225
+ } // namespace epilogue
226
+ } // namespace cutlass
227
+
228
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief
33
+
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
39
+ #include "cutlass/gemm/gemm.h"
40
+ #include "cutlass/layout/pitch_linear.h"
41
+
42
+ ////////////////////////////////////////////////////////////////////////////////
43
+
44
+ namespace cutlass {
45
+ namespace epilogue {
46
+ namespace threadblock {
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////
49
+
50
+ /// Defines the optimal thread map for Wmma TensorOp accumulator layouts
51
+ template <
52
+ typename ThreadblockShape_,
53
+ typename WarpShape_,
54
+ typename InstructionShape_,
55
+ int PartitionsK,
56
+ typename Element_,
57
+ int ElementsPerAccess
58
+ >
59
+ struct DefaultThreadMapWmmaTensorOp {
60
+
61
+ using ThreadblockShape = ThreadblockShape_;
62
+ using WarpShape = WarpShape_;
63
+ using InstructionShape = InstructionShape_;
64
+ static int const kPartitionsK = PartitionsK;
65
+ using Element = Element_;
66
+ static int const kElementsPerAccess = ElementsPerAccess;
67
+
68
+ //
69
+ // Definitions
70
+ //
71
+
72
+ struct Detail {
73
+
74
+ /// Wmma Tensor Operations fundamentally perform operations on InstructionShape::kM rows
75
+ static int const kTensorOpRows = InstructionShape::kM;
76
+ static int const kWarpSize = 32;
77
+
78
+ static_assert(
79
+ !(ThreadblockShape::kM % WarpShape::kM) &&
80
+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
81
+
82
+ /// Number of warps
83
+ using WarpCount = gemm::GemmShape<
84
+ ThreadblockShape::kM / WarpShape::kM,
85
+ ThreadblockShape::kN / WarpShape::kN,
86
+ kPartitionsK
87
+ >;
88
+
89
+ /// Number of participating threads
90
+ static int const kThreads = WarpCount::kCount * kWarpSize;
91
+ };
92
+
93
+ //
94
+ // ThreadMap
95
+ //
96
+
97
+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
98
+ using Type = OutputTileOptimalThreadMap <
99
+ OutputTileShape<ThreadblockShape::kN, Detail::kTensorOpRows, Detail::WarpCount::kM, 1, 1>,
100
+ OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>,
101
+ Detail::kThreads,
102
+ kElementsPerAccess,
103
+ sizeof_bits<Element>::value
104
+ >;
105
+ };
106
+
107
+ ////////////////////////////////////////////////////////////////////////////////
108
+
109
+ } // namespace threadblock
110
+ } // namespace epilogue
111
+ } // namespace cutlass
112
+
113
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+ #include "cutlass/layout/matrix.h"
45
+ #include "cutlass/layout/tensor.h"
46
+ #include "cutlass/matrix_shape.h"
47
+ #include "cutlass/tensor_ref.h"
48
+ #include "cutlass/transform/pitch_linear_thread_map.h"
49
+ #include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
50
+ #include "cutlass/arch/arch.h"
51
+ #include "cutlass/arch/memory.h"
52
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
53
+
54
+ ////////////////////////////////////////////////////////////////////////////////
55
+
56
+ namespace cutlass {
57
+
58
+ ////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace epilogue {
61
+ namespace threadblock {
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////
64
+
65
+ template <typename Element_>
66
+ class DirectStoreEpilogueIterator {
67
+ public:
68
+
69
+ using Element = Element_;
70
+
71
+ using Layout = layout::RowMajor;
72
+ using TensorRef = TensorRef<Element, Layout>;
73
+ using ConstTensorRef = typename TensorRef::ConstTensorRef;
74
+
75
+ using Index = typename Layout::Index;
76
+ using LongIndex = typename Layout::LongIndex;
77
+ using TensorCoord = MatrixCoord;
78
+
79
+ static int const kElementsPerAccess = 1;
80
+
81
+ /// Uses a non-template class
82
+ struct Params : PredicatedTileIteratorParams {
83
+ using Base = PredicatedTileIteratorParams;
84
+
85
+ CUTLASS_HOST_DEVICE
86
+ Params() { }
87
+
88
+ CUTLASS_HOST_DEVICE
89
+ Params(Layout const &layout) {
90
+ stride = layout.stride(0) * sizeof(Element);
91
+ }
92
+
93
+ CUTLASS_HOST_DEVICE
94
+ Params(Base const &base) :
95
+ Base(base) { }
96
+ };
97
+
98
+ public:
99
+
100
+ //
101
+ // Data members
102
+ //
103
+
104
+ Element *pointer; // pointer to the output matrix
105
+
106
+ LongIndex stride; // stride in elements between rows
107
+
108
+ TensorCoord extent; // extent of output matrix
109
+
110
+ int thread_idx; // thread index
111
+
112
+ TensorCoord threadblock_offset;
113
+
114
+ public:
115
+
116
+ /// Constructor
117
+ CUTLASS_DEVICE
118
+ DirectStoreEpilogueIterator(
119
+ PredicatedTileIteratorParams const & params,
120
+ Element *pointer_,
121
+ TensorCoord extent_,
122
+ int thread_idx_,
123
+ TensorCoord threadblock_offset_ = TensorCoord(),
124
+ int const * indices = nullptr
125
+ ):
126
+ pointer(pointer_),
127
+ stride(params.stride / sizeof(Element)),
128
+ extent(extent_),
129
+ thread_idx(thread_idx_),
130
+ threadblock_offset(threadblock_offset_)
131
+ {
132
+
133
+ }
134
+ };
135
+
136
+ ///////////////////////////////////////////////////////////////////////////////
137
+
138
+ } // namespace threadblock
139
+ } // namespace epilogue
140
+ } // namespace cutlass
141
+
142
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue.h ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ The shared memory resource is time-sliced across warps.
38
+ */
39
+
40
+ #pragma once
41
+ #include "cutlass/cutlass.h"
42
+ #include CUDA_STD_HEADER(cassert)
43
+
44
+ #include "cutlass/numeric_types.h"
45
+ #include "cutlass/array.h"
46
+ #include "cutlass/layout/vector.h"
47
+ #include "cutlass/layout/tensor.h"
48
+ #include "cutlass/tensor_coord.h"
49
+ #include "cutlass/aligned_buffer.h"
50
+ #include "cutlass/functional.h"
51
+
52
+ #include "cutlass/gemm/gemm.h"
53
+
54
+ #include "cutlass/transform/pitch_linear_thread_map.h"
55
+ #include "cutlass/transform/threadblock/regular_tile_iterator.h"
56
+
57
+ #include "cutlass/epilogue/threadblock/epilogue_base.h"
58
+ #include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
59
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
60
+
61
+ ////////////////////////////////////////////////////////////////////////////////
62
+
63
+ namespace cutlass {
64
+ namespace epilogue {
65
+ namespace threadblock {
66
+
67
+
68
+ ////////////////////////////////////////////////////////////////////////////////
69
+
70
+ /// Epilogue operator
71
+ template <
72
+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
73
+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
74
+ int PartitionsK, ///< Number of partitions of the K dimension
75
+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
76
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
77
+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
78
+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
79
+ typename OutputOp_, ///< Output operator
80
+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
81
+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
82
+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
83
+ (!IsEpilogueFunctorHeavy<OutputOp_>::value)
84
+ >
85
+ class Epilogue :
86
+ public EpilogueBase<
87
+ Shape_,
88
+ typename WarpMmaOperator_::Shape,
89
+ PartitionsK,
90
+ AccumulatorFragmentIterator_,
91
+ WarpTileIterator_,
92
+ Padding_,
93
+ FragmentsPerPartition>,
94
+ public EpilogueBaseStreamK<
95
+ Shape_,
96
+ PartitionsK,
97
+ WarpMmaOperator_,
98
+ AccumulatorFragmentIterator_>
99
+ {
100
+
101
+ public:
102
+
103
+ using Base = EpilogueBase<
104
+ Shape_,
105
+ typename WarpMmaOperator_::Shape,
106
+ PartitionsK,
107
+ AccumulatorFragmentIterator_,
108
+ WarpTileIterator_,
109
+ Padding_,
110
+ FragmentsPerPartition>;
111
+
112
+ using BaseStreamK = EpilogueBaseStreamK<
113
+ Shape_,
114
+ PartitionsK,
115
+ WarpMmaOperator_,
116
+ AccumulatorFragmentIterator_>;
117
+
118
+ using Shape = Shape_;
119
+ using WarpMmaOperator = WarpMmaOperator_;
120
+ static int const kPartitionsK = PartitionsK;
121
+ using OutputTileIterator = OutputTileIterator_;
122
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
123
+ using WarpTileIterator = WarpTileIterator_;
124
+ using SharedLoadIterator = SharedLoadIterator_;
125
+ using OutputOp = OutputOp_;
126
+ using Padding = Padding_;
127
+ using Layout = layout::RowMajor;
128
+ using LongIndex = typename Layout::LongIndex;
129
+
130
+ /// Number of warps per block
131
+ using WarpCount = typename Base::WarpCount;
132
+
133
+ /// Number of threads per block
134
+ static int const kBlockThreads = 32 * WarpCount::kCount;
135
+
136
+ /// Per-thread accumulator tile type
137
+ using AccumulatorTile = typename Base::AccumulatorTile;
138
+
139
+ /// Numerical accumulation element type
140
+ using ElementAccumulator = typename WarpMmaOperator::ElementC;
141
+
142
+ /// Fragment type used by the accumulator tile's fragment iterator
143
+ using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
144
+
145
+ /// Output element
146
+ using ElementOutput = typename OutputTileIterator::Element;
147
+
148
+ /// Output access size
149
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
150
+
151
+ /// Tensor reference to destination tensor
152
+ using TensorRef = typename OutputTileIterator::TensorRef;
153
+
154
+ /// Tensor reference to sync tensor
155
+ using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
156
+
157
+ /// Const tensor reference to source tensor
158
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
159
+
160
+ /// Vector type used by the global output iterator
161
+ using OutputAccessType = Array<
162
+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
163
+
164
+ /// Vector type used by the shared output iterator
165
+ using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
166
+
167
+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
168
+
169
+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
170
+
171
+
172
+ public:
173
+
174
+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
175
+ "Mismatch between shared load iterator and output tile iterator.");
176
+
177
+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
178
+
179
+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
180
+ "Divisibility");
181
+
182
+ static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
183
+
184
+
185
+ public:
186
+
187
+ /// Aspect for when epilogue source is not needed
188
+ struct SourceAspectNotNeeded
189
+ {
190
+ /// Constructor
191
+ CUTLASS_DEVICE
192
+ SourceAspectNotNeeded()
193
+ {}
194
+
195
+ // No-op
196
+ CUTLASS_DEVICE
197
+ void load() { }
198
+
199
+ /// Invoke the output functor over each vector of output
200
+ CUTLASS_DEVICE
201
+ void apply_output_operator(
202
+ typename OutputTileIterator::Fragment &output_fragment,
203
+ OutputOp const &output_op,
204
+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
205
+ {
206
+ OutputAccessType *output_frag_ptr =
207
+ reinterpret_cast<OutputAccessType *>(&output_fragment);
208
+
209
+ AccumulatorAccessType const *compute_frag_ptr =
210
+ reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
211
+
212
+ int const kOutputOpIterations =
213
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
214
+
215
+ CUTLASS_PRAGMA_UNROLL
216
+ for (int i = 0; i < kOutputOpIterations; ++i)
217
+ {
218
+ // Call the output operator
219
+ output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
220
+ }
221
+ }
222
+ };
223
+
224
+
225
+ /// Aspect for when epilogue source is needed
226
+ struct SourceAspectNeeded
227
+ {
228
+ OutputTileIterator source_iterator;
229
+
230
+ typename OutputTileIterator::Fragment source_fragment;
231
+
232
+ /// Invoke the output functor over each vector of output
233
+ CUTLASS_DEVICE
234
+ static void apply_output_operator(
235
+ typename OutputTileIterator::Fragment &output_fragment,
236
+ OutputOp const &output_op,
237
+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
238
+ typename OutputTileIterator::Fragment const &source_fragment)
239
+ {
240
+ OutputAccessType *output_frag_ptr =
241
+ reinterpret_cast<OutputAccessType *>(&output_fragment);
242
+
243
+ AccumulatorAccessType const *compute_frag_ptr =
244
+ reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
245
+
246
+ OutputAccessType const *source_frag_ptr =
247
+ reinterpret_cast<OutputAccessType const *>(&source_fragment);
248
+
249
+ int const kOutputOpIterations =
250
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
251
+
252
+ CUTLASS_PRAGMA_UNROLL
253
+ for (int i = 0; i < kOutputOpIterations; ++i)
254
+ {
255
+ // Call the output operator
256
+ output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
257
+ }
258
+ }
259
+
260
+ /// Constructor
261
+ CUTLASS_DEVICE
262
+ SourceAspectNeeded(OutputTileIterator source_iterator) :
263
+ source_iterator(source_iterator)
264
+ {
265
+ source_fragment.clear();
266
+ }
267
+
268
+ // Load addend source fragment from global memory
269
+ CUTLASS_DEVICE
270
+ void load() {
271
+ source_iterator.load(source_fragment);
272
+ ++source_iterator;
273
+ }
274
+
275
+ /// Invoke the output functor over each vector of output
276
+ CUTLASS_DEVICE
277
+ void apply_output_operator(
278
+ typename OutputTileIterator::Fragment &output_fragment,
279
+ OutputOp const &output_op,
280
+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
281
+ {
282
+ apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
283
+ }
284
+ };
285
+
286
+
287
+ private:
288
+
289
+ /// Loads fragment from shared memory aligned with output tensor
290
+ SharedLoadIterator shared_load_iterator_;
291
+
292
+ /// Thread index in the threadblock
293
+ int thread_idx;
294
+
295
+ /// Warp index in the threadblock
296
+ int warp_idx;
297
+
298
+ public:
299
+
300
+ /// Constructor
301
+ CUTLASS_DEVICE
302
+ Epilogue(
303
+ typename Base::SharedStorage &shared_storage, ///< Shared storage object
304
+ int thread_idx, ///< ID of a thread within the threadblock
305
+ int warp_idx, ///< ID of warp within threadblock
306
+ int lane_idx) ///< Id of thread within warp
307
+ :
308
+ Base(shared_storage, thread_idx, warp_idx, lane_idx),
309
+ BaseStreamK(thread_idx),
310
+ shared_load_iterator_(shared_storage.reference(), thread_idx),
311
+ thread_idx(thread_idx),
312
+ warp_idx(warp_idx)
313
+ {}
314
+
315
+
316
+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace,
317
+ /// performing epilogue computations, writing to output
318
+ CUTLASS_DEVICE
319
+ void reduce(
320
+ int peer_idx_begin,
321
+ int peer_idx_end,
322
+ int reduce_fragment_idx,
323
+ void *element_workspace,
324
+ OutputOp const &output_op, ///< Output operator
325
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
326
+ OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
327
+ {
328
+ // Reduce peer accumulator fragments into one fragment
329
+ AccumulatorFragment accum_fragment;
330
+ BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
331
+
332
+ // Store fragment to shared memory
333
+ this->warp_tile_iterator_.store(accum_fragment);
334
+
335
+ __syncthreads();
336
+
337
+ // Initialize/load source-fragment data
338
+ typename OutputTileIterator::Fragment source_fragment;
339
+ source_fragment.clear();
340
+
341
+ if (output_op.is_source_needed())
342
+ {
343
+ source_iterator += reduce_fragment_idx;
344
+ source_iterator.load(source_fragment);
345
+ }
346
+
347
+ // Load fragment from shared memory
348
+ typename SharedLoadIterator::Fragment aligned_accum_fragment;
349
+ shared_load_iterator_.load(aligned_accum_fragment);
350
+
351
+ // Add fragments shared by other k partitions
352
+ if (kPartitionsK > 1)
353
+ {
354
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
355
+
356
+ CUTLASS_PRAGMA_UNROLL
357
+ for ( int i = 1; i < kPartitionsK; ++i) {
358
+ typename SharedLoadIterator::Fragment aligned_addend_fragment;
359
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
360
+ shared_load_iterator_.load(aligned_addend_fragment);
361
+ aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment);
362
+ }
363
+ }
364
+
365
+ // Compute the output result
366
+ typename OutputTileIterator::Fragment output_fragment;
367
+
368
+ // Apply the output operator
369
+ SourceAspectNeeded::apply_output_operator(
370
+ output_fragment,
371
+ output_op,
372
+ aligned_accum_fragment,
373
+ source_fragment);
374
+
375
+ // Store the final result
376
+ destination_iterator += reduce_fragment_idx;
377
+ destination_iterator.store(output_fragment);
378
+ }
379
+
380
+
381
+ /// Perform the epilogue computations and stream the result to global memory.
382
+ CUTLASS_DEVICE
383
+ void operator()(
384
+ OutputOp const &output_op, ///< Output operator
385
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
386
+ AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
387
+ {
388
+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
389
+ }
390
+
391
+
392
+ /// Perform the epilogue computations and stream the result to global memory. Implements
393
+ /// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
394
+ CUTLASS_DEVICE
395
+ void operator()(
396
+ OutputOp const &output_op, ///< Output operator
397
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
398
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
399
+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source
400
+ {
401
+ if (output_op.is_source_needed())
402
+ {
403
+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
404
+ }
405
+ else
406
+ {
407
+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
408
+ }
409
+ }
410
+
411
+
412
+ /// Perform the epilogue computations and stream the result to global memory. Implements a
413
+ /// single codepath, regardless of whether the output op requires addend data to be loaded
414
+ CUTLASS_DEVICE
415
+ void unified(
416
+ OutputOp const &output_op, ///< Output operator
417
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
418
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
419
+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source
420
+ {
421
+ if (!output_op.is_source_needed())
422
+ {
423
+ source_iterator.clear_mask();
424
+ __syncthreads(); // Dummy (CUDA 11.0)
425
+ }
426
+
427
+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
428
+ }
429
+
430
+ template<class Seq>
431
+ struct acc2smem;
432
+
433
+ template <size_t... Seq>
434
+ struct acc2smem<cutlass::index_sequence<Seq...>> {
435
+ template<int Advance>
436
+ CUTLASS_DEVICE
437
+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
438
+ WarpTileIterator &warp_tile_iterator) {
439
+ CUTLASS_PRAGMA_UNROLL
440
+ for (int i = 0; i < Advance; i++) {
441
+ ++accum_fragment_iterator;
442
+ }
443
+
444
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
445
+
446
+ accum_fragment_iterator.load(accum_fragment);
447
+ ++accum_fragment_iterator;
448
+ warp_tile_iterator.store(accum_fragment);
449
+ }
450
+
451
+ CUTLASS_DEVICE
452
+ static void push(size_t pos,
453
+ AccumulatorFragmentIterator const &iterator_begin,
454
+ WarpTileIterator &warp_tile_iterator) {
455
+ int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
456
+ }
457
+ };
458
+
459
+
460
+ /// Streams the result to global memory
461
+ template <typename SourceAspect>
462
+ CUTLASS_DEVICE
463
+ void operator()(
464
+ OutputOp const &output_op, ///< Output operator
465
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
466
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
467
+ SourceAspect source)
468
+ {
469
+ // Iterator over warp-level accumulator fragment
470
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
471
+
472
+ //
473
+ // Iterate over accumulator tile
474
+ //
475
+
476
+ #ifdef __clang__
477
+ #pragma clang diagnostic push
478
+ #pragma clang diagnostic ignored "-Wcuda-compat"
479
+ // Turn off clangs warning about loop unroll argument using parens.
480
+ #endif
481
+
482
+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
483
+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter)
484
+ {
485
+ //
486
+ // Load the source
487
+ //
488
+
489
+ source.load();
490
+ //
491
+ // Convert and store fragment
492
+ //
493
+
494
+ __syncthreads();
495
+
496
+ acc2smem<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
497
+ iter, accum_fragment_iterator, this->warp_tile_iterator_);
498
+
499
+ __syncthreads();
500
+
501
+ //
502
+ // Load fragments from shared memory
503
+ //
504
+
505
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
506
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
507
+
508
+ if (kPartitionsK > 1) {
509
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
510
+
511
+ CUTLASS_PRAGMA_UNROLL
512
+ for ( int i = 1; i < kPartitionsK; ++i) {
513
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
514
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
515
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
516
+ }
517
+
518
+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
519
+ }
520
+
521
+ //
522
+ // Compute the output result
523
+ //
524
+
525
+ typename OutputTileIterator::Fragment output_fragment;
526
+ source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0]);
527
+
528
+ //
529
+ // Store the final result
530
+ //
531
+
532
+ destination_iterator.store(output_fragment);
533
+ ++destination_iterator;
534
+ }
535
+
536
+ #ifdef __clang__
537
+ #pragma clang diagnostic pop
538
+ #endif
539
+ }
540
+ };
541
+
542
+ ////////////////////////////////////////////////////////////////////////////////
543
+
544
+ } // namespace threadblock
545
+ } // namespace epilogue
546
+ } // namespace cutlass
547
+
548
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+ #include "cutlass/cutlass.h"
41
+ #if !defined(__CUDACC_RTC__)
42
+ #include <type_traits>
43
+ #include <utility>
44
+ #endif
45
+ #include CUDA_STD_HEADER(cassert)
46
+
47
+ #include "cutlass/matrix_shape.h"
48
+ #include "cutlass/numeric_types.h"
49
+ #include "cutlass/array.h"
50
+ #include "cutlass/layout/vector.h"
51
+ #include "cutlass/layout/tensor.h"
52
+ #include "cutlass/tensor_coord.h"
53
+ #include "cutlass/aligned_buffer.h"
54
+
55
+ #include "cutlass/gemm/gemm.h"
56
+
57
+ #include "cutlass/transform/pitch_linear_thread_map.h"
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////
60
+
61
+ namespace cutlass {
62
+ namespace epilogue {
63
+ namespace threadblock {
64
+
65
+ ////////////////////////////////////////////////////////////////////////////////
66
+
67
+ //
68
+ // This is used for metaprogramming epilogue functors. If they define
69
+ // `static bool const kIsHeavy = true;`, then the epilogue functor itself is
70
+ // not inlined. This results in smaller code and is advantageous if the epilogue
71
+ // functor consists of many instructions.
72
+ //
73
+ // If the epilogue functor does not define `kIsHeavy` or if it is `false`, then
74
+ // the behavior from CUTLASS 2.5 and before is retained. The epilogue is fully
75
+ // unrolled and inlined.
76
+ //
77
+
78
+ template<class>
79
+ struct TypeSink { typedef void type; };
80
+
81
+ template<class T> using TypeSinkT = typename TypeSink<T>::type;
82
+
83
+ template<class T, class=void> struct IsEpilogueFunctorHeavy {
84
+ static bool const value = false;
85
+ };
86
+
87
+ template<class T> struct IsEpilogueFunctorHeavy<T, TypeSinkT< decltype( T::kIsHeavy ) > > {
88
+ static bool const value = T::kIsHeavy;
89
+ };
90
+
91
+ ////////////////////////////////////////////////////////////////////////////////
92
+
93
+ /// Base class for epilogues defining warp-level
94
+ template <
95
+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
96
+ typename WarpShape_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
97
+ int PartitionsK, ///< Number of partitions of the K dimension
98
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
99
+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
100
+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
101
+ int FragmentsPerIteration = 1
102
+ >
103
+ class EpilogueBase {
104
+ public:
105
+
106
+ using Shape = Shape_;
107
+ using WarpShape = WarpShape_;
108
+ static int const kPartitionsK = PartitionsK;
109
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
110
+ using WarpTileIterator = WarpTileIterator_;
111
+ using Padding = Padding_;
112
+
113
+ /// Output layout is always row-major
114
+ using Layout = layout::RowMajor;
115
+
116
+ /// The complete warp-level accumulator tile
117
+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
118
+
119
+ /// Accumulator element
120
+ using ElementAccumulator = typename AccumulatorTile::Element;
121
+
122
+ /// Number of warps
123
+ using WarpCount = gemm::GemmShape<
124
+ Shape::kM / WarpShape::kM,
125
+ Shape::kN / WarpShape::kN,
126
+ kPartitionsK
127
+ >;
128
+
129
+ /// Use this to control the granularity of one epilogue 'iteration'
130
+ static int const kFragmentsPerIteration = FragmentsPerIteration;
131
+
132
+ public:
133
+
134
+ /// Shared storage allocation needed by the epilogue
135
+ struct SharedStorage {
136
+
137
+ //
138
+ // Type definitions
139
+ //
140
+
141
+ /// Element type of shared memory
142
+ using Element = typename WarpTileIterator::Element;
143
+
144
+ /// Tensor reference to shared memory allocation
145
+ using TensorRef = typename WarpTileIterator::TensorRef;
146
+
147
+ /// Layout of shared memory allocation
148
+ using Layout = typename WarpTileIterator::Layout;
149
+
150
+ /// Logical shape of the shared memory tile written to by all warps.
151
+ using Shape = MatrixShape<
152
+ WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,
153
+ WarpCount::kN * WarpTileIterator::Shape::kColumn
154
+ >;
155
+
156
+ /// Shape of the shared memory allocation for the epilogue
157
+ using StorageShape = MatrixShape<
158
+ (Shape::kRow + Padding::kRow) * kFragmentsPerIteration,
159
+ Shape::kColumn + Padding::kColumn
160
+ >;
161
+
162
+ //
163
+ // Data members
164
+ //
165
+
166
+ AlignedBuffer<Element, StorageShape::kCount> storage;
167
+
168
+ //
169
+ // Methods
170
+ //
171
+
172
+ /// Returns a pointer to the shared memory buffer
173
+ CUTLASS_DEVICE
174
+ Element *data() {
175
+ return storage.data();
176
+ }
177
+
178
+ /// Returns a tensor reference to the shared memory buffer
179
+ CUTLASS_DEVICE
180
+ TensorRef reference() {
181
+ return TensorRef(
182
+ storage.data(),
183
+ Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
184
+ }
185
+ };
186
+
187
+ protected:
188
+
189
+ //
190
+ // Data members
191
+ //
192
+
193
+ SharedStorage &shared_storage_;
194
+
195
+ /// Stores a warp's fragment of accumulators to SMEM
196
+ WarpTileIterator warp_tile_iterator_;
197
+
198
+ public:
199
+
200
+ /// Constructor
201
+ CUTLASS_DEVICE
202
+ EpilogueBase(
203
+ SharedStorage &shared_storage, ///< Shared storage object
204
+ int thread_idx, ///< ID of a thread within the threadblock
205
+ int warp_idx, ///< ID of warp within threadblock
206
+ int lane_idx ///< Id of thread within warp
207
+ ):
208
+ shared_storage_(shared_storage),
209
+ warp_tile_iterator_(shared_storage.reference(), lane_idx) {
210
+
211
+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
212
+ //
213
+ // _m: the warp's position within the threadblock along the M dimension
214
+ // _n: the warp's position within the threadblock along the N dimension
215
+ // _k: the warp's position within the threadblock along the K dimension
216
+
217
+ int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
218
+ int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
219
+ int warp_m = warp_mn % WarpCount::kM;
220
+ int warp_n = warp_mn / WarpCount::kM;
221
+
222
+ MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
223
+
224
+ warp_tile_iterator_.add_tile_offset(warp_offset);
225
+ }
226
+ };
227
+
228
+ ////////////////////////////////////////////////////////////////////////////////
229
+
230
+ } // namespace threadblock
231
+ } // namespace epilogue
232
+ } // namespace cutlass
233
+
234
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Basic subset of epilogue functionality for supporting StreamK decompositions
33
+ */
34
+
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/functional.h"
40
+ #include "cutlass/block_striped.h"
41
+
42
+ ////////////////////////////////////////////////////////////////////////////////
43
+
44
+ namespace cutlass {
45
+ namespace epilogue {
46
+ namespace threadblock {
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////
49
+
50
+
51
+ /// StreamK epilogue functionality for cross-block accumulator fragment reduction
52
+ template <
53
+ typename Shape, ///< Shape of threadblock tile (concept: GemmShape)
54
+ int PartitionsK,
55
+ typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
56
+ typename AccumulatorFragmentIterator> ///< Iterator for enumerating fragments within the per-thread tile of raw accumulators
57
+ class EpilogueBaseStreamK
58
+ {
59
+
60
+ protected:
61
+
62
+ /// The per-thread tile of raw accumulators
63
+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
64
+
65
+ /// Number of warps
66
+ using WarpCount = gemm::GemmShape<
67
+ Shape::kM / WarpMmaOperator::Shape::kM,
68
+ Shape::kN / WarpMmaOperator::Shape::kN,
69
+ PartitionsK>;
70
+
71
+ /// Number of threads per block
72
+ static int const kBlockThreads = 32 * WarpCount::kCount;
73
+
74
+ /// Numerical accumulation element type
75
+ using ElementAccumulator = typename WarpMmaOperator::ElementC;
76
+
77
+ /// Fragment type used by the accumulator tile's fragment iterator
78
+ using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
79
+
80
+ public:
81
+
82
+ /// Number of AccumulatorTile fragments per thread
83
+ static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations;
84
+
85
+ protected:
86
+
87
+ /// Number of AccumulatorTile fragments per block output tile
88
+ static int const kOutputTileFragments = kBlockThreads * kAccumulatorFragments;
89
+
90
+ /// Block-striped transfer utility for sharing AccumulatorFragment
91
+ using BlockStripedT = BlockStriped<kBlockThreads, AccumulatorFragment>;
92
+
93
+ /// AccumulatorFragment stride in the shared workspace between different peer blocks (each thread block can share accumulators for up to two block output tiles)
94
+ static const int kPeerFragmentStride = kOutputTileFragments * 2;
95
+
96
+ public:
97
+
98
+ /// Workspace bytes per thread block
99
+ static size_t const kWorkspaceBytesPerBlock =sizeof(AccumulatorFragment) * kPeerFragmentStride;
100
+
101
+ public:
102
+
103
+ /// Thread index in the threadblock
104
+ int thread_idx;
105
+
106
+ public:
107
+
108
+ /// Constructor
109
+ CUTLASS_DEVICE
110
+ EpilogueBaseStreamK(
111
+ int thread_idx) ///< ID of a thread within the threadblock
112
+ :
113
+ thread_idx(thread_idx)
114
+ {}
115
+
116
+
117
+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace
118
+ CUTLASS_DEVICE
119
+ void reduce(
120
+ AccumulatorFragment &accum_fragment, ///< [out] sum of all shared accumulator fragments for these peer partials
121
+ int peer_idx_begin,
122
+ int peer_idx_end,
123
+ int reduce_fragment_idx,
124
+ void *workspace_ptr)
125
+ {
126
+ plus<AccumulatorFragment> add_fragments;
127
+
128
+ AccumulatorFragment *fragment_workspace = reinterpret_cast<AccumulatorFragment *>(workspace_ptr);
129
+
130
+ int fragment_offset = (peer_idx_begin * kPeerFragmentStride) + (reduce_fragment_idx * kBlockThreads);
131
+
132
+ // Load first peer fragment
133
+ BlockStripedT::load(accum_fragment, fragment_workspace + fragment_offset, this->thread_idx);
134
+
135
+ fragment_offset += kPeerFragmentStride; // Move to next peer
136
+ fragment_offset += kOutputTileFragments; // Move to the set of fragments for this peer's "non-started" output tile
137
+
138
+ // Reduce fragments from additional peers
139
+ #pragma unroll 2
140
+ for (; fragment_offset < peer_idx_end * kPeerFragmentStride; fragment_offset += kPeerFragmentStride)
141
+ {
142
+ // Load peer fragment
143
+ AccumulatorFragment addend_fragment;
144
+ BlockStripedT::load(addend_fragment, fragment_workspace + fragment_offset, this->thread_idx);
145
+
146
+ // Add peer fragment
147
+ accum_fragment = add_fragments(accum_fragment, addend_fragment);
148
+ }
149
+ }
150
+
151
+
152
+ /// Shares the accumulator set with peers in the global workspace
153
+ CUTLASS_DEVICE
154
+ void share(
155
+ int peer_idx,
156
+ void *workspace_ptr,
157
+ AccumulatorTile const &accumulators,
158
+ bool started_tile) ///< Whether this thread block computed the first work volume for the current output tile
159
+ {
160
+ AccumulatorFragment *fragment_workspace = reinterpret_cast<AccumulatorFragment *>(workspace_ptr);
161
+
162
+ int fragment_offset = peer_idx * kPeerFragmentStride;
163
+
164
+ if (!started_tile) {
165
+ // Move to the set of fragments for the "non-started" output tile
166
+ fragment_offset += kOutputTileFragments;
167
+ }
168
+
169
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
170
+
171
+ // Convert raw accumulator tile to fragments and store
172
+ CUTLASS_PRAGMA_UNROLL
173
+ for (int iter = 0; iter < kAccumulatorFragments; ++iter)
174
+ {
175
+ // Acquire reordered accumulator fragment
176
+ AccumulatorFragment accum_fragment;
177
+ accum_fragment_iterator.load(accum_fragment);
178
+ ++accum_fragment_iterator;
179
+
180
+ // Store accumulator fragment
181
+ BlockStripedT::store(fragment_workspace + fragment_offset, accum_fragment, this->thread_idx);
182
+
183
+ fragment_offset += kBlockThreads;
184
+ }
185
+ }
186
+
187
+ };
188
+
189
+
190
+
191
+ ////////////////////////////////////////////////////////////////////////////////
192
+
193
+ } // namespace threadblock
194
+ } // namespace epilogue
195
+ } // namespace cutlass
196
+
197
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for Depthwise convoltuion
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/array.h"
42
+ #include "cutlass/cutlass.h"
43
+ #include "cutlass/epilogue/thread/conversion_op.h"
44
+ #include "cutlass/epilogue/thread/linear_combination.h"
45
+ #include "cutlass/epilogue/thread/reduction_op.h"
46
+ #include "cutlass/gemm/gemm.h"
47
+ #include "cutlass/numeric_types.h"
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ namespace cutlass {
52
+ namespace epilogue {
53
+ namespace threadblock {
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////
56
+
57
+ /// Epilogue operator
58
+ template <typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
59
+ typename ThreadOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
60
+ typename ThreadBlockOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
61
+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
62
+ ///< gemm::warp::MmaTensorOp)
63
+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
64
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
65
+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
66
+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
67
+ typename OutputOp_, ///< Output operator
68
+ typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept:
69
+ ///< MatrixShape)
70
+ >
71
+ class EpilogueDepthwise {
72
+ public:
73
+ using Shape = Shape_;
74
+ using WarpShape = typename WarpMmaOperator_::Shape;
75
+ using ThreadOutputShape = ThreadOutputShape_;
76
+ using ThreadBlockOutputShape = ThreadBlockOutputShape_;
77
+ using WarpMmaOperator = WarpMmaOperator_;
78
+ using OutputTileIterator = OutputTileIterator_;
79
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
80
+ using WarpTileIterator = WarpTileIterator_;
81
+ using SharedLoadIterator = SharedLoadIterator_;
82
+ using OutputOp = OutputOp_;
83
+ using Padding = Padding_;
84
+
85
+ using Layout = layout::RowMajor;
86
+ using LongIndex = typename Layout::LongIndex;
87
+
88
+ /// The complete warp-level accumulator tile
89
+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
90
+
91
+ /// Accumulator element
92
+ using ElementAccumulator = typename WarpTileIterator::Element;
93
+
94
+ /// Output element
95
+ using ElementOutput = typename OutputTileIterator::Element;
96
+
97
+ /// Output access size
98
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
99
+
100
+ /// Tensor reference to destination tensor
101
+ using TensorRef = typename OutputTileIterator::TensorRef;
102
+
103
+ /// Tensor reference to sync tensor
104
+ using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
105
+
106
+ /// Const tensor reference to source tensor
107
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
108
+
109
+ /// Array type used to output
110
+ using OutputAccessType =
111
+ Array<typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
112
+
113
+ /// Array type used by output functor
114
+ using AccumulatorAccessType =
115
+ Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
116
+
117
+ /// Number of warps
118
+ using WarpCount =
119
+ gemm::GemmShape<Shape::kM / WarpShape::kM, Shape::kN / WarpShape::kN>;
120
+
121
+ public:
122
+ static_assert(SharedLoadIterator::Fragment::kElements ==
123
+ OutputTileIterator::Fragment::kElements,
124
+ "Mismatch between shared load iterator and output tile iterator.");
125
+
126
+ static_assert(OutputTileIterator::kElementsPerAccess,
127
+ "OutputTileIterator::kElementsPerAccess must not be zero.");
128
+
129
+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
130
+ "Divisibility");
131
+
132
+ /// Shared storage allocation needed by the epilogue
133
+ struct SharedStorage {
134
+ //
135
+ // Type definitions
136
+ //
137
+
138
+ /// Element type of shared memory
139
+ using Element = typename WarpTileIterator::Element;
140
+
141
+ /// Tensor reference to shared memory allocation
142
+ using TensorRef = typename WarpTileIterator::TensorRef;
143
+
144
+ /// Layout of shared memory allocation
145
+ using Layout = typename WarpTileIterator::Layout;
146
+
147
+ /// Logical shape of the shared memory tile written to by all warps.
148
+ using Shape = MatrixShape<ThreadBlockOutputShape::kNHW, ThreadBlockOutputShape::kC>;
149
+
150
+ /// Shape of the shared memory allocation for the epilogue
151
+ using StorageShape = MatrixShape<Shape::kRow, Shape::kColumn>;
152
+
153
+ //
154
+ // Data members
155
+ //
156
+
157
+ AlignedBuffer<Element, StorageShape::kCount> storage;
158
+
159
+ //
160
+ // Methods
161
+ //
162
+
163
+ /// Returns a pointer to the shared memory buffer
164
+ CUTLASS_DEVICE
165
+ Element *data() { return storage.data(); }
166
+
167
+ /// Returns a tensor reference to the shared memory buffer
168
+ CUTLASS_DEVICE
169
+ TensorRef reference() {
170
+ return TensorRef(storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
171
+ }
172
+ };
173
+
174
+ private:
175
+ /// Loads fragment from shared memory aligned with output tensor
176
+ SharedLoadIterator shared_load_iterator_;
177
+
178
+ /// Stores a warp's fragment of accumulators to SMEM
179
+ WarpTileIterator warp_tile_iterator_;
180
+
181
+ LongIndex warp_offset;
182
+ int thread_idx;
183
+ int warp_idx;
184
+ int lane_idx;
185
+ int warp_m, warp_n; // warp coordinates within a cta
186
+ int tid_m, tid_n; // thread coordinates within a warp
187
+
188
+ public:
189
+ /// Constructor
190
+ CUTLASS_DEVICE
191
+ EpilogueDepthwise(SharedStorage &shared_storage, ///< Shared storage object
192
+ int thread_idx_, ///< ID of a thread within the threadblock
193
+ int warp_idx_, ///< ID of warp within threadblock
194
+ int lane_idx_ ///< Id of thread within warp
195
+ )
196
+ : thread_idx(thread_idx_),
197
+ warp_idx(warp_idx_),
198
+ lane_idx(lane_idx_),
199
+ shared_load_iterator_(shared_storage.reference(), thread_idx_),
200
+ warp_tile_iterator_(shared_storage.reference(), thread_idx_, lane_idx_) {}
201
+
202
+ /// Streams the result to global memory
203
+ CUTLASS_DEVICE
204
+ void operator()(OutputOp const &output_op, ///< Output operator
205
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
206
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
207
+ OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in
208
+ ///< units of threadblock tiles)
209
+ const int smem_base_offset) { ///< SMEM base offset for epilogue operation
210
+ // initiate the smem base offset for different output tile.
211
+ warp_tile_iterator_.set_smem_base_address(smem_base_offset);
212
+
213
+ shared_load_iterator_.set_smem_base_address(smem_base_offset);
214
+
215
+ if (!output_op.is_source_needed()) {
216
+ compute_source_not_needed_(output_op, destination_iterator, accumulators);
217
+ } else {
218
+ compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
219
+ }
220
+ }
221
+
222
+ private:
223
+ /// Streams the result to global memory
224
+ CUTLASS_DEVICE
225
+ void compute_source_needed_(
226
+ OutputOp const &output_op, ///< Output operator
227
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
228
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
229
+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
230
+
231
+ typename OutputTileIterator::Fragment source_fragment;
232
+
233
+ source_fragment.clear();
234
+
235
+ source_iterator.load(source_fragment);
236
+
237
+ // store to smem
238
+ warp_tile_iterator_.store(accumulators);
239
+
240
+ __syncthreads();
241
+
242
+ typename SharedLoadIterator::Fragment aligned_accum_fragment;
243
+
244
+ // load from smem
245
+ shared_load_iterator_.load(aligned_accum_fragment);
246
+
247
+ typename OutputTileIterator::Fragment output_fragment;
248
+
249
+ apply_output_operator_(output_fragment, output_op, aligned_accum_fragment, source_fragment);
250
+
251
+ // Store to GMEM
252
+ destination_iterator.store(output_fragment);
253
+ }
254
+
255
+ /// Streams the result to global memory
256
+ CUTLASS_DEVICE
257
+ void compute_source_not_needed_(
258
+ OutputOp const &output_op, ///< Output operator
259
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
260
+ AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
261
+
262
+ // store to smem
263
+ warp_tile_iterator_.store(accumulators);
264
+
265
+ __syncthreads();
266
+
267
+ typename SharedLoadIterator::Fragment aligned_accum_fragment;
268
+
269
+ // load from smem
270
+ shared_load_iterator_.load(aligned_accum_fragment);
271
+
272
+ typename OutputTileIterator::Fragment output_fragment;
273
+
274
+ apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment);
275
+
276
+ // Store to GMEM
277
+ destination_iterator.store(output_fragment);
278
+ }
279
+
280
+ /// Helper to invoke the output functor over each vector of output
281
+ CUTLASS_DEVICE
282
+ void apply_output_operator_(
283
+ typename OutputTileIterator::Fragment &output_fragment,
284
+ OutputOp const &output_op, ///< Output operator
285
+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
286
+ typename OutputTileIterator::Fragment const &source_fragment) {
287
+
288
+ OutputAccessType *output_frag_ptr =
289
+ reinterpret_cast<OutputAccessType *>(&output_fragment);
290
+
291
+ AccumulatorAccessType const *compute_frag_ptr =
292
+ reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
293
+
294
+ OutputAccessType const *source_frag_ptr =
295
+ reinterpret_cast<OutputAccessType const *>(&source_fragment);
296
+
297
+ int const kOutputOpIterations =
298
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
299
+
300
+ CUTLASS_PRAGMA_UNROLL
301
+ for (int i = 0; i < kOutputOpIterations; ++i) {
302
+ // Call the output operator
303
+ output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
304
+ }
305
+ }
306
+
307
+ /// Helper to invoke the output functor over each vector of output
308
+ CUTLASS_DEVICE
309
+ void apply_output_operator_source_not_needed_(
310
+ typename OutputTileIterator::Fragment &output_fragment,
311
+ OutputOp const &output_op, ///< Output operator
312
+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment) {
313
+ OutputAccessType *output_frag_ptr = reinterpret_cast<OutputAccessType *>(&output_fragment);
314
+
315
+ AccumulatorAccessType const *compute_frag_ptr =
316
+ reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
317
+
318
+ int const kOutputOpIterations =
319
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
320
+
321
+ CUTLASS_PRAGMA_UNROLL
322
+ for (int i = 0; i < kOutputOpIterations; ++i) {
323
+ // Call the output operator
324
+ output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
325
+ }
326
+ }
327
+ };
328
+
329
+ /////////////////////////////////////////////////////////////////////////////////////////////////
330
+
331
+ } // namespace threadblock
332
+ } // namespace epilogue
333
+ } // namespace cutlass
334
+
335
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs and convolution using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+
45
+ #include "cutlass/gemm/gemm.h"
46
+
47
+ #include "cutlass/epilogue/thread/linear_combination.h"
48
+ #include "cutlass/epilogue/thread/conversion_op.h"
49
+ #include "cutlass/epilogue/thread/reduction_op.h"
50
+
51
+ /////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ namespace cutlass {
54
+ namespace epilogue {
55
+ namespace threadblock {
56
+
57
+ ////////////////////////////////////////////////////////////////////////////////
58
+
59
+ /// Epilogue operator
60
+ template <
61
+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
62
+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
63
+ int PartitionsK, ///< Number of partitions of the K dimension
64
+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
65
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
66
+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
67
+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
68
+ typename OutputOp_ ///< Output operator
69
+ >
70
+ class EpilogueDirectStore {
71
+ public:
72
+
73
+ using Shape = Shape_;
74
+ using WarpMmaOperator = WarpMmaOperator_;
75
+ using WarpShape = typename WarpMmaOperator_::Shape;
76
+ static int const kPartitionsK = PartitionsK;
77
+ using OutputTileIterator = OutputTileIterator_;
78
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
79
+ using WarpTileIterator = WarpTileIterator_;
80
+ using OutputOp = OutputOp_;
81
+ using Padding = MatrixShape<0, 0>;
82
+
83
+ using Layout = layout::RowMajor;
84
+ using LongIndex = typename Layout::LongIndex;
85
+
86
+ /// The complete warp-level accumulator tile
87
+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
88
+
89
+ /// Accumulator element
90
+ using ElementAccumulator = typename WarpTileIterator::Element;
91
+
92
+ /// Output element
93
+ using ElementOutput = typename OutputTileIterator::Element;
94
+
95
+ /// Output access size
96
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
97
+
98
+ /// Tensor reference to destination tensor
99
+ using TensorRef = typename OutputTileIterator::TensorRef;
100
+
101
+ /// Tensor reference to sync tensor
102
+ using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
103
+
104
+ /// Const tensor reference to source tensor
105
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
106
+
107
+ /// Array type used to output
108
+ using OutputAccessType = Array<
109
+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
110
+
111
+ /// Array type used by output functor
112
+ using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
113
+
114
+ /// Number of warps
115
+ using WarpCount = gemm::GemmShape<
116
+ Shape::kM / WarpShape::kM,
117
+ Shape::kN / WarpShape::kN,
118
+ kPartitionsK
119
+ >;
120
+
121
+ /// Use this to control the granularity of one epilogue 'iteration'
122
+ static int const kFragmentsPerIteration = 1;
123
+
124
+ static int constexpr kSmemTiles = 1;
125
+ static int constexpr kSmemPointerOffset = 0;
126
+
127
+ /// Shared storage allocation needed by the epilogue
128
+ struct SharedStorage { } ;
129
+
130
+ private:
131
+
132
+ // Assume accumulator tile is multipile interleaved 32x32 tile.
133
+ static int const kElementsPerPartial = 4;
134
+ using EleShapePerPatial = typename platform::conditional<
135
+ platform::is_same<ElementAccumulator, float>::value,
136
+ MatrixShape<2, 2>,
137
+ MatrixShape<1, 4> >::type;
138
+ static int const kElementsPerMma = 8;
139
+ static int const kAccumulatorPatials = 2;
140
+ using QuadShapePerPatialMma = MatrixShape<4, 4>;
141
+
142
+ static_assert(OutputOp::kCount >= 2,
143
+ "The direct store epilogue for Tensor Ops requires the output functor have kCount >= 2.");
144
+
145
+ private:
146
+
147
+ LongIndex warp_offset;
148
+ int thread_idx;
149
+ int warp_idx;
150
+ int lane_idx;
151
+ int warp_m, warp_n; // warp coordinates within a cta
152
+ int tid_m, tid_n; // thread coordinates within a warp
153
+
154
+ public:
155
+
156
+ /// Constructor
157
+ CUTLASS_DEVICE
158
+ EpilogueDirectStore(
159
+ SharedStorage &shared_storage, ///< Shared storage object
160
+ int thread_idx_, ///< ID of a thread within the threadblock
161
+ int warp_idx_, ///< ID of warp within threadblock
162
+ int lane_idx_ ///< Id of thread within warp
163
+ ):
164
+ thread_idx(thread_idx_),
165
+ warp_idx(warp_idx_),
166
+ lane_idx(lane_idx_)
167
+ {
168
+
169
+ // warp offsetting calculations
170
+ warp_offset = warp_idx * WarpShape::kM * WarpShape::kN;
171
+ int warp_id_mn = warp_idx % (WarpCount::kM * WarpShape::kN);
172
+ warp_m = warp_id_mn % WarpCount::kM;
173
+ warp_n = warp_id_mn / WarpCount::kM;
174
+ MatrixCoord warp_offset_coord(warp_m*WarpShape::kM, warp_n*WarpShape::kN);
175
+
176
+ // thread offsetting calculations
177
+ int quad = (lane_idx >> 2);
178
+ int lane_in_quad = (lane_idx & 3);
179
+
180
+ // this seems to be te correct layout
181
+ tid_m = quad;
182
+ tid_n = 2 * lane_in_quad;
183
+ }
184
+
185
+ /// Streams the result to global memory
186
+ CUTLASS_DEVICE
187
+ void operator()(
188
+ OutputOp const &output_op, ///< Output operator
189
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
190
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
191
+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
192
+
193
+ if (!output_op.is_source_needed()) {
194
+ compute_source_not_needed_(output_op, destination_iterator, accumulators);
195
+ }
196
+ else {
197
+ compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
198
+ }
199
+ }
200
+
201
+ private:
202
+
203
+ /// Streams the result to global memory
204
+ CUTLASS_DEVICE
205
+ void compute_source_needed_(
206
+ OutputOp const &output_op, ///< Output operator
207
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
208
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
209
+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
210
+
211
+ const int kAccumBlockN = 2;
212
+ const int kThreadsM = 8;
213
+ const int kThreadsN = 4;
214
+ const int kBlockM = WarpShape::kM / kThreadsM;
215
+
216
+ /// Array type used to output
217
+ using OutputAccessType = AlignedArray<ElementOutput, kAccumBlockN>;
218
+
219
+ /// Array type passed to the output operator - unused elements are optimized away
220
+ using OutputFragmentType = Array<ElementOutput, OutputOp::kCount>;
221
+
222
+ /// Array type used by output functor
223
+ using AccumulatorAccessType = Array<ElementAccumulator, kAccumBlockN>;
224
+
225
+ /// Array type used by output functor
226
+ using AccumulatorFragmentType = Array<ElementAccumulator, OutputOp::kCount>;
227
+
228
+ AccumulatorAccessType const *accumulator_pair = reinterpret_cast<AccumulatorAccessType const *>(&accumulators);
229
+
230
+ CUTLASS_PRAGMA_UNROLL
231
+ for (int accum_m_idx = 0; accum_m_idx < WarpShape::kM / kThreadsM; accum_m_idx++) {
232
+
233
+ int accum_m = kThreadsM * accum_m_idx;
234
+ int mL = destination_iterator.threadblock_offset.row() + WarpShape::kM * warp_m + tid_m + accum_m;
235
+ int nL_base = destination_iterator.threadblock_offset.column() + WarpShape::kN * warp_n + tid_n;
236
+
237
+ ElementOutput *output_ptr = destination_iterator.pointer + mL * destination_iterator.stride;
238
+ ElementOutput *source_ptr = source_iterator.pointer + mL * source_iterator.stride;
239
+
240
+ int const kIterationsN = WarpShape::kN / kThreadsN / kAccumBlockN;
241
+
242
+ CUTLASS_PRAGMA_UNROLL
243
+ for (int accum_n_idx = 0; accum_n_idx < kIterationsN; accum_n_idx++) {
244
+
245
+ int accum_idx = accum_m_idx + kBlockM * accum_n_idx;
246
+ int accum_n = kThreadsM * accum_n_idx;
247
+
248
+ // mL and nL are logical coordinate in 2D mapping of epilogue's 4D output
249
+ int nL = nL_base + accum_n;
250
+
251
+ bool guard = (mL < destination_iterator.extent.row()) && (nL < destination_iterator.extent.column());
252
+
253
+ AccumulatorFragmentType accum_fragment;
254
+ reinterpret_cast<AccumulatorAccessType &>(accum_fragment) = accumulator_pair[accum_idx];
255
+
256
+ OutputFragmentType output_fragment;
257
+
258
+ if(guard) {
259
+ reinterpret_cast<OutputAccessType &>(output_fragment) =
260
+ *reinterpret_cast<OutputAccessType const *>(source_ptr + nL);
261
+ }
262
+
263
+ // Perform output operator
264
+ output_fragment = output_op(accum_fragment, output_fragment);
265
+
266
+ if(guard) {
267
+ // Store
268
+ *reinterpret_cast<OutputAccessType *>(output_ptr + nL) = reinterpret_cast<OutputAccessType const &>(output_fragment);
269
+ }
270
+ }
271
+ }
272
+ }
273
+
274
+ /// Streams the result to global memory
275
+ CUTLASS_DEVICE
276
+ void compute_source_not_needed_(
277
+ OutputOp const &output_op, ///< Output operator
278
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
279
+ AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
280
+
281
+ const int kAccumBlockN = 2;
282
+ const int kThreadsM = 8;
283
+ const int kThreadsN = 4;
284
+ const int kBlockM = WarpShape::kM / kThreadsM;
285
+
286
+ /// Array type used to output
287
+ using OutputAccessType = AlignedArray<ElementOutput, kAccumBlockN>;
288
+
289
+ /// Array type passed to the output operator - unused elements are optimized away
290
+ using OutputFragmentType = Array<ElementOutput, OutputOp::kCount>;
291
+
292
+ /// Array type used by output functor
293
+ using AccumulatorAccessType = Array<ElementAccumulator, kAccumBlockN>;
294
+
295
+ /// Array type used by output functor
296
+ using AccumulatorFragmentType = Array<ElementAccumulator, OutputOp::kCount>;
297
+
298
+ AccumulatorAccessType const *accumulator_pair = reinterpret_cast<AccumulatorAccessType const *>(&accumulators);
299
+
300
+ CUTLASS_PRAGMA_UNROLL
301
+ for (int accum_m_idx = 0; accum_m_idx < WarpShape::kM / kThreadsM; accum_m_idx++) {
302
+
303
+ int accum_m = kThreadsM * accum_m_idx;
304
+ int mL = destination_iterator.threadblock_offset.row() + WarpShape::kM * warp_m + tid_m + accum_m;
305
+ int nL_base = destination_iterator.threadblock_offset.column() + WarpShape::kN * warp_n + tid_n;
306
+
307
+ ElementOutput *output_ptr = destination_iterator.pointer + mL * destination_iterator.stride;
308
+
309
+ int const kIterationsN = WarpShape::kN / kThreadsN / kAccumBlockN;
310
+
311
+ CUTLASS_PRAGMA_UNROLL
312
+ for (int accum_n_idx = 0; accum_n_idx < kIterationsN; accum_n_idx++) {
313
+
314
+ int accum_idx = accum_m_idx + kBlockM * accum_n_idx;
315
+ int accum_n = kThreadsM * accum_n_idx;
316
+
317
+ // mL and nL are logical coordinate in 2D mapping of epilogue's 4D output
318
+ int nL = nL_base + accum_n;
319
+
320
+ bool guard = (mL < destination_iterator.extent.row()) && (nL < destination_iterator.extent.column());
321
+
322
+ AccumulatorFragmentType accum_fragment;
323
+ reinterpret_cast<AccumulatorAccessType &>(accum_fragment) = accumulator_pair[accum_idx];
324
+
325
+ OutputFragmentType output_fragment;
326
+
327
+ // Perform output operator
328
+ output_fragment = output_op(accum_fragment);
329
+
330
+ if(guard) {
331
+
332
+ // Store
333
+ *reinterpret_cast<OutputAccessType *>(output_ptr + nL) =
334
+ reinterpret_cast<OutputAccessType const &>(output_fragment);
335
+ }
336
+ }
337
+ }
338
+ }
339
+ };
340
+
341
+ /////////////////////////////////////////////////////////////////////////////////////////////////
342
+
343
+ } // namespace threadblock
344
+ } // namespace epilogue
345
+ } // namespace cutlass
346
+
347
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+ #include "cutlass/cutlass.h"
41
+ #include CUDA_STD_HEADER(cassert)
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+ #include "cutlass/layout/vector.h"
45
+ #include "cutlass/layout/tensor.h"
46
+ #include "cutlass/tensor_coord.h"
47
+ #include "cutlass/aligned_buffer.h"
48
+ #include "cutlass/functional.h"
49
+
50
+ #include "cutlass/gemm/gemm.h"
51
+
52
+ #include "cutlass/transform/pitch_linear_thread_map.h"
53
+ #include "cutlass/transform/threadblock/regular_tile_iterator.h"
54
+
55
+ #include "cutlass/epilogue/threadblock/epilogue_base.h"
56
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
57
+ #include "cutlass/numeric_types.h"
58
+
59
+ namespace cutlass {
60
+ namespace epilogue {
61
+ namespace threadblock {
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////
64
+
65
+ /// Epilogue operator
66
+ template <
67
+ typename ElementAccumulator_,
68
+ typename ElementOutput_,
69
+ typename ThreadBlockShape_, ///< Shape of threadblock tile (concept: GemmShape)
70
+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
71
+ bool ReduceKForA_
72
+ >
73
+ class EpilogueGemmKReduction {
74
+
75
+ public:
76
+
77
+ using ThreadBlockShape = ThreadBlockShape_;
78
+ using WarpMmaOperator = WarpMmaOperator_;
79
+ using WarpShape = typename WarpMmaOperator::Shape;
80
+ using Layout = layout::RowMajor;
81
+ using LongIndex = typename Layout::LongIndex;
82
+
83
+ /// Accumulator element
84
+ using ElementAccumulator = ElementAccumulator_;
85
+
86
+ /// Output element
87
+ using ElementOutput = ElementOutput_;
88
+
89
+ /// Output access size
90
+ static int const kElementsPerAccess = 1;
91
+
92
+ static bool const kReduceKForA = ReduceKForA_;
93
+
94
+ static int const kThreadBlockSize = kReduceKForA ? ThreadBlockShape::kM : ThreadBlockShape::kN;
95
+
96
+ static int const kWarpSize = kReduceKForA ? WarpShape::kM : WarpShape::kN;
97
+
98
+ static int const kIterations = kWarpSize / 8;
99
+
100
+ using FragmentAccumulator = Array<ElementAccumulator, kIterations>;
101
+
102
+ private:
103
+
104
+ int thread_offset_;
105
+ ElementOutput* pointer_;
106
+ int col_;
107
+ public:
108
+
109
+ /// Constructor
110
+ CUTLASS_DEVICE
111
+ EpilogueGemmKReduction(
112
+ int thread_idx, ///< ID of a thread within the threadblock
113
+ int warp_idx, ///< ID of warp within threadblock
114
+ int lane_idx, ///< Id of thread within warp
115
+ int threadblock_offset,
116
+ ElementOutput* pointer
117
+ )
118
+ {
119
+ col_ = lane_idx % 4;
120
+ thread_offset_ = threadblock_offset * kThreadBlockSize
121
+ + warp_idx * kWarpSize
122
+ + lane_idx / 4 + col_ * 8;
123
+
124
+ pointer_ = pointer + LongIndex(thread_offset_);
125
+ }
126
+
127
+ /// Streams the result to global memory
128
+ CUTLASS_DEVICE
129
+ void operator()(
130
+ int size,
131
+ FragmentAccumulator &gemm_k_with_reduction_accumulation,
132
+ bool LoadForSerialSplitK
133
+ ) {
134
+ bool guard[kIterations / 4];
135
+
136
+ CUTLASS_PRAGMA_UNROLL
137
+ for (int i = 0; i < kIterations / 4; ++i) {
138
+ guard[i] = ((thread_offset_ + i * 32) < size);
139
+ }
140
+
141
+ Array<ElementOutput, kIterations / 4> source;
142
+ source.clear();
143
+
144
+ CUTLASS_PRAGMA_UNROLL
145
+ for (int i = 0; i < kIterations / 4; ++i) {
146
+ ElementOutput *source_ptr = reinterpret_cast<ElementOutput *>(&source);
147
+ cutlass::arch::global_load<ElementOutput, sizeof(ElementOutput)>(
148
+ source_ptr[i],
149
+ (void *)(pointer_ + i * 32),
150
+ guard[i] && LoadForSerialSplitK);
151
+
152
+ }
153
+
154
+ FragmentAccumulator sum = gemm_k_with_reduction_accumulation;
155
+
156
+ CUTLASS_PRAGMA_UNROLL
157
+ for (int i = 0; i < kIterations; ++i) {
158
+ sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 1);
159
+ sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 2);
160
+ }
161
+
162
+ Array<ElementAccumulator, kIterations / 4> intermediate;
163
+
164
+ CUTLASS_PRAGMA_UNROLL
165
+ for (int i = 0; i < kIterations / 4; ++i) {
166
+ if (col_ == 0) {
167
+ intermediate[i] = sum[0 + i * 4];
168
+ }
169
+
170
+ if (col_ == 1) {
171
+ intermediate[i] = sum[1 + i * 4];
172
+ }
173
+
174
+ if (col_ == 2) {
175
+ intermediate[i] = sum[2 + i * 4];
176
+ }
177
+
178
+ if (col_ == 3) {
179
+ intermediate[i] = sum[3 + i * 4];
180
+ }
181
+ }
182
+
183
+ NumericArrayConverter<ElementAccumulator, ElementOutput, kIterations / 4> source_converter;
184
+ Array<ElementAccumulator, kIterations / 4> converted_source = source_converter(source);
185
+
186
+ plus<Array<ElementAccumulator, kIterations / 4>> plus_source;
187
+ intermediate = plus_source(intermediate, converted_source);
188
+
189
+ NumericArrayConverter<ElementOutput, ElementAccumulator, kIterations / 4> converter;
190
+ Array<ElementOutput, kIterations / 4> result = converter(intermediate);
191
+
192
+ CUTLASS_PRAGMA_UNROLL
193
+ for (int i = 0; i < kIterations / 4; ++i) {
194
+ cutlass::arch::global_store<ElementOutput, sizeof(ElementOutput)>(result[i],
195
+ (void *)(pointer_ + i * 32), guard[i]);
196
+ }
197
+ }
198
+ };
199
+
200
+ ////////////////////////////////////////////////////////////////////////////////
201
+
202
+ } // namespace threadblock
203
+ } // namespace epilogue
204
+ } // namespace cutlass
205
+
206
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
+
34
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
+
37
+ */
38
+
39
+ #pragma once
40
+
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+ #include "cutlass/array_planar_complex.h"
45
+ #include "cutlass/layout/vector.h"
46
+ #include "cutlass/layout/tensor.h"
47
+ #include "cutlass/tensor_coord.h"
48
+ #include "cutlass/aligned_buffer.h"
49
+ #include "cutlass/functional.h"
50
+
51
+ #include "cutlass/gemm/gemm.h"
52
+
53
+ #include "cutlass/transform/pitch_linear_thread_map.h"
54
+ #include "cutlass/transform/threadblock/regular_tile_iterator.h"
55
+
56
+ #include "cutlass/epilogue/threadblock/epilogue_base.h"
57
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////
60
+
61
+ namespace cutlass {
62
+ namespace epilogue {
63
+ namespace threadblock {
64
+
65
+ ////////////////////////////////////////////////////////////////////////////////
66
+
67
+ /// Epilogue operator for planar-complex output representations.
68
+ ///
69
+ /// Note, as with most CUTLASS components for planar complex, the template arguments describe
70
+ /// the underlying real data type.
71
+ template <
72
+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
73
+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
74
+ int PartitionsK, ///< Number of partitions of the K dimension
75
+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
76
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
77
+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
78
+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
79
+ typename OutputOp_, ///< Output operator
80
+ typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
81
+ >
82
+ class EpiloguePlanarComplex {
83
+ public:
84
+
85
+ using Shape = Shape_;
86
+ using WarpMmaOperator = WarpMmaOperator_;
87
+ static int const kPartitionsK = PartitionsK;
88
+ using OutputTileIterator = OutputTileIterator_;
89
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
90
+ using WarpTileIterator = WarpTileIterator_;
91
+ using SharedLoadIterator = SharedLoadIterator_;
92
+ using OutputOp = OutputOp_;
93
+ using Padding = Padding_;
94
+
95
+ /// Output layout is always row-major
96
+ using Layout = layout::RowMajor;
97
+ using LongIndex = typename Layout::LongIndex;
98
+
99
+ /// The complete warp-level accumulator tile
100
+ using AccumulatorTile = ArrayPlanarComplex<
101
+ typename WarpMmaOperator::FragmentC::Element,
102
+ WarpMmaOperator::FragmentC::kElements
103
+ >;
104
+
105
+ /// Accumulator element
106
+ using ElementAccumulator = typename WarpTileIterator::Element;
107
+
108
+ /// Output element
109
+ using ElementOutput = typename OutputTileIterator::Element;
110
+
111
+ /// Output access size
112
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
113
+
114
+ /// Tensor reference to destination tensor
115
+ using TensorRef = typename OutputTileIterator::TensorRef;
116
+
117
+ /// Tensor reference to sync tensor
118
+ using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
119
+
120
+ /// Const tensor reference to source tensor
121
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
122
+
123
+ /// Array type used to output
124
+ using OutputAccessType = Array<
125
+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
126
+
127
+ /// Array type used by output functor
128
+ using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
129
+
130
+ /// Shape of each warp-level operation
131
+ using WarpShape = typename WarpMmaOperator::Shape;
132
+
133
+ /// Number of warps
134
+ using WarpCount = gemm::GemmShape<
135
+ Shape::kM / WarpShape::kM,
136
+ Shape::kN / WarpShape::kN,
137
+ kPartitionsK
138
+ >;
139
+
140
+ /// Shared memory allocation
141
+ struct SharedStorage {
142
+
143
+ //
144
+ // Type definitions
145
+ //
146
+
147
+ /// Element type of shared memory
148
+ using Element = typename WarpTileIterator::Element;
149
+
150
+ /// Tensor reference to shared memory allocation
151
+ using TensorRef = typename WarpTileIterator::TensorRef;
152
+
153
+ /// Layout of shared memory allocation
154
+ using Layout = typename WarpTileIterator::Layout;
155
+
156
+ /// Logical shape of the shared memory tile written to by all warps.
157
+ using Shape = MatrixShape<
158
+ WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,
159
+ WarpCount::kN * WarpTileIterator::Shape::kColumn
160
+ >;
161
+
162
+ /// Shape of the shared memory allocation for the epilogue
163
+ using StorageShape = MatrixShape<
164
+ Shape::kRow + Padding::kRow,
165
+ Shape::kColumn + Padding::kColumn
166
+ >;
167
+
168
+ static int const kImaginaryStride = StorageShape::kCount;
169
+
170
+ //
171
+ // Data members
172
+ //
173
+
174
+ AlignedBuffer<Element, kImaginaryStride * 2> storage;
175
+
176
+ //
177
+ // Methods
178
+ //
179
+
180
+ /// Returns a pointer to the shared memory buffer
181
+ CUTLASS_DEVICE
182
+ Element *data() {
183
+ return storage.data();
184
+ }
185
+
186
+ /// Returns a tensor reference to the shared memory buffer
187
+ CUTLASS_DEVICE
188
+ TensorRef reference() {
189
+ return TensorRef(
190
+ storage.data(),
191
+ Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
192
+ }
193
+ };
194
+
195
+ private:
196
+
197
+ //
198
+ // Data members
199
+ //
200
+
201
+ SharedStorage &shared_storage_;
202
+
203
+ /// Loads fragment from shared memory aligned with output tensor
204
+ SharedLoadIterator shared_load_iterator_;
205
+
206
+ /// Stores a warp's fragment of accumulators to SMEM
207
+ WarpTileIterator warp_tile_iterator_;
208
+
209
+ public:
210
+
211
+ /// Constructor
212
+ CUTLASS_DEVICE
213
+ EpiloguePlanarComplex(
214
+ SharedStorage &shared_storage, ///< Shared storage object
215
+ int thread_idx, ///< ID of a thread within the threadblock
216
+ int warp_idx, ///< ID of warp within threadblock
217
+ int lane_idx ///< Id of thread within warp
218
+ ):
219
+ shared_storage_(shared_storage),
220
+ shared_load_iterator_(shared_storage.reference(), thread_idx),
221
+ warp_tile_iterator_(shared_storage.reference(), lane_idx) {
222
+
223
+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
224
+ //
225
+ // _m: the warp's position within the threadblock along the M dimension
226
+ // _n: the warp's position within the threadblock along the N dimension
227
+ // _k: the warp's position within the threadblock along the K dimension
228
+
229
+ int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
230
+ int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
231
+ int warp_m = warp_mn % WarpCount::kM;
232
+ int warp_n = warp_mn / WarpCount::kM;
233
+
234
+ MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
235
+
236
+ warp_tile_iterator_.add_tile_offset(warp_offset);
237
+ }
238
+
239
+ /// Streams the result to global memory
240
+ CUTLASS_DEVICE
241
+ void operator()(
242
+ OutputOp const &output_op, ///< Output operator
243
+ OutputTileIterator destination_iterator_real, ///< Tile iterator for destination
244
+ OutputTileIterator destination_iterator_imag, ///< Tile iterator for destination
245
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
246
+ OutputTileIterator source_iterator_real, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
247
+ OutputTileIterator source_iterator_imag) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
248
+
249
+ typename OutputTileIterator::Fragment source_fragment_real;
250
+ typename OutputTileIterator::Fragment source_fragment_imag;
251
+
252
+ if (!output_op.is_source_needed()) {
253
+ source_iterator_real.clear_mask();
254
+ source_iterator_imag.clear_mask();
255
+ }
256
+
257
+ source_fragment_real.clear();
258
+ source_fragment_imag.clear();
259
+
260
+ //
261
+ // Iterator over warp-level accumulator fragment
262
+ //
263
+
264
+ AccumulatorFragmentIterator accum_fragment_iterator_real(accumulators.real);
265
+ AccumulatorFragmentIterator accum_fragment_iterator_imag(accumulators.imag);
266
+
267
+ //
268
+ // Iterate over accumulator tile
269
+ //
270
+
271
+ CUTLASS_PRAGMA_UNROLL
272
+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
273
+
274
+ //
275
+ // Load the source
276
+ //
277
+
278
+ source_iterator_real.load(source_fragment_real);
279
+ source_iterator_imag.load(source_fragment_imag);
280
+
281
+ ++source_iterator_real;
282
+ ++source_iterator_imag;
283
+
284
+ //
285
+ // Convert and store fragment
286
+ //
287
+
288
+ __syncthreads();
289
+
290
+ typename AccumulatorFragmentIterator::Fragment accum_fragment_real;
291
+ typename AccumulatorFragmentIterator::Fragment accum_fragment_imag;
292
+
293
+ accum_fragment_iterator_real.load(accum_fragment_real);
294
+ accum_fragment_iterator_imag.load(accum_fragment_imag);
295
+
296
+ ++accum_fragment_iterator_real;
297
+ ++accum_fragment_iterator_imag;
298
+
299
+ this->warp_tile_iterator_.store(accum_fragment_real);
300
+ this->warp_tile_iterator_.store_with_pointer_offset(accum_fragment_imag, SharedStorage::kImaginaryStride);
301
+
302
+ __syncthreads();
303
+
304
+ //
305
+ // Load fragments from shared memory
306
+ //
307
+
308
+ typename SharedLoadIterator::Fragment aligned_accum_fragment_real[kPartitionsK];
309
+ typename SharedLoadIterator::Fragment aligned_accum_fragment_imag[kPartitionsK];
310
+
311
+ shared_load_iterator_.load(aligned_accum_fragment_real[0]);
312
+ shared_load_iterator_.load_with_pointer_offset(aligned_accum_fragment_imag[0], SharedStorage::kImaginaryStride);
313
+
314
+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices
315
+ static_assert(kPartitionsK == 1, "Sliced-K not supported for planar complex at this time");
316
+
317
+ //
318
+ // Compute the output result
319
+ //
320
+
321
+ typename OutputTileIterator::Fragment output_fragment_real;
322
+ typename OutputTileIterator::Fragment output_fragment_imag;
323
+
324
+ apply_output_operator_(
325
+ output_fragment_real,
326
+ output_fragment_imag,
327
+ output_op,
328
+ aligned_accum_fragment_real[0],
329
+ aligned_accum_fragment_imag[0],
330
+ source_fragment_real,
331
+ source_fragment_imag);
332
+
333
+ //
334
+ // Store the final result
335
+ //
336
+
337
+ destination_iterator_real.store(output_fragment_real);
338
+ destination_iterator_imag.store(output_fragment_imag);
339
+
340
+ ++destination_iterator_real;
341
+ ++destination_iterator_imag;
342
+ }
343
+ }
344
+
345
+ private:
346
+
347
+ /// Helper to invoke the output functor over each vector of output
348
+ CUTLASS_DEVICE
349
+ void apply_output_operator_(
350
+ typename OutputTileIterator::Fragment &output_fragment_real,
351
+ typename OutputTileIterator::Fragment &output_fragment_imag,
352
+ OutputOp const &output_op, ///< Output operator
353
+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment_real,
354
+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment_imag,
355
+ typename OutputTileIterator::Fragment const &source_fragment_real,
356
+ typename OutputTileIterator::Fragment const &source_fragment_imag) {
357
+
358
+ OutputAccessType *output_frag_real_ptr =
359
+ reinterpret_cast<OutputAccessType *>(&output_fragment_real);
360
+
361
+ OutputAccessType *output_frag_imag_ptr =
362
+ reinterpret_cast<OutputAccessType *>(&output_fragment_imag);
363
+
364
+ AccumulatorAccessType const *compute_frag_real_ptr =
365
+ reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment_real);
366
+
367
+ AccumulatorAccessType const *compute_frag_imag_ptr =
368
+ reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment_imag);
369
+
370
+ OutputAccessType const *source_frag_real_ptr =
371
+ reinterpret_cast<OutputAccessType const *>(&source_fragment_real);
372
+
373
+ OutputAccessType const *source_frag_imag_ptr =
374
+ reinterpret_cast<OutputAccessType const *>(&source_fragment_imag);
375
+
376
+ int const kOutputOpIterations =
377
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
378
+
379
+ CUTLASS_PRAGMA_UNROLL
380
+ for (int i = 0; i < kOutputOpIterations; ++i) {
381
+
382
+ // Call the output operator
383
+ auto result_fragment = output_op(
384
+ make_ArrayPlanarComplex(compute_frag_real_ptr[i], compute_frag_imag_ptr[i]),
385
+ make_ArrayPlanarComplex(source_frag_real_ptr[i], source_frag_imag_ptr[i])
386
+ );
387
+
388
+ output_frag_real_ptr[i] = result_fragment.real;
389
+ output_frag_imag_ptr[i] = result_fragment.imag;
390
+ }
391
+ }
392
+
393
+ };
394
+
395
+ ////////////////////////////////////////////////////////////////////////////////
396
+
397
+ } // namespace threadblock
398
+ } // namespace epilogue
399
+ } // namespace cutlass
400
+
401
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue for threadblock scoped GEMM/CONV to store accumulator in shared memory after
33
+ applying scale, bias loaded from global memory and element-wise operations.
34
+
35
+ This Epilogue is typically used in fused GEMM/CONV to stage the intermediate accumulator.
36
+
37
+ */
38
+
39
+ #pragma once
40
+ #include "cutlass/cutlass.h"
41
+ #include CUDA_STD_HEADER(cassert)
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/array.h"
44
+ #include "cutlass/layout/vector.h"
45
+ #include "cutlass/layout/tensor.h"
46
+ #include "cutlass/tensor_coord.h"
47
+ #include "cutlass/aligned_buffer.h"
48
+ #include "cutlass/functional.h"
49
+
50
+ #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
51
+ #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
52
+ ////////////////////////////////////////////////////////////////////////////////
53
+
54
+ namespace cutlass {
55
+ namespace epilogue {
56
+ namespace threadblock {
57
+
58
+ ////////////////////////////////////////////////////////////////////////////////
59
+
60
+ /// Epilogue operator
61
+ template <
62
+ typename SmemTileIterator_, ///< Shared memory Tile iterator to output to shared memory
63
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
64
+ typename ScaleBiasIterator_, ///< Iterator to load scale and bias from global memory
65
+ typename OutputOp_ ///< Output operator
66
+ >
67
+ class EpilogueSmemAccumulator {
68
+
69
+ public:
70
+
71
+ using SmemTileIterator = SmemTileIterator_;
72
+
73
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
74
+
75
+ using ScaleBiasIterator = ScaleBiasIterator_;
76
+
77
+ using OutputOp = OutputOp_;
78
+
79
+ /// Fragment of accumulator tile
80
+ using FragmentAccumulator = typename AccumulatorFragmentIterator::Fragment;
81
+
82
+ /// The complete warp-level accumulator tile
83
+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
84
+
85
+ /// Fragment of Scale and Bias loaded from global memory
86
+ using FragmentScaleBias = typename ScaleBiasIterator::Fragment;
87
+
88
+ static const bool PerChannelScale = (OutputOp::kScale ==
89
+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
90
+
91
+ /// Constructor
92
+ CUTLASS_DEVICE
93
+ EpilogueSmemAccumulator() {}
94
+
95
+ /// Streams the result to shared memory
96
+ CUTLASS_DEVICE
97
+ void operator()(
98
+ OutputOp const &output_op, ///< Output operator
99
+ SmemTileIterator smem_iterator, ///< Tile iterator for destination in shared memory
100
+ AccumulatorTile const &accumulator, ///< Complete warp-level accumulator tile
101
+ ScaleBiasIterator scale_iterator, ///< iterator for scale vector in global memory
102
+ ScaleBiasIterator bias_iterator) { ///< iterator for bias vector in global memory
103
+
104
+
105
+ // Fragment to load scale bias from global memory
106
+ FragmentScaleBias tb_frag_scale;
107
+ FragmentScaleBias tb_frag_bias;
108
+
109
+ /// Fragment Iterator to load slice of accumulator tile
110
+ AccumulatorFragmentIterator frag_iterator_accum(accumulator);
111
+ FragmentAccumulator tb_frag_accum;
112
+
113
+ /// Epilogue output fragment
114
+ typename SmemTileIterator::Fragment tb_frag_smem;
115
+
116
+ /// Load scale and bias from global memory
117
+
118
+ if(PerChannelScale)
119
+ scale_iterator.load(tb_frag_scale);
120
+
121
+ bias_iterator.load(tb_frag_bias);
122
+
123
+ /// Iterate over the accumulator tile and store to shared memory
124
+ CUTLASS_PRAGMA_UNROLL
125
+ for (int rid = 0; rid < AccumulatorFragmentIterator::TileIterations::kRow; ++rid) {
126
+
127
+ CUTLASS_PRAGMA_UNROLL
128
+ for (int cid = 0; cid < AccumulatorFragmentIterator::TileIterations::kColumn; ++cid) {
129
+
130
+ using AccumulatorAccessType = typename OutputOp::FragmentAccumulator;
131
+ using ScaleBiasAccessType = typename OutputOp::FragmentScaleBias;
132
+ using FragmentSmemAccessType = typename OutputOp::FragmentOutput;
133
+
134
+
135
+ ScaleBiasAccessType const * scale_frag_ptr =
136
+ reinterpret_cast<ScaleBiasAccessType const *>(&tb_frag_scale);
137
+ ScaleBiasAccessType const * bias_frag_ptr =
138
+ reinterpret_cast<ScaleBiasAccessType const *>(&tb_frag_bias);
139
+
140
+ FragmentSmemAccessType * smem_frag_ptr =
141
+ reinterpret_cast<FragmentSmemAccessType *>(&tb_frag_smem);
142
+
143
+ CUTLASS_PRAGMA_UNROLL
144
+ for (int idx = 0; idx < AccumulatorFragmentIterator::kIterationsPerTile; ++idx) {
145
+ frag_iterator_accum.load(tb_frag_accum);
146
+ ++frag_iterator_accum;
147
+
148
+ AccumulatorAccessType const * accumulator_frag_ptr =
149
+ reinterpret_cast<AccumulatorAccessType const *>(&tb_frag_accum);
150
+ const int kOutputIterations = FragmentAccumulator::kElements / OutputOp::kCount;
151
+
152
+ CUTLASS_PRAGMA_UNROLL
153
+ for (int it = 0; it < kOutputIterations; it++) {
154
+ smem_frag_ptr[idx * kOutputIterations + it] = output_op(accumulator_frag_ptr[it],
155
+ scale_frag_ptr[cid * kOutputIterations + it], bias_frag_ptr[cid * kOutputIterations + it]);
156
+ }
157
+ }
158
+
159
+ smem_iterator.store(tb_frag_smem);
160
+ ++smem_iterator;
161
+
162
+ }
163
+ }
164
+ }
165
+
166
+ /// Streams the result to shared memory
167
+ CUTLASS_DEVICE
168
+ void operator()(
169
+ OutputOp const &output_op, ///< Output operator
170
+ SmemTileIterator smem_iterator, ///< Tile iterator for destination in shared memory
171
+ AccumulatorTile const &accumulator) { ///< Complete warp-level accumulator tile
172
+
173
+ /// Fragment Iterator to load slice of accumulator tile
174
+ AccumulatorFragmentIterator frag_iterator_accum(accumulator);
175
+ FragmentAccumulator tb_frag_accum;
176
+
177
+ /// Epilogue output fragment
178
+ typename SmemTileIterator::Fragment tb_frag_smem;
179
+
180
+ /// Iterate over the accumulator tile and store to shared memory
181
+ CUTLASS_PRAGMA_UNROLL
182
+ for (int rid = 0; rid < AccumulatorFragmentIterator::TileIterations::kRow; ++rid) {
183
+
184
+ CUTLASS_PRAGMA_UNROLL
185
+ for (int cid = 0; cid < AccumulatorFragmentIterator::TileIterations::kColumn; ++cid) {
186
+
187
+ using AccumulatorAccessType = typename OutputOp::FragmentAccumulator;
188
+ using FragmentSmemAccessType = typename OutputOp::FragmentOutput;
189
+
190
+ FragmentSmemAccessType * smem_frag_ptr =
191
+ reinterpret_cast<FragmentSmemAccessType *>(&tb_frag_smem);
192
+
193
+ CUTLASS_PRAGMA_UNROLL
194
+ for (int idx = 0; idx < AccumulatorFragmentIterator::kIterationsPerTile; ++idx) {
195
+ frag_iterator_accum.load(tb_frag_accum);
196
+ ++frag_iterator_accum;
197
+
198
+ AccumulatorAccessType const * accumulator_frag_ptr =
199
+ reinterpret_cast<AccumulatorAccessType const *>(&tb_frag_accum);
200
+ const int kOutputIterations = FragmentAccumulator::kElements / OutputOp::kCount;
201
+
202
+ CUTLASS_PRAGMA_UNROLL
203
+ for (int it = 0; it < kOutputIterations; it++) {
204
+ smem_frag_ptr[idx * kOutputIterations + it] = output_op(accumulator_frag_ptr[it]);
205
+ }
206
+ }
207
+
208
+ smem_iterator.store(tb_frag_smem);
209
+ ++smem_iterator;
210
+
211
+ }
212
+ }
213
+ }
214
+
215
+ };
216
+
217
+ ////////////////////////////////////////////////////////////////////////////////
218
+
219
+ } // namespace threadblock
220
+ } // namespace epilogue
221
+ } // namespace cutlass
222
+
223
+ ////////////////////////////////////////////////////////////////////////////////
224
+
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+
33
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
34
+
35
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
36
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
37
+
38
+ */
39
+
40
+ #pragma once
41
+ #include "cutlass/cutlass.h"
42
+
43
+ #include CUDA_STD_HEADER(cassert)
44
+
45
+ #if defined(__CUDACC_RTC__)
46
+ #include CUDA_STD_HEADER(utility)
47
+ #else
48
+ #include <utility>
49
+ #endif
50
+
51
+ #include "cutlass/array.h"
52
+ #include "cutlass/numeric_types.h"
53
+ #include "cutlass/numeric_conversion.h"
54
+ #include "cutlass/tensor_coord.h"
55
+ #include "cutlass/aligned_buffer.h"
56
+ #include "cutlass/functional.h"
57
+ #include "cutlass/fast_math.h"
58
+ #include "cutlass/layout/vector.h"
59
+ #include "cutlass/layout/tensor.h"
60
+
61
+ #include "cutlass/gemm/gemm.h"
62
+
63
+ #include "cutlass/transform/pitch_linear_thread_map.h"
64
+ #include "cutlass/transform/threadblock/regular_tile_iterator.h"
65
+
66
+ #include "cutlass/epilogue/threadblock/epilogue_base.h"
67
+ #include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
68
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
69
+
70
+ #include "cutlass/numeric_types.h"
71
+
72
+ /////////////////////////////////////////////////////////////////////////////////////////////////
73
+
74
+ namespace cutlass {
75
+ namespace epilogue {
76
+ namespace threadblock {
77
+
78
+ /////////////////////////////////////////////////////////////////////////////////////////////////
79
+
80
+ /// This base class is meant to define the concept required of the
81
+ /// EpilogueStreamkWithBroadcast::OutputOp
82
+ template <
83
+ typename ElementC_,
84
+ typename ElementAccumulator_,
85
+ typename ElementCompute_,
86
+ typename ElementZ_,
87
+ typename ElementT_,
88
+ int ElementsPerAccess,
89
+ bool StoreZ = true,
90
+ bool StoreT = true
91
+ >
92
+ struct EpilogueStreamkWithBroadcastOpBase : EpilogueWithBroadcastOpBase<
93
+ ElementC_,
94
+ ElementAccumulator_,
95
+ ElementCompute_,
96
+ ElementZ_,
97
+ ElementT_,
98
+ ElementsPerAccess,
99
+ StoreZ,
100
+ StoreT
101
+ >
102
+ {
103
+
104
+ /// Parameters structure - required
105
+ struct Params { };
106
+
107
+ //
108
+ // Methods
109
+ //
110
+
111
+ /// Constructor from Params
112
+ EpilogueStreamkWithBroadcastOpBase(Params const &params_) { }
113
+ };
114
+
115
+ ////////////////////////////////////////////////////////////////////////////////
116
+
117
+ /// Epilogue operator with bias vector broadcast over columns.
118
+ ///
119
+ /// Computes the following:
120
+ ///
121
+ ///
122
+ /// Z, T = OutputOp(AB, C, Broadcast)
123
+ ///
124
+ /// if (ElementwiseOp::kStoreZ) {
125
+ /// store(converted_u);
126
+ /// }
127
+ ///
128
+ /// if (ElementwiseOp::kStoreT) {
129
+ /// store(v);
130
+ /// }
131
+ ///
132
+ template <
133
+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
134
+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
135
+ int PartitionsK, ///< Number of partitions of the K dimension
136
+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
137
+ typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
138
+ typename ElementVector_, ///< Pointer to broadcast vector
139
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
140
+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
141
+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
142
+ typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
143
+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
144
+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
145
+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
146
+ (!IsEpilogueFunctorHeavy<OutputOp_>::value),
147
+ bool IsSingleSource = OutputOp_::kIsSingleSource
148
+ >
149
+ class EpilogueStreamkWithBroadcast;
150
+
151
+
152
+ /////////////////////////////////////////////////////////////////////////////////////////////////
153
+
154
+ /// EpilogueStreamkWithBroadcast: Two sources
155
+
156
+ template <
157
+ typename Shape_,
158
+ typename WarpMmaOperator_,
159
+ int PartitionsK,
160
+ typename OutputTileIterator_,
161
+ typename TensorTileIterator_,
162
+ typename ElementVector_,
163
+ typename AccumulatorFragmentIterator_,
164
+ typename WarpTileIterator_,
165
+ typename SharedLoadIterator_,
166
+ typename OutputOp_,
167
+ typename Padding_,
168
+ int FragmentsPerPartition,
169
+ int IterationsUnroll
170
+ >
171
+ class EpilogueStreamkWithBroadcast<
172
+ Shape_,
173
+ WarpMmaOperator_,
174
+ PartitionsK,
175
+ OutputTileIterator_,
176
+ TensorTileIterator_,
177
+ ElementVector_,
178
+ AccumulatorFragmentIterator_,
179
+ WarpTileIterator_,
180
+ SharedLoadIterator_,
181
+ OutputOp_,
182
+ Padding_,
183
+ FragmentsPerPartition,
184
+ IterationsUnroll,
185
+ false
186
+ > :
187
+ public EpilogueWithBroadcast<
188
+ Shape_,
189
+ WarpMmaOperator_,
190
+ PartitionsK,
191
+ OutputTileIterator_,
192
+ TensorTileIterator_,
193
+ ElementVector_,
194
+ AccumulatorFragmentIterator_,
195
+ WarpTileIterator_,
196
+ SharedLoadIterator_,
197
+ OutputOp_,
198
+ Padding_,
199
+ FragmentsPerPartition,
200
+ IterationsUnroll,
201
+ false>,
202
+ public EpilogueBaseStreamK<
203
+ Shape_,
204
+ PartitionsK,
205
+ WarpMmaOperator_,
206
+ AccumulatorFragmentIterator_>
207
+ {
208
+
209
+ public:
210
+
211
+ using Base = EpilogueWithBroadcast<
212
+ Shape_,
213
+ WarpMmaOperator_,
214
+ PartitionsK,
215
+ OutputTileIterator_,
216
+ TensorTileIterator_,
217
+ ElementVector_,
218
+ AccumulatorFragmentIterator_,
219
+ WarpTileIterator_,
220
+ SharedLoadIterator_,
221
+ OutputOp_,
222
+ Padding_,
223
+ FragmentsPerPartition,
224
+ IterationsUnroll,
225
+ false>;
226
+
227
+ using BaseStreamK = EpilogueBaseStreamK<
228
+ Shape_,
229
+ PartitionsK,
230
+ WarpMmaOperator_,
231
+ AccumulatorFragmentIterator_>;
232
+
233
+ using Shape = Shape_;
234
+ static int const kPartitionsK = PartitionsK;
235
+ using OutputTileIterator = OutputTileIterator_;
236
+ using TensorTileIterator = TensorTileIterator_;
237
+ using ElementVector = ElementVector_;
238
+ using SharedLoadIterator = SharedLoadIterator_;
239
+ using OutputOp = OutputOp_;
240
+
241
+ /// Fragment type used by the accumulator tile's fragment iterator
242
+ using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment;
243
+
244
+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction
245
+ using SharedStorage = typename Base::SharedStorage;
246
+
247
+ public:
248
+
249
+ /// Constructor
250
+ CUTLASS_DEVICE
251
+ EpilogueStreamkWithBroadcast(
252
+ SharedStorage &shared_storage, ///< Shared storage object
253
+ int thread_idx, ///< ID of a thread within the threadblock
254
+ int warp_idx, ///< ID of warp within threadblock
255
+ int lane_idx ///< Id of thread within warp
256
+ ):
257
+ Base(shared_storage, thread_idx, warp_idx, lane_idx),
258
+ BaseStreamK(thread_idx)
259
+ { }
260
+
261
+
262
+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace,
263
+ /// performing epilogue computations, writing to output
264
+ CUTLASS_DEVICE
265
+ void reduce(
266
+ int peer_idx_begin,
267
+ int peer_idx_end,
268
+ int reduce_fragment_idx,
269
+ void *element_workspace,
270
+ OutputOp const &output_op, ///< Output operator
271
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
272
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
273
+ OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
274
+ OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
275
+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
276
+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
277
+ MatrixCoord(Shape::kM, Shape::kN),
278
+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
279
+ MatrixCoord())
280
+ {
281
+ // Reduce peer accumulator fragments into one fragment
282
+ AccumulatorFragment accum_fragment;
283
+ BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
284
+
285
+ // Store fragment to shared memory
286
+ this->warp_tile_iterator_.store(accum_fragment);
287
+
288
+ __syncthreads();
289
+
290
+ Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator1, source_iterator2, tensor_iterator, problem_size, threadblock_offset);
291
+
292
+ }
293
+ };
294
+
295
+ /////////////////////////////////////////////////////////////////////////////////////////////////
296
+
297
+ /// EpilogueStreamkWithBroadcast: Single source
298
+
299
+ template <
300
+ typename Shape_,
301
+ typename WarpMmaOperator_,
302
+ int PartitionsK,
303
+ typename OutputTileIterator_,
304
+ typename TensorTileIterator_,
305
+ typename ElementVector_,
306
+ typename AccumulatorFragmentIterator_,
307
+ typename WarpTileIterator_,
308
+ typename SharedLoadIterator_,
309
+ typename OutputOp_,
310
+ typename Padding_,
311
+ int FragmentsPerPartition,
312
+ int IterationsUnroll
313
+ >
314
+ class EpilogueStreamkWithBroadcast<
315
+ Shape_,
316
+ WarpMmaOperator_,
317
+ PartitionsK,
318
+ OutputTileIterator_,
319
+ TensorTileIterator_,
320
+ ElementVector_,
321
+ AccumulatorFragmentIterator_,
322
+ WarpTileIterator_,
323
+ SharedLoadIterator_,
324
+ OutputOp_,
325
+ Padding_,
326
+ FragmentsPerPartition,
327
+ IterationsUnroll,
328
+ true
329
+ > :
330
+ public EpilogueWithBroadcast<
331
+ Shape_,
332
+ WarpMmaOperator_,
333
+ PartitionsK,
334
+ OutputTileIterator_,
335
+ TensorTileIterator_,
336
+ ElementVector_,
337
+ AccumulatorFragmentIterator_,
338
+ WarpTileIterator_,
339
+ SharedLoadIterator_,
340
+ OutputOp_,
341
+ Padding_,
342
+ FragmentsPerPartition,
343
+ IterationsUnroll,
344
+ true>,
345
+ public EpilogueBaseStreamK<
346
+ Shape_,
347
+ PartitionsK,
348
+ WarpMmaOperator_,
349
+ AccumulatorFragmentIterator_>
350
+ {
351
+
352
+ public:
353
+
354
+ using Base = EpilogueWithBroadcast<
355
+ Shape_,
356
+ WarpMmaOperator_,
357
+ PartitionsK,
358
+ OutputTileIterator_,
359
+ TensorTileIterator_,
360
+ ElementVector_,
361
+ AccumulatorFragmentIterator_,
362
+ WarpTileIterator_,
363
+ SharedLoadIterator_,
364
+ OutputOp_,
365
+ Padding_,
366
+ FragmentsPerPartition,
367
+ IterationsUnroll,
368
+ true>;
369
+
370
+ using BaseStreamK = EpilogueBaseStreamK<
371
+ Shape_,
372
+ PartitionsK,
373
+ WarpMmaOperator_,
374
+ AccumulatorFragmentIterator_>;
375
+
376
+ using Shape = Shape_;
377
+ static int const kPartitionsK = PartitionsK;
378
+ using OutputTileIterator = OutputTileIterator_;
379
+ using TensorTileIterator = TensorTileIterator_;
380
+ using ElementVector = ElementVector_;
381
+ using SharedLoadIterator = SharedLoadIterator_;
382
+ using OutputOp = OutputOp_;
383
+
384
+ /// Fragment type used by the accumulator tile's fragment iterator
385
+ using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment;
386
+
387
+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction
388
+ using SharedStorage = typename Base::SharedStorage;
389
+
390
+ public:
391
+
392
+ /// Constructor
393
+ CUTLASS_DEVICE
394
+ EpilogueStreamkWithBroadcast(
395
+ SharedStorage &shared_storage, ///< Shared storage object
396
+ int thread_idx, ///< ID of a thread within the threadblock
397
+ int warp_idx, ///< ID of warp within threadblock
398
+ int lane_idx ///< Id of thread within warp
399
+ ):
400
+ Base(shared_storage, thread_idx, warp_idx, lane_idx),
401
+ BaseStreamK(thread_idx)
402
+ { }
403
+
404
+
405
+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace,
406
+ /// performing epilogue computations, writing to output
407
+ CUTLASS_DEVICE
408
+ void reduce(
409
+ int peer_idx_begin,
410
+ int peer_idx_end,
411
+ int reduce_fragment_idx,
412
+ void *element_workspace,
413
+ OutputOp const &output_op, ///< Output operator
414
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
415
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
416
+ OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
417
+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
418
+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
419
+ MatrixCoord(Shape::kM, Shape::kN),
420
+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
421
+ MatrixCoord())
422
+ {
423
+ // Reduce peer accumulator fragments into one fragment
424
+ AccumulatorFragment accum_fragment;
425
+ BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
426
+
427
+ // Store fragment to shared memory
428
+ this->warp_tile_iterator_.store(accum_fragment);
429
+
430
+ __syncthreads();
431
+
432
+ Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator, tensor_iterator, problem_size, threadblock_offset);
433
+
434
+ }
435
+ };
436
+
437
+ ////////////////////////////////////////////////////////////////////////////////
438
+
439
+ } // namespace threadblock
440
+ } // namespace epilogue
441
+ } // namespace cutlass
442
+
443
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue.
33
+
34
+ The epilogue finds max values in each row of the row-major output matrix and stores them.
35
+ The max values are also used for a further round of threadblock scoped reduction operation, where
36
+ the partial reduction results are stored in a pre-allocated array and used for further full reduction.
37
+
38
+ */
39
+
40
+ #pragma once
41
+
42
+ /////////////////////////////////////////////////////////////////////////////////////////////////
43
+
44
+ #include "cutlass/cutlass.h"
45
+ #include "cutlass/arch/memory.h"
46
+ #include "cutlass/arch/memory_sm75.h"
47
+ #include "cutlass/numeric_conversion.h"
48
+ #include "cutlass/fast_math.h"
49
+
50
+ namespace cutlass {
51
+ namespace epilogue {
52
+ namespace threadblock {
53
+
54
+ template <
55
+ typename ThreadblockShape_,
56
+ int ThreadCount,
57
+ typename OutputTileIterator_,
58
+ typename ElementAccumulator_,
59
+ typename ElementNorm_,
60
+ typename ElementSum_,
61
+ typename ElementSoftmaxCompute_,
62
+ typename ElementwiseFunctor_,
63
+ bool UseMasking_ = false
64
+ >
65
+ class EpilogueVisitorSoftmax {
66
+ public:
67
+
68
+ using ThreadblockShape = ThreadblockShape_;
69
+ static int const kThreadCount = ThreadCount;
70
+
71
+ using OutputTileIterator = OutputTileIterator_;
72
+ using ElementwiseFunctor = ElementwiseFunctor_;
73
+
74
+ static int const kIterations = OutputTileIterator::kIterations;
75
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
76
+
77
+ using ElementOutput = typename OutputTileIterator::Element;
78
+ using LayoutOutput = cutlass::layout::RowMajor;
79
+ using ElementAccumulator = ElementAccumulator_;
80
+
81
+ using ElementNorm = ElementNorm_;
82
+ using ElementSum = ElementSum_;
83
+ using ElementSoftmaxCompute = ElementSoftmaxCompute_;
84
+
85
+ using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
86
+ using SoftmaxFragment = Array<ElementSoftmaxCompute, kElementsPerAccess>;
87
+ using OutputVector = Array<ElementOutput, kElementsPerAccess>;
88
+ using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
89
+
90
+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
91
+ static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
92
+ static bool const kUseMasking = UseMasking_;
93
+
94
+ /// Argument structure
95
+ struct Arguments {
96
+
97
+ typename ElementwiseFunctor::Params elementwise;
98
+ int64_t batch_stride_C;
99
+ int64_t batch_stride_D;
100
+ int64_t batch_stride_Max;
101
+ int64_t batch_stride_Sum;
102
+
103
+ //
104
+ // Methods
105
+ //
106
+ Arguments():
107
+ batch_stride_C(0),
108
+ batch_stride_D(0),
109
+ batch_stride_Max(0),
110
+ batch_stride_Sum(0)
111
+ {
112
+
113
+ }
114
+
115
+ Arguments(
116
+ typename ElementwiseFunctor::Params elementwise_
117
+ ):
118
+ elementwise(elementwise_),
119
+ batch_stride_C(0),
120
+ batch_stride_D(0),
121
+ batch_stride_Max(0),
122
+ batch_stride_Sum(0)
123
+ {
124
+
125
+ }
126
+
127
+ Arguments(
128
+ typename ElementwiseFunctor::Params elementwise_,
129
+ int64_t batch_stride_C_,
130
+ int64_t batch_stride_D_,
131
+ int64_t batch_stride_Max_,
132
+ int64_t batch_stride_Sum_
133
+ ):
134
+ elementwise(elementwise_),
135
+ batch_stride_C(batch_stride_C_),
136
+ batch_stride_D(batch_stride_D_),
137
+ batch_stride_Max(batch_stride_Max_),
138
+ batch_stride_Sum(batch_stride_Sum_)
139
+ {
140
+
141
+ }
142
+
143
+ };
144
+
145
+ struct Params {
146
+
147
+ typename ElementwiseFunctor::Params elementwise;
148
+ int64_t batch_stride_C;
149
+ int64_t batch_stride_D;
150
+ int64_t batch_stride_Max;
151
+ int64_t batch_stride_Sum;
152
+ //
153
+ // Methods
154
+ //
155
+ CUTLASS_HOST_DEVICE
156
+ Params()
157
+ {
158
+
159
+ }
160
+
161
+ CUTLASS_HOST_DEVICE
162
+ Params(Arguments const &args):
163
+ elementwise(args.elementwise),
164
+ batch_stride_C(args.batch_stride_C),
165
+ batch_stride_D(args.batch_stride_D),
166
+ batch_stride_Max(args.batch_stride_Max),
167
+ batch_stride_Sum(args.batch_stride_Sum)
168
+ {
169
+
170
+ }
171
+ };
172
+
173
+ /// Shared storage
174
+ struct SharedStorage {
175
+
176
+ };
177
+
178
+ private:
179
+
180
+ Params const & params_;
181
+ SharedStorage & shared_storage_;
182
+ MatrixCoord extent_;
183
+ MatrixCoord extent_real_;
184
+ ElementwiseFunctor elementwise_;
185
+
186
+ OutputTileIterator iterator_C_;
187
+ OutputTileIterator iterator_D_;
188
+ typename OutputTileIterator::Fragment fragment_C_;
189
+ typename OutputTileIterator::Fragment fragment_D_;
190
+
191
+ ElementAccumulator alpha_;
192
+ ElementAccumulator beta_;
193
+
194
+ ElementNorm *ptr_Max_;
195
+ ElementSum *ptr_Sum_;
196
+
197
+ int column_offset_;
198
+
199
+ ElementSoftmaxCompute accum_max_;
200
+ ElementSoftmaxCompute accum_sum_;
201
+
202
+ MatrixCoord thread_offset_;
203
+
204
+ float infinity_;
205
+
206
+ public:
207
+
208
+ CUTLASS_DEVICE
209
+ EpilogueVisitorSoftmax(
210
+ Params const &params,
211
+ SharedStorage &shared_storage,
212
+ cutlass::MatrixCoord const &problem_size,
213
+ int thread_idx,
214
+ int warp_idx,
215
+ int lane_idx,
216
+ typename OutputTileIterator::Params params_C,
217
+ typename OutputTileIterator::Params params_D,
218
+ typename OutputTileIterator::Element *ptr_C,
219
+ typename OutputTileIterator::Element *ptr_D,
220
+ ElementNorm *ptr_Max = nullptr,
221
+ ElementSum *ptr_Sum = nullptr,
222
+ cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0),
223
+ int column_offset = 0,
224
+ cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0),
225
+ float infinity = 10000.0f
226
+ ):
227
+ params_(params),
228
+ shared_storage_(shared_storage),
229
+ extent_(problem_size),
230
+ elementwise_(params.elementwise),
231
+ iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
232
+ iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
233
+ ptr_Max_(ptr_Max),
234
+ ptr_Sum_(ptr_Sum),
235
+ column_offset_(column_offset),
236
+ extent_real_(problem_size_real),
237
+ infinity_(infinity)
238
+ {
239
+ alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
240
+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
241
+
242
+ if (beta_ == ElementAccumulator()) {
243
+ iterator_C_.clear_mask();
244
+ }
245
+ }
246
+
247
+ /// Helper to indicate split-K behavior
248
+ CUTLASS_DEVICE
249
+ void set_k_partition(
250
+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
251
+ int split_k_slices) { ///< Total number of split-K slices
252
+
253
+ }
254
+
255
+ /// Called to set the batch index
256
+ CUTLASS_DEVICE
257
+ void set_batch_index(int batch_idx) {
258
+ iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
259
+ iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
260
+ }
261
+
262
+ /// Called at the start of the epilogue just before iterating over accumulator slices
263
+ CUTLASS_DEVICE
264
+ void begin_epilogue() {
265
+
266
+ }
267
+
268
+ /// Called at the start of one step before starting accumulator exchange
269
+ CUTLASS_DEVICE
270
+ void begin_step(int step_idx) {
271
+ fragment_D_.clear();
272
+ fragment_C_.clear();
273
+
274
+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
275
+ iterator_C_.load(fragment_C_);
276
+ ++iterator_C_;
277
+ }
278
+
279
+ }
280
+
281
+ /// Called at the start of a row
282
+ CUTLASS_DEVICE
283
+ void begin_row(int row_idx) {
284
+ // Clear accumulators for max and sum when starting a whole row
285
+ clear_accum_();
286
+
287
+ }
288
+
289
+ /// Called after accumulators have been exchanged for each accumulator vector
290
+ CUTLASS_DEVICE
291
+ void visit(
292
+ int iter_idx,
293
+ int row_idx,
294
+ int column_idx,
295
+ int frag_idx,
296
+ AccumulatorFragment const &accum) {
297
+
298
+ using Mul = cutlass::multiplies<SoftmaxFragment>;
299
+ using Minus = cutlass::minus<SoftmaxFragment>;
300
+ using Exp = cutlass::fast_exp_op<SoftmaxFragment>;
301
+
302
+ Minus minus;
303
+ Exp exponential;
304
+
305
+ SoftmaxFragment result;
306
+
307
+ NumericArrayConverter<ElementSoftmaxCompute, ElementOutput, kElementsPerAccess> source_converter;
308
+ OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
309
+
310
+ if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
311
+ result = source_converter(elementwise_(accum));
312
+ }else{
313
+ result = source_converter(elementwise_(accum, source_vector));
314
+ }
315
+
316
+ thread_offset_ =
317
+ iterator_D_.thread_start() +
318
+ OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
319
+
320
+ bool column_guard = (thread_offset_.column() < extent_.column());
321
+
322
+ if (kUseMasking) {
323
+ int elements_in_boundary = extent_real_.column() - thread_offset_.column();
324
+ elements_in_boundary = (elements_in_boundary > kElementsPerAccess) ? kElementsPerAccess : elements_in_boundary;
325
+ elementwise_padding_(result, elements_in_boundary);
326
+ }
327
+
328
+ ElementSoftmaxCompute accum_max_prev = accum_max_;
329
+
330
+ // Compute the maximum within one row
331
+ if (!column_idx) {
332
+ // This is the first fragment in a new row
333
+ if (column_guard) {
334
+ accum_max_ = maximum_accumulator_(result);
335
+ }
336
+ }
337
+ else {
338
+ // This is an additional fragment in the same row
339
+ if (column_guard) {
340
+ accum_max_ = maximum_accumulator_(result, accum_max_);
341
+ }
342
+ }
343
+
344
+ // proactively compute max in warps
345
+ accum_max_ = warp_reduce_max_(accum_max_);
346
+
347
+ ElementSoftmaxCompute updater = fast_exp(accum_max_prev - accum_max_);
348
+
349
+ SoftmaxFragment intermediate = exponential(minus(result, accum_max_));
350
+
351
+ if (kHasMultiStepsInRow) {
352
+ if (!column_idx) {
353
+ accum_sum_ = (column_guard) ? \
354
+ sum_accumulator_(intermediate) : ElementSoftmaxCompute(0);
355
+ } else {
356
+ // Algorithm in $3.1, https://arxiv.org/pdf/2205.14135v1.pdf
357
+ // S* = S* x updater + sum_row(P'), where updater = exp(M* - M_row)
358
+ accum_sum_ = (column_guard) ? \
359
+ sum_accumulator_(intermediate, accum_sum_ * updater) : accum_sum_ * updater;
360
+ }
361
+ } else {
362
+ accum_sum_ = (column_guard) ? sum_accumulator_(intermediate, accum_sum_) : ElementSoftmaxCompute(0);
363
+ }
364
+
365
+ // Convert to the output
366
+ NumericArrayConverter<ElementOutput, ElementSoftmaxCompute, kElementsPerAccess> output_converter;
367
+ OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
368
+ output = output_converter(result);
369
+ }
370
+
371
+ /// Called at the end of a row
372
+ CUTLASS_DEVICE
373
+ void end_row(int row_idx) {
374
+
375
+ using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>;
376
+ using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>;
377
+
378
+ ConvertSumOutput convert_sum_output;
379
+ ConvertNormOutput convert_norm_output;
380
+
381
+ // Compute accumulate sum only in the last step
382
+ accum_sum_ = warp_reduce_sum_(accum_sum_);
383
+
384
+ bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0);
385
+ bool row_guard = thread_offset_.row() < extent_.row();
386
+ bool is_write_thread = row_guard && is_first_thread_in_tile;
387
+
388
+ int block_batch = blockIdx.z;
389
+
390
+ ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Max;
391
+ ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Sum;
392
+
393
+ arch::global_store<ElementNorm, sizeof(ElementNorm)>(
394
+ convert_norm_output(accum_max_),
395
+ (void *)curr_ptr_max,
396
+ is_write_thread);
397
+
398
+ arch::global_store<ElementSum, sizeof(ElementSum)>(
399
+ convert_sum_output(accum_sum_),
400
+ (void *)curr_ptr_sum,
401
+ is_write_thread);
402
+
403
+ // Clear accumulators for max and sum when finishing a whole row
404
+ clear_accum_();
405
+
406
+ }
407
+
408
+ /// Called after all accumulator elements have been visited
409
+ CUTLASS_DEVICE
410
+ void end_step(int step_idx) {
411
+
412
+ iterator_D_.store(fragment_D_);
413
+ ++iterator_D_;
414
+ }
415
+
416
+ /// Called after all steps have been completed
417
+ CUTLASS_DEVICE
418
+ void end_epilogue() {
419
+
420
+ }
421
+
422
+ private:
423
+
424
+ CUTLASS_DEVICE
425
+ void elementwise_padding_(SoftmaxFragment &result, int elements_in_boundary) {
426
+ CUTLASS_PRAGMA_UNROLL
427
+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
428
+ result[i] = (i < elements_in_boundary) ? result[i] : ElementSoftmaxCompute(-infinity_);
429
+ }
430
+ }
431
+
432
+ CUTLASS_DEVICE
433
+ ElementSoftmaxCompute warp_reduce_sum_(ElementSoftmaxCompute sum_) {
434
+ int half_thread_in_row = (kThreadsPerRow >> 1);
435
+ CUTLASS_PRAGMA_UNROLL
436
+ for (int i = half_thread_in_row; i > 0; i >>= 1) {
437
+ ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, sum_, i);
438
+ sum_ += tmp;
439
+ }
440
+ return sum_;
441
+ }
442
+
443
+ CUTLASS_DEVICE
444
+ ElementSoftmaxCompute warp_reduce_max_(ElementSoftmaxCompute max_) {
445
+ int half_thread_in_row = (kThreadsPerRow >> 1);
446
+ CUTLASS_PRAGMA_UNROLL
447
+ for (int i = half_thread_in_row; i > 0; i >>= 1) {
448
+ ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, max_, i);
449
+ max_ = fast_max(max_, tmp);
450
+ }
451
+ return max_;
452
+ }
453
+
454
+ CUTLASS_DEVICE
455
+ void clear_accum_() {
456
+
457
+ uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX
458
+ float min_float = reinterpret_cast<float const &>(float_max_bits);
459
+ accum_max_ = ElementSoftmaxCompute(min_float);
460
+ accum_sum_ = ElementSoftmaxCompute(0);
461
+ }
462
+
463
+ CUTLASS_DEVICE
464
+ ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) {
465
+ ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
466
+
467
+ CUTLASS_PRAGMA_UNROLL
468
+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
469
+ sum_ += ElementSoftmaxCompute(accum[i]);
470
+ }
471
+
472
+ return sum_;
473
+ }
474
+
475
+ CUTLASS_DEVICE
476
+ ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute sum_) {
477
+ // ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
478
+
479
+ CUTLASS_PRAGMA_UNROLL
480
+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
481
+ sum_ += ElementSoftmaxCompute(accum[i]);
482
+ }
483
+
484
+ return sum_;
485
+ }
486
+
487
+ CUTLASS_DEVICE
488
+ ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) {
489
+ ElementSoftmaxCompute max_ = accum[0];
490
+
491
+ CUTLASS_PRAGMA_UNROLL
492
+ for (int i = 1; i < SoftmaxFragment::kElements; ++i) {
493
+ max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
494
+ }
495
+
496
+ return max_;
497
+ }
498
+
499
+ CUTLASS_DEVICE
500
+ ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) {
501
+
502
+ CUTLASS_PRAGMA_UNROLL
503
+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
504
+ max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
505
+ }
506
+
507
+ return max_;
508
+ }
509
+ };
510
+
511
+ } // namespace threadblock
512
+ } // namespace epilogue
513
+ } // namespace cutlass
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+
34
+ \brief Threadblock-level epilogue computing:
35
+ Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias
36
+ D = activation(Aux)
37
+
38
+ if Aux is fp8 type:
39
+ abs_max_output = max( abs(aux) | (for every aux in Aux))
40
+ Aux = scale_aux * Aux
41
+ endif
42
+
43
+ if D is fp8 type:
44
+ abs_max_output = max( abs(d) | (for every d in D))
45
+ D = scale_d * D
46
+ endif
47
+
48
+ Parameter Aux is optionally stored to global memory
49
+ */
50
+
51
+ #pragma once
52
+ #include "cutlass/cutlass.h"
53
+ #include CUDA_STD_HEADER(cassert)
54
+
55
+ #if defined(__CUDACC_RTC__)
56
+ #include CUDA_STD_HEADER(utility)
57
+ #else
58
+ #include <utility>
59
+ #endif
60
+
61
+ #include "cutlass/array.h"
62
+ #include "cutlass/numeric_types.h"
63
+ #include "cutlass/numeric_conversion.h"
64
+ #include "cutlass/tensor_coord.h"
65
+ #include "cutlass/aligned_buffer.h"
66
+ #include "cutlass/functional.h"
67
+ #include "cutlass/fast_math.h"
68
+ #include "cutlass/layout/vector.h"
69
+ #include "cutlass/layout/tensor.h"
70
+
71
+ #include "cutlass/gemm/gemm.h"
72
+
73
+ #include "cutlass/transform/pitch_linear_thread_map.h"
74
+ #include "cutlass/transform/threadblock/regular_tile_iterator.h"
75
+
76
+ #include "cutlass/epilogue/threadblock/epilogue_base.h"
77
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
78
+
79
+ #include "cutlass/numeric_types.h"
80
+
81
+ /////////////////////////////////////////////////////////////////////////////////////////////////
82
+
83
+ namespace cutlass {
84
+ namespace epilogue {
85
+ namespace threadblock {
86
+
87
+ /////////////////////////////////////////////////////////////////////////////////////////////////
88
+
89
+ namespace detail {
90
+
91
+ /// Helper class for keeping track of absolute maximums and performing scaling
92
+ template <
93
+ typename Iterator, // Iterator type used for storing the data for which absolute maximum and scaling
94
+ // will be computed. This type is used for predicating absolute maximum calculations.
95
+ typename Fragment, // Type of input to be computed on
96
+ bool ScalingAndAmaxNeeded // Whether to perform absolute maximum and scaling operations
97
+ >
98
+ struct ScalingAndAmaxHelper;
99
+
100
+ /// Partial specialization that does not perform scaling or calculate an absolute maximum
101
+ template <typename Iterator, typename Fragment>
102
+ struct ScalingAndAmaxHelper<Iterator, Fragment, false> {
103
+ using Element = typename Fragment::Element;
104
+
105
+ CUTLASS_HOST_DEVICE
106
+ ScalingAndAmaxHelper(Element scale) { }
107
+
108
+ CUTLASS_DEVICE
109
+ Fragment operator()(const Iterator& iterator, const Fragment& inp) {
110
+ return inp;
111
+ }
112
+
113
+ CUTLASS_HOST_DEVICE
114
+ Element get_abs_max() const {
115
+ return Element(0.);
116
+ }
117
+
118
+ CUTLASS_HOST_DEVICE
119
+ void set_scaling_factor(Element scale_) { }
120
+ };
121
+
122
+ /// Partial specialization that keeps track of an absolute maximum value of inputs seen
123
+ /// and scales inputs
124
+ template <typename Iterator, typename Fragment>
125
+ struct ScalingAndAmaxHelper<Iterator, Fragment, true> {
126
+ using Element = typename Fragment::Element;
127
+ using AccessType = typename Iterator::AccessType;
128
+ using ThreadMap = typename Iterator::ThreadMap;
129
+
130
+ Element abs_max;
131
+ Element scale;
132
+
133
+ // Operators
134
+ maximum_with_nan_propogation<Element> max_op;
135
+ absolute_value_op<Element> abs_op;
136
+ multiplies<Fragment> multiply;
137
+
138
+ CUTLASS_HOST_DEVICE
139
+ ScalingAndAmaxHelper(Element scale_) : abs_max(0.), scale(scale_) { }
140
+
141
+ // Compute the absolute maximum value between `abs_max` and the entries
142
+ // of `frag` for predicated-on entries of `iterator`. Return a scaled
143
+ // version of `inp`.
144
+ CUTLASS_DEVICE
145
+ Fragment operator()(const Iterator& iterator, const Fragment& frag) {
146
+ using PredicateGroup = Array<Element, Iterator::ThreadMap::kElementsPerAccess>;
147
+ PredicateGroup const *frag_ptr = reinterpret_cast<PredicateGroup const *>(&frag);
148
+
149
+ typename Iterator::Mask mask;
150
+ iterator.get_mask(mask);
151
+
152
+ CUTLASS_PRAGMA_UNROLL
153
+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
154
+
155
+ CUTLASS_PRAGMA_UNROLL
156
+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
157
+
158
+ CUTLASS_PRAGMA_UNROLL
159
+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
160
+ int frag_row_idx =
161
+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
162
+
163
+ int row_offset = row * ThreadMap::Delta::kRow
164
+ + group * ThreadMap::Delta::kGroup
165
+ + cluster * ThreadMap::Delta::kCluster;
166
+
167
+ bool row_guard = ((row_offset + iterator.thread_start_row()) < iterator.extent_row());
168
+
169
+ CUTLASS_PRAGMA_UNROLL
170
+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
171
+ bool guard = row_guard && mask.predicates[column];
172
+
173
+ if (guard) {
174
+ int access_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
175
+ CUTLASS_PRAGMA_UNROLL
176
+ for (int i = 0; i < PredicateGroup::kElements; ++i) {
177
+ abs_max = max_op(abs_max, abs_op(frag_ptr[access_idx][i]));
178
+ }
179
+ }
180
+ }
181
+ }
182
+ }
183
+ }
184
+
185
+ // Perform scaling
186
+ return multiply(scale, frag);
187
+ }
188
+
189
+ CUTLASS_HOST_DEVICE
190
+ Element get_abs_max() const {
191
+ return abs_max;
192
+ }
193
+
194
+ CUTLASS_HOST_DEVICE
195
+ void set_scaling_factor(Element scale_) {
196
+ scale = scale_;
197
+ }
198
+ };
199
+
200
+ } // namespace detail
201
+
202
+ /////////////////////////////////////////////////////////////////////////////////////////////////
203
+
204
+ template <
205
+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
206
+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
207
+ int PartitionsK, ///< Number of partitions of the K dimension
208
+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
209
+ typename AuxOutputTileIterator_, ///< Tile iterator writing auxiliary output tensors
210
+ typename ElementVector_, ///< Data type of bias vector
211
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
212
+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
213
+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
214
+ typename OutputOp_, ///< Output operator
215
+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
216
+ int FragmentsPerPartition = 1, ///< Used to coarsen the epilogue granularity
217
+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
218
+ (!IsEpilogueFunctorHeavy<OutputOp_>::value)
219
+ >
220
+ class EpilogueWithAbsMax :
221
+ public EpilogueBase<
222
+ Shape_,
223
+ typename WarpMmaOperator_::Shape,
224
+ PartitionsK,
225
+ AccumulatorFragmentIterator_,
226
+ WarpTileIterator_,
227
+ Padding_,
228
+ FragmentsPerPartition> {
229
+
230
+ public:
231
+
232
+ using Base = EpilogueBase<
233
+ Shape_,
234
+ typename WarpMmaOperator_::Shape,
235
+ PartitionsK,
236
+ AccumulatorFragmentIterator_,
237
+ WarpTileIterator_,
238
+ Padding_,
239
+ FragmentsPerPartition>;
240
+
241
+ static bool const kIsSingleSource = true;
242
+ using Shape = Shape_;
243
+ using WarpMmaOperator = WarpMmaOperator_;
244
+ static int const kPartitionsK = PartitionsK;
245
+ using OutputTileIterator = OutputTileIterator_;
246
+ using AuxOutputTileIterator = AuxOutputTileIterator_;
247
+ using ElementVector = ElementVector_;
248
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
249
+ using WarpTileIterator = WarpTileIterator_;
250
+ using SharedLoadIterator = SharedLoadIterator_;
251
+ using OutputOp = OutputOp_;
252
+ using Padding = Padding_;
253
+
254
+ using Layout = layout::RowMajor;
255
+ using LongIndex = typename Layout::LongIndex;
256
+
257
+ /// The complete warp-level accumulator tile
258
+ using AccumulatorTile = typename Base::AccumulatorTile;
259
+
260
+ /// Accumulator element
261
+ using ElementAccumulator = typename WarpTileIterator::Element;
262
+
263
+ /// Data type used for absolute maximum value
264
+ using ElementAbsmax = typename OutputOp::ElementAbsmax;
265
+
266
+ /// Compute data type produced by the output op
267
+ using ElementCompute = typename OutputOp::ElementCompute;
268
+
269
+ /// Compute fragment
270
+ using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
271
+
272
+ /// Helpers for (optionally) computing absolute maximums and scaling output and auxiliary output
273
+ using OutputScaler = detail::ScalingAndAmaxHelper<OutputTileIterator,
274
+ FragmentCompute,
275
+ OutputOp::kIsScalingAndAmaxOutputNeeded>;
276
+
277
+ using AuxOutputScaler = detail::ScalingAndAmaxHelper<AuxOutputTileIterator,
278
+ FragmentCompute,
279
+ OutputOp::kIsScalingAndAmaxAuxOutputNeeded>;
280
+
281
+ /// Thread map used by output tile iterators
282
+ using ThreadMap = typename OutputTileIterator::ThreadMap;
283
+
284
+ /// Fragment object used to store the broadcast values
285
+ using BroadcastFragment = Array<
286
+ ElementCompute,
287
+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
288
+
289
+ /// Output element
290
+ using ElementOutput = typename OutputTileIterator::Element;
291
+
292
+ /// Data type of auxiliary output
293
+ using ElementAuxOutput = typename AuxOutputTileIterator::Element;
294
+
295
+ /// Output access size
296
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
297
+
298
+ /// Tensor reference to destination tensor
299
+ using TensorRef = typename OutputTileIterator::TensorRef;
300
+
301
+ /// Tensor reference to sync tensor
302
+ using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
303
+
304
+ /// Const tensor reference to source tensor
305
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
306
+
307
+ /// Array type used to output
308
+ using OutputAccessType = Array<
309
+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
310
+
311
+ /// Array type used by output functor
312
+ using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
313
+
314
+ /// Array type used by output functor
315
+ using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
316
+
317
+ /// Auxiliary output access type
318
+ using AuxAccessType = Array<ElementAuxOutput, OutputTileIterator::kElementsPerAccess>;
319
+
320
+ /// Number of warps
321
+ using WarpCount = typename Base::WarpCount;
322
+
323
+ /// Shared memory allocation from epilogue base class
324
+ using BaseSharedStorage = typename Base::SharedStorage;
325
+
326
+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
327
+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
328
+
329
+ /// Used for the broadcast
330
+ struct BroadcastDetail {
331
+
332
+ /// Number of threads per warp
333
+ static int const kWarpSize = 32;
334
+
335
+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
336
+
337
+ /// Number of distinct scalar column indices handled by each thread
338
+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
339
+
340
+ /// Number of distinct scalar row indices handled by each thread
341
+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
342
+
343
+ /// Number of threads per threadblock
344
+ static int const kThreadCount = kWarpSize * WarpCount::kCount;
345
+
346
+ /// Number of distinct threads per row of output tile
347
+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
348
+
349
+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
350
+ static int const kThreadRows = kThreadCount / kThreadsPerRow;
351
+
352
+ /// I'm not sure what I meant here.
353
+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
354
+
355
+ /// Shape of the shared memory allocation for the epilogue
356
+ using StorageShape = MatrixShape<
357
+ kThreadRows,
358
+ Shape::kN
359
+ >;
360
+
361
+ /// Debug printing
362
+ CUTLASS_DEVICE
363
+ static void print() {
364
+ #if 0
365
+ printf("BroadcastDetail {\n");
366
+ printf(
367
+ " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
368
+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
369
+ kColumnsPerThread,
370
+ kRowsPerThread,
371
+ kThreadCount,
372
+ kThreadsPerRow,
373
+ kThreadRows,
374
+ kThreadAccessesPerRow,
375
+ StorageShape::kRow,
376
+ StorageShape::kColumn,
377
+ StorageShape::kCount
378
+ );
379
+ printf("};\n");
380
+ #endif
381
+ }
382
+ };
383
+
384
+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction
385
+ struct SharedStorage {
386
+ union {
387
+ BaseSharedStorage base;
388
+ };
389
+
390
+ CUTLASS_HOST_DEVICE
391
+ SharedStorage() { }
392
+ };
393
+
394
+ public:
395
+
396
+
397
+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
398
+ "Mismatch between shared load iterator and output tile iterator.");
399
+
400
+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
401
+
402
+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
403
+ "Divisibility");
404
+
405
+ private:
406
+
407
+ /// Loads fragment from shared memory aligned with output tensor
408
+ SharedLoadIterator shared_load_iterator_;
409
+
410
+ /// Thread index within the threadblock
411
+ int thread_idx_;
412
+
413
+ public:
414
+
415
+ /// Constructor
416
+ CUTLASS_DEVICE
417
+ EpilogueWithAbsMax(
418
+ SharedStorage &shared_storage, ///< Shared storage object
419
+ int thread_idx, ///< ID of a thread within the threadblock
420
+ int warp_idx, ///< ID of warp within threadblock
421
+ int lane_idx ///< Id of thread within warp
422
+ ):
423
+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
424
+ shared_load_iterator_(shared_storage.base.reference(), thread_idx),
425
+ thread_idx_(thread_idx)
426
+ {
427
+
428
+ }
429
+
430
+ /// Streams the result to global memory
431
+ CUTLASS_DEVICE
432
+ void operator()(
433
+ OutputOp &output_op, ///< Output operator
434
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
435
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
436
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
437
+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
438
+ AuxOutputTileIterator aux_iterator, ///< Tile iterator for destination auxiliary output
439
+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
440
+ MatrixCoord(Shape::kM, Shape::kN),
441
+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
442
+ MatrixCoord()) {
443
+
444
+ BroadcastFragment broadcast_fragment;
445
+
446
+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
447
+
448
+ OutputScaler output_scaler(output_op.get_scale_d());
449
+
450
+ AuxOutputScaler aux_scaler(output_op.get_scale_aux());
451
+
452
+ if (!output_op.is_source_needed()) {
453
+ compute_source_not_needed_(
454
+ output_op,
455
+ broadcast_fragment,
456
+ destination_iterator,
457
+ accumulators,
458
+ aux_iterator,
459
+ output_scaler,
460
+ aux_scaler);
461
+ }
462
+ else {
463
+ compute_source_needed_(
464
+ output_op,
465
+ broadcast_fragment,
466
+ destination_iterator,
467
+ accumulators,
468
+ source_iterator,
469
+ aux_iterator,
470
+ output_scaler,
471
+ aux_scaler);
472
+ }
473
+
474
+ // Store the absolute maximum values of the output and auxiliar tensors, if needed.
475
+ if (output_op.get_ptr_output_abs_max() != nullptr) {
476
+ ElementAbsmax local_abs_max =
477
+ NumericConverter<ElementAbsmax, ElementCompute, OutputOp::kRound>{}(output_scaler.get_abs_max());
478
+ atomic_maximum<ElementAbsmax>{}(
479
+ output_op.get_ptr_output_abs_max(), local_abs_max);
480
+ }
481
+
482
+ if (output_op.get_ptr_aux_output_abs_max() != nullptr) {
483
+ ElementAbsmax local_abs_max =
484
+ NumericConverter<ElementAbsmax, ElementCompute, OutputOp::kRound>{}(aux_scaler.get_abs_max());
485
+ atomic_maximum<ElementAbsmax>{}(
486
+ output_op.get_ptr_aux_output_abs_max(), local_abs_max);
487
+ }
488
+ }
489
+
490
+ private:
491
+
492
+ CUTLASS_DEVICE
493
+ void load_broadcast_fragment_(
494
+ BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
495
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
496
+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
497
+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
498
+ ) {
499
+
500
+ broadcast_fragment.clear();
501
+
502
+ // If no pointer is supplied, set with all zeros and avoid memory accesses
503
+ if (!broadcast_ptr) {
504
+ return;
505
+ }
506
+
507
+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
508
+
509
+ int thread_column_idx = threadblock_offset.column() + thread_initial_column;
510
+ broadcast_ptr += thread_initial_column;
511
+
512
+ NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
513
+ using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
514
+ using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
515
+
516
+ ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
517
+
518
+ CUTLASS_PRAGMA_UNROLL
519
+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
520
+
521
+ AccessType loaded;
522
+
523
+ loaded.clear();
524
+
525
+ if (thread_column_idx < problem_size.column()) {
526
+ loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
527
+ }
528
+
529
+ ComputeFragmentType cvt = converter(loaded);
530
+ frag_ptr[j] = cvt;
531
+
532
+ thread_column_idx += ThreadMap::Delta::kColumn;
533
+ broadcast_ptr += ThreadMap::Delta::kColumn;
534
+ }
535
+ }
536
+
537
+ template <class Seq>
538
+ struct acc2smem_source_not_needed;
539
+
540
+ template <size_t... Seq>
541
+ struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
542
+ template <int Advance>
543
+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
544
+ WarpTileIterator &warp_tile_iterator) {
545
+ CUTLASS_PRAGMA_UNROLL
546
+ for (int i = 0; i < Advance; i++) {
547
+ ++accum_fragment_iterator;
548
+ }
549
+
550
+ CUTLASS_PRAGMA_UNROLL
551
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
552
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
553
+
554
+ accum_fragment_iterator.load(accum_fragment);
555
+ ++accum_fragment_iterator;
556
+
557
+ warp_tile_iterator.store(accum_fragment);
558
+ if (p < Base::kFragmentsPerIteration - 1) {
559
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
560
+ }
561
+ }
562
+
563
+ if (Base::kFragmentsPerIteration > 1) {
564
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
565
+ (1 - Base::kFragmentsPerIteration));
566
+ }
567
+ }
568
+
569
+ CUTLASS_DEVICE
570
+ static void push(size_t pos,
571
+ AccumulatorFragmentIterator const &iterator_begin,
572
+ WarpTileIterator &warp_tile_iterator) {
573
+ int dummy[] = {
574
+ (pos == (Seq * Base::kFragmentsPerIteration)) &&
575
+ (helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
576
+
577
+ CUTLASS_UNUSED(dummy[0]);
578
+ }
579
+ };
580
+
581
+ /// Streams the result to global memory
582
+ CUTLASS_DEVICE
583
+ void compute_source_not_needed_(
584
+ OutputOp &output_op, ///< Output operator
585
+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
586
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
587
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
588
+ AuxOutputTileIterator aux_iterator, ///< Tile iterator for destination auxiliary output
589
+ OutputScaler& output_scaler, ///< Helper for (optionally) computing the absolute maximum and scaling output
590
+ AuxOutputScaler& aux_scaler ///< Helper for (optionally) computing the absolute maximum and scaling the auxiliary output
591
+ ) {
592
+
593
+ //
594
+ // Iterator over warp-level accumulator fragment
595
+ //
596
+
597
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
598
+
599
+ //
600
+ // Iterate over accumulator tile
601
+ //
602
+
603
+ // CUTLASS_PRAGMA_UNROLL
604
+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
605
+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
606
+
607
+ //
608
+ // Convert and store fragment
609
+ //
610
+
611
+
612
+ __syncthreads();
613
+
614
+ acc2smem_source_not_needed<
615
+ cutlass::make_index_sequence<OutputTileIterator::kIterations /
616
+ Base::kFragmentsPerIteration>>::push(iter,
617
+ accum_fragment_iterator,
618
+ this->warp_tile_iterator_);
619
+
620
+ __syncthreads();
621
+
622
+ //
623
+ // Load fragments from shared memory
624
+ //
625
+
626
+ CUTLASS_PRAGMA_UNROLL
627
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
628
+
629
+
630
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
631
+
632
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
633
+
634
+ if (p < Base::kFragmentsPerIteration - 1) {
635
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
636
+ }
637
+ else if (kPartitionsK > 1) {
638
+
639
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
640
+
641
+ CUTLASS_PRAGMA_UNROLL
642
+ for ( int i = 1; i < kPartitionsK; ++i) {
643
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
644
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
645
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
646
+ }
647
+
648
+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
649
+ }
650
+
651
+ //
652
+ // Apply output operation
653
+ //
654
+
655
+ FragmentCompute frag_Z_compute;
656
+ FragmentCompute frag_Aux_compute;
657
+
658
+ apply_output_operator_source_not_needed_(
659
+ frag_Z_compute,
660
+ frag_Aux_compute,
661
+ output_op,
662
+ aligned_accum_fragment[0],
663
+ broadcast_fragment);
664
+
665
+ //
666
+ // Conditionally store fragments
667
+ //
668
+
669
+ // (Optionally) compute the absolute maximum of frag_Z and scale frag_Z
670
+ frag_Z_compute = output_scaler(destination_iterator, frag_Z_compute);
671
+ NumericArrayConverter<typename OutputTileIterator::Fragment::Element, ElementCompute,
672
+ OutputTileIterator::Fragment::kElements> cvt_to_dst;
673
+ typename OutputTileIterator::Fragment frag_Z = cvt_to_dst(frag_Z_compute);
674
+
675
+ // Always store the output
676
+ destination_iterator.store(frag_Z);
677
+ ++destination_iterator;
678
+
679
+ // Only store the auxiliary output if scaling and absolute-maximum calculation were needed
680
+ if (OutputOp::kIsScalingAndAmaxAuxOutputNeeded) {
681
+ frag_Aux_compute = aux_scaler(aux_iterator, frag_Aux_compute);
682
+
683
+ NumericArrayConverter<typename AuxOutputTileIterator::Fragment::Element, ElementCompute,
684
+ AuxOutputTileIterator::Fragment::kElements> cvt_to_aux;
685
+ typename AuxOutputTileIterator::Fragment frag_Aux = cvt_to_aux(frag_Aux_compute);
686
+ aux_iterator.store(frag_Aux);
687
+ ++aux_iterator;
688
+ }
689
+ }
690
+
691
+ if (Base::kFragmentsPerIteration > 1) {
692
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
693
+ }
694
+ }
695
+ }
696
+
697
+
698
+ template<class Seq>
699
+ struct acc2smem_source_needed;
700
+
701
+ template <size_t... Seq>
702
+ struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
703
+ template<int Advance>
704
+ CUTLASS_DEVICE
705
+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
706
+ WarpTileIterator &warp_tile_iterator) {
707
+ CUTLASS_PRAGMA_UNROLL
708
+ for (int i = 0; i < Advance; i++) {
709
+ ++accum_fragment_iterator;
710
+ }
711
+
712
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
713
+ accum_fragment_iterator.load(accum_fragment);
714
+ warp_tile_iterator.store(accum_fragment);
715
+ }
716
+
717
+ CUTLASS_DEVICE
718
+ static void push(size_t pos,
719
+ AccumulatorFragmentIterator const &iterator_begin,
720
+ WarpTileIterator &warp_tile_iterator) {
721
+ int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
722
+ }
723
+ };
724
+
725
+
726
+ /// Streams the result to global memory
727
+ CUTLASS_DEVICE
728
+ void compute_source_needed_(
729
+ OutputOp &output_op, ///< Output operator
730
+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
731
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
732
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
733
+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
734
+ AuxOutputTileIterator aux_iterator, ///< Tile iterator for destination auxiliary output
735
+ OutputScaler& output_scaler, ///< Helper for (optionally) computing the absolute maximum and scaling output
736
+ AuxOutputScaler& aux_scaler ///< Helper for (optionally) computing the absolute maximum and scaling the auxiliary output
737
+ ) {
738
+
739
+ typename OutputTileIterator::Fragment source_fragment;
740
+ source_fragment.clear();
741
+
742
+ //
743
+ // Iterator over warp-level accumulator fragment
744
+ //
745
+
746
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
747
+
748
+ //
749
+ // Iterate over accumulator tile
750
+ //
751
+
752
+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
753
+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
754
+
755
+ //
756
+ // Load the source
757
+ //
758
+
759
+ source_iterator.load(source_fragment);
760
+ ++source_iterator;
761
+
762
+ //
763
+ // Convert and store fragment
764
+ //
765
+
766
+ __syncthreads();
767
+
768
+ acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
769
+ iter, accum_fragment_iterator, this->warp_tile_iterator_);
770
+
771
+ __syncthreads();
772
+
773
+ //
774
+ // Load fragments from shared memory
775
+ //
776
+
777
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
778
+
779
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
780
+
781
+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices
782
+ if (kPartitionsK > 1)
783
+ {
784
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
785
+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
786
+
787
+ CUTLASS_PRAGMA_UNROLL
788
+ for ( int i = 1; i < kPartitionsK; ++i) {
789
+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
790
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
791
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
792
+ }
793
+
794
+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
795
+ }
796
+
797
+ //
798
+ // Apply output operation
799
+ //
800
+
801
+ FragmentCompute frag_Z_compute;
802
+ FragmentCompute frag_Aux_compute;
803
+
804
+ apply_output_operator_(
805
+ frag_Z_compute,
806
+ frag_Aux_compute,
807
+ output_op,
808
+ aligned_accum_fragment[0],
809
+ source_fragment,
810
+ broadcast_fragment);
811
+
812
+ //
813
+ // Conditionally store fragments
814
+ //
815
+
816
+ // (Optionally) compute the absolute maximum of frag_Z and scale frag_Z
817
+ frag_Z_compute = output_scaler(destination_iterator, frag_Z_compute);
818
+ NumericArrayConverter<typename OutputTileIterator::Fragment::Element, ElementCompute,
819
+ OutputTileIterator::Fragment::kElements> cvt_to_dst;
820
+ typename OutputTileIterator::Fragment frag_Z = cvt_to_dst(frag_Z_compute);
821
+
822
+ // Always store the output
823
+ destination_iterator.store(frag_Z);
824
+ ++destination_iterator;
825
+
826
+ // Only store the auxiliary output if scaling and absolute-maximum calculation were needed
827
+ if (OutputOp::kIsScalingAndAmaxAuxOutputNeeded) {
828
+ frag_Aux_compute = aux_scaler(aux_iterator, frag_Aux_compute);
829
+
830
+ NumericArrayConverter<typename AuxOutputTileIterator::Fragment::Element, ElementCompute,
831
+ AuxOutputTileIterator::Fragment::kElements> cvt_to_aux;
832
+ typename AuxOutputTileIterator::Fragment frag_Aux = cvt_to_aux(frag_Aux_compute);
833
+ aux_iterator.store(frag_Aux);
834
+ ++aux_iterator;
835
+ }
836
+ }
837
+ }
838
+
839
+ /// Helper to invoke the output functor over each vector of output
840
+ CUTLASS_DEVICE
841
+ void apply_output_operator_(
842
+ FragmentCompute &frag_Z,
843
+ FragmentCompute &frag_Aux,
844
+ OutputOp &output_op,
845
+ typename SharedLoadIterator::Fragment const &frag_AB,
846
+ typename OutputTileIterator::Fragment const &frag_C,
847
+ BroadcastFragment const &frag_Broadcast) {
848
+
849
+ using AccessTypeZ = Array<ElementCompute, kElementsPerAccess>;
850
+ using AccessTypeAux = Array<ElementCompute, kElementsPerAccess>;
851
+ using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
852
+
853
+ AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
854
+ AccessTypeAux *frag_Aux_ptr = reinterpret_cast<AccessTypeAux *>(&frag_Aux);
855
+
856
+ AccumulatorAccessType const *frag_AB_ptr =
857
+ reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
858
+
859
+ OutputAccessType const *frag_C_ptr =
860
+ reinterpret_cast<OutputAccessType const *>(&frag_C);
861
+
862
+ AccessTypeBroadcast const *frag_Broadcast_ptr =
863
+ reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
864
+
865
+ int const kOutputOpIterations =
866
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
867
+
868
+ CUTLASS_PRAGMA_UNROLL
869
+ for (int i = 0; i < kOutputOpIterations; ++i) {
870
+ output_op(
871
+ frag_Z_ptr[i],
872
+ frag_Aux_ptr[i],
873
+ frag_AB_ptr[i],
874
+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn],
875
+ frag_C_ptr[i]);
876
+ }
877
+ }
878
+
879
+ /// Helper to invoke the output functor over each vector of output
880
+ CUTLASS_DEVICE
881
+ void apply_output_operator_source_not_needed_(
882
+ FragmentCompute &frag_Z,
883
+ FragmentCompute &frag_Aux,
884
+ OutputOp &output_op,
885
+ typename SharedLoadIterator::Fragment const &frag_AB,
886
+ BroadcastFragment const &frag_Broadcast) {
887
+
888
+ using AccessTypeZ = Array<ElementCompute, kElementsPerAccess>;
889
+ using AccessTypeAux = Array<ElementCompute, kElementsPerAccess>;
890
+ using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
891
+
892
+ AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
893
+ AccessTypeAux *frag_Aux_ptr = reinterpret_cast<AccessTypeAux *>(&frag_Aux);
894
+
895
+ AccumulatorAccessType const *frag_AB_ptr =
896
+ reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
897
+
898
+ AccessTypeBroadcast const *frag_Broadcast_ptr =
899
+ reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
900
+
901
+ int const kOutputOpIterations =
902
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
903
+
904
+ CUTLASS_PRAGMA_UNROLL
905
+ for (int i = 0; i < kOutputOpIterations; ++i) {
906
+
907
+ output_op(
908
+ frag_Z_ptr[i],
909
+ frag_Aux_ptr[i],
910
+ frag_AB_ptr[i],
911
+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
912
+ }
913
+ }
914
+ };
915
+
916
+ ////////////////////////////////////////////////////////////////////////////////
917
+
918
+ } // namespace threadblock
919
+ } // namespace epilogue
920
+ } // namespace cutlass
921
+
922
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu130-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h ADDED
@@ -0,0 +1,1717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+
33
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
34
+
35
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
36
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
37
+
38
+ */
39
+
40
+ #pragma once
41
+ #include "cutlass/cutlass.h"
42
+ #include CUDA_STD_HEADER(cassert)
43
+
44
+ #if defined(__CUDACC_RTC__)
45
+ #include CUDA_STD_HEADER(utility)
46
+ #else
47
+ #include <utility>
48
+ #endif
49
+
50
+ #include "cutlass/array.h"
51
+ #include "cutlass/numeric_types.h"
52
+ #include "cutlass/numeric_conversion.h"
53
+ #include "cutlass/tensor_coord.h"
54
+ #include "cutlass/aligned_buffer.h"
55
+ #include "cutlass/functional.h"
56
+ #include "cutlass/fast_math.h"
57
+ #include "cutlass/layout/vector.h"
58
+ #include "cutlass/layout/tensor.h"
59
+
60
+ #include "cutlass/gemm/gemm.h"
61
+
62
+ #include "cutlass/transform/pitch_linear_thread_map.h"
63
+ #include "cutlass/transform/threadblock/regular_tile_iterator.h"
64
+
65
+ #include "cutlass/epilogue/threadblock/epilogue_base.h"
66
+ #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
67
+
68
+ #include "cutlass/numeric_types.h"
69
+
70
+ /////////////////////////////////////////////////////////////////////////////////////////////////
71
+
72
+ namespace cutlass {
73
+ namespace epilogue {
74
+ namespace threadblock {
75
+
76
+ /////////////////////////////////////////////////////////////////////////////////////////////////
77
+
78
+ /// This base class is meant to define the concept required of the
79
+ /// EpilogueWithBroadcast::OutputOp
80
+ template <
81
+ typename ElementC_,
82
+ typename ElementAccumulator_,
83
+ typename ElementCompute_,
84
+ typename ElementZ_,
85
+ typename ElementT_,
86
+ int ElementsPerAccess,
87
+ bool StoreZ = true,
88
+ bool StoreT = true
89
+ >
90
+ struct EpilogueWithBroadcastOpBase {
91
+
92
+ using ElementOutput = ElementC_;
93
+ using ElementAccumulator = ElementAccumulator_;
94
+ using ElementCompute = ElementCompute_;
95
+ using ElementZ = ElementZ_;
96
+ using ElementT = ElementT_;
97
+ static int const kElementsPerAccess = ElementsPerAccess;
98
+
99
+ using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
100
+ using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
101
+ using FragmentC = Array<ElementOutput, kElementsPerAccess>;
102
+ using FragmentZ = Array<ElementZ, kElementsPerAccess>;
103
+ using FragmentT = Array<ElementT, kElementsPerAccess>;
104
+
105
+ /// If true, the 'Z' tensor is stored
106
+ static bool const kStoreZ = StoreZ;
107
+
108
+ /// If true, the 'T' tensor is stored
109
+ static bool const kStoreT = StoreT;
110
+
111
+ /// Parameters structure - required
112
+ struct Params { };
113
+
114
+ //
115
+ // Methods
116
+ //
117
+
118
+ /// Constructor from Params
119
+ EpilogueWithBroadcastOpBase(Params const &params_) { }
120
+
121
+ /// Determine if the source is needed. May return false if
122
+ bool is_source_needed() const {
123
+ return true;
124
+ }
125
+
126
+ CUTLASS_HOST_DEVICE
127
+ void set_k_partition(int k_partition, int k_partition_count) { }
128
+
129
+ /// Applies the operation when is_source_needed() is true
130
+ CUTLASS_HOST_DEVICE
131
+ void operator()(
132
+ FragmentZ &frag_Z,
133
+ FragmentT &frag_T,
134
+ FragmentAccumulator const &AB,
135
+ FragmentC const &frag_C1,
136
+ FragmentC const &frag_C2,
137
+ FragmentCompute const &V) const {
138
+
139
+ }
140
+
141
+ /// Applies the operation when is_source_needed() is false
142
+ CUTLASS_HOST_DEVICE
143
+ void operator()(
144
+ FragmentZ &frag_Z,
145
+ FragmentT &frag_T,
146
+ FragmentAccumulator const &AB,
147
+ FragmentCompute const &V) const {
148
+
149
+ }
150
+ };
151
+
152
+ ////////////////////////////////////////////////////////////////////////////////
153
+
154
+ /// Epilogue operator with bias vector broadcast over columns.
155
+ ///
156
+ /// Computes the following:
157
+ ///
158
+ ///
159
+ /// Z, T = OutputOp(AB, C, Broadcast)
160
+ ///
161
+ /// if (ElementwiseOp::kStoreZ) {
162
+ /// store(converted_u);
163
+ /// }
164
+ ///
165
+ /// if (ElementwiseOp::kStoreT) {
166
+ /// store(v);
167
+ /// }
168
+ ///
169
+ template <
170
+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
171
+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
172
+ int PartitionsK, ///< Number of partitions of the K dimension
173
+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
174
+ typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
175
+ typename ElementVector_, ///< Pointer to broadcast vector
176
+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
177
+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
178
+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
179
+ typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
180
+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
181
+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
182
+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
183
+ (!IsEpilogueFunctorHeavy<OutputOp_>::value),
184
+ bool IsSingleSource = OutputOp_::kIsSingleSource
185
+ >
186
+ class EpilogueWithBroadcast;
187
+
188
+ template <
189
+ typename Shape_,
190
+ typename WarpMmaOperator_,
191
+ int PartitionsK,
192
+ typename OutputTileIterator_,
193
+ typename TensorTileIterator_,
194
+ typename ElementVector_,
195
+ typename AccumulatorFragmentIterator_,
196
+ typename WarpTileIterator_,
197
+ typename SharedLoadIterator_,
198
+ typename OutputOp_,
199
+ typename Padding_,
200
+ int FragmentsPerPartition,
201
+ int IterationsUnroll
202
+ >
203
+ class EpilogueWithBroadcast<
204
+ Shape_,
205
+ WarpMmaOperator_,
206
+ PartitionsK,
207
+ OutputTileIterator_,
208
+ TensorTileIterator_,
209
+ ElementVector_,
210
+ AccumulatorFragmentIterator_,
211
+ WarpTileIterator_,
212
+ SharedLoadIterator_,
213
+ OutputOp_,
214
+ Padding_,
215
+ FragmentsPerPartition,
216
+ IterationsUnroll,
217
+ false
218
+ > :
219
+ public EpilogueBase<
220
+ Shape_,
221
+ typename WarpMmaOperator_::Shape,
222
+ PartitionsK,
223
+ AccumulatorFragmentIterator_,
224
+ WarpTileIterator_,
225
+ Padding_,
226
+ FragmentsPerPartition> {
227
+
228
+ public:
229
+
230
+ using Base = EpilogueBase<
231
+ Shape_,
232
+ typename WarpMmaOperator_::Shape,
233
+ PartitionsK,
234
+ AccumulatorFragmentIterator_,
235
+ WarpTileIterator_,
236
+ Padding_,
237
+ FragmentsPerPartition>;
238
+
239
+ static bool const kIsSingleSource = false;
240
+ using Shape = Shape_;
241
+ using WarpMmaOperator = WarpMmaOperator_;
242
+ static int const kPartitionsK = PartitionsK;
243
+ using OutputTileIterator = OutputTileIterator_;
244
+ using TensorTileIterator = TensorTileIterator_;
245
+ using ElementVector = ElementVector_;
246
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
247
+ using WarpTileIterator = WarpTileIterator_;
248
+ using SharedLoadIterator = SharedLoadIterator_;
249
+ using OutputOp = OutputOp_;
250
+ using Padding = Padding_;
251
+
252
+ using Layout = layout::RowMajor;
253
+ using LongIndex = typename Layout::LongIndex;
254
+
255
+ /// The complete warp-level accumulator tile
256
+ using AccumulatorTile = typename Base::AccumulatorTile;
257
+
258
+ /// Accumulator element
259
+ using ElementAccumulator = typename WarpTileIterator::Element;
260
+
261
+ /// Compute data type produced by the output op
262
+ using ElementCompute = typename OutputOp::ElementCompute;
263
+
264
+ /// Compute fragment
265
+ using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
266
+
267
+ /// Thread map used by output tile iterators
268
+ using ThreadMap = typename OutputTileIterator::ThreadMap;
269
+
270
+ /// Fragment object used to store the broadcast values
271
+ using BroadcastFragment = Array<
272
+ ElementCompute,
273
+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
274
+
275
+ /// Output element
276
+ using ElementOutput = typename OutputTileIterator::Element;
277
+
278
+ /// Data type of additional tensor
279
+ using ElementTensor = typename TensorTileIterator::Element;
280
+
281
+ /// Output access size
282
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
283
+
284
+ /// Tensor reference to destination tensor
285
+ using TensorRef = typename OutputTileIterator::TensorRef;
286
+
287
+ /// Tensor reference to sync tensor
288
+ using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
289
+
290
+ /// Const tensor reference to source tensor
291
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
292
+
293
+ /// Array type used to output
294
+ using OutputAccessType = Array<
295
+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
296
+
297
+ /// Array type used by output functor
298
+ using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
299
+
300
+ /// Array type used by output functor
301
+ using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
302
+
303
+ /// Tensor access type
304
+ using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
305
+
306
+ /// Number of warps
307
+ using WarpCount = typename Base::WarpCount;
308
+
309
+ /// Shared memory allocation from epilogue base class
310
+ using BaseSharedStorage = typename Base::SharedStorage;
311
+
312
+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
313
+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
314
+
315
+ /// Used for the broadcast
316
+ struct BroadcastDetail {
317
+
318
+ /// Number of threads per warp
319
+ static int const kWarpSize = 32;
320
+
321
+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
322
+
323
+ /// Number of distinct scalar column indices handled by each thread
324
+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
325
+
326
+ /// Number of distinct scalar row indices handled by each thread
327
+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
328
+
329
+ /// Number of threads per threadblock
330
+ static int const kThreadCount = kWarpSize * WarpCount::kCount;
331
+
332
+ /// Number of distinct threads per row of output tile
333
+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
334
+
335
+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
336
+ static int const kThreadRows = kThreadCount / kThreadsPerRow;
337
+
338
+ /// I'm not sure what I meant here.
339
+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
340
+
341
+ /// Shape of the shared memory allocation for the epilogue
342
+ using StorageShape = MatrixShape<
343
+ kThreadRows,
344
+ Shape::kN
345
+ >;
346
+
347
+ /// Debug printing
348
+ CUTLASS_DEVICE
349
+ static void print() {
350
+ #if 0
351
+ printf("BroadcastDetail {\n");
352
+ printf(
353
+ " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
354
+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
355
+ kColumnsPerThread,
356
+ kRowsPerThread,
357
+ kThreadCount,
358
+ kThreadsPerRow,
359
+ kThreadRows,
360
+ kThreadAccessesPerRow,
361
+ StorageShape::kRow,
362
+ StorageShape::kColumn,
363
+ StorageShape::kCount
364
+ );
365
+ printf("};\n");
366
+ #endif
367
+ }
368
+ };
369
+
370
+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction
371
+ struct SharedStorage {
372
+ union {
373
+ BaseSharedStorage base;
374
+ };
375
+
376
+ CUTLASS_HOST_DEVICE
377
+ SharedStorage() { }
378
+ };
379
+
380
+ public:
381
+
382
+
383
+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
384
+ "Mismatch between shared load iterator and output tile iterator.");
385
+
386
+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
387
+
388
+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
389
+ "Divisibility");
390
+
391
+ private:
392
+
393
+ /// Loads fragment from shared memory aligned with output tensor
394
+ SharedLoadIterator shared_load_iterator_;
395
+
396
+ /// Thread index within the threadblock
397
+ int thread_idx_;
398
+
399
+ public:
400
+
401
+ /// Constructor
402
+ CUTLASS_DEVICE
403
+ EpilogueWithBroadcast(
404
+ SharedStorage &shared_storage, ///< Shared storage object
405
+ int thread_idx, ///< ID of a thread within the threadblock
406
+ int warp_idx, ///< ID of warp within threadblock
407
+ int lane_idx ///< Id of thread within warp
408
+ ):
409
+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
410
+ shared_load_iterator_(shared_storage.base.reference(), thread_idx),
411
+ thread_idx_(thread_idx)
412
+ {
413
+
414
+ }
415
+
416
+ /// Streams the result to global memory
417
+ CUTLASS_DEVICE
418
+ void operator()(
419
+ OutputOp const &output_op, ///< Output operator
420
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
421
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
422
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
423
+ OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
424
+ OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
425
+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
426
+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
427
+ MatrixCoord(Shape::kM, Shape::kN),
428
+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
429
+ MatrixCoord()) {
430
+
431
+ BroadcastFragment broadcast_fragment;
432
+
433
+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
434
+
435
+ if (!output_op.is_source_needed()) {
436
+ compute_source_not_needed_(
437
+ output_op,
438
+ broadcast_fragment,
439
+ destination_iterator,
440
+ accumulators,
441
+ tensor_iterator);
442
+ }
443
+ else {
444
+ compute_source_needed_(
445
+ output_op,
446
+ broadcast_fragment,
447
+ destination_iterator,
448
+ accumulators,
449
+ source_iterator1,
450
+ source_iterator2,
451
+ tensor_iterator);
452
+ }
453
+ }
454
+
455
+ private:
456
+
457
+ CUTLASS_DEVICE
458
+ void load_broadcast_fragment_(
459
+ BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
460
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
461
+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
462
+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
463
+ ) {
464
+
465
+ broadcast_fragment.clear();
466
+
467
+ // If no pointer is supplied, set with all zeros and avoid memory accesses
468
+ if (!broadcast_ptr) {
469
+ return;
470
+ }
471
+
472
+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
473
+
474
+ int thread_column_idx = threadblock_offset.column() + thread_initial_column;
475
+ broadcast_ptr += thread_initial_column;
476
+
477
+ NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
478
+ using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
479
+ using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
480
+
481
+ ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
482
+
483
+ CUTLASS_PRAGMA_UNROLL
484
+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
485
+
486
+ AccessType loaded;
487
+
488
+ loaded.clear();
489
+
490
+ if (thread_column_idx < problem_size.column()) {
491
+ loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
492
+ }
493
+
494
+ ComputeFragmentType cvt = converter(loaded);
495
+ frag_ptr[j] = cvt;
496
+
497
+ thread_column_idx += ThreadMap::Delta::kColumn;
498
+ broadcast_ptr += ThreadMap::Delta::kColumn;
499
+ }
500
+ }
501
+
502
+ template <class Seq>
503
+ struct acc2smem_source_not_needed;
504
+
505
+ template <size_t... Seq>
506
+ struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
507
+ template <int Advance>
508
+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
509
+ WarpTileIterator &warp_tile_iterator) {
510
+ CUTLASS_PRAGMA_UNROLL
511
+ for (int i = 0; i < Advance; i++) {
512
+ ++accum_fragment_iterator;
513
+ }
514
+
515
+ CUTLASS_PRAGMA_UNROLL
516
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
517
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
518
+
519
+ accum_fragment_iterator.load(accum_fragment);
520
+ ++accum_fragment_iterator;
521
+
522
+ warp_tile_iterator.store(accum_fragment);
523
+ if (p < Base::kFragmentsPerIteration - 1) {
524
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
525
+ }
526
+ }
527
+
528
+ if (Base::kFragmentsPerIteration > 1) {
529
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
530
+ (1 - Base::kFragmentsPerIteration));
531
+ }
532
+ }
533
+
534
+ CUTLASS_DEVICE
535
+ static void push(size_t pos,
536
+ AccumulatorFragmentIterator const &iterator_begin,
537
+ WarpTileIterator &warp_tile_iterator) {
538
+ int dummy[] = {
539
+ (pos == (Seq * Base::kFragmentsPerIteration)) &&
540
+ (helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
541
+
542
+ CUTLASS_UNUSED(dummy[0]);
543
+ }
544
+ };
545
+
546
+ /// Streams the result to global memory
547
+ CUTLASS_DEVICE
548
+ void compute_source_not_needed_(
549
+ OutputOp const &output_op, ///< Output operator
550
+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
551
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
552
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
553
+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
554
+ ) {
555
+
556
+ //
557
+ // Iterator over warp-level accumulator fragment
558
+ //
559
+
560
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
561
+
562
+ //
563
+ // Iterate over accumulator tile
564
+ //
565
+
566
+ // CUTLASS_PRAGMA_UNROLL
567
+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
568
+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
569
+
570
+ //
571
+ // Convert and store fragment
572
+ //
573
+
574
+
575
+ __syncthreads();
576
+
577
+ acc2smem_source_not_needed<
578
+ cutlass::make_index_sequence<OutputTileIterator::kIterations /
579
+ Base::kFragmentsPerIteration>>::push(iter,
580
+ accum_fragment_iterator,
581
+ this->warp_tile_iterator_);
582
+
583
+ __syncthreads();
584
+
585
+ //
586
+ // Load fragments from shared memory
587
+ //
588
+
589
+ CUTLASS_PRAGMA_UNROLL
590
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
591
+
592
+
593
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
594
+
595
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
596
+
597
+ if (p < Base::kFragmentsPerIteration - 1) {
598
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
599
+ }
600
+ else if (kPartitionsK > 1) {
601
+
602
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
603
+
604
+ CUTLASS_PRAGMA_UNROLL
605
+ for ( int i = 1; i < kPartitionsK; ++i) {
606
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
607
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
608
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
609
+ }
610
+
611
+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
612
+ }
613
+
614
+ //
615
+ // Apply output operation
616
+ //
617
+
618
+ typename OutputTileIterator::Fragment frag_Z;
619
+ typename TensorTileIterator::Fragment frag_T;
620
+
621
+ apply_output_operator_source_not_needed_(
622
+ frag_Z,
623
+ frag_T,
624
+ output_op,
625
+ aligned_accum_fragment[0],
626
+ broadcast_fragment);
627
+
628
+ //
629
+ // Conditionally store fragments
630
+ //
631
+
632
+ if (OutputOp::kStoreZ) {
633
+ destination_iterator.store(frag_Z);
634
+ ++destination_iterator;
635
+ }
636
+
637
+ if (OutputOp::kStoreT) {
638
+ tensor_iterator.store(frag_T);
639
+ ++tensor_iterator;
640
+ }
641
+ }
642
+
643
+ if (Base::kFragmentsPerIteration > 1) {
644
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
645
+ }
646
+ }
647
+ }
648
+
649
+
650
+ template<class Seq>
651
+ struct acc2smem_source_needed;
652
+
653
+ template <size_t... Seq>
654
+ struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
655
+ template<int Advance>
656
+ CUTLASS_DEVICE
657
+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
658
+ WarpTileIterator &warp_tile_iterator) {
659
+ CUTLASS_PRAGMA_UNROLL
660
+ for (int i = 0; i < Advance; i++) {
661
+ ++accum_fragment_iterator;
662
+ }
663
+
664
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
665
+ accum_fragment_iterator.load(accum_fragment);
666
+ warp_tile_iterator.store(accum_fragment);
667
+ }
668
+
669
+ CUTLASS_DEVICE
670
+ static void push(size_t pos,
671
+ AccumulatorFragmentIterator const &iterator_begin,
672
+ WarpTileIterator &warp_tile_iterator) {
673
+ int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
674
+ }
675
+ };
676
+
677
+
678
+ /// Streams the result to global memory
679
+ CUTLASS_DEVICE
680
+ void compute_source_needed_(
681
+ OutputOp const &output_op, ///< Output operator
682
+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
683
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
684
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
685
+ OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
686
+ OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
687
+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
688
+ ) {
689
+
690
+ typename OutputTileIterator::Fragment source_fragment1;
691
+ source_fragment1.clear();
692
+ typename OutputTileIterator::Fragment source_fragment2;
693
+ source_fragment2.clear();
694
+
695
+ //
696
+ // Iterator over warp-level accumulator fragment
697
+ //
698
+
699
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
700
+
701
+ //
702
+ // Iterate over accumulator tile
703
+ //
704
+
705
+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
706
+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
707
+
708
+ //
709
+ // Load the source
710
+ //
711
+
712
+ source_iterator1.load(source_fragment1);
713
+ ++source_iterator1;
714
+
715
+ source_iterator2.load(source_fragment2);
716
+ ++source_iterator2;
717
+
718
+ //
719
+ // Convert and store fragment
720
+ //
721
+
722
+ __syncthreads();
723
+
724
+ acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
725
+ iter, accum_fragment_iterator, this->warp_tile_iterator_);
726
+
727
+ __syncthreads();
728
+
729
+ //
730
+ // Load fragments from shared memory
731
+ //
732
+
733
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
734
+
735
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
736
+
737
+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices
738
+ if (kPartitionsK > 1)
739
+ {
740
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
741
+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
742
+
743
+ CUTLASS_PRAGMA_UNROLL
744
+ for ( int i = 1; i < kPartitionsK; ++i) {
745
+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
746
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
747
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
748
+ }
749
+
750
+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
751
+ }
752
+
753
+ //
754
+ // Apply output operation
755
+ //
756
+
757
+ typename OutputTileIterator::Fragment frag_Z;
758
+ typename TensorTileIterator::Fragment frag_T;
759
+
760
+ apply_output_operator_(
761
+ frag_Z,
762
+ frag_T,
763
+ output_op,
764
+ aligned_accum_fragment[0],
765
+ source_fragment1,
766
+ source_fragment2,
767
+ broadcast_fragment);
768
+
769
+ //
770
+ // Conditionally store fragments
771
+ //
772
+
773
+ if (OutputOp::kStoreZ) {
774
+ destination_iterator.store(frag_Z);
775
+ ++destination_iterator;
776
+ }
777
+
778
+ if (OutputOp::kStoreT) {
779
+ tensor_iterator.store(frag_T);
780
+ ++tensor_iterator;
781
+ }
782
+ }
783
+ }
784
+
785
+ /// Helper to invoke the output functor over each vector of output
786
+ CUTLASS_DEVICE
787
+ void apply_output_operator_(
788
+ typename OutputTileIterator::Fragment &frag_Z,
789
+ typename TensorTileIterator::Fragment &frag_T,
790
+ OutputOp const &output_op,
791
+ typename SharedLoadIterator::Fragment const &frag_AB,
792
+ typename OutputTileIterator::Fragment const &frag_C1,
793
+ typename OutputTileIterator::Fragment const &frag_C2,
794
+ BroadcastFragment const &frag_Broadcast) {
795
+
796
+ using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
797
+ using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
798
+ using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
799
+
800
+ AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
801
+ AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
802
+
803
+ AccumulatorAccessType const *frag_AB_ptr =
804
+ reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
805
+
806
+ OutputAccessType const *frag_C1_ptr =
807
+ reinterpret_cast<OutputAccessType const *>(&frag_C1);
808
+
809
+ OutputAccessType const *frag_C2_ptr =
810
+ reinterpret_cast<OutputAccessType const *>(&frag_C2);
811
+
812
+ AccessTypeBroadcast const *frag_Broadcast_ptr =
813
+ reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
814
+
815
+ int const kOutputOpIterations =
816
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
817
+
818
+ CUTLASS_PRAGMA_UNROLL
819
+ for (int i = 0; i < kOutputOpIterations; ++i) {
820
+ output_op(
821
+ frag_Z_ptr[i],
822
+ frag_T_ptr[i],
823
+ frag_AB_ptr[i],
824
+ frag_C1_ptr[i],
825
+ frag_C2_ptr[i],
826
+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
827
+ }
828
+ }
829
+
830
+ /// Helper to invoke the output functor over each vector of output
831
+ CUTLASS_DEVICE
832
+ void apply_output_operator_source_not_needed_(
833
+ typename OutputTileIterator::Fragment &frag_Z,
834
+ typename TensorTileIterator::Fragment &frag_T,
835
+ OutputOp const &output_op,
836
+ typename SharedLoadIterator::Fragment const &frag_AB,
837
+ BroadcastFragment const &frag_Broadcast) {
838
+
839
+ using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
840
+ using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
841
+ using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
842
+
843
+ AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
844
+ AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
845
+
846
+ AccumulatorAccessType const *frag_AB_ptr =
847
+ reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
848
+
849
+ AccessTypeBroadcast const *frag_Broadcast_ptr =
850
+ reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
851
+
852
+ int const kOutputOpIterations =
853
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
854
+
855
+ CUTLASS_PRAGMA_UNROLL
856
+ for (int i = 0; i < kOutputOpIterations; ++i) {
857
+
858
+ output_op(
859
+ frag_Z_ptr[i],
860
+ frag_T_ptr[i],
861
+ frag_AB_ptr[i],
862
+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
863
+ }
864
+ }
865
+
866
+ public:
867
+ /// Stream-K reduce helper
868
+ CUTLASS_DEVICE
869
+ void reduce(
870
+ int reduce_fragment_idx, ///< Reduce fragment index
871
+ OutputOp const &output_op, ///< Output operator
872
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
873
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
874
+ OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
875
+ OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
876
+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
877
+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
878
+ MatrixCoord(Shape::kM, Shape::kN),
879
+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
880
+ MatrixCoord())
881
+ {
882
+
883
+ BroadcastFragment broadcast_fragment;
884
+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
885
+
886
+ // Initialize/load source-fragment data
887
+ typename OutputTileIterator::Fragment source_fragment1;
888
+ source_fragment1.clear();
889
+ typename OutputTileIterator::Fragment source_fragment2;
890
+ source_fragment2.clear();
891
+
892
+ if (output_op.is_source_needed())
893
+ {
894
+ source_iterator1 += reduce_fragment_idx;
895
+ source_iterator1.load(source_fragment1);
896
+
897
+ source_iterator2 += reduce_fragment_idx;
898
+ source_iterator2.load(source_fragment2);
899
+ }
900
+
901
+ // Load fragment from shared memory
902
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
903
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
904
+
905
+ // Add fragments shared by other k partitions
906
+ if (kPartitionsK > 1)
907
+ {
908
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
909
+
910
+ CUTLASS_PRAGMA_UNROLL
911
+ for ( int i = 1; i < kPartitionsK; ++i) {
912
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
913
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
914
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
915
+ }
916
+ }
917
+
918
+ //
919
+ // Apply output operation
920
+ //
921
+
922
+ typename OutputTileIterator::Fragment frag_Z;
923
+ typename TensorTileIterator::Fragment frag_T;
924
+
925
+ if (!output_op.is_source_needed()) {
926
+ apply_output_operator_source_not_needed_(
927
+ frag_Z,
928
+ frag_T,
929
+ output_op,
930
+ aligned_accum_fragment[0],
931
+ broadcast_fragment);
932
+ } else {
933
+ apply_output_operator_(
934
+ frag_Z,
935
+ frag_T,
936
+ output_op,
937
+ aligned_accum_fragment[0],
938
+ source_fragment1,
939
+ source_fragment2,
940
+ broadcast_fragment);
941
+ }
942
+
943
+ //
944
+ // Conditionally store fragments
945
+ //
946
+
947
+ if (OutputOp::kStoreZ) {
948
+ destination_iterator += reduce_fragment_idx;
949
+ destination_iterator.store(frag_Z);
950
+ }
951
+
952
+ if (OutputOp::kStoreT) {
953
+ tensor_iterator += reduce_fragment_idx;
954
+ tensor_iterator.store(frag_T);
955
+ }
956
+ }
957
+ };
958
+
959
+
960
+ template <
961
+ typename Shape_,
962
+ typename WarpMmaOperator_,
963
+ int PartitionsK,
964
+ typename OutputTileIterator_,
965
+ typename TensorTileIterator_,
966
+ typename ElementVector_,
967
+ typename AccumulatorFragmentIterator_,
968
+ typename WarpTileIterator_,
969
+ typename SharedLoadIterator_,
970
+ typename OutputOp_,
971
+ typename Padding_,
972
+ int FragmentsPerPartition,
973
+ int IterationsUnroll
974
+ >
975
+ class EpilogueWithBroadcast<
976
+ Shape_,
977
+ WarpMmaOperator_,
978
+ PartitionsK,
979
+ OutputTileIterator_,
980
+ TensorTileIterator_,
981
+ ElementVector_,
982
+ AccumulatorFragmentIterator_,
983
+ WarpTileIterator_,
984
+ SharedLoadIterator_,
985
+ OutputOp_,
986
+ Padding_,
987
+ FragmentsPerPartition,
988
+ IterationsUnroll,
989
+ true
990
+ > :
991
+ public EpilogueBase<
992
+ Shape_,
993
+ typename WarpMmaOperator_::Shape,
994
+ PartitionsK,
995
+ AccumulatorFragmentIterator_,
996
+ WarpTileIterator_,
997
+ Padding_,
998
+ FragmentsPerPartition> {
999
+
1000
+ public:
1001
+
1002
+ using Base = EpilogueBase<
1003
+ Shape_,
1004
+ typename WarpMmaOperator_::Shape,
1005
+ PartitionsK,
1006
+ AccumulatorFragmentIterator_,
1007
+ WarpTileIterator_,
1008
+ Padding_,
1009
+ FragmentsPerPartition>;
1010
+
1011
+ static bool const kIsSingleSource = true;
1012
+ using Shape = Shape_;
1013
+ using WarpMmaOperator = WarpMmaOperator_;
1014
+ static int const kPartitionsK = PartitionsK;
1015
+ using OutputTileIterator = OutputTileIterator_;
1016
+ using TensorTileIterator = TensorTileIterator_;
1017
+ using ElementVector = ElementVector_;
1018
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
1019
+ using WarpTileIterator = WarpTileIterator_;
1020
+ using SharedLoadIterator = SharedLoadIterator_;
1021
+ using OutputOp = OutputOp_;
1022
+ using Padding = Padding_;
1023
+
1024
+ using Layout = layout::RowMajor;
1025
+ using LongIndex = typename Layout::LongIndex;
1026
+
1027
+ /// The complete warp-level accumulator tile
1028
+ using AccumulatorTile = typename Base::AccumulatorTile;
1029
+
1030
+ /// Accumulator element
1031
+ using ElementAccumulator = typename WarpTileIterator::Element;
1032
+
1033
+ /// Compute data type produced by the output op
1034
+ using ElementCompute = typename OutputOp::ElementCompute;
1035
+
1036
+ /// Compute fragment
1037
+ using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
1038
+
1039
+ /// Thread map used by output tile iterators
1040
+ using ThreadMap = typename OutputTileIterator::ThreadMap;
1041
+
1042
+ /// Fragment object used to store the broadcast values
1043
+ using BroadcastFragment = Array<
1044
+ ElementCompute,
1045
+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
1046
+
1047
+ /// Output element
1048
+ using ElementOutput = typename OutputTileIterator::Element;
1049
+
1050
+ /// Data type of additional tensor
1051
+ using ElementTensor = typename TensorTileIterator::Element;
1052
+
1053
+ /// Output access size
1054
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
1055
+
1056
+ /// Tensor reference to destination tensor
1057
+ using TensorRef = typename OutputTileIterator::TensorRef;
1058
+
1059
+ /// Tensor reference to sync tensor
1060
+ using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
1061
+
1062
+ /// Const tensor reference to source tensor
1063
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
1064
+
1065
+ /// Array type used to output
1066
+ using OutputAccessType = Array<
1067
+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
1068
+
1069
+ /// Array type used by output functor
1070
+ using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
1071
+
1072
+ /// Array type used by output functor
1073
+ using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
1074
+
1075
+ /// Tensor access type
1076
+ using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
1077
+
1078
+ /// Number of warps
1079
+ using WarpCount = typename Base::WarpCount;
1080
+
1081
+ /// Shared memory allocation from epilogue base class
1082
+ using BaseSharedStorage = typename Base::SharedStorage;
1083
+
1084
+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
1085
+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
1086
+
1087
+ /// Used for the broadcast
1088
+ struct BroadcastDetail {
1089
+
1090
+ /// Number of threads per warp
1091
+ static int const kWarpSize = 32;
1092
+
1093
+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
1094
+
1095
+ /// Number of distinct scalar column indices handled by each thread
1096
+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
1097
+
1098
+ /// Number of distinct scalar row indices handled by each thread
1099
+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
1100
+
1101
+ /// Number of threads per threadblock
1102
+ static int const kThreadCount = kWarpSize * WarpCount::kCount;
1103
+
1104
+ /// Number of distinct threads per row of output tile
1105
+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
1106
+
1107
+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
1108
+ static int const kThreadRows = kThreadCount / kThreadsPerRow;
1109
+
1110
+ /// I'm not sure what I meant here.
1111
+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
1112
+
1113
+ /// Shape of the shared memory allocation for the epilogue
1114
+ using StorageShape = MatrixShape<
1115
+ kThreadRows,
1116
+ Shape::kN
1117
+ >;
1118
+
1119
+ /// Debug printing
1120
+ CUTLASS_DEVICE
1121
+ static void print() {
1122
+ #if 0
1123
+ printf("BroadcastDetail {\n");
1124
+ printf(
1125
+ " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
1126
+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
1127
+ kColumnsPerThread,
1128
+ kRowsPerThread,
1129
+ kThreadCount,
1130
+ kThreadsPerRow,
1131
+ kThreadRows,
1132
+ kThreadAccessesPerRow,
1133
+ StorageShape::kRow,
1134
+ StorageShape::kColumn,
1135
+ StorageShape::kCount
1136
+ );
1137
+ printf("};\n");
1138
+ #endif
1139
+ }
1140
+ };
1141
+
1142
+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction
1143
+ struct SharedStorage {
1144
+ union {
1145
+ BaseSharedStorage base;
1146
+ };
1147
+
1148
+ CUTLASS_HOST_DEVICE
1149
+ SharedStorage() { }
1150
+ };
1151
+
1152
+ public:
1153
+
1154
+
1155
+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
1156
+ "Mismatch between shared load iterator and output tile iterator.");
1157
+
1158
+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
1159
+
1160
+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
1161
+ "Divisibility");
1162
+
1163
+ private:
1164
+
1165
+ /// Loads fragment from shared memory aligned with output tensor
1166
+ SharedLoadIterator shared_load_iterator_;
1167
+
1168
+ /// Thread index within the threadblock
1169
+ int thread_idx_;
1170
+
1171
+ public:
1172
+
1173
+ /// Constructor
1174
+ CUTLASS_DEVICE
1175
+ EpilogueWithBroadcast(
1176
+ SharedStorage &shared_storage, ///< Shared storage object
1177
+ int thread_idx, ///< ID of a thread within the threadblock
1178
+ int warp_idx, ///< ID of warp within threadblock
1179
+ int lane_idx ///< Id of thread within warp
1180
+ ):
1181
+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
1182
+ shared_load_iterator_(shared_storage.base.reference(), thread_idx),
1183
+ thread_idx_(thread_idx)
1184
+ {
1185
+
1186
+ }
1187
+
1188
+ /// Streams the result to global memory
1189
+ CUTLASS_DEVICE
1190
+ void operator()(
1191
+ OutputOp const &output_op, ///< Output operator
1192
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
1193
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
1194
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
1195
+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
1196
+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
1197
+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
1198
+ MatrixCoord(Shape::kM, Shape::kN),
1199
+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
1200
+ MatrixCoord()) {
1201
+
1202
+ BroadcastFragment broadcast_fragment;
1203
+
1204
+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
1205
+
1206
+ if (!output_op.is_source_needed()) {
1207
+ compute_source_not_needed_(
1208
+ output_op,
1209
+ broadcast_fragment,
1210
+ destination_iterator,
1211
+ accumulators,
1212
+ tensor_iterator);
1213
+ }
1214
+ else {
1215
+ compute_source_needed_(
1216
+ output_op,
1217
+ broadcast_fragment,
1218
+ destination_iterator,
1219
+ accumulators,
1220
+ source_iterator,
1221
+ tensor_iterator);
1222
+ }
1223
+ }
1224
+
1225
+ private:
1226
+
1227
+ CUTLASS_DEVICE
1228
+ void load_broadcast_fragment_(
1229
+ BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
1230
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
1231
+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
1232
+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
1233
+ ) {
1234
+
1235
+ broadcast_fragment.clear();
1236
+
1237
+ // If no pointer is supplied, set with all zeros and avoid memory accesses
1238
+ if (!broadcast_ptr) {
1239
+ return;
1240
+ }
1241
+
1242
+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
1243
+
1244
+ int thread_column_idx = threadblock_offset.column() + thread_initial_column;
1245
+ broadcast_ptr += thread_initial_column;
1246
+
1247
+ NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
1248
+ using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
1249
+ using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
1250
+
1251
+ ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
1252
+
1253
+ CUTLASS_PRAGMA_UNROLL
1254
+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
1255
+
1256
+ AccessType loaded;
1257
+
1258
+ loaded.clear();
1259
+
1260
+ if (thread_column_idx < problem_size.column()) {
1261
+ loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
1262
+ }
1263
+
1264
+ ComputeFragmentType cvt = converter(loaded);
1265
+ frag_ptr[j] = cvt;
1266
+
1267
+ thread_column_idx += ThreadMap::Delta::kColumn;
1268
+ broadcast_ptr += ThreadMap::Delta::kColumn;
1269
+ }
1270
+ }
1271
+
1272
+ template <class Seq>
1273
+ struct acc2smem_source_not_needed;
1274
+
1275
+ template <size_t... Seq>
1276
+ struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
1277
+ template <int Advance>
1278
+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
1279
+ WarpTileIterator &warp_tile_iterator) {
1280
+ CUTLASS_PRAGMA_UNROLL
1281
+ for (int i = 0; i < Advance; i++) {
1282
+ ++accum_fragment_iterator;
1283
+ }
1284
+
1285
+ CUTLASS_PRAGMA_UNROLL
1286
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
1287
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
1288
+
1289
+ accum_fragment_iterator.load(accum_fragment);
1290
+ ++accum_fragment_iterator;
1291
+
1292
+ warp_tile_iterator.store(accum_fragment);
1293
+ if (p < Base::kFragmentsPerIteration - 1) {
1294
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
1295
+ }
1296
+ }
1297
+
1298
+ if (Base::kFragmentsPerIteration > 1) {
1299
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
1300
+ (1 - Base::kFragmentsPerIteration));
1301
+ }
1302
+ }
1303
+
1304
+ CUTLASS_DEVICE
1305
+ static void push(size_t pos,
1306
+ AccumulatorFragmentIterator const &iterator_begin,
1307
+ WarpTileIterator &warp_tile_iterator) {
1308
+ int dummy[] = {
1309
+ (pos == (Seq * Base::kFragmentsPerIteration)) &&
1310
+ (helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
1311
+
1312
+ CUTLASS_UNUSED(dummy[0]);
1313
+ }
1314
+ };
1315
+
1316
+ /// Streams the result to global memory
1317
+ CUTLASS_DEVICE
1318
+ void compute_source_not_needed_(
1319
+ OutputOp const &output_op, ///< Output operator
1320
+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
1321
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
1322
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
1323
+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
1324
+ ) {
1325
+
1326
+ //
1327
+ // Iterator over warp-level accumulator fragment
1328
+ //
1329
+
1330
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
1331
+
1332
+ //
1333
+ // Iterate over accumulator tile
1334
+ //
1335
+
1336
+ // CUTLASS_PRAGMA_UNROLL
1337
+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
1338
+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
1339
+
1340
+ //
1341
+ // Convert and store fragment
1342
+ //
1343
+
1344
+
1345
+ __syncthreads();
1346
+
1347
+ acc2smem_source_not_needed<
1348
+ cutlass::make_index_sequence<OutputTileIterator::kIterations /
1349
+ Base::kFragmentsPerIteration>>::push(iter,
1350
+ accum_fragment_iterator,
1351
+ this->warp_tile_iterator_);
1352
+
1353
+ __syncthreads();
1354
+
1355
+ //
1356
+ // Load fragments from shared memory
1357
+ //
1358
+
1359
+ CUTLASS_PRAGMA_UNROLL
1360
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
1361
+
1362
+
1363
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
1364
+
1365
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
1366
+
1367
+ if (p < Base::kFragmentsPerIteration - 1) {
1368
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
1369
+ }
1370
+ else if (kPartitionsK > 1) {
1371
+
1372
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
1373
+
1374
+ CUTLASS_PRAGMA_UNROLL
1375
+ for ( int i = 1; i < kPartitionsK; ++i) {
1376
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
1377
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
1378
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
1379
+ }
1380
+
1381
+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
1382
+ }
1383
+
1384
+ //
1385
+ // Apply output operation
1386
+ //
1387
+
1388
+ typename OutputTileIterator::Fragment frag_Z;
1389
+ typename TensorTileIterator::Fragment frag_T;
1390
+
1391
+ apply_output_operator_source_not_needed_(
1392
+ frag_Z,
1393
+ frag_T,
1394
+ output_op,
1395
+ aligned_accum_fragment[0],
1396
+ broadcast_fragment);
1397
+
1398
+ //
1399
+ // Conditionally store fragments
1400
+ //
1401
+
1402
+ if (OutputOp::kStoreZ) {
1403
+ destination_iterator.store(frag_Z);
1404
+ ++destination_iterator;
1405
+ }
1406
+
1407
+ if (OutputOp::kStoreT) {
1408
+ tensor_iterator.store(frag_T);
1409
+ ++tensor_iterator;
1410
+ }
1411
+ }
1412
+
1413
+ if (Base::kFragmentsPerIteration > 1) {
1414
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
1415
+ }
1416
+ }
1417
+ }
1418
+
1419
+
1420
+ template<class Seq>
1421
+ struct acc2smem_source_needed;
1422
+
1423
+ template <size_t... Seq>
1424
+ struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
1425
+ template<int Advance>
1426
+ CUTLASS_DEVICE
1427
+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
1428
+ WarpTileIterator &warp_tile_iterator) {
1429
+ CUTLASS_PRAGMA_UNROLL
1430
+ for (int i = 0; i < Advance; i++) {
1431
+ ++accum_fragment_iterator;
1432
+ }
1433
+
1434
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
1435
+ accum_fragment_iterator.load(accum_fragment);
1436
+ warp_tile_iterator.store(accum_fragment);
1437
+ }
1438
+
1439
+ CUTLASS_DEVICE
1440
+ static void push(size_t pos,
1441
+ AccumulatorFragmentIterator const &iterator_begin,
1442
+ WarpTileIterator &warp_tile_iterator) {
1443
+ int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
1444
+ }
1445
+ };
1446
+
1447
+
1448
+ /// Streams the result to global memory
1449
+ CUTLASS_DEVICE
1450
+ void compute_source_needed_(
1451
+ OutputOp const &output_op, ///< Output operator
1452
+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
1453
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
1454
+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
1455
+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
1456
+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
1457
+ ) {
1458
+
1459
+ typename OutputTileIterator::Fragment source_fragment;
1460
+ source_fragment.clear();
1461
+
1462
+ //
1463
+ // Iterator over warp-level accumulator fragment
1464
+ //
1465
+
1466
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
1467
+
1468
+ //
1469
+ // Iterate over accumulator tile
1470
+ //
1471
+
1472
+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
1473
+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
1474
+
1475
+ //
1476
+ // Load the source
1477
+ //
1478
+
1479
+ source_iterator.load(source_fragment);
1480
+ ++source_iterator;
1481
+
1482
+ //
1483
+ // Convert and store fragment
1484
+ //
1485
+
1486
+ __syncthreads();
1487
+
1488
+ acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
1489
+ iter, accum_fragment_iterator, this->warp_tile_iterator_);
1490
+
1491
+ __syncthreads();
1492
+
1493
+ //
1494
+ // Load fragments from shared memory
1495
+ //
1496
+
1497
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
1498
+
1499
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
1500
+
1501
+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices
1502
+ if (kPartitionsK > 1)
1503
+ {
1504
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
1505
+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
1506
+
1507
+ CUTLASS_PRAGMA_UNROLL
1508
+ for ( int i = 1; i < kPartitionsK; ++i) {
1509
+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
1510
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
1511
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
1512
+ }
1513
+
1514
+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
1515
+ }
1516
+
1517
+ //
1518
+ // Apply output operation
1519
+ //
1520
+
1521
+ typename OutputTileIterator::Fragment frag_Z;
1522
+ typename TensorTileIterator::Fragment frag_T;
1523
+
1524
+ apply_output_operator_(
1525
+ frag_Z,
1526
+ frag_T,
1527
+ output_op,
1528
+ aligned_accum_fragment[0],
1529
+ source_fragment,
1530
+ broadcast_fragment);
1531
+
1532
+ //
1533
+ // Conditionally store fragments
1534
+ //
1535
+
1536
+ if (OutputOp::kStoreZ) {
1537
+ destination_iterator.store(frag_Z);
1538
+ ++destination_iterator;
1539
+ }
1540
+
1541
+ if (OutputOp::kStoreT) {
1542
+ tensor_iterator.store(frag_T);
1543
+ ++tensor_iterator;
1544
+ }
1545
+ }
1546
+ }
1547
+
1548
+ /// Helper to invoke the output functor over each vector of output
1549
+ CUTLASS_DEVICE
1550
+ void apply_output_operator_(
1551
+ typename OutputTileIterator::Fragment &frag_Z,
1552
+ typename TensorTileIterator::Fragment &frag_T,
1553
+ OutputOp const &output_op,
1554
+ typename SharedLoadIterator::Fragment const &frag_AB,
1555
+ typename OutputTileIterator::Fragment const &frag_C,
1556
+ BroadcastFragment const &frag_Broadcast) {
1557
+
1558
+ using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
1559
+ using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
1560
+ using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
1561
+
1562
+ AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
1563
+ AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
1564
+
1565
+ AccumulatorAccessType const *frag_AB_ptr =
1566
+ reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
1567
+
1568
+ OutputAccessType const *frag_C_ptr =
1569
+ reinterpret_cast<OutputAccessType const *>(&frag_C);
1570
+
1571
+ AccessTypeBroadcast const *frag_Broadcast_ptr =
1572
+ reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
1573
+
1574
+ int const kOutputOpIterations =
1575
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
1576
+
1577
+ CUTLASS_PRAGMA_UNROLL
1578
+ for (int i = 0; i < kOutputOpIterations; ++i) {
1579
+ output_op(
1580
+ frag_Z_ptr[i],
1581
+ frag_T_ptr[i],
1582
+ frag_AB_ptr[i],
1583
+ frag_C_ptr[i],
1584
+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
1585
+ }
1586
+ }
1587
+
1588
+ /// Helper to invoke the output functor over each vector of output
1589
+ CUTLASS_DEVICE
1590
+ void apply_output_operator_source_not_needed_(
1591
+ typename OutputTileIterator::Fragment &frag_Z,
1592
+ typename TensorTileIterator::Fragment &frag_T,
1593
+ OutputOp const &output_op,
1594
+ typename SharedLoadIterator::Fragment const &frag_AB,
1595
+ BroadcastFragment const &frag_Broadcast) {
1596
+
1597
+ using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
1598
+ using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
1599
+ using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
1600
+
1601
+ AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
1602
+ AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
1603
+
1604
+ AccumulatorAccessType const *frag_AB_ptr =
1605
+ reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
1606
+
1607
+ AccessTypeBroadcast const *frag_Broadcast_ptr =
1608
+ reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
1609
+
1610
+ int const kOutputOpIterations =
1611
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
1612
+
1613
+ CUTLASS_PRAGMA_UNROLL
1614
+ for (int i = 0; i < kOutputOpIterations; ++i) {
1615
+
1616
+ output_op(
1617
+ frag_Z_ptr[i],
1618
+ frag_T_ptr[i],
1619
+ frag_AB_ptr[i],
1620
+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
1621
+ }
1622
+ }
1623
+
1624
+
1625
+ public:
1626
+ /// Stream-K reduce helper
1627
+ CUTLASS_DEVICE
1628
+ void reduce(
1629
+ int reduce_fragment_idx, ///< Reduce fragment index
1630
+ OutputOp const &output_op, ///< Output operator
1631
+ ElementVector const * broadcast_ptr, ///< Broadcast vector
1632
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
1633
+ OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
1634
+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
1635
+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
1636
+ MatrixCoord(Shape::kM, Shape::kN),
1637
+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
1638
+ MatrixCoord())
1639
+ {
1640
+
1641
+ BroadcastFragment broadcast_fragment;
1642
+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
1643
+
1644
+ // Initialize/load source-fragment data
1645
+ typename OutputTileIterator::Fragment source_fragment;
1646
+ source_fragment.clear();
1647
+
1648
+ if (output_op.is_source_needed())
1649
+ {
1650
+ source_iterator += reduce_fragment_idx;
1651
+ source_iterator.load(source_fragment);
1652
+ }
1653
+
1654
+ // Load fragment from shared memory
1655
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
1656
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
1657
+
1658
+ // Add fragments shared by other k partitions
1659
+ if (kPartitionsK > 1)
1660
+ {
1661
+ plus <typename SharedLoadIterator::Fragment> add_fragments;
1662
+
1663
+ CUTLASS_PRAGMA_UNROLL
1664
+ for ( int i = 1; i < kPartitionsK; ++i) {
1665
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
1666
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
1667
+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
1668
+ }
1669
+ }
1670
+
1671
+ //
1672
+ // Apply output operation
1673
+ //
1674
+
1675
+ typename OutputTileIterator::Fragment frag_Z;
1676
+ typename TensorTileIterator::Fragment frag_T;
1677
+
1678
+ if (!output_op.is_source_needed()) {
1679
+ apply_output_operator_source_not_needed_(
1680
+ frag_Z,
1681
+ frag_T,
1682
+ output_op,
1683
+ aligned_accum_fragment[0],
1684
+ broadcast_fragment);
1685
+ } else {
1686
+ apply_output_operator_(
1687
+ frag_Z,
1688
+ frag_T,
1689
+ output_op,
1690
+ aligned_accum_fragment[0],
1691
+ source_fragment,
1692
+ broadcast_fragment);
1693
+ }
1694
+
1695
+ //
1696
+ // Conditionally store fragments
1697
+ //
1698
+
1699
+ if (OutputOp::kStoreZ) {
1700
+ destination_iterator.store(frag_Z);
1701
+ ++destination_iterator;
1702
+ }
1703
+
1704
+ if (OutputOp::kStoreT) {
1705
+ tensor_iterator.store(frag_T);
1706
+ ++tensor_iterator;
1707
+ }
1708
+ }
1709
+ };
1710
+
1711
+ ////////////////////////////////////////////////////////////////////////////////
1712
+
1713
+ } // namespace threadblock
1714
+ } // namespace epilogue
1715
+ } // namespace cutlass
1716
+
1717
+ ////////////////////////////////////////////////////////////////////////////////