kernels-bot commited on
Commit
b3e9a67
·
verified ·
1 Parent(s): 1b1ada2

Uploaded using `kernel-builder`.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch-cuda/__init__.py +24 -0
  2. build/torch-cuda/_ops.py +8 -0
  3. build/torch-cuda/_ops_compat.py +10 -0
  4. build/torch-cuda/enums.py +30 -0
  5. build/torch-cuda/functional/__init__.py +554 -0
  6. build/torch-cuda/functional/backward.py +682 -0
  7. build/torch-cuda/functional/forward.py +238 -0
  8. build/torch-cuda/functional/grouped_gemm.py +0 -0
  9. build/torch-cuda/functional/moe_config.py +581 -0
  10. build/torch-cuda/functional/reduction_over_k_gather.py +164 -0
  11. build/torch-cuda/functional/tile_scheduler.py +91 -0
  12. build/torch-cuda/functional/topk_softmax.py +195 -0
  13. build/torch-cuda/functional/triton_kernels/__init__.py +351 -0
  14. build/torch-cuda/functional/triton_kernels/bitmatrix.py +147 -0
  15. build/torch-cuda/functional/utils.py +25 -0
  16. build/torch-cuda/jit.py +159 -0
  17. build/torch-cuda/metadata.json +10 -0
  18. build/torch-cuda/moe.py +368 -0
  19. build/torch-cuda/quack/__init__.py +8 -0
  20. build/torch-cuda/quack/_ops_compat.py +4 -0
  21. build/torch-cuda/quack/activation.py +524 -0
  22. build/torch-cuda/quack/autotuner.py +369 -0
  23. build/torch-cuda/quack/broadcast_utils.py +29 -0
  24. build/torch-cuda/quack/compile_utils.py +19 -0
  25. build/torch-cuda/quack/copy_utils.py +614 -0
  26. build/torch-cuda/quack/cute_dsl_ptxas.py +151 -0
  27. build/torch-cuda/quack/cute_dsl_utils.py +104 -0
  28. build/torch-cuda/quack/fast_math.py +80 -0
  29. build/torch-cuda/quack/gemm.py +194 -0
  30. build/torch-cuda/quack/gemm_act.py +510 -0
  31. build/torch-cuda/quack/gemm_config.py +95 -0
  32. build/torch-cuda/quack/gemm_dact.py +215 -0
  33. build/torch-cuda/quack/gemm_default_epi.py +259 -0
  34. build/torch-cuda/quack/gemm_interface.py +1058 -0
  35. build/torch-cuda/quack/gemm_sm100.py +0 -0
  36. build/torch-cuda/quack/gemm_sm90.py +2070 -0
  37. build/torch-cuda/quack/gemm_symmetric.py +330 -0
  38. build/torch-cuda/quack/gemm_wrapper_utils.py +317 -0
  39. build/torch-cuda/quack/layout_utils.py +295 -0
  40. build/torch-cuda/quack/pipeline.py +324 -0
  41. build/torch-cuda/quack/reduce.py +279 -0
  42. build/torch-cuda/quack/reduction_base.py +83 -0
  43. build/torch-cuda/quack/sm100_utils.py +62 -0
  44. build/torch-cuda/quack/sm90_utils.py +157 -0
  45. build/torch-cuda/quack/sort/__init__.py +1 -0
  46. build/torch-cuda/quack/sort/bitonic_sort.py +129 -0
  47. build/torch-cuda/quack/sort/generate_sorting_networks.py +326 -0
  48. build/torch-cuda/quack/sort/sorting_networks.py +120 -0
  49. build/torch-cuda/quack/sort/utils.py +31 -0
  50. build/torch-cuda/quack/tensormap_manager.py +115 -0
build/torch-cuda/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ from functools import lru_cache
6
+
7
+ __version__ = "0.1.1"
8
+
9
+ from .enums import KernelBackendMoE
10
+
11
+ from .moe import MoE
12
+ from .functional import (
13
+ enable_quack_gemm,
14
+ moe_general_routing_inputs,
15
+ moe_TC_softmax_topk_layer,
16
+ )
17
+
18
+ __all__ = [
19
+ "KernelBackendMoE",
20
+ "MoE",
21
+ "enable_quack_gemm",
22
+ "moe_general_routing_inputs",
23
+ "moe_TC_softmax_topk_layer",
24
+ ]
build/torch-cuda/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._sonic_moe_57a1b31
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_sonic_moe_57a1b31::{op_name}"
build/torch-cuda/_ops_compat.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compatibility helpers for op namespacing in source and built layouts."""
2
+
3
+ try:
4
+ from ._ops import add_op_namespace_prefix as _generated_add_op_namespace_prefix
5
+ except ImportError:
6
+ def _generated_add_op_namespace_prefix(name: str) -> str:
7
+ return name if "::" in name else f"sonicmoe::{name}"
8
+
9
+ def add_op_namespace_prefix(name: str) -> str:
10
+ return _generated_add_op_namespace_prefix(name)
build/torch-cuda/enums.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ from enum import Enum
6
+
7
+
8
+ LIBRARY_NAME = "sonicmoe"
9
+ TENSORMAP = "tensormap"
10
+
11
+
12
+ class KernelBackendMoE(Enum):
13
+ scattermoe = "scattermoe"
14
+ torch = "torch"
15
+ sonicmoe = "sonicmoe"
16
+
17
+
18
+ class ActivationType(Enum):
19
+ SWIGLU = "swiglu"
20
+ GEGLU = "geglu"
21
+ REGLU = "reglu"
22
+
23
+ RELU_SQ = "relu_sq"
24
+ RELU = "relu"
25
+ GELU = "gelu"
26
+ SILU = "silu"
27
+
28
+
29
+ def is_glu(activation_type: ActivationType):
30
+ return activation_type in [ActivationType.SWIGLU, ActivationType.REGLU, ActivationType.GEGLU]
build/torch-cuda/functional/__init__.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from ..quack.gemm_interface import gemm
10
+
11
+ from ..enums import ActivationType, is_glu
12
+ from ..quack_utils import gemm_dgated, gemm_gated
13
+ from .backward import (
14
+ _down_projection_backward_act,
15
+ _down_projection_backward_weight,
16
+ _softmax_topk_bwd,
17
+ _token_broadcast_backward,
18
+ _up_projection_backward_act,
19
+ _up_projection_backward_weight,
20
+ )
21
+ from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward
22
+ from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton
23
+ from .utils import enable_quack_gemm, is_using_quack_gemm
24
+
25
+
26
+ class TC_Softmax_Topk_Router_Function(torch.autograd.Function):
27
+ @staticmethod
28
+ def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Tensor, torch.Tensor]:
29
+ T = router_logits.size(0)
30
+
31
+ # change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy
32
+ topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device)
33
+ topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device)
34
+
35
+ _softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K)
36
+
37
+ ctx.save_for_backward(topk_router_score, topk_router_indices)
38
+ ctx.E = E
39
+ ctx.dtype = router_logits.dtype
40
+
41
+ return topk_router_score, topk_router_indices
42
+
43
+ @staticmethod
44
+ def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
45
+ T, K = dtopk_score.size()
46
+
47
+ topk_router_score, topk_router_indices = ctx.saved_tensors
48
+ dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device)
49
+
50
+ _softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K)
51
+
52
+ return dlogits, None, None
53
+
54
+
55
+ class _UpProjection(torch.autograd.Function):
56
+ @staticmethod
57
+ def forward(
58
+ ctx,
59
+ x: torch.Tensor,
60
+ w1: torch.Tensor,
61
+ b1: torch.Tensor | None,
62
+ expert_frequency_offset: torch.Tensor,
63
+ total_expert_freq: int,
64
+ K: int,
65
+ stream_id: int,
66
+ x_gather_idx: torch.Tensor,
67
+ s_scatter_idx: torch.Tensor,
68
+ s_reverse_scatter_idx: torch.Tensor,
69
+ num_activated_expert_per_token_offset: torch.Tensor,
70
+ is_varlen_K: bool,
71
+ activation_type: ActivationType,
72
+ is_inference_mode_enabled: bool,
73
+ ) -> torch.Tensor:
74
+ T, H = x.shape
75
+ I, H, E = w1.shape
76
+ is_glu_activation = is_glu(activation_type)
77
+ if is_glu_activation:
78
+ I //= 2
79
+ TK = total_expert_freq
80
+
81
+ if is_using_quack_gemm():
82
+ assert not torch.compiler.is_compiling()
83
+ assert is_glu_activation, "QuACK GEMM does not support non GLU activation yet"
84
+ z, y1 = gemm_gated(
85
+ x,
86
+ w1.permute(2, 1, 0),
87
+ activation="swiglu",
88
+ cu_seqlens_m=expert_frequency_offset,
89
+ A_idx=x_gather_idx,
90
+ dynamic_scheduler=False,
91
+ )
92
+ else:
93
+ z = torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device)
94
+ y1 = torch.empty(TK, I, dtype=x.dtype, device=x.device)
95
+ _up_projection_forward(
96
+ x=x,
97
+ w1=w1,
98
+ z=z,
99
+ y1=y1,
100
+ b1=b1,
101
+ expert_frequency_offset=expert_frequency_offset,
102
+ expert_schedule_order=None,
103
+ x_gather_idx=x_gather_idx,
104
+ stream_id=stream_id,
105
+ activation_type=activation_type.value,
106
+ is_glu_activation=is_glu_activation,
107
+ is_inference_mode_enabled=is_inference_mode_enabled,
108
+ )
109
+
110
+ ctx.T = T
111
+ ctx.TK = TK
112
+ ctx.E = E
113
+ ctx.K = K
114
+ ctx.H = H
115
+ ctx.I = I
116
+ ctx.is_varlen_K = is_varlen_K
117
+ ctx.is_glu_activation = is_glu_activation
118
+ ctx.stream_id = stream_id
119
+
120
+ ctx.save_for_backward(
121
+ x,
122
+ w1,
123
+ b1,
124
+ expert_frequency_offset,
125
+ x_gather_idx,
126
+ s_scatter_idx,
127
+ s_reverse_scatter_idx,
128
+ num_activated_expert_per_token_offset,
129
+ )
130
+
131
+ ctx.mark_non_differentiable(y1)
132
+ ctx.set_materialize_grads(False)
133
+
134
+ return y1, z
135
+
136
+ @staticmethod
137
+ def backward(ctx, _: None, dz: torch.Tensor):
138
+ is_compiling = torch.compiler.is_compiling()
139
+
140
+ if not is_compiling:
141
+ assert _ is None
142
+
143
+ T = ctx.T
144
+ TK = ctx.TK
145
+ E = ctx.E
146
+ K = ctx.K
147
+ H = ctx.H
148
+ is_glu_activation = ctx.is_glu_activation
149
+ is_varlen_K = ctx.is_varlen_K
150
+ stream_id = ctx.stream_id
151
+
152
+ (
153
+ x,
154
+ w1,
155
+ b1,
156
+ expert_frequency_offset,
157
+ x_gather_idx,
158
+ s_scatter_idx,
159
+ s_reverse_scatter_idx,
160
+ num_activated_expert_per_token_offset,
161
+ ) = ctx.saved_tensors
162
+
163
+ dw1 = torch.empty_like(w1)
164
+ db1 = None if b1 is None else torch.empty_like(b1)
165
+
166
+ if is_using_quack_gemm():
167
+ assert not is_compiling
168
+
169
+ gemm(
170
+ x.T,
171
+ dz,
172
+ out=dw1.permute(2, 1, 0),
173
+ cu_seqlens_k=expert_frequency_offset,
174
+ A_idx=x_gather_idx,
175
+ batch_idx_permute=None,
176
+ dynamic_scheduler=False,
177
+ )
178
+ dx_expanded = gemm(dz, w1.permute(2, 0, 1), cu_seqlens_m=expert_frequency_offset, dynamic_scheduler=False)
179
+ else:
180
+ dx_expanded = torch.empty(TK, H, dtype=dz.dtype, device=dz.device)
181
+
182
+ _up_projection_backward_act(
183
+ w1=w1,
184
+ dx_expanded=dx_expanded,
185
+ dz=dz,
186
+ db1=db1,
187
+ expert_frequency_offset=expert_frequency_offset,
188
+ expert_schedule_order=None,
189
+ x_gather_idx=x_gather_idx,
190
+ s_scatter_idx=s_scatter_idx,
191
+ is_glu_activation=is_glu_activation,
192
+ stream_id=stream_id,
193
+ )
194
+
195
+ _up_projection_backward_weight(
196
+ x=x,
197
+ dw1=dw1,
198
+ dz=dz,
199
+ expert_frequency_offset=expert_frequency_offset,
200
+ expert_schedule_order=None,
201
+ x_gather_idx=x_gather_idx,
202
+ is_glu_activation=is_glu_activation,
203
+ stream_id=stream_id,
204
+ )
205
+
206
+ dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device)
207
+
208
+ _token_broadcast_backward(
209
+ dx_reduced=dx_reduced,
210
+ dx_expanded=dx_expanded,
211
+ s_reverse_scatter_idx=s_reverse_scatter_idx,
212
+ num_activated_expert_per_token_offset=num_activated_expert_per_token_offset,
213
+ varlen_K_max=(E if is_varlen_K else K),
214
+ H=H,
215
+ is_varlen_K=is_varlen_K,
216
+ )
217
+
218
+ return dx_reduced, dw1, db1, *[None] * 12
219
+
220
+
221
+ class _DownProjection(torch.autograd.Function):
222
+ @staticmethod
223
+ def forward(
224
+ ctx,
225
+ y1: torch.Tensor,
226
+ z: torch.Tensor,
227
+ w2: torch.Tensor,
228
+ b2: torch.Tensor | None,
229
+ topk_scores: torch.Tensor,
230
+ expert_frequency_offset: torch.Tensor,
231
+ T: int,
232
+ K: int,
233
+ stream_id: int,
234
+ x_gather_idx: torch.Tensor,
235
+ s_scatter_idx: torch.Tensor,
236
+ s_reverse_scatter_idx: torch.Tensor,
237
+ num_activated_expert_per_token_offset: torch.Tensor,
238
+ is_varlen_K: bool,
239
+ activation_type: ActivationType,
240
+ ) -> torch.Tensor:
241
+ TK = y1.size(0)
242
+ H, I, E = w2.shape
243
+
244
+ if is_using_quack_gemm():
245
+ assert not torch.compiler.is_compiling()
246
+
247
+ assert b2 is None
248
+ y2 = gemm(y1, w2.permute(2, 1, 0), cu_seqlens_m=expert_frequency_offset)
249
+ else:
250
+ y2 = torch.empty(TK, H, dtype=y1.dtype, device=y1.device)
251
+ _down_projection_forward(
252
+ w2=w2,
253
+ y1=y1,
254
+ y2=y2,
255
+ b2=b2,
256
+ expert_frequency_offset=expert_frequency_offset,
257
+ expert_schedule_order=None,
258
+ x_gather_idx=x_gather_idx,
259
+ stream_id=stream_id,
260
+ )
261
+
262
+ o = torch.empty(T, H, device=z.device, dtype=z.dtype)
263
+ topk_scores = topk_scores.flatten()
264
+
265
+ _router_forward(
266
+ y2=y2,
267
+ o=o,
268
+ topk_scores=topk_scores,
269
+ s_reverse_scatter_idx=s_reverse_scatter_idx,
270
+ num_activated_expert_per_token_offset=num_activated_expert_per_token_offset,
271
+ varlen_K_max=(E if is_varlen_K else K),
272
+ H=H,
273
+ is_varlen_K=is_varlen_K,
274
+ )
275
+
276
+ ctx.T = T
277
+ ctx.K = K
278
+ ctx.is_varlen_K = is_varlen_K
279
+ ctx.activation_type = activation_type
280
+ ctx.stream_id = stream_id
281
+
282
+ ctx.save_for_backward(
283
+ z,
284
+ w2,
285
+ b2,
286
+ topk_scores,
287
+ expert_frequency_offset,
288
+ x_gather_idx,
289
+ s_scatter_idx,
290
+ s_reverse_scatter_idx,
291
+ )
292
+
293
+ return o
294
+
295
+ @staticmethod
296
+ def backward(ctx, dout: torch.Tensor):
297
+ T = ctx.T
298
+ K = ctx.K
299
+ stream_id = ctx.stream_id
300
+ is_varlen_K = ctx.is_varlen_K
301
+ activation_type = ctx.activation_type
302
+
303
+ (
304
+ z,
305
+ w2,
306
+ b2,
307
+ topk_scores,
308
+ expert_frequency_offset,
309
+ x_gather_idx,
310
+ s_scatter_idx,
311
+ s_reverse_scatter_idx,
312
+ ) = ctx.saved_tensors
313
+
314
+ dw2 = torch.empty_like(w2)
315
+ db2 = None if b2 is None else torch.empty_like(b2)
316
+ dz = torch.empty_like(z)
317
+
318
+ if is_using_quack_gemm():
319
+ assert not torch.compiler.is_compiling()
320
+ assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet"
321
+
322
+ s = topk_scores[s_scatter_idx]
323
+ _, y1s, ds = gemm_dgated(
324
+ dout,
325
+ w2.permute(2, 0, 1),
326
+ PreAct=z,
327
+ activation="swiglu",
328
+ dx_out=dz,
329
+ colvec_scale=s,
330
+ colvec_reduce=True,
331
+ cu_seqlens_m=expert_frequency_offset,
332
+ A_idx=x_gather_idx,
333
+ dynamic_scheduler=False,
334
+ )
335
+ gemm(
336
+ dout.T,
337
+ y1s,
338
+ out=dw2.permute(2, 0, 1),
339
+ cu_seqlens_k=expert_frequency_offset,
340
+ A_idx=x_gather_idx,
341
+ batch_idx_permute=None,
342
+ dynamic_scheduler=False,
343
+ )
344
+
345
+ ds = ds[s_reverse_scatter_idx]
346
+ else:
347
+ ds = torch.empty_like(topk_scores)
348
+
349
+ I = w2.size(1)
350
+ TK = x_gather_idx.size(0)
351
+
352
+ y1s = torch.empty(TK, I, dtype=z.dtype, device=z.device)
353
+ is_glu_activation = is_glu(activation_type)
354
+
355
+ _down_projection_backward_act(
356
+ dout=dout,
357
+ z=z,
358
+ w2=w2,
359
+ dz=dz,
360
+ ds=ds,
361
+ b2=b2,
362
+ db2=db2,
363
+ y1s=y1s,
364
+ topk_scores=topk_scores,
365
+ expert_frequency_offset=expert_frequency_offset,
366
+ expert_schedule_order=None,
367
+ x_gather_idx=x_gather_idx,
368
+ s_scatter_idx=s_scatter_idx,
369
+ is_glu_activation=is_glu_activation,
370
+ activation_type=activation_type.value,
371
+ stream_id=stream_id,
372
+ )
373
+
374
+ _down_projection_backward_weight(
375
+ dout=dout,
376
+ y1s=y1s,
377
+ dw2=dw2,
378
+ expert_frequency_offset=expert_frequency_offset,
379
+ expert_schedule_order=None,
380
+ x_gather_idx=x_gather_idx,
381
+ stream_id=stream_id,
382
+ )
383
+
384
+ # TC top-K routing
385
+ if not is_varlen_K:
386
+ ds = ds.view(T, K)
387
+
388
+ return None, dz, dw2, db2, ds, *[None] * 10
389
+
390
+
391
+ def moe_TC_softmax_topk_layer(
392
+ x: torch.Tensor,
393
+ router_w: torch.Tensor,
394
+ w1: torch.Tensor,
395
+ b1: torch.Tensor | None,
396
+ w2: torch.Tensor,
397
+ b2: torch.Tensor | None,
398
+ K: int,
399
+ stream_id: int,
400
+ activation_type: ActivationType | str = ActivationType.SWIGLU,
401
+ is_inference_mode_enabled: bool = False,
402
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
403
+ assert ((b1 is None) and (b2 is None)) or (
404
+ (b1 is not None) and (b2 is not None)
405
+ ), "b1 and b2 has to be None or not None at the same time!"
406
+ E = router_w.size(0)
407
+ router_logits = F.linear(x, router_w)
408
+ topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, E, K)
409
+
410
+ T, K = topk_indices.size()
411
+ TK = T * K
412
+ device = topk_indices.device
413
+
414
+ s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
415
+ s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
416
+ expert_frequency = torch.empty(E, dtype=torch.int32, device=device)
417
+ expert_frequency_offset = torch.empty(E + 1, dtype=torch.int32, device=device)
418
+ x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device)
419
+
420
+ TC_topk_router_metadata_triton(
421
+ topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx
422
+ )
423
+
424
+ T = x.size(0)
425
+
426
+ if type(activation_type) == str:
427
+ activation_type = ActivationType(activation_type)
428
+
429
+ y1, z = _UpProjection.apply(
430
+ x,
431
+ w1,
432
+ b1,
433
+ expert_frequency_offset,
434
+ T * K,
435
+ K,
436
+ stream_id,
437
+ x_gather_idx,
438
+ s_scatter_idx,
439
+ s_reverse_scatter_idx,
440
+ None,
441
+ False, # is_varlen_K
442
+ activation_type,
443
+ is_inference_mode_enabled,
444
+ )
445
+
446
+ o = _DownProjection.apply(
447
+ y1,
448
+ z,
449
+ w2,
450
+ b2,
451
+ topk_scores,
452
+ expert_frequency_offset,
453
+ T,
454
+ K,
455
+ stream_id,
456
+ x_gather_idx,
457
+ s_scatter_idx,
458
+ s_reverse_scatter_idx,
459
+ None,
460
+ False, # is_varlen_K
461
+ activation_type,
462
+ )
463
+
464
+ return o, router_logits, expert_frequency
465
+
466
+
467
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
468
+ # Weight format requirements:
469
+ # - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1), must be interleaved [gate_row0, up_row0, gate_row1, up_row1, ...]
470
+ # - w2_weight: Shape (H, I, E), stride order (2, 0, 1)
471
+
472
+
473
+ # We assume token_indices is already SORTED ascendingly !!!
474
+ # and len(token_indices) = len(expert_indices) = len(router_scores)
475
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
476
+ def moe_general_routing_inputs(
477
+ x: torch.Tensor,
478
+ router_scores: torch.Tensor,
479
+ token_indices: torch.Tensor,
480
+ expert_indices: torch.Tensor,
481
+ w1: torch.Tensor,
482
+ b1: torch.Tensor | None,
483
+ w2: torch.Tensor,
484
+ b2: torch.Tensor | None,
485
+ E: int,
486
+ stream_id: int,
487
+ activation_type: ActivationType,
488
+ is_inference_mode_enabled: bool = False,
489
+ ) -> tuple[torch.Tensor, torch.Tensor]:
490
+ assert ((b1 is None) and (b2 is None)) or (
491
+ (b1 is not None) and (b2 is not None)
492
+ ), "b1 and b2 has to be None or not None at the same time!"
493
+
494
+ T = x.size(0)
495
+ TK = router_scores.size(0)
496
+ E = w2.size(-1)
497
+ device = router_scores.device
498
+
499
+ s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
500
+ s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
501
+ expert_frequency = torch.empty(E, dtype=torch.int32, device=device)
502
+ expert_frequency_offset = torch.empty(E + 1, dtype=torch.int32, device=device)
503
+ x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device)
504
+ num_activated_expert_per_token_offset = torch.empty(T + 1, dtype=torch.int32, device=device)
505
+
506
+ general_routing_router_metadata_triton(
507
+ token_indices,
508
+ expert_indices,
509
+ T,
510
+ E,
511
+ expert_frequency,
512
+ expert_frequency_offset,
513
+ x_gather_idx,
514
+ s_scatter_idx,
515
+ s_reverse_scatter_idx,
516
+ num_activated_expert_per_token_offset,
517
+ )
518
+
519
+ y1, z = _UpProjection.apply(
520
+ x,
521
+ w1,
522
+ b1,
523
+ expert_frequency_offset,
524
+ TK,
525
+ None, # K, not needed
526
+ stream_id,
527
+ x_gather_idx,
528
+ s_scatter_idx,
529
+ s_reverse_scatter_idx,
530
+ num_activated_expert_per_token_offset,
531
+ True, # is_varlen_K
532
+ activation_type,
533
+ is_inference_mode_enabled,
534
+ )
535
+
536
+ o = _DownProjection.apply(
537
+ y1,
538
+ z,
539
+ w2,
540
+ b2,
541
+ router_scores,
542
+ expert_frequency_offset,
543
+ T,
544
+ None, # K, not needed
545
+ stream_id,
546
+ x_gather_idx,
547
+ s_scatter_idx,
548
+ s_reverse_scatter_idx,
549
+ num_activated_expert_per_token_offset,
550
+ True, # is_varlen_K
551
+ activation_type,
552
+ )
553
+
554
+ return o, expert_frequency
build/torch-cuda/functional/backward.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ from typing import Optional
6
+
7
+ import cuda.bindings.driver as cuda
8
+ import cutlass.cute as cute
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from .._ops_compat import add_op_namespace_prefix
14
+ from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType
15
+ from ..utils import ceil_divide, convert_torch_tensor_to_cute_tensor, get_powers_of_2
16
+ from .moe_config import (
17
+ HopperWgmma_MoE_Down_proj_ActGrad_Bwd,
18
+ HopperWgmma_MoE_Down_proj_WeightGrad_Bwd,
19
+ HopperWgmma_MoE_Up_proj_ActGrad_Bwd,
20
+ HopperWgmma_MoE_Up_proj_WeightGrad_Bwd,
21
+ )
22
+ from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
23
+
24
+
25
+ def _get_autotune_configs_for_db2_and_ds() -> list[triton.Config]:
26
+ configs = []
27
+ for BLOCK_TK in get_powers_of_2(4, 32):
28
+ configs.append(triton.Config({"BLOCK_TK": BLOCK_TK}, num_warps=8, num_stages=4))
29
+ return configs
30
+
31
+
32
+ @triton.autotune(
33
+ configs=_get_autotune_configs_for_db2_and_ds(),
34
+ key=["H", "E"],
35
+ )
36
+ @triton.jit
37
+ def db2_and_ds_kernel(
38
+ dout_ptr, # (T, H)
39
+ s_ptr, # (TK,)
40
+ new_ds_partial_ptr, # (TK, n_h_blocks)
41
+ old_ds_partial_ptr, # (TK, OLD_DS_PARTIAL_N)
42
+ b2_ptr, # (E, H),
43
+ db2_ptr, # (E, H),
44
+ x_gather_idx_ptr, # (TK,), maps grouped -> token index
45
+ s_scatter_idx_ptr, # (TK,), maps grouped -> scatter index
46
+ expert_offset_ptr, # (E+1,), offsets in grouped layout
47
+ H: tl.constexpr,
48
+ E: tl.constexpr,
49
+ OLD_DS_PARTIAL_N: tl.constexpr,
50
+ BLOCK_H: tl.constexpr, # Block size for H dimension
51
+ BLOCK_TK: tl.constexpr, # Block size for token dimension
52
+ BLOCK_OLD_DS_PARTIAL_N: tl.constexpr,
53
+ ):
54
+ Eidx = tl.program_id(0) # expert id
55
+ Hidx = tl.program_id(1) # h-block id
56
+ NUM_H_BLOCKS: tl.constexpr = tl.num_programs(1)
57
+
58
+ # Hidden dimension indices for this block
59
+ h_offsets = Hidx * BLOCK_H + tl.arange(0, BLOCK_H)
60
+ h_mask = h_offsets < H
61
+
62
+ E_count_start = tl.load(expert_offset_ptr + Eidx)
63
+ E_count_end = tl.load(expert_offset_ptr + Eidx + 1)
64
+ n_tokens = E_count_end - E_count_start
65
+
66
+ b2 = tl.load(b2_ptr + Eidx * H + h_offsets, mask=h_mask, other=0.0).to(tl.float32)
67
+
68
+ db2_acc = tl.zeros([BLOCK_H], dtype=tl.float32)
69
+
70
+ # Process tokens in blocks of BLOCK_TK
71
+ for block_start in tl.range(0, n_tokens, BLOCK_TK):
72
+ # Token offsets within this block
73
+ tk_offsets = block_start + tl.arange(0, BLOCK_TK)
74
+ tk_mask = tk_offsets < n_tokens
75
+ tk_grouped = E_count_start + tk_offsets
76
+
77
+ # Gather token indices: [BLOCK_TK]
78
+ token_indices = tl.load(x_gather_idx_ptr + tk_grouped, mask=tk_mask, other=0).to(tl.uint32)
79
+
80
+ # Get scatter indices: [BLOCK_TK]
81
+ scatter_indices = tl.load(s_scatter_idx_ptr + tk_grouped, mask=tk_mask, other=0).to(tl.uint32)
82
+
83
+ s = tl.load(s_ptr + scatter_indices, mask=tk_mask, other=0.0).to(tl.float32)
84
+
85
+ # Gather dout: [BLOCK_TK, BLOCK_H]
86
+ dout_offsets = token_indices[:, None] * H + h_offsets[None, :]
87
+ dout_mask = tk_mask[:, None] & h_mask[None, :]
88
+ dout = tl.load(dout_ptr + dout_offsets, mask=dout_mask, other=0.0).to(tl.float32)
89
+
90
+ # Accumulate db2: sum over tokens of (dout * s)
91
+ db2_acc += tl.sum(dout * s[:, None], axis=0) # Sum over BLOCK_TK dimension
92
+
93
+ # Compute ds: dot(dout, b2) for this H-block
94
+ ds_partial = tl.sum(dout * b2[None, :], axis=1) # [BLOCK_TK]
95
+
96
+ # On first H-block, add old_ds_partial.sum(dim=1)
97
+ if Hidx == 0:
98
+ n_offsets = tl.arange(0, BLOCK_OLD_DS_PARTIAL_N)
99
+ old_ds_partial_offsets = scatter_indices[:, None] * OLD_DS_PARTIAL_N + n_offsets[None, :]
100
+ old_ds_partial_mask = tk_mask[:, None] & (n_offsets[None, :] < OLD_DS_PARTIAL_N)
101
+ old_ds_partial_vals = tl.load(
102
+ old_ds_partial_ptr + old_ds_partial_offsets, mask=old_ds_partial_mask, other=0.0
103
+ ).to(tl.float32)
104
+ ds_partial += tl.sum(old_ds_partial_vals, axis=1)
105
+
106
+ tl.store(new_ds_partial_ptr + scatter_indices * NUM_H_BLOCKS + Hidx, ds_partial, mask=tk_mask)
107
+
108
+ tl.store(db2_ptr + Eidx * H + h_offsets, db2_acc, mask=h_mask)
109
+
110
+
111
+ def _get_autotune_configs_for_db1() -> list[triton.Config]:
112
+ configs = []
113
+ for BLOCK_TK in get_powers_of_2(4, 128):
114
+ for BLOCK_I in get_powers_of_2(64, 4096):
115
+ if 4096 <= BLOCK_I * BLOCK_TK <= 16384:
116
+ configs.append(triton.Config({"BLOCK_I": BLOCK_I, "BLOCK_TK": BLOCK_TK}, num_warps=8, num_stages=4))
117
+ return configs
118
+
119
+
120
+ def _prune_triton_autotune_config(configs, nargs, **kw):
121
+ pruned_configs = []
122
+ for c in configs:
123
+ if c.kwargs["BLOCK_I"] <= triton.next_power_of_2(nargs["I"]):
124
+ pruned_configs.append(c)
125
+ return pruned_configs
126
+
127
+
128
+ @triton.autotune(
129
+ configs=_get_autotune_configs_for_db1(),
130
+ key=["I", "E"],
131
+ prune_configs_by={"early_config_prune": _prune_triton_autotune_config},
132
+ )
133
+ @triton.jit
134
+ def db1_kernel(
135
+ dz_ptr, # (T, H)
136
+ db1_ptr, # (E, H),
137
+ expert_offset_ptr, # (E+1,), offsets in grouped layout
138
+ I: tl.constexpr,
139
+ E: tl.constexpr,
140
+ BLOCK_I: tl.constexpr, # Block size for H dimension
141
+ BLOCK_TK: tl.constexpr, # Block size for token dimension
142
+ ):
143
+ Eidx = tl.program_id(0) # expert id
144
+
145
+ E_count_start = tl.load(expert_offset_ptr + Eidx).to(tl.int64)
146
+ E_count_end = tl.load(expert_offset_ptr + Eidx + 1).to(tl.int64)
147
+ n_tokens = E_count_end - E_count_start
148
+
149
+ NUM_I_BLOCKS: tl.constexpr = triton.cdiv(I, BLOCK_I)
150
+ for Iidx in tl.static_range(0, NUM_I_BLOCKS, 1):
151
+ i_offsets = Iidx * BLOCK_I + tl.arange(0, BLOCK_I)
152
+ i_mask = i_offsets < I
153
+
154
+ db1_acc = tl.zeros([BLOCK_I], dtype=tl.float32)
155
+
156
+ # Process tokens in blocks of BLOCK_TK
157
+ for block_start in tl.range(0, n_tokens, BLOCK_TK):
158
+ # Token offsets within this block
159
+ tk_offsets = block_start + tl.arange(0, BLOCK_TK)
160
+ tk_mask = tk_offsets < n_tokens
161
+ tk_grouped = E_count_start + tk_offsets
162
+
163
+ dz_offsets = tk_grouped[:, None] * I + i_offsets[None, :]
164
+ dz_mask = tk_mask[:, None] & i_mask[None, :]
165
+ dz = tl.load(dz_ptr + dz_offsets, mask=dz_mask, other=0.0).to(tl.float32)
166
+
167
+ db1_acc += tl.sum(dz, axis=0) # Sum over BLOCK_TK dimension
168
+
169
+ db1_offsets = Eidx.to(tl.int64) * I + i_offsets
170
+ tl.store(db1_ptr + db1_offsets, db1_acc, mask=i_mask)
171
+
172
+
173
+ @triton.jit
174
+ def _colsum_smallN_kernel(
175
+ y_ptr, # *mut T, shape [M]
176
+ x_ptr, # *const T, shape [M, N]
177
+ stride_xm: tl.constexpr,
178
+ stride_xn: tl.constexpr, # strides of X
179
+ stride_y: tl.constexpr, # stride of Y (usually 1)
180
+ N: tl.constexpr, # sizes
181
+ BLOCK_N: tl.constexpr, # tile size along N
182
+ ):
183
+ row = tl.program_id(0)
184
+
185
+ # assume BLOCK_N >= N
186
+ offs = tl.arange(0, BLOCK_N)
187
+ mask = offs < N
188
+ # Load a tile from the row; cast to fp32 for the reduction
189
+ x = tl.load(x_ptr + row * stride_xm + offs * stride_xn, mask=mask, other=0).to(tl.float32)
190
+ # Reduce this tile to a scalar and add
191
+ acc = tl.sum(x, axis=0)
192
+
193
+ # Store the row-sum (cast back to y dtype)
194
+ tl.store(y_ptr + row * stride_y, acc)
195
+
196
+
197
+ @torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_act"), mutates_args={"dx_expanded", "db1"})
198
+ def _up_projection_backward_act(
199
+ w1: torch.Tensor,
200
+ dx_expanded: torch.Tensor,
201
+ dz: torch.Tensor,
202
+ db1: torch.Tensor | None,
203
+ expert_frequency_offset: torch.Tensor,
204
+ expert_schedule_order: torch.Tensor | None,
205
+ x_gather_idx: torch.Tensor,
206
+ s_scatter_idx: torch.Tensor,
207
+ is_glu_activation: bool,
208
+ stream_id: int,
209
+ ) -> None:
210
+ I, H, E = w1.size()
211
+ if is_glu_activation:
212
+ I //= 2
213
+
214
+ # db1 computation
215
+ if db1 is not None:
216
+ db1_kernel[(E,)](dz, db1, expert_frequency_offset, (2 * I if is_glu_activation else I), E)
217
+
218
+ mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
219
+ mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
220
+ mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id)
221
+ mDz = convert_torch_tensor_to_cute_tensor(dz, (0, 1), 1, 16, 8, stream=stream_id)
222
+ mDx_expanded = convert_torch_tensor_to_cute_tensor(dx_expanded, (0, 1), 1, 16, 8, stream=stream_id)
223
+ mW1_trans = convert_torch_tensor_to_cute_tensor(w1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
224
+
225
+ if expert_schedule_order is None:
226
+ mE_permute_order = None
227
+ else:
228
+ mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
229
+ current_stream = cuda.CUstream(stream_id)
230
+
231
+ compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype)
232
+ if compile_dx_key not in _up_projection_backward_act.compile_cache:
233
+ dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation)
234
+ tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)]
235
+ _up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile(
236
+ dx_module,
237
+ mDz,
238
+ mW1_trans,
239
+ mDx_expanded,
240
+ mE_offset,
241
+ mX_gather,
242
+ mS_scatter,
243
+ tensormaps,
244
+ mE_permute_order,
245
+ current_stream,
246
+ )
247
+ _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] = tensormaps
248
+
249
+ dx_tensormaps = _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"]
250
+ _up_projection_backward_act.compile_cache[compile_dx_key](
251
+ mDz,
252
+ mW1_trans,
253
+ mDx_expanded,
254
+ mE_offset,
255
+ mX_gather,
256
+ mS_scatter,
257
+ dx_tensormaps,
258
+ mE_permute_order,
259
+ current_stream,
260
+ )
261
+
262
+
263
+ _up_projection_backward_act.compile_cache = {}
264
+
265
+
266
+ @torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_weight"), mutates_args={"dw1"})
267
+ def _up_projection_backward_weight(
268
+ x: torch.Tensor,
269
+ dw1: torch.Tensor,
270
+ dz: torch.Tensor,
271
+ expert_frequency_offset: torch.Tensor,
272
+ expert_schedule_order: torch.Tensor | None,
273
+ x_gather_idx: torch.Tensor,
274
+ is_glu_activation: bool,
275
+ stream_id: int,
276
+ ) -> None:
277
+ I, H, E = dw1.size()
278
+ if is_glu_activation:
279
+ I //= 2
280
+
281
+ x = x.detach()
282
+
283
+ mDz_trans = convert_torch_tensor_to_cute_tensor(dz.T, (1, 0), 0, 16, 8, stream=stream_id)
284
+ mDw1_trans = convert_torch_tensor_to_cute_tensor(dw1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
285
+
286
+ mX_trans = convert_torch_tensor_to_cute_tensor(x.T, (1, 0), 0, 16, 8, stream=stream_id)
287
+ mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
288
+ mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
289
+
290
+ if expert_schedule_order is None:
291
+ mE_permute_order = None
292
+ else:
293
+ mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
294
+ current_stream = cuda.CUstream(stream_id)
295
+
296
+ compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype)
297
+ if compile_dw1_key not in _up_projection_backward_weight.compile_cache:
298
+ dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation)
299
+ tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)]
300
+ _up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile(
301
+ dw1_module,
302
+ mX_trans,
303
+ mDz_trans,
304
+ mDw1_trans,
305
+ mE_offset,
306
+ mX_gather,
307
+ tensormaps,
308
+ mE_permute_order,
309
+ current_stream,
310
+ )
311
+ _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] = tensormaps
312
+
313
+ dw1_tensormaps = _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"]
314
+ _up_projection_backward_weight.compile_cache[compile_dw1_key](
315
+ mX_trans,
316
+ mDz_trans,
317
+ mDw1_trans,
318
+ mE_offset,
319
+ mX_gather,
320
+ dw1_tensormaps,
321
+ mE_permute_order,
322
+ current_stream,
323
+ )
324
+
325
+
326
+ _up_projection_backward_weight.compile_cache = {}
327
+
328
+
329
+ @torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"dz", "ds", "db2", "y1s"})
330
+ def _down_projection_backward_act(
331
+ dout: torch.Tensor,
332
+ z: torch.Tensor,
333
+ w2: torch.Tensor,
334
+ dz: torch.Tensor,
335
+ ds: torch.Tensor,
336
+ b2: torch.Tensor | None,
337
+ db2: torch.Tensor | None,
338
+ y1s: torch.Tensor,
339
+ topk_scores: torch.Tensor,
340
+ expert_frequency_offset: torch.Tensor,
341
+ expert_schedule_order: torch.Tensor | None,
342
+ x_gather_idx: torch.Tensor,
343
+ s_scatter_idx: torch.Tensor,
344
+ is_glu_activation: bool,
345
+ activation_type: str,
346
+ stream_id: int,
347
+ ) -> None:
348
+ H, I, E = w2.size()
349
+ TK = x_gather_idx.size(0)
350
+
351
+ dout = dout.detach()
352
+ w2 = w2.detach()
353
+ topk_scores = topk_scores.detach()
354
+
355
+ mDout = convert_torch_tensor_to_cute_tensor(dout, (0, 1), 1, 16, 8, stream=stream_id)
356
+ mW2_trans = convert_torch_tensor_to_cute_tensor(w2.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
357
+ mS = convert_torch_tensor_to_cute_tensor(topk_scores, (0,), 0, 4, 1, stream=stream_id)
358
+ if is_glu_activation:
359
+ mDz_kernel_input = convert_torch_tensor_to_cute_tensor(
360
+ dz.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id
361
+ )
362
+ mZ_kernel_input = convert_torch_tensor_to_cute_tensor(
363
+ z.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id
364
+ )
365
+ else:
366
+ mDz_kernel_input = convert_torch_tensor_to_cute_tensor(dz.detach(), (0, 1), 1, 16, 8, stream=stream_id)
367
+ mZ_kernel_input = convert_torch_tensor_to_cute_tensor(z.detach(), (0, 1), 1, 16, 8, stream=stream_id)
368
+
369
+ mY1S = convert_torch_tensor_to_cute_tensor(y1s, (0, 1), 1, 16, 8, stream=stream_id)
370
+ mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
371
+ mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
372
+ mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id)
373
+
374
+ if expert_schedule_order is None:
375
+ mE_permute_order = None
376
+ else:
377
+ mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
378
+ current_stream = cuda.CUstream(stream_id)
379
+ ds_partial = None
380
+
381
+ compile_dz_key = ("dz", E, H, I, z.dtype, activation_type)
382
+ if compile_dz_key not in _down_projection_backward_act.compile_cache:
383
+ # I don't know why but this sync appears to fix a mysterious initialization bug??
384
+ torch.cuda.synchronize()
385
+ dz_module = HopperWgmma_MoE_Down_proj_ActGrad_Bwd(E, H, I, ActivationType(activation_type))
386
+ tensormaps = [dz_module.module.generate_tensormap(None, None, None) for _ in range(3)]
387
+
388
+ ds_partial_N = max(ceil_divide(I, dz_module.module.tile_shape_mnk[1]), 1)
389
+ ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
390
+ mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
391
+
392
+ _down_projection_backward_act.compile_cache["ds_partial_N"] = ds_partial_N
393
+ _down_projection_backward_act.compile_cache[compile_dz_key] = cute.compile(
394
+ dz_module,
395
+ mDout,
396
+ mW2_trans,
397
+ mZ_kernel_input,
398
+ mDz_kernel_input,
399
+ mY1S,
400
+ mS,
401
+ mDS_partial,
402
+ mE_offset,
403
+ mX_gather,
404
+ mS_scatter,
405
+ tensormaps,
406
+ mE_permute_order,
407
+ current_stream,
408
+ )
409
+ _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] = tensormaps
410
+
411
+ if ds_partial is None:
412
+ ds_partial_N = _down_projection_backward_act.compile_cache["ds_partial_N"]
413
+ ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
414
+ mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
415
+
416
+ dz_tensormaps = _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"]
417
+ _down_projection_backward_act.compile_cache[compile_dz_key](
418
+ mDout,
419
+ mW2_trans,
420
+ mZ_kernel_input,
421
+ mDz_kernel_input,
422
+ mY1S,
423
+ mS,
424
+ mDS_partial,
425
+ mE_offset,
426
+ mX_gather,
427
+ mS_scatter,
428
+ dz_tensormaps,
429
+ mE_permute_order,
430
+ current_stream,
431
+ )
432
+
433
+ if db2 is None:
434
+ # we don't need to update ds
435
+ if ds_partial.size(1) == 1:
436
+ ds.copy_(ds_partial.view(-1).to(dtype=ds.dtype))
437
+ elif ds_partial.size(1) <= 32:
438
+ ds.copy_(ds_partial.sum(dim=-1, dtype=ds.dtype))
439
+ else:
440
+ M, N = ds_partial.size()
441
+
442
+ _colsum_smallN_kernel[M,](
443
+ y_ptr=ds,
444
+ x_ptr=ds_partial,
445
+ stride_xm=ds_partial.stride(0),
446
+ stride_xn=ds_partial.stride(1),
447
+ stride_y=1,
448
+ N=N,
449
+ BLOCK_N=triton.next_power_of_2(N),
450
+ )
451
+ else:
452
+ # db2 and ds update
453
+ BLOCK_H = min(triton.next_power_of_2(H), 2048)
454
+ NUM_H_BLOCKS = triton.cdiv(H, BLOCK_H)
455
+
456
+ new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, device=ds.device, dtype=torch.float32)
457
+
458
+ db2_and_ds_kernel[(E, NUM_H_BLOCKS)](
459
+ dout,
460
+ topk_scores,
461
+ new_ds_partial,
462
+ ds_partial,
463
+ b2,
464
+ db2,
465
+ x_gather_idx,
466
+ s_scatter_idx,
467
+ expert_frequency_offset,
468
+ H,
469
+ E,
470
+ ds_partial_N,
471
+ BLOCK_H=BLOCK_H,
472
+ BLOCK_OLD_DS_PARTIAL_N=triton.next_power_of_2(ds_partial_N),
473
+ )
474
+
475
+ if NUM_H_BLOCKS == 1:
476
+ ds.copy_(new_ds_partial.view(-1).to(dtype=ds.dtype))
477
+ else:
478
+ ds.copy_(new_ds_partial.sum(dim=-1, dtype=ds.dtype))
479
+
480
+
481
+ _down_projection_backward_act.compile_cache = {}
482
+
483
+
484
+ @torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_weight"), mutates_args={"dw2"})
485
+ def _down_projection_backward_weight(
486
+ dout: torch.Tensor,
487
+ y1s: torch.Tensor,
488
+ dw2: torch.Tensor,
489
+ expert_frequency_offset: torch.Tensor,
490
+ expert_schedule_order: torch.Tensor | None,
491
+ x_gather_idx: torch.Tensor,
492
+ stream_id: int,
493
+ ) -> None:
494
+ H, I, E = dw2.size()
495
+
496
+ mDout_trans = convert_torch_tensor_to_cute_tensor(dout.T, (1, 0), 0, 16, 8, stream=stream_id)
497
+ mDw2 = convert_torch_tensor_to_cute_tensor(dw2, (2, 0, 1), 1, 16, 8, stream=stream_id)
498
+ mY1S_trans = convert_torch_tensor_to_cute_tensor(y1s.T, (1, 0), 0, 16, 8, stream=stream_id)
499
+ mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
500
+ mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
501
+
502
+ if expert_schedule_order is None:
503
+ mE_permute_order = None
504
+ else:
505
+ mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
506
+ current_stream = cuda.CUstream(stream_id)
507
+
508
+ compile_dw2_key = ("dw2", E, H, I, dw2.dtype)
509
+ if compile_dw2_key not in _down_projection_backward_weight.compile_cache:
510
+ dw2_module = HopperWgmma_MoE_Down_proj_WeightGrad_Bwd(E, H, I)
511
+ tensormaps = [dw2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
512
+ _down_projection_backward_weight.compile_cache[compile_dw2_key] = cute.compile(
513
+ dw2_module,
514
+ mDout_trans,
515
+ mY1S_trans,
516
+ mDw2,
517
+ mE_offset,
518
+ mX_gather,
519
+ tensormaps,
520
+ mE_permute_order,
521
+ current_stream,
522
+ )
523
+ _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] = tensormaps
524
+
525
+ dw2_tensormaps = _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"]
526
+ _down_projection_backward_weight.compile_cache[compile_dw2_key](
527
+ mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, dw2_tensormaps, mE_permute_order, current_stream
528
+ )
529
+
530
+
531
+ _down_projection_backward_weight.compile_cache = {}
532
+
533
+
534
+ @torch.library.custom_op(add_op_namespace_prefix("_token_broadcast_backward"), mutates_args={"dx_reduced"})
535
+ def _token_broadcast_backward(
536
+ dx_reduced: torch.Tensor,
537
+ dx_expanded: torch.Tensor,
538
+ s_reverse_scatter_idx: torch.Tensor,
539
+ num_activated_expert_per_token_offset: Optional[torch.Tensor],
540
+ varlen_K_max: int,
541
+ H: int,
542
+ is_varlen_K: bool,
543
+ ) -> None:
544
+ if num_activated_expert_per_token_offset is None:
545
+ assert not is_varlen_K, "`num_activated_expert_per_token_offset` as None requires fixed top-K routing"
546
+ token_gather_and_sum_varlen_K_triton(
547
+ dx_expanded,
548
+ None,
549
+ dx_reduced,
550
+ s_reverse_scatter_idx,
551
+ num_activated_expert_per_token_offset,
552
+ dx_reduced.size(0),
553
+ varlen_K_max,
554
+ H,
555
+ is_varlen_K,
556
+ )
557
+
558
+
559
+ @triton.jit
560
+ def _softmax_bwd_scatter_small_kernel(
561
+ dlogits_ptr,
562
+ dlogits_full_ptr,
563
+ score_ptr,
564
+ dscore_ptr,
565
+ idx_ptr,
566
+ stride_dm: tl.constexpr,
567
+ stride_dn: tl.constexpr,
568
+ stride_sm: tl.constexpr,
569
+ stride_sn: tl.constexpr,
570
+ stride_gm: tl.constexpr,
571
+ stride_gk: tl.constexpr,
572
+ stride_im: tl.constexpr,
573
+ stride_ik: tl.constexpr,
574
+ K: tl.constexpr,
575
+ BLOCK_K: tl.constexpr,
576
+ dlogits_is_none: tl.constexpr,
577
+ ):
578
+ row = tl.program_id(axis=0)
579
+
580
+ # tl.assume(K <= BLOCK_K)
581
+ k_offs = tl.arange(0, BLOCK_K)
582
+ k_mask = k_offs < K
583
+
584
+ idx = tl.load(idx_ptr + row * stride_im + k_offs * stride_ik, mask=k_mask, other=0).to(tl.int32)
585
+ s_sel = tl.load(score_ptr + row * stride_sm + k_offs * stride_sn, mask=k_mask, other=0).to(tl.float32)
586
+ g_sel = tl.load(dscore_ptr + row * stride_gm + k_offs * stride_gk, mask=k_mask, other=0).to(tl.float32)
587
+
588
+ # dot = sum_j g_j * y_j over selected columns
589
+ dot = tl.sum(g_sel * s_sel, axis=0)
590
+
591
+ # scatter-only: dx[idx] += y_sel * (g_sel - dot)
592
+ add_vals = s_sel * (g_sel - dot)
593
+
594
+ indices = row * stride_dm + idx * stride_dn
595
+ if not dlogits_is_none:
596
+ add_vals += tl.load(dlogits_ptr + indices, mask=k_mask)
597
+ tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)
598
+
599
+
600
+ @torch.library.custom_op(add_op_namespace_prefix("_softmax_topk_bwd"), mutates_args={"dlogits_full"})
601
+ def _softmax_topk_bwd(
602
+ dlogits_full: torch.Tensor,
603
+ dlogits: Optional[torch.Tensor],
604
+ dtopk_score: torch.Tensor,
605
+ topk_router_score: torch.Tensor,
606
+ topk_router_indices: torch.Tensor,
607
+ K: int,
608
+ ) -> None:
609
+ T = dtopk_score.shape[0]
610
+
611
+ _softmax_bwd_scatter_small_kernel[T,](
612
+ dlogits,
613
+ dlogits_full,
614
+ topk_router_score,
615
+ dtopk_score,
616
+ topk_router_indices,
617
+ dlogits_full.stride(0),
618
+ dlogits_full.stride(1),
619
+ topk_router_score.stride(0),
620
+ topk_router_score.stride(1),
621
+ dtopk_score.stride(0),
622
+ dtopk_score.stride(1),
623
+ topk_router_indices.stride(0),
624
+ topk_router_indices.stride(1),
625
+ K,
626
+ triton.next_power_of_2(K),
627
+ (dlogits is None),
628
+ )
629
+
630
+
631
+ @triton.jit
632
+ def _topk_bwd_scatter_small_kernel(
633
+ dlogits_full_ptr,
634
+ dscore_ptr,
635
+ idx_ptr,
636
+ stride_dm: tl.constexpr,
637
+ stride_dn: tl.constexpr,
638
+ stride_gm: tl.constexpr,
639
+ stride_gk: tl.constexpr,
640
+ stride_im: tl.constexpr,
641
+ stride_ik: tl.constexpr,
642
+ K: tl.constexpr,
643
+ BLOCK_K: tl.constexpr,
644
+ ):
645
+ row = tl.program_id(axis=0)
646
+
647
+ # tl.assume(K <= BLOCK_K)
648
+ k_offs = tl.arange(0, BLOCK_K)
649
+ k_mask = k_offs < K
650
+
651
+ idx = tl.load(idx_ptr + row * stride_im + k_offs * stride_ik, mask=k_mask, other=0).to(tl.int32)
652
+ g_sel = tl.load(dscore_ptr + row * stride_gm + k_offs * stride_gk, mask=k_mask, other=0).to(tl.float32)
653
+
654
+ # scatter-only: dx[idx] += y_sel * (g_sel - dot)
655
+ add_vals = g_sel
656
+
657
+ indices = row * stride_dm + idx * stride_dn
658
+ tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)
659
+
660
+
661
+ @torch.library.custom_op(add_op_namespace_prefix("_topk_bwd"), mutates_args={"dlogits_full"})
662
+ def _topk_bwd(
663
+ dlogits_full: torch.Tensor,
664
+ dtopk_values: torch.Tensor,
665
+ topk_indices: torch.Tensor,
666
+ K: int,
667
+ ) -> None:
668
+ T = dtopk_values.shape[0]
669
+
670
+ _topk_bwd_scatter_small_kernel[T,](
671
+ dlogits_full,
672
+ dtopk_values,
673
+ topk_indices,
674
+ dlogits_full.stride(0),
675
+ dlogits_full.stride(1),
676
+ dtopk_values.stride(0),
677
+ dtopk_values.stride(1),
678
+ topk_indices.stride(0),
679
+ topk_indices.stride(1),
680
+ K,
681
+ triton.next_power_of_2(K),
682
+ )
build/torch-cuda/functional/forward.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ import cuda.bindings.driver as cuda
6
+ import cutlass.cute as cute
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from cutlass.cute.runtime import from_dlpack
11
+ from ..quack.cute_dsl_utils import torch2cute_dtype_map
12
+
13
+ from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType
14
+ from .._ops_compat import add_op_namespace_prefix
15
+ from ..utils import convert_torch_tensor_to_cute_tensor
16
+ from .moe_config import HopperWgmma_MoE_Down_proj_Fwd, HopperWgmma_MoE_Up_proj_Fwd
17
+ from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
18
+ from .topk_softmax import TopK_Softmax
19
+
20
+
21
+ @torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"})
22
+ def _topk_fwd(
23
+ x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor, require_softmax_fusion: bool = True
24
+ ) -> None:
25
+ """Top-k forward pass.
26
+ Args:
27
+ x: Input tensor of shape (M, N)
28
+ k: Number of top elements to return
29
+ Returns:
30
+ Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
31
+ """
32
+ N = x.size(1)
33
+
34
+ input_dtype = torch2cute_dtype_map[x.dtype]
35
+ output_dtype = torch2cute_dtype_map[values.dtype]
36
+ convert_from_dlpack = lambda tensor: (
37
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
38
+ )
39
+
40
+ x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
41
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
42
+ compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion)
43
+ if compile_key not in _topk_fwd.compile_cache:
44
+ topk_op = TopK_Softmax(input_dtype, output_dtype, N, k, require_softmax_fusion)
45
+ _topk_fwd.compile_cache[compile_key] = cute.compile(
46
+ topk_op, x_tensor, values_tensor, indices_tensor, current_stream
47
+ )
48
+ _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
49
+
50
+
51
+ _topk_fwd.compile_cache = {}
52
+
53
+
54
+ @torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"z", "y1"})
55
+ def _up_projection_forward(
56
+ x: torch.Tensor,
57
+ w1: torch.Tensor,
58
+ z: torch.Tensor,
59
+ y1: torch.Tensor,
60
+ b1: torch.Tensor | None,
61
+ expert_frequency_offset: torch.Tensor,
62
+ expert_schedule_order: torch.Tensor,
63
+ x_gather_idx: torch.Tensor,
64
+ stream_id: int,
65
+ activation_type: str,
66
+ is_glu_activation: bool,
67
+ is_inference_mode_enabled: bool = False,
68
+ ) -> None:
69
+ I, H, E = w1.size()
70
+ if is_glu_activation:
71
+ I //= 2
72
+
73
+ mX = convert_torch_tensor_to_cute_tensor(x.detach(), (0, 1), 1, 16, 8, stream=stream_id)
74
+ mW1 = convert_torch_tensor_to_cute_tensor(w1.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id)
75
+ mZ = convert_torch_tensor_to_cute_tensor(z, (0, 1), 1, 16, 8, stream=stream_id)
76
+ mY1 = convert_torch_tensor_to_cute_tensor(y1, (0, 1), 1, 16, 8, stream=stream_id)
77
+ mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
78
+ mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
79
+
80
+ if expert_schedule_order is None:
81
+ mE_permute_order = None
82
+ else:
83
+ mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
84
+
85
+ if b1 is None:
86
+ mB1 = None
87
+ else:
88
+ mB1 = convert_torch_tensor_to_cute_tensor(b1.detach(), (0, 1), 1, 16, 8, stream=stream_id)
89
+
90
+ current_stream = cuda.CUstream(stream_id)
91
+
92
+ compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled)
93
+ if compile_w1_key not in _up_projection_forward.compile_cache:
94
+ w1_module = HopperWgmma_MoE_Up_proj_Fwd(
95
+ E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled
96
+ )
97
+ tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)]
98
+ _up_projection_forward.compile_cache[compile_w1_key] = cute.compile(
99
+ w1_module,
100
+ mX,
101
+ mW1,
102
+ mZ,
103
+ mY1,
104
+ mB1,
105
+ mE_offset,
106
+ mX_gather,
107
+ tensormaps[0],
108
+ tensormaps[1],
109
+ mE_permute_order,
110
+ current_stream,
111
+ )
112
+ _up_projection_forward.compile_cache[TENSORMAP] = tensormaps
113
+
114
+ w1_tensormaps = _up_projection_forward.compile_cache[TENSORMAP]
115
+ _up_projection_forward.compile_cache[compile_w1_key](
116
+ mX,
117
+ mW1,
118
+ mZ,
119
+ mY1,
120
+ mB1,
121
+ mE_offset,
122
+ mX_gather,
123
+ w1_tensormaps[0],
124
+ w1_tensormaps[1],
125
+ mE_permute_order,
126
+ current_stream,
127
+ )
128
+
129
+
130
+ _up_projection_forward.compile_cache = {}
131
+
132
+
133
+ @torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y2"})
134
+ def _down_projection_forward(
135
+ w2: torch.Tensor,
136
+ y1: torch.Tensor,
137
+ y2: torch.Tensor,
138
+ b2: torch.Tensor | None,
139
+ expert_frequency_offset: torch.Tensor,
140
+ expert_schedule_order: torch.Tensor,
141
+ x_gather_idx: torch.Tensor,
142
+ stream_id: int,
143
+ ) -> None:
144
+ H, I, E = w2.size()
145
+
146
+ mW2 = convert_torch_tensor_to_cute_tensor(w2.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id)
147
+ mY1 = convert_torch_tensor_to_cute_tensor(y1.detach(), (0, 1), 1, 16, 8, stream=stream_id)
148
+ mY2 = convert_torch_tensor_to_cute_tensor(y2, (0, 1), 1, 16, 8, stream=stream_id)
149
+ mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
150
+ mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
151
+
152
+ if expert_schedule_order is None:
153
+ mE_permute_order = None
154
+ else:
155
+ mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
156
+
157
+ if b2 is None:
158
+ mB2 = None
159
+ else:
160
+ mB2 = convert_torch_tensor_to_cute_tensor(b2.detach(), (0, 1), 1, 16, 8, stream=stream_id)
161
+
162
+ current_stream = cuda.CUstream(stream_id)
163
+
164
+ compile_w2_key = (E, H, I, (b2 is None), w2.dtype)
165
+ if compile_w2_key not in _down_projection_forward.compile_cache:
166
+ w2_module = HopperWgmma_MoE_Down_proj_Fwd(E, H, I)
167
+ tensormaps = [w2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
168
+ _down_projection_forward.compile_cache[compile_w2_key] = cute.compile(
169
+ w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream
170
+ )
171
+ _down_projection_forward.compile_cache[TENSORMAP] = tensormaps
172
+
173
+ w2_tensormaps = _down_projection_forward.compile_cache[TENSORMAP]
174
+ _down_projection_forward.compile_cache[compile_w2_key](
175
+ mY1, mW2, mY2, mB2, mE_offset, mX_gather, w2_tensormaps[0], mE_permute_order, current_stream
176
+ )
177
+
178
+
179
+ _down_projection_forward.compile_cache = {}
180
+
181
+
182
+ @torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"})
183
+ def _router_forward(
184
+ y2: torch.Tensor,
185
+ o: torch.Tensor,
186
+ topk_scores: torch.Tensor,
187
+ s_reverse_scatter_idx: torch.Tensor,
188
+ num_activated_expert_per_token_offset: torch.Tensor,
189
+ varlen_K_max: int,
190
+ H: int,
191
+ is_varlen_K: bool,
192
+ ) -> None:
193
+ token_gather_and_sum_varlen_K_triton(
194
+ y2,
195
+ topk_scores,
196
+ o,
197
+ s_reverse_scatter_idx,
198
+ num_activated_expert_per_token_offset,
199
+ o.size(0),
200
+ varlen_K_max,
201
+ H,
202
+ is_varlen_K,
203
+ )
204
+
205
+
206
+ @triton.jit
207
+ def _softmax_fwd_small_kernel(
208
+ logits_ptr, stride_lm: tl.constexpr, stride_ln: tl.constexpr, K: tl.constexpr, BLOCK_K: tl.constexpr
209
+ ):
210
+ row = tl.program_id(axis=0)
211
+
212
+ # tl.assume(K <= BLOCK_K)
213
+ k_offs = tl.arange(0, BLOCK_K)
214
+ k_mask = k_offs < K
215
+
216
+ # load full row (all columns) in one go (N is small)
217
+ x = tl.load(logits_ptr + row * stride_lm + k_offs * stride_ln, mask=k_mask, other=-float("inf")).to(tl.float32)
218
+ x = x - tl.max(x, axis=0)
219
+ ex = tl.exp(x)
220
+ y = ex / tl.sum(ex, axis=0)
221
+
222
+ tl.store(logits_ptr + row * stride_lm + k_offs * stride_ln, y, mask=k_mask)
223
+
224
+
225
+ @torch.library.custom_op(
226
+ add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"}
227
+ )
228
+ def _softmax_topk_fwd(
229
+ router_logits: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, E: int, K: int
230
+ ) -> None:
231
+ # T = router_logits.shape[0]
232
+ if E <= 4096 and K <= 16 and E % 8 == 0:
233
+ # fast topk-softmax fusion that covers most common MoE configs
234
+ _topk_fwd(router_logits, K, topk_router_score, topk_router_indices, require_softmax_fusion=True)
235
+ else:
236
+ topk_results = router_logits.topk(K, dim=-1)
237
+ topk_router_score.copy_(topk_results.values.softmax(dim=-1, dtype=torch.float32).to(topk_router_score.dtype))
238
+ topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
build/torch-cuda/functional/grouped_gemm.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch-cuda/functional/moe_config.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ import cuda.bindings.driver as cuda
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ import torch
12
+ from cutlass import const_expr
13
+ from ..quack.tile_scheduler import RasterOrderOption
14
+
15
+ from ..enums import ActivationType, is_glu
16
+ from .grouped_gemm import HopperWgmma_MoE_kernel
17
+
18
+
19
+ LIBRARY_NAME = "cutedsl_kernels"
20
+
21
+
22
+ def ceil_div(a: int, b: int):
23
+ return int(math.ceil(a / b))
24
+
25
+
26
+ @dataclass
27
+ class HopperGEMMConfig:
28
+ tile_shape_mnk: cutlass.Constexpr[cute.Shape] = (128, 256, 64)
29
+ cluster_shape_mnk: cutlass.Constexpr[cute.Shape] = (2, 1)
30
+ epi_tile_size: cutlass.Constexpr[int] = 32
31
+ ## assume we always use persistent kernel
32
+ # is_persistent: cutlass.Constexpr[bool] = True
33
+ is_pingpong: cutlass.Constexpr[bool] = False
34
+ raster_order: RasterOrderOption = RasterOrderOption.Heuristic
35
+ L2_group_size: int = 8
36
+ initial_d_epi_stage: cutlass.Constexpr[int] = 4
37
+
38
+
39
+ class HopperWgmma_MoE_Up_proj_Fwd:
40
+ def __init__(self, E: int, H: int, I: int, activation_type: ActivationType, inference_mode=False):
41
+ super().__init__()
42
+ is_glu_activation = is_glu(activation_type)
43
+ if is_glu_activation:
44
+ assert (
45
+ H % 64 == 0 and H >= 512 and I % 64 == 0
46
+ ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
47
+ else:
48
+ assert (
49
+ H % 64 == 0 and H >= 512 and I % 128 == 0
50
+ ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
51
+ # TODO: this assertion does not mean that the MoE impl prohibits such config.
52
+ # Instead, we just do not search for the best configs manually yet for small-shaped MoE
53
+ if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation):
54
+ up_config = HopperGEMMConfig(
55
+ tile_shape_mnk=(128, 256, 64),
56
+ cluster_shape_mnk=(2, 1),
57
+ epi_tile_size=(32 if not inference_mode else 64),
58
+ is_pingpong=False,
59
+ initial_d_epi_stage=2,
60
+ raster_order=RasterOrderOption.AlongM,
61
+ )
62
+ elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation):
63
+ up_config = HopperGEMMConfig(
64
+ tile_shape_mnk=(192, 128, 64),
65
+ cluster_shape_mnk=(1, 1),
66
+ epi_tile_size=(32 if not inference_mode else 64),
67
+ is_pingpong=True,
68
+ initial_d_epi_stage=8,
69
+ raster_order=RasterOrderOption.AlongM,
70
+ )
71
+ else:
72
+ raise NotImplementedError()
73
+
74
+ compute_swiglu = False
75
+ compute_geglu = False
76
+ compute_reglu = False
77
+
78
+ compute_relu_sq = False
79
+ compute_silu = False
80
+ compute_relu = False
81
+ compute_gelu = False
82
+
83
+ if activation_type == ActivationType.SWIGLU:
84
+ compute_swiglu = True
85
+ elif activation_type == ActivationType.GEGLU:
86
+ compute_geglu = True
87
+ elif activation_type == ActivationType.REGLU:
88
+ compute_reglu = True
89
+
90
+ elif activation_type == ActivationType.RELU_SQ:
91
+ compute_relu_sq = True
92
+ elif activation_type == ActivationType.RELU:
93
+ compute_relu = True
94
+ elif activation_type == ActivationType.SILU:
95
+ compute_silu = True
96
+ elif activation_type == ActivationType.GELU:
97
+ compute_gelu = True
98
+
99
+ else:
100
+ raise NotImplementedError(f"Activation function {activation_type} not supported yet!")
101
+
102
+ self.module = HopperWgmma_MoE_kernel(
103
+ E,
104
+ cutlass.Float32,
105
+ up_config.tile_shape_mnk,
106
+ (*up_config.cluster_shape_mnk, 1),
107
+ pingpong=up_config.is_pingpong,
108
+ is_persistent=True,
109
+ compute_swiglu=compute_swiglu,
110
+ compute_reglu=compute_reglu,
111
+ compute_geglu=compute_geglu,
112
+ compute_relu_sq=compute_relu_sq,
113
+ compute_relu=compute_relu,
114
+ compute_silu=compute_silu,
115
+ compute_gelu=compute_gelu,
116
+ is_A_gather=True,
117
+ epi_tile_size=up_config.epi_tile_size,
118
+ initial_d_epi_stage=up_config.initial_d_epi_stage,
119
+ inference_mode=inference_mode,
120
+ )
121
+ self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
122
+ up_config.cluster_shape_mnk[0] * up_config.cluster_shape_mnk[1]
123
+ )
124
+ self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
125
+
126
+ @cute.jit
127
+ def __call__(
128
+ self, mX, mW1, mZ, mY1, mB1, mE_offset, mX_gather, mD_tensormap, mY1_tensormap, mE_permute_order, stream
129
+ ):
130
+ return self.module(
131
+ mX,
132
+ mW1,
133
+ None,
134
+ mB1,
135
+ mZ,
136
+ mY1,
137
+ None,
138
+ None,
139
+ mE_offset,
140
+ mX_gather,
141
+ None,
142
+ None,
143
+ None,
144
+ None,
145
+ None,
146
+ mD_tensormap,
147
+ mY1_tensormap,
148
+ None,
149
+ mE_permute_order,
150
+ const_expr(self.max_active_clusters),
151
+ stream,
152
+ )
153
+
154
+
155
+ class HopperWgmma_MoE_Down_proj_Fwd:
156
+ def __init__(self, E: int, H: int, I: int):
157
+ super().__init__()
158
+ assert (
159
+ H % 64 == 0 and H >= 512 and I % 64 == 0
160
+ ), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
161
+ if I >= 1024:
162
+ down_config = HopperGEMMConfig(
163
+ tile_shape_mnk=(128, 256, 64),
164
+ cluster_shape_mnk=(2, 1),
165
+ epi_tile_size=32,
166
+ is_pingpong=False,
167
+ initial_d_epi_stage=4,
168
+ raster_order=RasterOrderOption.AlongN,
169
+ )
170
+ elif I >= 256:
171
+ down_config = HopperGEMMConfig(
172
+ tile_shape_mnk=(128, 192, 64),
173
+ cluster_shape_mnk=(2, 1),
174
+ epi_tile_size=(96 if H % 96 == 0 else 64),
175
+ is_pingpong=True,
176
+ initial_d_epi_stage=5,
177
+ raster_order=RasterOrderOption.AlongN,
178
+ )
179
+ elif I >= 64:
180
+ down_config = HopperGEMMConfig(
181
+ tile_shape_mnk=(128, 192, 64),
182
+ cluster_shape_mnk=(1, 2),
183
+ epi_tile_size=64,
184
+ is_pingpong=True,
185
+ initial_d_epi_stage=8,
186
+ raster_order=RasterOrderOption.AlongN,
187
+ )
188
+ else:
189
+ raise NotImplementedError()
190
+
191
+ self.module = HopperWgmma_MoE_kernel(
192
+ E,
193
+ cutlass.Float32,
194
+ down_config.tile_shape_mnk,
195
+ (*down_config.cluster_shape_mnk, 1),
196
+ pingpong=down_config.is_pingpong,
197
+ is_persistent=True,
198
+ compute_swiglu=False,
199
+ is_A_gather=False,
200
+ epi_tile_size=down_config.epi_tile_size,
201
+ initial_d_epi_stage=down_config.initial_d_epi_stage,
202
+ )
203
+ self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
204
+ down_config.cluster_shape_mnk[0] * down_config.cluster_shape_mnk[1]
205
+ )
206
+
207
+ @cute.jit
208
+ def __call__(self, mY1, mW2, mY2, mB2, mE_offset, mX_gather, mD_tensormap, mE_permute_order, stream):
209
+ # we are not really using mX_gather in the Grouped GEMM,
210
+ # but CuTe-DSL compiler disallows dynamic flow so we still need to pass this argument
211
+ return self.module(
212
+ mY1,
213
+ mW2,
214
+ None,
215
+ mB2,
216
+ mY2,
217
+ None,
218
+ None,
219
+ None,
220
+ mE_offset,
221
+ mX_gather,
222
+ None,
223
+ None,
224
+ None,
225
+ None,
226
+ None,
227
+ mD_tensormap,
228
+ None,
229
+ None,
230
+ mE_permute_order,
231
+ const_expr(self.max_active_clusters),
232
+ stream,
233
+ )
234
+
235
+
236
+ class HopperWgmma_MoE_Down_proj_ActGrad_Bwd:
237
+ def __init__(self, E: int, H: int, I: int, activation_type: ActivationType):
238
+ super().__init__()
239
+ is_glu_activation = is_glu(activation_type)
240
+ if is_glu_activation:
241
+ assert (
242
+ H % 64 == 0 and H >= 512 and I % 64 == 0
243
+ ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
244
+ else:
245
+ assert (
246
+ H % 64 == 0 and H >= 512 and I % 128 == 0
247
+ ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
248
+
249
+ # heavy register pressure due to pingpong + heavy epilogue
250
+ # effectively no alternatives to this config
251
+ dz_partial_ds_config = HopperGEMMConfig(
252
+ tile_shape_mnk=(128, 128, 64),
253
+ cluster_shape_mnk=(2, 1),
254
+ epi_tile_size=32,
255
+ initial_d_epi_stage=4,
256
+ is_pingpong=True,
257
+ raster_order=RasterOrderOption.Heuristic,
258
+ )
259
+
260
+ compute_swiglu = False
261
+ compute_geglu = False
262
+ compute_reglu = False
263
+
264
+ compute_relu_sq = False
265
+ compute_silu = False
266
+ compute_relu = False
267
+ compute_gelu = False
268
+
269
+ if activation_type == ActivationType.SWIGLU:
270
+ compute_swiglu = True
271
+ elif activation_type == ActivationType.GEGLU:
272
+ compute_geglu = True
273
+ elif activation_type == ActivationType.REGLU:
274
+ compute_reglu = True
275
+
276
+ elif activation_type == ActivationType.RELU_SQ:
277
+ compute_relu_sq = True
278
+ elif activation_type == ActivationType.RELU:
279
+ compute_relu = True
280
+ elif activation_type == ActivationType.SILU:
281
+ compute_silu = True
282
+ elif activation_type == ActivationType.GELU:
283
+ compute_gelu = True
284
+
285
+ else:
286
+ raise NotImplementedError(f"Activation function {activation_type} not supported yet!")
287
+
288
+ self.module = HopperWgmma_MoE_kernel(
289
+ E,
290
+ cutlass.Float32,
291
+ dz_partial_ds_config.tile_shape_mnk,
292
+ (*dz_partial_ds_config.cluster_shape_mnk, 1),
293
+ pingpong=dz_partial_ds_config.is_pingpong,
294
+ is_persistent=True,
295
+ compute_swiglu=compute_swiglu,
296
+ compute_reglu=compute_reglu,
297
+ compute_geglu=compute_geglu,
298
+ compute_relu_sq=compute_relu_sq,
299
+ compute_relu=compute_relu,
300
+ compute_silu=compute_silu,
301
+ compute_gelu=compute_gelu,
302
+ compute_dz_and_partial_ds_and_y1s=True,
303
+ is_A_gather=True,
304
+ epi_tile_size=dz_partial_ds_config.epi_tile_size,
305
+ initial_d_epi_stage=dz_partial_ds_config.initial_d_epi_stage,
306
+ )
307
+ self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
308
+ dz_partial_ds_config.cluster_shape_mnk[0] * dz_partial_ds_config.cluster_shape_mnk[1]
309
+ )
310
+
311
+ @cute.jit
312
+ def __call__(
313
+ self,
314
+ mDout,
315
+ mW2_trans,
316
+ mZ_FP32_if_GLU_else_BF16,
317
+ mDz_FP32_if_GLU_else_BF16,
318
+ mY1S,
319
+ mS,
320
+ mDS_partial,
321
+ mE_offset,
322
+ mX_gather,
323
+ mS_scatter,
324
+ tensormaps,
325
+ mE_permute_order,
326
+ stream,
327
+ ):
328
+ return self.module(
329
+ mDout,
330
+ mW2_trans,
331
+ mZ_FP32_if_GLU_else_BF16,
332
+ None,
333
+ mDz_FP32_if_GLU_else_BF16,
334
+ mY1S,
335
+ mS,
336
+ mDS_partial,
337
+ mE_offset,
338
+ mX_gather,
339
+ None,
340
+ mS_scatter,
341
+ None,
342
+ None,
343
+ tensormaps[0],
344
+ tensormaps[1],
345
+ tensormaps[2],
346
+ None,
347
+ mE_permute_order,
348
+ const_expr(self.max_active_clusters),
349
+ stream,
350
+ )
351
+
352
+
353
+ class HopperWgmma_MoE_Down_proj_WeightGrad_Bwd:
354
+ def __init__(self, E: int, H: int, I: int):
355
+ super().__init__()
356
+ assert (
357
+ H % 64 == 0 and H >= 512 and I % 64 == 0
358
+ ), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
359
+
360
+ if I >= 128:
361
+ dw2_config = HopperGEMMConfig(
362
+ tile_shape_mnk=(128, 256, 64),
363
+ cluster_shape_mnk=(2, 1),
364
+ epi_tile_size=16,
365
+ is_pingpong=False,
366
+ initial_d_epi_stage=6,
367
+ raster_order=RasterOrderOption.AlongN,
368
+ )
369
+ elif I == 64:
370
+ dw2_config = HopperGEMMConfig(
371
+ tile_shape_mnk=(64, 192, 64),
372
+ cluster_shape_mnk=(2, 1),
373
+ epi_tile_size=32,
374
+ is_pingpong=True,
375
+ initial_d_epi_stage=6,
376
+ raster_order=RasterOrderOption.AlongN,
377
+ )
378
+ else:
379
+ raise NotImplementedError()
380
+
381
+ self.module = HopperWgmma_MoE_kernel(
382
+ E,
383
+ cutlass.Float32,
384
+ dw2_config.tile_shape_mnk,
385
+ (*dw2_config.cluster_shape_mnk, 1),
386
+ pingpong=dw2_config.is_pingpong,
387
+ is_persistent=True,
388
+ compute_swiglu=False,
389
+ compute_weight_gradient=True,
390
+ compute_dz_and_partial_ds_and_y1s=False,
391
+ is_A_gather=True,
392
+ epi_tile_size=dw2_config.epi_tile_size,
393
+ initial_d_epi_stage=dw2_config.initial_d_epi_stage,
394
+ )
395
+ self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
396
+ dw2_config.cluster_shape_mnk[0] * dw2_config.cluster_shape_mnk[1]
397
+ )
398
+
399
+ @cute.jit
400
+ def __call__(self, mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
401
+ return self.module(
402
+ mDout_trans,
403
+ mY1S_trans,
404
+ None,
405
+ None,
406
+ mDw2,
407
+ None,
408
+ None,
409
+ None,
410
+ mE_offset,
411
+ mX_gather,
412
+ None,
413
+ None,
414
+ None,
415
+ tensormaps[0],
416
+ None,
417
+ None,
418
+ None,
419
+ None,
420
+ mE_permute_order,
421
+ const_expr(self.max_active_clusters),
422
+ stream,
423
+ )
424
+
425
+
426
+ class HopperWgmma_MoE_Up_proj_ActGrad_Bwd:
427
+ def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
428
+ super().__init__()
429
+ if is_glu_activation:
430
+ assert (
431
+ H % 64 == 0 and H >= 512 and I % 64 == 0
432
+ ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
433
+ else:
434
+ assert (
435
+ H % 64 == 0 and H >= 512 and I % 128 == 0
436
+ ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
437
+
438
+ if (I >= 512 and is_glu_activation) or (I >= 1024 and not is_glu_activation):
439
+ dx_config = HopperGEMMConfig(
440
+ tile_shape_mnk=(128, 256, 64),
441
+ cluster_shape_mnk=(2, 1),
442
+ epi_tile_size=32,
443
+ is_pingpong=False,
444
+ initial_d_epi_stage=4,
445
+ raster_order=RasterOrderOption.AlongN,
446
+ )
447
+ elif (I >= 64 and is_glu_activation) or (I >= 128 and not is_glu_activation):
448
+ dx_config = HopperGEMMConfig(
449
+ tile_shape_mnk=(128, 192, 64),
450
+ cluster_shape_mnk=(2, 1),
451
+ epi_tile_size=64,
452
+ is_pingpong=True,
453
+ initial_d_epi_stage=8,
454
+ raster_order=RasterOrderOption.AlongN,
455
+ )
456
+ else:
457
+ raise NotImplementedError()
458
+
459
+ self.module = HopperWgmma_MoE_kernel(
460
+ E,
461
+ cutlass.Float32,
462
+ dx_config.tile_shape_mnk,
463
+ (*dx_config.cluster_shape_mnk, 1),
464
+ pingpong=dx_config.is_pingpong,
465
+ is_persistent=True,
466
+ compute_swiglu=False,
467
+ compute_dz_and_partial_ds_and_y1s=False,
468
+ is_A_gather=False,
469
+ epi_tile_size=dx_config.epi_tile_size,
470
+ )
471
+
472
+ self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
473
+ dx_config.cluster_shape_mnk[0] * dx_config.cluster_shape_mnk[1]
474
+ )
475
+ self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
476
+
477
+ @cute.jit
478
+ def __call__(
479
+ self, mDz, mW1_trans, mDx_expanded, mE_offset, mX_gather, mS_scatter, tensormaps, mE_permute_order, stream
480
+ ):
481
+ return self.module(
482
+ mDz,
483
+ mW1_trans,
484
+ None,
485
+ None,
486
+ mDx_expanded,
487
+ None,
488
+ None,
489
+ None,
490
+ mE_offset,
491
+ mX_gather,
492
+ None,
493
+ mS_scatter,
494
+ None,
495
+ None,
496
+ None,
497
+ tensormaps[0],
498
+ tensormaps[1],
499
+ None,
500
+ mE_permute_order,
501
+ const_expr(self.max_active_clusters),
502
+ stream,
503
+ )
504
+
505
+
506
+ class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd:
507
+ def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
508
+ super().__init__()
509
+ if is_glu_activation:
510
+ assert (
511
+ H % 64 == 0 and H >= 512 and I % 64 == 0
512
+ ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
513
+ else:
514
+ assert (
515
+ H % 64 == 0 and H >= 512 and I % 128 == 0
516
+ ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
517
+
518
+ if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation):
519
+ dw1_config = HopperGEMMConfig(
520
+ tile_shape_mnk=(128, 256, 64),
521
+ cluster_shape_mnk=(2, 1),
522
+ epi_tile_size=16,
523
+ is_pingpong=False,
524
+ initial_d_epi_stage=6,
525
+ raster_order=RasterOrderOption.Heuristic,
526
+ )
527
+ elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation):
528
+ dw1_config = HopperGEMMConfig(
529
+ tile_shape_mnk=(256, 128, 64),
530
+ cluster_shape_mnk=(2, 1),
531
+ epi_tile_size=16,
532
+ is_pingpong=False,
533
+ initial_d_epi_stage=6,
534
+ raster_order=RasterOrderOption.AlongN,
535
+ )
536
+ else:
537
+ raise NotImplementedError()
538
+
539
+ self.module = HopperWgmma_MoE_kernel(
540
+ E,
541
+ cutlass.Float32,
542
+ dw1_config.tile_shape_mnk,
543
+ (*dw1_config.cluster_shape_mnk, 1),
544
+ pingpong=dw1_config.is_pingpong,
545
+ is_persistent=True,
546
+ compute_swiglu=False,
547
+ compute_weight_gradient=True,
548
+ compute_dz_and_partial_ds_and_y1s=False,
549
+ is_A_gather=True,
550
+ epi_tile_size=dw1_config.epi_tile_size,
551
+ )
552
+
553
+ self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
554
+ dw1_config.cluster_shape_mnk[0] * dw1_config.cluster_shape_mnk[1]
555
+ )
556
+
557
+ @cute.jit
558
+ def __call__(self, mX_trans, mDz_trans, mDw1_trans, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
559
+ return self.module(
560
+ mX_trans,
561
+ mDz_trans,
562
+ None,
563
+ None,
564
+ mDw1_trans,
565
+ None,
566
+ None,
567
+ None,
568
+ mE_offset,
569
+ mX_gather,
570
+ None,
571
+ None,
572
+ None,
573
+ tensormaps[0],
574
+ None,
575
+ None,
576
+ None,
577
+ None,
578
+ mE_permute_order,
579
+ const_expr(self.max_active_clusters),
580
+ stream,
581
+ )
build/torch-cuda/functional/reduction_over_k_gather.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from ..utils import get_powers_of_2
12
+
13
+
14
+ ### This triton impl is equivalent as the cute-dsl impl shown above,
15
+ # and also achieves similar memory bandwidth on H100 for large K and H.
16
+ # However, for small K and H, this impl is better by autotuning so we use it as the default.
17
+ def _get_triton_autotune_configs() -> list[triton.Config]:
18
+ configs = []
19
+ for BLOCK_H in get_powers_of_2(256, 4096):
20
+ for BLOCK_K in get_powers_of_2(1, 128):
21
+ for num_warps in [4, 8]:
22
+ if BLOCK_K * BLOCK_H <= 32768:
23
+ configs.append(
24
+ triton.Config({"BLOCK_H": BLOCK_H, "BLOCK_K": BLOCK_K}, num_warps=num_warps, num_stages=4)
25
+ )
26
+ return configs
27
+
28
+
29
+ def _prune_triton_autotune_config(configs, nargs, **kw):
30
+ pruned_configs = []
31
+ for c in configs:
32
+ BLOCK_H = c.kwargs["BLOCK_H"]
33
+ BLOCK_K = c.kwargs["BLOCK_K"]
34
+ H = kw["H"]
35
+ MAX_K = kw["MAX_K"]
36
+ if (
37
+ BLOCK_H <= triton.next_power_of_2(H)
38
+ and BLOCK_K <= triton.next_power_of_2(MAX_K)
39
+ and min(H * MAX_K, 1024) <= (BLOCK_H * BLOCK_K)
40
+ ):
41
+ pruned_configs.append(c)
42
+
43
+ if len(pruned_configs) == 0:
44
+ return configs
45
+ else:
46
+ return pruned_configs
47
+
48
+
49
+ @triton.autotune(
50
+ configs=_get_triton_autotune_configs(),
51
+ key=["H", "MAX_K", "w_is_None", "is_varlen_K"],
52
+ prune_configs_by={"early_config_prune": _prune_triton_autotune_config},
53
+ )
54
+ @triton.jit
55
+ def token_gather_sum_kernel(
56
+ x_ptr, # (Mtotal, H)
57
+ w_ptr, # (Mtotal,)
58
+ M_perm_ptr, # (Mtotal,) int32
59
+ M_offset_ptr, # (T+1,) int32
60
+ out_ptr, # (T, H)
61
+ T,
62
+ H: tl.constexpr,
63
+ MAX_K: tl.constexpr,
64
+ # strides
65
+ stride_xM: tl.constexpr,
66
+ stride_xH: tl.constexpr,
67
+ stride_outT: tl.constexpr,
68
+ stride_outH: tl.constexpr,
69
+ # tile sizes
70
+ BLOCK_H: tl.constexpr,
71
+ BLOCK_K: tl.constexpr,
72
+ w_is_None: tl.constexpr,
73
+ is_varlen_K: tl.constexpr,
74
+ ):
75
+ # 1D tiling over T only
76
+ pid_t = tl.program_id(axis=0)
77
+ t_idx = pid_t.to(tl.uint32)
78
+
79
+ # Load segment starts and ends for this token
80
+ if is_varlen_K:
81
+ Ms = tl.load(M_offset_ptr + t_idx).to(tl.uint32)
82
+ Me = tl.load(M_offset_ptr + t_idx + 1).to(tl.uint32)
83
+ K_this_token = Me - Ms # actual K for this token
84
+ else:
85
+ Ms = MAX_K * t_idx
86
+ K_this_token: tl.constexpr = MAX_K
87
+
88
+ # Outer loop over H tiles
89
+ for h_tile in tl.static_range(triton.cdiv(H, BLOCK_H)):
90
+ h_idx = (h_tile * BLOCK_H + tl.arange(0, BLOCK_H)).to(tl.uint32) # [BLOCK_H]
91
+ m_h = h_idx < H
92
+
93
+ # Initialize accumulator for this H tile
94
+ acc = tl.zeros([BLOCK_H], dtype=tl.float32) # [BLOCK_H]
95
+
96
+ # Inner loop over K tiles
97
+ for k_tile in tl.range(tl.cdiv(K_this_token, BLOCK_K)):
98
+ k_offset = k_tile * BLOCK_K
99
+
100
+ k_idx = (k_offset + tl.arange(0, BLOCK_K)).to(tl.uint32) # [BLOCK_K]
101
+
102
+ # Mask for valid K indices
103
+ m_k = k_idx < K_this_token # [BLOCK_K]
104
+
105
+ # Absolute positions into M_perm and w
106
+ m_abs = Ms + k_idx # [BLOCK_K]
107
+
108
+ # Gather permuted indices
109
+ perm_idx = tl.load(M_perm_ptr + m_abs, mask=m_k, other=0).to(tl.uint32) # [BLOCK_K]
110
+
111
+ # Load x values: [BLOCK_K, BLOCK_H]
112
+ x_ptrs = x_ptr + perm_idx[:, None] * stride_xM + h_idx[None, :] * stride_xH
113
+ x_mask = m_k[:, None] & m_h[None, :]
114
+ x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
115
+
116
+ # Reduce along K dimension and add to accumulator
117
+ if w_is_None:
118
+ acc += tl.sum(x_vals, axis=0) # [BLOCK_H]
119
+ else:
120
+ w_vals = tl.load(w_ptr + m_abs, mask=m_k, other=0.0).to(tl.float32) # [BLOCK_K]
121
+ acc += tl.sum(x_vals * w_vals[:, None], axis=0) # [BLOCK_H]
122
+
123
+ # Store final result for this H tile (only once!)
124
+ out_ptrs = out_ptr + t_idx * stride_outT + h_idx * stride_outH
125
+ tl.store(out_ptrs, acc, mask=m_h)
126
+
127
+
128
+ def token_gather_and_sum_varlen_K_triton(
129
+ x: torch.Tensor, # (Mtotal, H)
130
+ w: Optional[torch.Tensor], # (Mtotal,)
131
+ out: torch.Tensor, # (T, H)
132
+ M_perm: torch.Tensor, # (Mtotal,) int32
133
+ M_offset: torch.Tensor, # (T+1,) int32, variable K per token
134
+ T: int,
135
+ MAX_K: int, # maximum K across all tokens
136
+ H: int,
137
+ is_varlen_K: bool,
138
+ ):
139
+ """
140
+ 1D parallelization over T, with iterative accumulation over K tiles and H tiles.
141
+ Supports variable K per token.
142
+
143
+ out[i, :] = sum_{j=0..K[i]-1} x[M_perm[M_offset[i] + j], :] * w[M_offset[i] + j]
144
+
145
+ where K[i] = M_offset[i+1] - M_offset[i] can vary per token.
146
+ """
147
+
148
+ # 1D grid over T only
149
+ token_gather_sum_kernel[(T,)](
150
+ x,
151
+ w,
152
+ M_perm,
153
+ M_offset,
154
+ out,
155
+ T=T,
156
+ H=H,
157
+ MAX_K=MAX_K,
158
+ stride_xM=x.stride(0),
159
+ stride_xH=x.stride(1),
160
+ stride_outT=out.stride(0),
161
+ stride_outH=out.stride(1),
162
+ w_is_None=(w is None),
163
+ is_varlen_K=is_varlen_K,
164
+ )
build/torch-cuda/functional/tile_scheduler.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ from __future__ import annotations
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Boolean, Int32, const_expr
10
+ from ..quack.pipeline import PipelineStateWAdvance
11
+ from ..quack.tile_scheduler import TileScheduler, VarlenMTileScheduler
12
+
13
+
14
+ class SonicMoETileScheduler(TileScheduler):
15
+ @staticmethod
16
+ @cute.jit
17
+ def create(
18
+ params: TileScheduler.Params,
19
+ tile_count: cute.Tensor | None = None,
20
+ scheduler_pipeline: cutlass.pipeline.PipelineAsync | None = None,
21
+ is_scheduler_warp: bool | Boolean = False,
22
+ *,
23
+ loc=None,
24
+ ip=None,
25
+ ) -> SonicMoETileScheduler:
26
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
27
+ stages = 0
28
+ if const_expr(not params.is_persistent):
29
+ cidx, cidy, _ = cute.arch.cluster_idx()
30
+ cdimx, _, _ = cute.arch.cluster_dim()
31
+ cluster_id = cidx + cidy * cdimx
32
+ current_work_linear_idx = Int32(cluster_id)
33
+ else:
34
+ _, _, bidz = cute.arch.block_idx()
35
+ current_work_linear_idx = Int32(bidz)
36
+ if const_expr(params.tile_count_semaphore is not None):
37
+ assert tile_count is not None
38
+ assert scheduler_pipeline is not None
39
+ stages = const_expr(cute.size(tile_count))
40
+ return SonicMoETileScheduler(
41
+ current_work_linear_idx,
42
+ Int32(0), # num_tiles_executed
43
+ tile_count,
44
+ scheduler_pipeline,
45
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
46
+ params,
47
+ loc=loc,
48
+ ip=ip,
49
+ )
50
+
51
+ def prefetch_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
52
+ old_current_work_linear_idx = self._current_work_linear_idx
53
+ if const_expr(self.params.is_persistent):
54
+ num_persistent_clusters = cute.arch.grid_dim()[2]
55
+ self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
56
+ future_tile_coord_mnkl = self.get_current_work()
57
+ self._current_work_linear_idx = old_current_work_linear_idx
58
+ return future_tile_coord_mnkl
59
+
60
+
61
+ class SonicMoEVarlenMTileScheduler(VarlenMTileScheduler, SonicMoETileScheduler):
62
+ @staticmethod
63
+ @cute.jit
64
+ def create(
65
+ params: VarlenMTileScheduler.Params,
66
+ tile_count: cute.Tensor | None = None,
67
+ scheduler_pipeline: cutlass.pipeline.PipelineAsync | None = None,
68
+ is_scheduler_warp: bool | Boolean = False,
69
+ *,
70
+ loc=None,
71
+ ip=None,
72
+ ) -> SonicMoEVarlenMTileScheduler:
73
+ stages = 0
74
+ _, _, bidz = cute.arch.block_idx()
75
+ current_work_linear_idx = Int32(bidz)
76
+ if const_expr(params.tile_count_semaphore is not None):
77
+ assert tile_count is not None
78
+ assert scheduler_pipeline is not None
79
+ stages = const_expr(cute.size(tile_count))
80
+ return SonicMoEVarlenMTileScheduler(
81
+ current_work_linear_idx,
82
+ Int32(0), # num_tiles_executed
83
+ Int32(0), # current_batch_idx
84
+ Int32(0), # num_work_idx_before_cur_batch
85
+ tile_count,
86
+ scheduler_pipeline,
87
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
88
+ params,
89
+ loc=loc,
90
+ ip=ip,
91
+ )
build/torch-cuda/functional/topk_softmax.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ # this impl is adapted from QuACK's topk https://github.com/Dao-AILab/quack/blob/main/quack/topk.py
6
+ import math
7
+ from typing import Type
8
+
9
+ import cuda.bindings.driver as cuda
10
+ import cutlass
11
+ import cutlass.cute as cute
12
+ from ..quack import utils
13
+ from cutlass import const_expr
14
+ from ..quack.sort.bitonic_sort import bitonic_topk
15
+ from triton import next_power_of_2
16
+
17
+ from ..utils import domain_offset_i64
18
+
19
+
20
+ class TopK_Softmax:
21
+ def __init__(
22
+ self,
23
+ input_dtype: Type[cutlass.Numeric],
24
+ output_dtype: Type[cutlass.Numeric],
25
+ N: int,
26
+ k: int,
27
+ require_softmax_fusion: bool = True,
28
+ ):
29
+ self.input_dtype = input_dtype
30
+ self.output_dtype = output_dtype
31
+ self.N = N
32
+ self.input_vecsize = 128 // input_dtype.width
33
+ self.output_vecsize = 128 // output_dtype.width
34
+ self.k = k
35
+ self.next_power_of_2_N = next_power_of_2(N)
36
+ self.next_power_of_2_K = next_power_of_2(k)
37
+ assert k <= 128 and k <= N
38
+ assert N <= 4096 and N % 8 == 0
39
+ assert input_dtype.width <= output_dtype.width, "input bitwidth must <= output bitwidth"
40
+
41
+ self.require_softmax_fusion = require_softmax_fusion
42
+
43
+ def _calculate_threads_per_row(self):
44
+ # we want num_elems_per_thread >= self.k
45
+ # and each thread can handle at most 64 elements
46
+ N = self.next_power_of_2_N
47
+ num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
48
+ return num_threads_per_row
49
+
50
+ def _get_tv_layout(self, vecsize):
51
+ N = self.next_power_of_2_N
52
+ num_threads = 128 if N <= 16384 else 256
53
+ threads_per_row = self._calculate_threads_per_row()
54
+ cols_per_block = num_threads // threads_per_row
55
+ num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
56
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
57
+ tv_layout = cute.make_layout(
58
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
59
+ stride=(
60
+ (vecsize * cols_per_block, 1),
61
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
62
+ ),
63
+ )
64
+ return tiler_mn, tv_layout
65
+
66
+ @cute.jit
67
+ def __call__(
68
+ self,
69
+ mX: cute.Tensor,
70
+ mValues: cute.Tensor,
71
+ mIndices: cute.Tensor,
72
+ stream: cuda.CUstream,
73
+ ):
74
+ assert mX.element_type == self.input_dtype
75
+ assert mValues.element_type == self.output_dtype
76
+ assert mIndices.element_type == cutlass.Int32
77
+ input_tiler_mn, input_tv_layout = self._get_tv_layout(self.input_vecsize)
78
+ output_tiler_mn, output_tv_layout = self._get_tv_layout(self.output_vecsize)
79
+
80
+ num_threads = cute.size(input_tv_layout, mode=[0])
81
+ self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout, output_tiler_mn).launch(
82
+ grid=[cute.ceil_div(mX.shape[0], input_tiler_mn[0]), 1, 1],
83
+ block=[num_threads, 1, 1],
84
+ stream=stream,
85
+ )
86
+
87
+ @cute.kernel
88
+ def kernel(
89
+ self,
90
+ mX: cute.Tensor,
91
+ mValues: cute.Tensor,
92
+ mIndices: cute.Tensor,
93
+ input_tv_layout: cute.Layout,
94
+ input_tiler_mn: cute.Shape,
95
+ output_tv_layout: cute.Layout,
96
+ output_tiler_mn: cute.Shape,
97
+ ):
98
+ tidx, _, _ = cute.arch.thread_idx()
99
+ bidx, _, _ = cute.arch.block_idx()
100
+
101
+ shape = mX.shape
102
+ idX = cute.make_identity_tensor(shape)
103
+ # slice for CTAs
104
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
105
+ mX = domain_offset_i64((bidx * input_tiler_mn[0], 0), mX)
106
+ gX = cute.local_tile(mX, input_tiler_mn, (0, 0))
107
+ cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0))
108
+
109
+ # declare the atoms which will be used later for memory copy
110
+ copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
111
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx)
112
+ tXgX = thr_copy_X.partition_S(gX)
113
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
114
+
115
+ # allocate fragments for gmem->rmem
116
+ tXrX = cute.make_rmem_tensor_like(tXgX)
117
+
118
+ is_even_N = const_expr(shape[1] == input_tiler_mn[1])
119
+ tXpX = (
120
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
121
+ if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N))
122
+ else None
123
+ )
124
+ if tXcX[0][0] < shape[0]:
125
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
126
+ tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, cutlass.Float32)
127
+ tXrX_f32.store(tXrX.load().to(cutlass.Float32))
128
+
129
+ # Encode the indices into the bottom bits of values.
130
+ log_N = int(math.log2(self.next_power_of_2_N))
131
+ idx_mask = const_expr((1 << log_N) - 1)
132
+ input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0])
133
+ tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32)
134
+ # Encode indices into the last log_N bits of tXrX_u32
135
+ for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True):
136
+ # tXcX only keeps track of the indices for every @vecsize elements
137
+ col_idx = cutlass.Uint32(tXcX[i // input_vecsize][1] + i % input_vecsize)
138
+ # If positive, invert the bits of the index, so that if there's a tie,
139
+ # indices coming from a earlier column will win.
140
+ encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx
141
+ # Mask to keep only the last log_N bits of the encoded index
142
+ encoded_idx = encoded_idx & idx_mask
143
+ # Clear the last log_N bits and set them to our encoded index
144
+ tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx
145
+
146
+ # Fill OOB values with -inf for top-k
147
+ if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)):
148
+ utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
149
+
150
+ threads_per_row = input_tv_layout.shape[0][0]
151
+ topk_vals = bitonic_topk(tXrX_f32, self.next_power_of_2_K, warp_width=threads_per_row)
152
+
153
+ # Extract indices and clean values
154
+ topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
155
+ topk_indices = cute.make_rmem_tensor(self.k, cutlass.Int32)
156
+ for i in cutlass.range_constexpr(self.k):
157
+ # Extract the encoded index from the last log_N bits
158
+ encoded_idx = topk_vals_u32[i] & idx_mask
159
+ # Check if original value was positive by looking at the cleaned value
160
+ topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits
161
+ # If positive, we need to invert the bits back to get original index
162
+ col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
163
+ topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
164
+
165
+ if const_expr(self.require_softmax_fusion):
166
+ topk_vals_max = -cutlass.Float32.inf
167
+ for i in cutlass.range_constexpr(self.k):
168
+ topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max)
169
+
170
+ topk_exp_sum = cutlass.Int32(0.0)
171
+ for i in cutlass.range_constexpr(self.k):
172
+ topk_vals[i] = cute.math.exp(topk_vals[i] - topk_vals_max)
173
+ topk_exp_sum = topk_exp_sum + topk_vals[i]
174
+
175
+ for i in cutlass.range_constexpr(self.k):
176
+ topk_vals[i] = topk_vals[i] / topk_exp_sum
177
+
178
+ # Convert cleaned values to output type
179
+ topk_vals_out = cute.make_rmem_tensor_like(topk_indices, mValues.element_type)
180
+ for i in cutlass.range_constexpr(self.k):
181
+ topk_vals_out[i] = topk_vals[i].to(mValues.element_type)
182
+
183
+ row = tXcX[0][0]
184
+ # Only the 1st thread in this row writes the top-k values and indices
185
+ output_vecsize = cutlass.const_expr(output_tv_layout.shape[1][0])
186
+ if row < shape[0] and tXcX[0][1] == 0:
187
+ # Vectorized write
188
+ elems_per_store = const_expr(math.gcd(output_vecsize, self.k))
189
+ mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
190
+ mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
191
+ topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
192
+ topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
193
+ for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])):
194
+ cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
195
+ cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
build/torch-cuda/functional/triton_kernels/__init__.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from ..._ops_compat import add_op_namespace_prefix
8
+ from .bitmatrix import _bitmatrix_metadata_compute_stage1, _bitmatrix_metadata_compute_stage2, _keyed_add
9
+
10
+
11
+ @triton.jit
12
+ def _compute_col_partial_sum_kernel(
13
+ topk_indices_ptr,
14
+ partial_sum_ptr,
15
+ T,
16
+ E: tl.constexpr,
17
+ n_tiles,
18
+ TOKENS_PER_TILE: tl.constexpr,
19
+ K_POW2: tl.constexpr, # next_power_of_2(K),
20
+ K: tl.constexpr, # actual number of experts per token
21
+ E_POW2: tl.constexpr, # next_power_of_2(E)
22
+ ):
23
+ # One CTA per tile. Tile `t` covers tokens [t * TOKENS_PER_TILE, (t+1) * TOKENS_PER_TILE).
24
+ # Produces partial_sum[e, tile_id] = number of entries in this tile routed to expert e.
25
+ # Layout: partial_sum is [E, n_tiles] (row-major), so partial_sum[e, t] = partial_sum_ptr + e * n_tiles + t.
26
+ # Caller transposes to [n_tiles, E] before passing to stage1/stage2.
27
+ tile_id = tl.program_id(0)
28
+
29
+ # Zero this tile's column in partial_sum[*, tile_id].
30
+ # Chunked by E_POW2 to keep vector width a power of 2.
31
+ for e_start in tl.static_range(0, E, E_POW2):
32
+ e_offs = e_start + tl.arange(0, E_POW2)
33
+ tl.store(
34
+ partial_sum_ptr + e_offs * n_tiles + tile_id,
35
+ tl.zeros([E_POW2], tl.int32),
36
+ mask=e_offs < E,
37
+ )
38
+
39
+ # Load expert ids for this tile: shape [TOKENS_PER_TILE, K_POW2].
40
+ # Tokens beyond T and k-slots beyond K are masked out (other=-1).
41
+ tok_offs = tile_id * TOKENS_PER_TILE + tl.arange(0, TOKENS_PER_TILE)
42
+ k_offs = tl.arange(0, K_POW2)
43
+ tok_mask = tok_offs < T
44
+
45
+ load_mask = tok_mask[:, None] & (k_offs[None, :] < K)
46
+ safe_k = tl.minimum(k_offs, K - 1) # avoid OOB when k_offs >= K
47
+ expert_ids = tl.load(
48
+ topk_indices_ptr + tok_offs[:, None] * K + safe_k[None, :],
49
+ mask=load_mask,
50
+ other=-1,
51
+ )
52
+
53
+ # Flatten to [TOKENS_PER_TILE * K_POW2] and histogram into partial_sum.
54
+ # safe_experts remaps masked (-1) entries to expert 0 (harmless: flat_mask=False).
55
+ flat_experts = tl.reshape(expert_ids, [TOKENS_PER_TILE * K_POW2])
56
+ flat_mask = tl.reshape(load_mask, [TOKENS_PER_TILE * K_POW2])
57
+ safe_experts = tl.where(flat_mask, flat_experts, 0)
58
+
59
+ tl.atomic_add(
60
+ partial_sum_ptr + safe_experts * n_tiles + tile_id,
61
+ tl.full([TOKENS_PER_TILE * K_POW2], 1, dtype=tl.int32),
62
+ mask=flat_mask,
63
+ )
64
+
65
+
66
+ @torch.library.custom_op(
67
+ add_op_namespace_prefix("triton_kernels__TC_topk_router_metadata"),
68
+ mutates_args={
69
+ "expert_frequency",
70
+ "expert_frequency_offset",
71
+ "x_gather_idx",
72
+ "s_scatter_idx",
73
+ "s_reverse_scatter_idx",
74
+ },
75
+ )
76
+ def TC_topk_router_metadata_triton(
77
+ topk_router_indices: torch.Tensor,
78
+ E: int,
79
+ expert_frequency: torch.Tensor,
80
+ expert_frequency_offset: torch.Tensor,
81
+ x_gather_idx: torch.Tensor,
82
+ s_scatter_idx: torch.Tensor,
83
+ s_reverse_scatter_idx: torch.Tensor,
84
+ ) -> None:
85
+ T, K = topk_router_indices.size()
86
+ TK = T * K
87
+ device = topk_router_indices.device
88
+ E_POW2 = triton.next_power_of_2(E)
89
+ K_POW2 = triton.next_power_of_2(K)
90
+ TOKENS_PER_BLOCK = 1024 // K_POW2
91
+ n_tiles = triton.cdiv(T, TOKENS_PER_BLOCK)
92
+
93
+ # ── Kernel 1: tiled histogram ─────────────────────────────────────────────
94
+ # col_partial_sum_trans[E, n_tiles]: raw per-expert-per-tile counts.
95
+ # Stored transposed so each CTA writes to its own column (tile_id), avoiding
96
+ # cross-CTA write conflicts. Transposed back to [n_tiles, E] for stage1/stage2.
97
+ col_partial_sum_trans = torch.empty(E, n_tiles, dtype=torch.int32, device=device)
98
+ _compute_col_partial_sum_kernel[(n_tiles,)](
99
+ topk_router_indices,
100
+ col_partial_sum_trans,
101
+ T,
102
+ E,
103
+ n_tiles,
104
+ TOKENS_PER_TILE=TOKENS_PER_BLOCK,
105
+ K_POW2=K_POW2,
106
+ K=K,
107
+ E_POW2=E_POW2,
108
+ )
109
+
110
+ expert_frequency.copy_(col_partial_sum_trans.sum(dim=1, dtype=torch.int32))
111
+ col_partial_sum = col_partial_sum_trans.T # [n_tiles, E]
112
+
113
+ # ── Kernel 2: stage1 ─────────────────────────────────────────────────────
114
+ # - For each expert e (pid < E): convert col_partial_sum[*, e] from raw
115
+ # counts to exclusive prefix sums over tiles in-place.
116
+ # - For pid == E: write exclusive cumsum of expert_freq_offset into
117
+ # expert_freq_off[0:E] (= col_offs, a view into expert_freq_off).
118
+
119
+ _bitmatrix_metadata_compute_stage1[(E + 2,)](
120
+ expert_frequency,
121
+ expert_frequency_offset,
122
+ E,
123
+ col_partial_sum,
124
+ n_tiles,
125
+ TK,
126
+ BLOCK_M=128,
127
+ BLOCK_N=E_POW2,
128
+ )
129
+
130
+ # ── Kernel 3: stage2 ─────────────────────────────────────────���───────────
131
+ # For each tile: sort entries by expert, compute output positions, scatter.
132
+ _bitmatrix_metadata_compute_stage2[(n_tiles,)](
133
+ s_scatter_idx,
134
+ s_reverse_scatter_idx,
135
+ x_gather_idx,
136
+ topk_router_indices,
137
+ T,
138
+ col_partial_sum,
139
+ n_tiles,
140
+ expert_frequency_offset[:E],
141
+ K_POW2=K_POW2,
142
+ TOKENS_PER_BLOCK=TOKENS_PER_BLOCK,
143
+ K=K,
144
+ )
145
+
146
+
147
+ # ── general_routing_router_metadata_triton --- Kernel 1: tiled histogram over flat selected_E ────────────────────────────
148
+ @triton.jit
149
+ def _general_compute_col_partial_sum_kernel(
150
+ selected_E_ptr,
151
+ partial_sum_ptr, # [E, n_tiles], column-major per tile
152
+ TK,
153
+ E: tl.constexpr,
154
+ n_tiles,
155
+ BLOCK_SIZE: tl.constexpr,
156
+ E_POW2: tl.constexpr,
157
+ ):
158
+ tile_id = tl.program_id(0)
159
+
160
+ # Zero this tile's column in partial_sum[*, tile_id].
161
+ for e_start in tl.static_range(0, E, E_POW2):
162
+ e_offs = e_start + tl.arange(0, E_POW2)
163
+ tl.store(
164
+ partial_sum_ptr + e_offs * n_tiles + tile_id,
165
+ tl.zeros([E_POW2], tl.int32),
166
+ mask=e_offs < E,
167
+ )
168
+
169
+ # Load expert ids for this tile (flat indexing into selected_E).
170
+ offs = tile_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
171
+ mask = offs < TK
172
+ expert_ids = tl.load(selected_E_ptr + offs, mask=mask, other=-1)
173
+
174
+ safe_experts = tl.where(mask, expert_ids, 0)
175
+ tl.atomic_add(
176
+ partial_sum_ptr + safe_experts * n_tiles + tile_id,
177
+ tl.full([BLOCK_SIZE], 1, dtype=tl.int32),
178
+ mask=mask,
179
+ )
180
+
181
+
182
+ # ── general_routing_router_metadata_triton --- Kernel 3: sort entries by expert within each tile, scatter ────────────────
183
+ @triton.jit
184
+ def _general_metadata_compute_stage2(
185
+ s_scatter_idx_ptr,
186
+ s_reverse_scatter_idx_ptr,
187
+ x_gather_idx_ptr,
188
+ selected_E_ptr,
189
+ sorted_selected_T_ptr,
190
+ TK,
191
+ partial_sum_ptr, # [n_tiles, E] with strides (1, n_tiles)
192
+ n_tiles,
193
+ expert_offs_ptr,
194
+ BLOCK_SIZE: tl.constexpr,
195
+ ):
196
+ tl.static_assert(BLOCK_SIZE <= 32768)
197
+
198
+ pid_m = tl.program_id(0)
199
+ offs_local = tl.arange(0, BLOCK_SIZE)
200
+ offs_global = pid_m * BLOCK_SIZE + offs_local
201
+ mask = offs_global < TK
202
+
203
+ # Load expert id for each entry in this tile.
204
+ expert = tl.load(selected_E_ptr + offs_global, mask=mask, other=-1).to(tl.uint32)
205
+
206
+ # Pack (expert, local_offset) into uint32 and sort by expert.
207
+ # Upper 16 bits = expert id, lower 16 bits = pre-sort local offset.
208
+ kv_pairs = tl.sort(((expert << 16) | offs_local).to(tl.uint32), 0)
209
+ expert = kv_pairs >> 16
210
+ mask = expert != 0xFFFF
211
+
212
+ # Segmented scan for within-expert rank.
213
+ scan_input = (kv_pairs & 0xFFFF0000) | 0x00000001
214
+ inclusive_run_lengths = tl.associative_scan(scan_input, 0, _keyed_add)
215
+ within_expert_rank = (inclusive_run_lengths - 1) & 0xFFFF
216
+
217
+ # Output position = expert_offs[e] + partial_sum[tile, e] + within_expert_rank.
218
+ s_reverse_scatter_val = tl.load(partial_sum_ptr + pid_m + expert * n_tiles, mask=mask)
219
+ s_reverse_scatter_val += tl.load(expert_offs_ptr + expert, mask=mask)
220
+ s_reverse_scatter_val += within_expert_rank
221
+
222
+ # Recover pre-sort entry index and look up the token index.
223
+ presort_offs = kv_pairs & 0xFFFF
224
+ entry_idx = pid_m * BLOCK_SIZE + presort_offs
225
+ token_idx = tl.load(sorted_selected_T_ptr + entry_idx, mask=mask)
226
+
227
+ tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_val, mask=mask)
228
+ tl.store(s_scatter_idx_ptr + s_reverse_scatter_val, entry_idx, mask=mask)
229
+ tl.store(x_gather_idx_ptr + s_reverse_scatter_val, token_idx, mask=mask)
230
+
231
+
232
+ # ── general_routing_router_metadata_triton --- Kernel 4: parallel binary search for token offset ─────────────────────────
233
+ # Since sorted_selected_T is sorted ascending, num_activated_expert_per_token_offset[t]
234
+ # is exactly searchsorted_left(sorted_selected_T, t): the index of the first entry
235
+ # with token index >= t. We compute this via parallel binary search over T+1 queries,
236
+ # replacing the PyTorch bincount + cumsum path.
237
+ @triton.jit
238
+ def _token_offset_searchsorted_kernel(
239
+ sorted_T_ptr, # [TK] int32, sorted ascending
240
+ offset_ptr, # [T+1] int32, output
241
+ T, # number of tokens
242
+ TK, # length of sorted_T
243
+ BLOCK_SIZE: tl.constexpr,
244
+ N_ITERS: tl.constexpr, # ceil(log2(TK + 1)), controls binary search depth
245
+ ):
246
+ pid = tl.program_id(0)
247
+ t_offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
248
+ mask = t_offs <= T # T+1 total values: offset[0], ..., offset[T]
249
+
250
+ t_vals = t_offs.to(tl.int32)
251
+
252
+ # Binary search: find smallest i such that sorted_T[i] >= t_vals
253
+ lo = tl.zeros([BLOCK_SIZE], dtype=tl.int32)
254
+ hi = tl.full([BLOCK_SIZE], TK, dtype=tl.int32)
255
+
256
+ for _ in tl.static_range(0, N_ITERS):
257
+ mid = (lo + hi) >> 1
258
+ # When mid >= TK, treat the value as +inf (>= any t), so hi = mid.
259
+ safe_mid = tl.where(mid < TK, mid, 0)
260
+ val = tl.load(sorted_T_ptr + safe_mid, mask=mask & (TK > 0), other=T)
261
+ go_right = (val < t_vals) & (mid < TK)
262
+ lo = tl.where(go_right, mid + 1, lo)
263
+ hi = tl.where(go_right, hi, mid)
264
+
265
+ tl.store(offset_ptr + t_offs, lo, mask=mask)
266
+
267
+
268
+ @torch.library.custom_op(
269
+ add_op_namespace_prefix("triton_kernels__general_routing_router_metadata"),
270
+ mutates_args={
271
+ "expert_frequency",
272
+ "expert_frequency_offset",
273
+ "x_gather_idx",
274
+ "s_scatter_idx",
275
+ "s_reverse_scatter_idx",
276
+ "num_activated_expert_per_token_offset",
277
+ },
278
+ )
279
+ def general_routing_router_metadata_triton(
280
+ sorted_selected_T: torch.Tensor,
281
+ selected_E: torch.Tensor,
282
+ T: int,
283
+ E: int,
284
+ expert_frequency: torch.Tensor,
285
+ expert_frequency_offset: torch.Tensor,
286
+ x_gather_idx: torch.Tensor,
287
+ s_scatter_idx: torch.Tensor,
288
+ s_reverse_scatter_idx: torch.Tensor,
289
+ num_activated_expert_per_token_offset: torch.Tensor,
290
+ ) -> None:
291
+ TK = selected_E.size(0)
292
+ device = selected_E.device
293
+ E_POW2 = triton.next_power_of_2(E)
294
+ BLOCK_SIZE = 1024
295
+ n_tiles = triton.cdiv(TK, BLOCK_SIZE)
296
+
297
+ # ── Kernel 1: tiled histogram ─────────────────────────────────────────
298
+ col_partial_sum_trans = torch.empty(E, n_tiles, dtype=torch.int32, device=device)
299
+ _general_compute_col_partial_sum_kernel[(n_tiles,)](
300
+ selected_E,
301
+ col_partial_sum_trans,
302
+ TK,
303
+ E,
304
+ n_tiles,
305
+ BLOCK_SIZE=BLOCK_SIZE,
306
+ E_POW2=E_POW2,
307
+ )
308
+
309
+ expert_frequency.copy_(col_partial_sum_trans.sum(dim=1, dtype=torch.int32))
310
+ col_partial_sum = col_partial_sum_trans.T # [n_tiles, E], strides (1, n_tiles)
311
+
312
+ # ── Kernel 2: stage1 ─────────────────────────────────────────────────
313
+ _bitmatrix_metadata_compute_stage1[(E + 2,)](
314
+ expert_frequency,
315
+ expert_frequency_offset,
316
+ E,
317
+ col_partial_sum,
318
+ n_tiles,
319
+ TK,
320
+ BLOCK_M=128,
321
+ BLOCK_N=E_POW2,
322
+ )
323
+
324
+ # ── Kernel 3: stage2 ─────────────────────────────────────────────────
325
+ _general_metadata_compute_stage2[(n_tiles,)](
326
+ s_scatter_idx,
327
+ s_reverse_scatter_idx,
328
+ x_gather_idx,
329
+ selected_E,
330
+ sorted_selected_T,
331
+ TK,
332
+ col_partial_sum,
333
+ n_tiles,
334
+ expert_frequency_offset[:E],
335
+ BLOCK_SIZE=BLOCK_SIZE,
336
+ )
337
+
338
+ # ── Kernel 4: num_activated_expert_per_token_offset via searchsorted ──
339
+ # sorted_selected_T is sorted ascending, so offset[t] = searchsorted_left(sorted_T, t).
340
+ # Parallel binary search: each thread handles one token index, O(log TK) work.
341
+ N_ITERS = max(1, math.ceil(math.log2(TK + 1)))
342
+ TOKEN_BLOCK = 1024
343
+ n_token_blocks = triton.cdiv(T + 1, TOKEN_BLOCK)
344
+ _token_offset_searchsorted_kernel[(n_token_blocks,)](
345
+ sorted_selected_T,
346
+ num_activated_expert_per_token_offset,
347
+ T,
348
+ TK,
349
+ BLOCK_SIZE=TOKEN_BLOCK,
350
+ N_ITERS=N_ITERS,
351
+ )
build/torch-cuda/functional/triton_kernels/bitmatrix.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ # https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L33
6
+ @triton.jit
7
+ def _keyed_add(x, y):
8
+ # we keep the key in the upper 16 bits of a uint32:
9
+ key_mask: tl.constexpr = 0xFFFF0000
10
+
11
+ kx = x & key_mask
12
+ ky = y & key_mask
13
+ z = tl.where(kx == ky, x + y - kx, y)
14
+ return z
15
+
16
+
17
+ # Adapted from https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L44
18
+ @triton.jit
19
+ def _bitmatrix_metadata_compute_stage1(
20
+ expert_freq_ptr,
21
+ expert_freq_offs_ptr,
22
+ E: tl.constexpr,
23
+ partial_sum_ptr,
24
+ n_tiles,
25
+ TK,
26
+ BLOCK_M: tl.constexpr, # chunk size for iterating over tiles per expert
27
+ BLOCK_N: tl.constexpr, # chunk size for iterating over experts in cumsum
28
+ ):
29
+ # Assume grid size == E + 1
30
+
31
+ pid = tl.program_id(0)
32
+ if pid < E:
33
+ # convert partial_sum[e, *] from raw counts to exclusive prefix
34
+ # sums over tiles. After this kernel, partial_sum[e, t] =
35
+ # number of entries for expert e in tiles 0..t-1.
36
+
37
+ # This is read by stage2 to locate each entry's position within expert e's contiguous output segment.
38
+ expert_partial_sum_ptr = partial_sum_ptr + pid * n_tiles
39
+ curr_sum = 0
40
+ for start in range(0, n_tiles, BLOCK_M):
41
+ offs = start + tl.arange(0, BLOCK_M)
42
+ tile_counts = tl.load(expert_partial_sum_ptr + offs, mask=offs < n_tiles, other=0)
43
+ excl_cumsum = tl.cumsum(tile_counts, 0) - tile_counts + curr_sum
44
+ curr_sum += tl.sum(tile_counts, 0)
45
+ tl.store(expert_partial_sum_ptr + offs, excl_cumsum, mask=offs < n_tiles)
46
+ elif pid == E:
47
+ # Exclusive prefix sum of per-expert total counts → expert_offs[e].
48
+ # expert_freq_offset[e] = total entries routed to expert e (from A.sum(dim=1)).
49
+ # expert_offs[e] = sum of expert_freq_offset[0..e-1] = global start of expert e.
50
+ curr_sum = 0
51
+ for start in tl.static_range(0, E, BLOCK_N):
52
+ offs = start + tl.arange(0, BLOCK_N)
53
+ expert_freq = tl.load(expert_freq_ptr + offs, mask=offs < E, other=0)
54
+ excl_cumsum = tl.cumsum(expert_freq, 0) - expert_freq + curr_sum
55
+ curr_sum += tl.sum(expert_freq, 0)
56
+ tl.store(expert_freq_offs_ptr + offs, excl_cumsum, mask=offs < E)
57
+ elif pid == E + 1:
58
+ # expert_freq_off[E] = TK (total number of entries)
59
+ tl.store(expert_freq_offs_ptr + E, TK)
60
+
61
+
62
+ # Adapted from https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L44
63
+ @triton.jit
64
+ def _bitmatrix_metadata_compute_stage2(
65
+ s_scatter_idx_ptr,
66
+ s_reverse_scatter_idx_ptr,
67
+ x_gather_idx_ptr,
68
+ topk_indices_ptr,
69
+ T,
70
+ partial_sum_ptr,
71
+ n_tiles,
72
+ expert_offs_ptr,
73
+ K_POW2: tl.constexpr, # padded K, == BLOCK_SIZE / BLOCK
74
+ K: tl.constexpr, # actual experts per token
75
+ TOKENS_PER_BLOCK: tl.constexpr, # tokens per tile
76
+ ):
77
+ # One CTA per tile, same tiling as _compute_col_partial_sum_kernel.
78
+ # For each entry (token t, k-slot k) in this tile:
79
+ # s_reverse_scatter_idx[entry_idx] = output position in expert-sorted order
80
+ # s_scatter_idx[output_pos] = entry_idx (inverse permutation)
81
+ # x_gather_idx[output_pos] = token index (= entry_idx // K)
82
+ #
83
+ # Output position = expert_offs[e] (global start of expert e)
84
+ # + partial_sum[tile, e] (entries for e in earlier tiles, after stage1)
85
+ # + within_expert_rank (position within this tile's group for e)
86
+ BLOCK_SIZE: tl.constexpr = TOKENS_PER_BLOCK * K_POW2
87
+ IS_POW2_K: tl.constexpr = K == K_POW2 # fast path: no padding waste
88
+ tl.static_assert(BLOCK_SIZE <= 32768)
89
+
90
+ pid_m = tl.program_id(0)
91
+ offs_local = tl.arange(0, BLOCK_SIZE) # position within this tile's flat [BLOCK*K_POW2] space
92
+ offs_global = pid_m * BLOCK_SIZE + offs_local
93
+ mask = offs_global < T * K_POW2
94
+
95
+ # Load expert id for each slot. IS_POW2_K fast path reads topk_indices as a
96
+ # flat 1D array (no padding gaps). Non-pow2 path reads 2D with k_slot masking.
97
+ if IS_POW2_K:
98
+ expert = tl.load(topk_indices_ptr + offs_global, mask=mask, other=-1).to(tl.uint32)
99
+ else:
100
+ token_i_local = offs_local // K_POW2
101
+ k_slot = offs_local % K_POW2
102
+ token_i_global = pid_m * TOKENS_PER_BLOCK + token_i_local
103
+ load_mask = mask & (k_slot < K)
104
+ safe_k = tl.minimum(k_slot, K - 1)
105
+ expert = tl.load(
106
+ topk_indices_ptr + token_i_global * K + safe_k,
107
+ mask=load_mask,
108
+ other=-1,
109
+ ).to(tl.uint32)
110
+
111
+ # Pack (expert, presort_offs) into a uint32 kv pair and sort by expert.
112
+ # Upper 16 bits = expert id (sort key), lower 16 bits = pre-sort local offset.
113
+ # Invalid slots have expert=0xffff (from other=-1 cast to uint32 >> 16).
114
+ kv_pairs = tl.sort(((expert << 16) | offs_local).to(tl.uint32), 0)
115
+ expert = kv_pairs >> 16
116
+ mask = expert != 0xFFFF # exclude padding/OOB slots
117
+
118
+ # Segmented scan to compute within-expert rank (0-based exclusive count).
119
+ # scan_input packs expert id in upper 16 bits and count=1 in lower 16 bits.
120
+ # _keyed_add resets the count at each expert boundary.
121
+ scan_input = (kv_pairs & 0xFFFF0000) | 0x00000001
122
+ inclusive_run_lengths = tl.associative_scan(scan_input, 0, _keyed_add)
123
+ within_expert_rank = (inclusive_run_lengths - 1) & 0xFFFF # exclusive = inclusive - 1
124
+
125
+ # Output position for this entry in the expert-sorted output array.
126
+ # partial_sum layout after stage1: [n_tiles, E], stride (1, n_tiles).
127
+ # So partial_sum[pid_m, expert] = partial_sum_ptr + pid_m*1 + expert*n_tiles.
128
+ s_reverse_scatter_idx = tl.load(partial_sum_ptr + pid_m + expert * n_tiles, mask=mask)
129
+ s_reverse_scatter_idx += tl.load(expert_offs_ptr + expert, mask=mask)
130
+ s_reverse_scatter_idx += within_expert_rank
131
+
132
+ if IS_POW2_K:
133
+ # presort_offs == offs_local before sort; entry_idx is the flat index into
134
+ # topk_router_indices.view(-1), i.e. token * K + k_slot.
135
+ presort_offs = kv_pairs & 0xFFFF
136
+ entry_idx = pid_m * BLOCK_SIZE + presort_offs
137
+ tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_idx, mask=mask)
138
+ tl.store(s_scatter_idx_ptr + s_reverse_scatter_idx, entry_idx, mask=mask)
139
+ tl.store(x_gather_idx_ptr + s_reverse_scatter_idx, entry_idx // K_POW2, mask=mask)
140
+ else:
141
+ # presort_offs is in K_POW2-padded space; convert to unpadded entry_idx.
142
+ presort_offs = kv_pairs & 0xFFFF
143
+ token_i_global_s = pid_m * TOKENS_PER_BLOCK + presort_offs // K_POW2
144
+ entry_idx = token_i_global_s * K + presort_offs % K_POW2
145
+ tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_idx, mask=mask)
146
+ tl.store(s_scatter_idx_ptr + s_reverse_scatter_idx, entry_idx, mask=mask)
147
+ tl.store(x_gather_idx_ptr + s_reverse_scatter_idx, token_i_global_s, mask=mask)
build/torch-cuda/functional/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ import os
6
+ from contextlib import contextmanager
7
+
8
+
9
+ _IS_USING_QUACK_GEMM = os.getenv("USE_QUACK_GEMM", "0") == "1"
10
+
11
+
12
+ @contextmanager
13
+ def enable_quack_gemm(enable: bool = True):
14
+ global _IS_USING_QUACK_GEMM
15
+
16
+ previous_value = _IS_USING_QUACK_GEMM
17
+ _IS_USING_QUACK_GEMM = enable
18
+
19
+ yield
20
+
21
+ _IS_USING_QUACK_GEMM = previous_value
22
+
23
+
24
+ def is_using_quack_gemm() -> bool:
25
+ return _IS_USING_QUACK_GEMM
build/torch-cuda/jit.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ import inspect
6
+ import os
7
+ from shutil import rmtree
8
+ from typing import Callable
9
+ from uuid import uuid4
10
+
11
+ import torch
12
+ from torch.utils.cpp_extension import load as load_cpp_extension
13
+
14
+
15
+ _CPP_MODULE_PREFIX = "sonicmoe"
16
+ _GLOBAL_RANK = int(os.getenv("RANK", 0))
17
+ _WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
18
+
19
+ _ALL_COMPILED_MODULES = {}
20
+
21
+
22
+ @torch.compiler.disable
23
+ def _get_cpp_function(function_name: str, module_name: str, source_files: list[str], build_directory: str) -> Callable:
24
+ module_name = f"{_CPP_MODULE_PREFIX}_{module_name}"
25
+
26
+ extra_cflags = ["-O3", "-Wall", "-shared", "-fPIC", "-fdiagnostics-color"]
27
+ extra_cuda_cflags = ["-O3", "-lineinfo"]
28
+ extra_include_paths = [
29
+ os.path.dirname(__file__), # sonicmoe/include
30
+ os.path.dirname(os.path.dirname(__file__)) + "/cutlass/include", # cutlass
31
+ os.path.dirname(os.path.dirname(__file__)) + "/cutlass/tools/util/include", # cutlass
32
+ ]
33
+
34
+ module = _ALL_COMPILED_MODULES.get(module_name, None)
35
+
36
+ if module is None:
37
+ if torch.distributed.is_initialized():
38
+ os.makedirs(build_directory, exist_ok=True)
39
+
40
+ if _GLOBAL_RANK == 0:
41
+ module = load_cpp_extension(
42
+ module_name,
43
+ sources=source_files,
44
+ with_cuda=True,
45
+ extra_cflags=extra_cflags,
46
+ extra_cuda_cflags=extra_cuda_cflags,
47
+ extra_include_paths=extra_include_paths,
48
+ build_directory=build_directory,
49
+ verbose=True,
50
+ )
51
+
52
+ torch.distributed.barrier()
53
+
54
+ if _GLOBAL_RANK != 0:
55
+ module = load_cpp_extension(
56
+ module_name,
57
+ sources=source_files,
58
+ with_cuda=True,
59
+ extra_cflags=extra_cflags,
60
+ extra_cuda_cflags=extra_cuda_cflags,
61
+ extra_include_paths=extra_include_paths,
62
+ build_directory=build_directory,
63
+ verbose=False,
64
+ )
65
+ else:
66
+ if _WORLD_SIZE > 1:
67
+ build_directory = os.path.join(build_directory, str(uuid4()))
68
+
69
+ os.makedirs(build_directory, exist_ok=True)
70
+
71
+ module = load_cpp_extension(
72
+ module_name,
73
+ sources=source_files,
74
+ with_cuda=True,
75
+ extra_cflags=extra_cflags,
76
+ extra_cuda_cflags=extra_cuda_cflags,
77
+ extra_include_paths=extra_include_paths,
78
+ build_directory=build_directory,
79
+ verbose=True,
80
+ )
81
+
82
+ if _WORLD_SIZE > 1:
83
+ rmtree(build_directory, ignore_errors=True)
84
+
85
+ _ALL_COMPILED_MODULES[module_name] = module
86
+
87
+ return getattr(module, function_name)
88
+
89
+
90
+ def cpp_jit(
91
+ function_name: str | None = None,
92
+ extra_source_files: list[str] = [],
93
+ build_directory: str | None = None,
94
+ depth: int = 0,
95
+ ) -> Callable:
96
+ """wrapper to compile C++/CUDA source code at runtime.
97
+
98
+ Args:
99
+ function_name (str | None, optional): name of the function to expose from the C++ file, the python function
100
+ name should match the funcion name in the C++ file if this is not specified. Defaults to None.
101
+ extra_source_files (list[str], optional): any extra files to use for compilation, by default it scans the
102
+ directory of the python stub file. Defaults to [].
103
+ build_directory (str | None, optional): directory in which to place the build artifacts. Defaults to None.
104
+ depth (int, optional): number of times dirname is called to get the build path. Defaults to 2.
105
+
106
+ Returns:
107
+ Callable: returns the wrapped function that can be used to call the C++ functions from python
108
+ """
109
+ cpp_function = None
110
+ args_spec = None
111
+
112
+ source_files = []
113
+ source_files.extend(extra_source_files)
114
+
115
+ calling_filename = inspect.stack()[1].filename
116
+ calling_directory = os.path.dirname(calling_filename)
117
+
118
+ for dirname, _, filenames in os.walk(calling_directory):
119
+ filenames = [os.path.join(dirname, f) for f in filenames]
120
+ filenames = filter(lambda f: os.path.splitext(f)[1] in [".cu", ".cpp"], filenames)
121
+ source_files.extend(filenames)
122
+
123
+ if build_directory is None:
124
+ module_name = calling_directory
125
+ for _ in range(depth):
126
+ module_name = os.path.dirname(module_name)
127
+ module_name = os.path.basename(module_name)
128
+
129
+ build_directory = os.path.join(os.path.dirname(os.path.dirname(__file__)), "build", module_name)
130
+
131
+ def _run(*args, **kwargs):
132
+ nonlocal cpp_function
133
+
134
+ if cpp_function is None:
135
+ cpp_function = _get_cpp_function(
136
+ function_name=_run.__name__,
137
+ module_name=module_name,
138
+ source_files=source_files,
139
+ build_directory=build_directory,
140
+ )
141
+
142
+ full_args = []
143
+ full_args.extend(args)
144
+ for variable_name in args_spec.args[len(args) :]:
145
+ full_args.append(kwargs[variable_name])
146
+
147
+ return cpp_function(*full_args)
148
+
149
+ def _wrapper(function: Callable) -> Callable:
150
+ nonlocal args_spec
151
+ args_spec = inspect.getfullargspec(function)
152
+
153
+ _run.__doc__ = function.__doc__
154
+ _run.__name__ = function.__name__ if function_name is None else function_name
155
+ _run.__signature__ = inspect.signature(function)
156
+
157
+ return _run
158
+
159
+ return _wrapper
build/torch-cuda/metadata.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [
5
+ "nvidia-cutlass-dsl"
6
+ ],
7
+ "backend": {
8
+ "type": "cuda"
9
+ }
10
+ }
build/torch-cuda/moe.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ********************************************************************************
2
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
+ # ********************************************************************************
4
+
5
+ from typing import Callable
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .enums import ActivationType, KernelBackendMoE, is_glu
12
+ from .functional import moe_TC_softmax_topk_layer
13
+
14
+
15
+ try:
16
+ from xma.modules.moe import scattered_experts
17
+
18
+ _IS_XMA_AVAILABLE = True
19
+ except ImportError:
20
+ _IS_XMA_AVAILABLE = False
21
+
22
+
23
+ def _swiglu(x: torch.Tensor) -> torch.Tensor:
24
+ u = x[..., 1::2]
25
+ g = x[..., ::2]
26
+ return u * F.silu(g)
27
+
28
+
29
+ def _geglu(x: torch.Tensor) -> torch.Tensor:
30
+ u = x[..., 1::2]
31
+ g = x[..., ::2]
32
+ return (F.gelu(g.to(dtype=torch.float32)) * u).to(dtype=g.dtype)
33
+
34
+
35
+ def _gelu(x: torch.Tensor) -> torch.Tensor:
36
+ return F.gelu(x.to(dtype=torch.float32)).to(dtype=x.dtype)
37
+
38
+
39
+ def _reglu(x: torch.Tensor) -> torch.Tensor:
40
+ u = x[..., 1::2]
41
+ g = x[..., ::2]
42
+ return (F.relu(g) * u).to(dtype=g.dtype)
43
+
44
+
45
+ def _relu(x: torch.Tensor) -> torch.Tensor:
46
+ return F.relu(x)
47
+
48
+
49
+ def _relu_sq(x: torch.Tensor) -> torch.Tensor:
50
+ return F.relu(x) ** 2
51
+
52
+
53
+ def _silu(x: torch.Tensor) -> torch.Tensor:
54
+ return F.silu(x)
55
+
56
+
57
+ class Experts(nn.Module):
58
+ def __init__(
59
+ self, num_experts: int, in_features: int, out_features: int, add_bias: bool = True, std: float | None = None
60
+ ) -> None:
61
+ super().__init__()
62
+
63
+ self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features))
64
+
65
+ self.bias = None
66
+ if add_bias:
67
+ self.bias = nn.Parameter(torch.empty(num_experts, out_features))
68
+
69
+ self.std = std
70
+
71
+ self.num_experts = num_experts
72
+ self.in_features = in_features
73
+ self.out_features = out_features
74
+
75
+ self.reset_parameters()
76
+
77
+ def up_projection_scattermoe_forward(
78
+ self,
79
+ input: torch.Tensor,
80
+ num_experts_per_token: int | None = None,
81
+ sorted_expert_idxs: torch.Tensor | None = None,
82
+ sorted_scattered_idxs: torch.Tensor | None = None,
83
+ expert_offsets: torch.Tensor | None = None,
84
+ ) -> torch.Tensor:
85
+ assert self.bias is None
86
+
87
+ if not _IS_XMA_AVAILABLE:
88
+ raise ImportError(
89
+ "install accelerated-model-architectures from https://github.com/open-lm-engine/accelerated-model-architectures"
90
+ )
91
+
92
+ input = scattered_experts(
93
+ inputs=input,
94
+ expert_weights=self.weight.permute(0, 2, 1),
95
+ k=num_experts_per_token,
96
+ sorted_expert_idxs=sorted_expert_idxs,
97
+ sorted_scattered_idxs=sorted_scattered_idxs,
98
+ expert_offsets=expert_offsets,
99
+ gates=None,
100
+ grouped_in=False,
101
+ grouped_out=True,
102
+ )
103
+
104
+ return input
105
+
106
+ def down_projection_scattermoe_forward(
107
+ self,
108
+ input: torch.Tensor,
109
+ num_experts_per_token: int | None = None,
110
+ sorted_expert_idxs: torch.Tensor | None = None,
111
+ sorted_scattered_idxs: torch.Tensor | None = None,
112
+ expert_offsets: torch.Tensor | None = None,
113
+ gates: torch.Tensor | None = None,
114
+ ) -> torch.Tensor:
115
+ assert self.bias is None
116
+
117
+ if not _IS_XMA_AVAILABLE:
118
+ raise ImportError(
119
+ "install accelerated-model-architectures from https://github.com/open-lm-engine/accelerated-model-architectures"
120
+ )
121
+
122
+ input = scattered_experts(
123
+ inputs=input,
124
+ expert_weights=self.weight.permute(0, 2, 1),
125
+ k=num_experts_per_token,
126
+ sorted_expert_idxs=sorted_expert_idxs,
127
+ sorted_scattered_idxs=sorted_scattered_idxs,
128
+ expert_offsets=expert_offsets,
129
+ gates=gates,
130
+ grouped_in=True,
131
+ grouped_out=False,
132
+ )
133
+
134
+ return input
135
+
136
+ def torch_forward(
137
+ self, input: torch.Tensor, expert_frequency: torch.Tensor | None, return_list: bool = False
138
+ ) -> list[torch.Tensor] | torch.Tensor:
139
+ if isinstance(input, torch.Tensor):
140
+ input = input.split(expert_frequency.tolist(), dim=0)
141
+ else:
142
+ assert expert_frequency is None
143
+
144
+ input = [
145
+ F.linear(input[i], self.weight[i], None if self.bias is None else self.bias[i])
146
+ for i in range(self.num_experts)
147
+ ]
148
+
149
+ if not return_list:
150
+ input = torch.cat(input, dim=0)
151
+
152
+ return input
153
+
154
+ def extra_repr(self):
155
+ return "num_experts={}, in_features={}, out_features={}".format(
156
+ self.num_experts, self.in_features, self.out_features
157
+ )
158
+
159
+ @torch.no_grad()
160
+ def reset_parameters(self) -> None:
161
+ nn.init.normal_(self.weight, mean=0, std=self.std)
162
+ if hasattr(self, "bias") and self.bias is not None:
163
+ self.bias.zero_()
164
+
165
+
166
+ class MoE(nn.Module):
167
+ def __init__(
168
+ self,
169
+ num_experts: int,
170
+ num_experts_per_tok: int,
171
+ hidden_size: int,
172
+ intermediate_size: int,
173
+ activation_function: ActivationType,
174
+ add_bias: bool,
175
+ std: float,
176
+ ) -> None:
177
+ super().__init__()
178
+
179
+ self.num_experts = num_experts
180
+ self.top_k = num_experts_per_tok
181
+
182
+ self.hidden_size = hidden_size
183
+ self.intermediate_size = intermediate_size
184
+
185
+ self.router = nn.Linear(in_features=self.hidden_size, out_features=num_experts, bias=False)
186
+
187
+ self.activation_function = activation_function
188
+
189
+ self.c_fc = Experts(
190
+ num_experts=num_experts,
191
+ in_features=self.hidden_size,
192
+ out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size,
193
+ add_bias=add_bias,
194
+ std=std,
195
+ )
196
+
197
+ self.c_proj = Experts(
198
+ num_experts=num_experts,
199
+ in_features=self.intermediate_size,
200
+ out_features=self.hidden_size,
201
+ add_bias=add_bias,
202
+ std=std,
203
+ )
204
+
205
+ self.stream_id = torch.cuda.current_stream().cuda_stream
206
+
207
+ def forward(
208
+ self,
209
+ hidden_states: torch.Tensor,
210
+ kernel_backend_moe: KernelBackendMoE = KernelBackendMoE.sonicmoe,
211
+ is_inference_mode: bool = False,
212
+ ) -> tuple[torch.Tensor, torch.Tensor]:
213
+ original_shape = hidden_states.shape
214
+
215
+ # hidden_states -> (batch_size, query_length, hidden_size)
216
+ hidden_states = hidden_states.view(-1, self.hidden_size)
217
+
218
+ if kernel_backend_moe == KernelBackendMoE.sonicmoe and self.num_experts <= 32768:
219
+ hidden_states, router_logits, expert_frequency = moe_TC_softmax_topk_layer(
220
+ hidden_states,
221
+ self.router.weight,
222
+ self.c_fc.weight.permute(1, 2, 0),
223
+ self.c_fc.bias,
224
+ self.c_proj.weight.permute(1, 2, 0),
225
+ self.c_proj.bias,
226
+ self.top_k,
227
+ self.stream_id,
228
+ self.activation_function,
229
+ is_inference_mode or not self.training,
230
+ )
231
+ else:
232
+ # hidden_states -> (total_q, hidden_size)
233
+ router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states)
234
+
235
+ # router_logits -> (total_q, num_experts)
236
+ # router_weights -> (total_q, top_k)
237
+ # selected_experts -> (total_q, top_k)
238
+
239
+ hidden_states, expert_frequency = self._compute_experts(
240
+ hidden_states,
241
+ router_weights,
242
+ selected_experts,
243
+ kernel_backend_moe=kernel_backend_moe,
244
+ )
245
+
246
+ hidden_states = hidden_states.view(original_shape)
247
+
248
+ # hidden_states -> (batch_size, query_length, hidden_size)
249
+
250
+ if is_inference_mode:
251
+ aux_loss = None
252
+ else:
253
+ aux_loss = self._compute_switch_loss(
254
+ logits=router_logits,
255
+ probs=F.softmax(router_logits, dim=-1, dtype=torch.float32),
256
+ expert_frequency=expert_frequency,
257
+ )
258
+
259
+ return hidden_states, aux_loss
260
+
261
+ # copied from https://github.com/open-lm-engine/lm-engine/blob/1447883df709727839bbbb367ce727fa56962a6a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py#L432-L455
262
+ # NOTE we don't do all_reduce here for expert frequency for simplicity across data parallel workers
263
+ def _compute_switch_loss(
264
+ self, logits: torch.Tensor, probs: torch.Tensor, expert_frequency: torch.Tensor
265
+ ) -> torch.Tensor:
266
+ logits = logits.view(-1, logits.size(-1))
267
+ probs = probs.view(-1, probs.size(-1))
268
+
269
+ num_experts = logits.size(1)
270
+ acc_probs = probs.sum(0)
271
+
272
+ expert_frequency = expert_frequency.float()
273
+
274
+ aux_loss = num_experts * (F.normalize(acc_probs, p=1, dim=0) * F.normalize(expert_frequency, p=1, dim=0)).sum()
275
+
276
+ return aux_loss
277
+
278
+ def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]:
279
+ # hidden_states -> (total_q, hidden_size)
280
+ router_logits = self.router(hidden_states)
281
+ # router_logits -> (total_q, num_experts)
282
+
283
+ router_weights, selected_experts = self._get_topk(router_logits)
284
+
285
+ # router_weights -> (total_q, top_k)
286
+ # selected_experts -> (total_q, top_k)
287
+
288
+ router_weights = F.softmax(router_weights.float(), dim=-1)
289
+ router_weights = router_weights.type_as(hidden_states)
290
+
291
+ return router_logits, router_weights, selected_experts
292
+
293
+ def _compute_experts(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ router_weights: torch.Tensor,
297
+ selected_experts: torch.Tensor,
298
+ kernel_backend_moe: KernelBackendMoE,
299
+ ) -> tuple[torch.Tensor, torch.Tensor]:
300
+ selected_experts = selected_experts.flatten()
301
+
302
+ with torch.no_grad():
303
+ sorted_expert_idxs, sorted_scattered_idxs = selected_experts.sort()
304
+
305
+ expert_frequency = selected_experts.bincount(minlength=self.num_experts).to(torch.int32)
306
+ expert_offsets = expert_frequency.cumsum(-1).to(torch.int32)
307
+
308
+ act_func = {
309
+ ActivationType.SWIGLU: _swiglu,
310
+ ActivationType.GEGLU: _geglu,
311
+ ActivationType.REGLU: _reglu,
312
+ ActivationType.GELU: _gelu,
313
+ ActivationType.RELU: _relu,
314
+ ActivationType.SILU: _silu,
315
+ ActivationType.RELU_SQ: _relu_sq,
316
+ }[self.activation_function]
317
+
318
+ T = hidden_states.size(0)
319
+
320
+ if kernel_backend_moe == KernelBackendMoE.scattermoe:
321
+ hidden_states = self.c_fc.up_projection_scattermoe_forward(
322
+ input=hidden_states,
323
+ num_experts_per_token=self.top_k,
324
+ sorted_expert_idxs=sorted_expert_idxs,
325
+ sorted_scattered_idxs=sorted_scattered_idxs,
326
+ expert_offsets=expert_offsets,
327
+ )
328
+ hidden_states = act_func(hidden_states)
329
+ hidden_states = self.c_proj.down_projection_scattermoe_forward(
330
+ input=hidden_states,
331
+ num_experts_per_token=1,
332
+ sorted_expert_idxs=sorted_expert_idxs,
333
+ sorted_scattered_idxs=sorted_scattered_idxs,
334
+ expert_offsets=expert_offsets,
335
+ gates=router_weights,
336
+ )
337
+ elif kernel_backend_moe == KernelBackendMoE.torch:
338
+ # sort and group input tokens according to expert assignment
339
+ fan_in_index = sorted_scattered_idxs // self.top_k
340
+
341
+ # gather the gate values for grouped input tokens
342
+ router_weights = router_weights.flatten()
343
+ batch_gates = router_weights[sorted_scattered_idxs]
344
+
345
+ hidden_states = hidden_states[fan_in_index]
346
+
347
+ hidden_states = self.c_fc.torch_forward(
348
+ input=hidden_states, expert_frequency=expert_frequency, return_list=True
349
+ )
350
+
351
+ hidden_states = [act_func(i) for i in hidden_states]
352
+ hidden_states = self.c_proj.torch_forward(input=hidden_states, expert_frequency=None, return_list=False)
353
+
354
+ hidden_states = hidden_states * batch_gates.unsqueeze(-1)
355
+ zeros = torch.zeros((T, self.hidden_size), dtype=torch.float32, device=hidden_states.device)
356
+ hidden_states = zeros.index_add(0, fan_in_index, hidden_states)
357
+ else:
358
+ raise ValueError(f"unexpected kernel_backend_moe ({kernel_backend_moe})")
359
+
360
+ return hidden_states, expert_frequency
361
+
362
+ def _get_topk(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
363
+ if self.top_k == 1:
364
+ x, indices = x.max(dim=-1, keepdim=True)
365
+ else:
366
+ x, indices = x.topk(self.top_k, dim=-1)
367
+
368
+ return x, indices
build/torch-cuda/quack/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.2.5"
2
+
3
+ import os
4
+
5
+ if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
6
+ from . import cute_dsl_ptxas
7
+
8
+ cute_dsl_ptxas.patch()
build/torch-cuda/quack/_ops_compat.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .._ops_compat import add_op_namespace_prefix
2
+
3
+ def add_quack_op_namespace_prefix(name: str) -> str:
4
+ return add_op_namespace_prefix(f"quack__{name}")
build/torch-cuda/quack/activation.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import cutlass.cute as cute
7
+ from cutlass import Float32, Boolean, const_expr
8
+ from cutlass.cutlass_dsl import T, dsl_user_op
9
+ from cutlass._mlir.dialects import llvm
10
+
11
+ from . import utils as utils
12
+
13
+
14
+ F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
15
+
16
+
17
+ @dsl_user_op
18
+ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
19
+ return Float32(
20
+ llvm.inline_asm(
21
+ T.f32(),
22
+ [Float32(a).ir_value(loc=loc, ip=ip)],
23
+ "tanh.approx.f32 $0, $1;",
24
+ "=f,f",
25
+ has_side_effects=False,
26
+ is_align_stack=False,
27
+ asm_dialect=llvm.AsmDialect.AD_ATT,
28
+ )
29
+ )
30
+
31
+
32
+ @dsl_user_op
33
+ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
34
+ if const_expr(not isinstance(x, tuple)):
35
+ # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
36
+ return 0.5 + 0.5 * tanh(0.5 * x)
37
+ else:
38
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
39
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
40
+ return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
41
+
42
+
43
+ @dsl_user_op
44
+ def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
45
+ # return dout * out * (1.0 - out)
46
+ return dout * (out - out * out)
47
+
48
+
49
+ @dsl_user_op
50
+ def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
51
+ if const_expr(not isinstance(x, tuple)):
52
+ return cute.arch.fmax(x, Float32(0.0))
53
+ else:
54
+ return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
55
+
56
+
57
+ @dsl_user_op
58
+ @cute.jit
59
+ def drelu(
60
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
61
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
62
+ if const_expr(not isinstance(x, tuple)):
63
+ x_pos = Boolean(x > 0)
64
+ return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
65
+ else:
66
+ x0_pos = Boolean(x[0] > 0)
67
+ x1_pos = Boolean(x[1] > 0)
68
+ dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
69
+ return dx, relu(x)
70
+
71
+
72
+ @dsl_user_op
73
+ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
74
+ if const_expr(not isinstance(x, tuple)):
75
+ return cute.arch.fmax(x, Float32(0.0)) * x
76
+ else:
77
+ relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
78
+ return utils.mul_packed_f32x2(relu_x, x)
79
+
80
+
81
+ @dsl_user_op
82
+ @cute.jit
83
+ def drelu_sq(
84
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
85
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
86
+ """
87
+ ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
88
+ Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
89
+ Returns: (dx, relu_sq_out) where:
90
+ - dx = dout * 2 * x if x > 0, else 0
91
+ - relu_sq_out = max(x, 0) * x
92
+ """
93
+ if const_expr(not isinstance(x, tuple)):
94
+ relu_x = relu(x)
95
+ relu_sq_out = relu_x * x
96
+ # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
97
+ dx = 2.0 * (dout * relu_x)
98
+ return dx, relu_sq_out
99
+ else:
100
+ relu_x = relu(x)
101
+ relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
102
+ dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
103
+ return dx, relu_sq_out
104
+
105
+
106
+ @dsl_user_op
107
+ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
108
+ """
109
+ gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
110
+ = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
111
+ """
112
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
113
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
114
+ if const_expr(not isinstance(x, tuple)):
115
+ return 0.5 * (
116
+ x
117
+ # Currently cute.math.tanh(x, fastmath=True) generates very slow code
118
+ # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
119
+ * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
120
+ )
121
+ else:
122
+ x_sq = utils.mul_packed_f32x2(x, x)
123
+ x_sq_scaled = utils.fma_packed_f32x2(
124
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
125
+ )
126
+ z = utils.mul_packed_f32x2(x, x_sq_scaled)
127
+ tanh_z = (tanh(z[0]), tanh(z[1]))
128
+ x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
129
+ return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
130
+
131
+
132
+ @dsl_user_op
133
+ def dgelu_tanh_approx(
134
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
135
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
136
+ """
137
+ GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
138
+ Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
139
+ Returns: (dx, gelu_out)
140
+
141
+ Derivative uses the chain rule:
142
+ d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
143
+ where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
144
+ and sech^2(z) = 1 - tanh^2(z)
145
+ """
146
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
147
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
148
+ sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
149
+
150
+ if const_expr(not isinstance(x, tuple)):
151
+ # Compute z = x * (c1 + c2 * x^2)
152
+ x_sq = x * x
153
+ # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
154
+ tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
155
+ half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
156
+ gelu_out = x * half_tanh_z_plus_one
157
+
158
+ # Compute gradient
159
+ # sech^2(z) = 1 - tanh^2(z)
160
+ sech2_z = 1 - tanh_z * tanh_z
161
+ # dz/dx = c1 + 3 * c2 * x^2
162
+ dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
163
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
164
+ dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
165
+
166
+ dx = dout * dgelu
167
+ return dx, gelu_out
168
+ else:
169
+ # Compute z = x * (c1 + c2 * x^2)
170
+ x_sq = utils.mul_packed_f32x2(x, x)
171
+ x_sq_scaled = utils.fma_packed_f32x2(
172
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
173
+ )
174
+ z = utils.mul_packed_f32x2(x, x_sq_scaled)
175
+ tanh_z = (tanh(z[0]), tanh(z[1]))
176
+ half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
177
+ gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
178
+
179
+ # Compute gradient
180
+ # sech^2(z) = 1 - tanh^2(z)
181
+ sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
182
+ # dz/dx = c1 + 3 * c2 * x^2
183
+ dz_dx = utils.fma_packed_f32x2(
184
+ x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
185
+ )
186
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
187
+ sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
188
+ x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
189
+ dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
190
+
191
+ dx = utils.mul_packed_f32x2(dout, dgelu)
192
+ return dx, gelu_out
193
+
194
+
195
+ @dsl_user_op
196
+ @cute.jit
197
+ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
198
+ if const_expr(not isinstance(x, tuple)):
199
+ use_linear = Boolean(x > 20.0)
200
+ return (
201
+ cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
202
+ if not use_linear
203
+ else x
204
+ )
205
+ else:
206
+ log2_e = math.log2(math.e)
207
+ x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
208
+ x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
209
+ x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
210
+ log_x_exp_p1 = (
211
+ cute.math.log2(x_exp_p1[0], fastmath=True),
212
+ cute.math.log2(x_exp_p1[1], fastmath=True),
213
+ )
214
+ ln2 = math.log(2.0)
215
+ softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
216
+ use_linear_0 = Boolean(x[0] > 20.0)
217
+ use_linear_1 = Boolean(x[1] > 20.0)
218
+ return (
219
+ softplus_x[0] if not use_linear_0 else x[0],
220
+ softplus_x[1] if not use_linear_1 else x[1],
221
+ )
222
+
223
+
224
+ @dsl_user_op
225
+ @cute.jit
226
+ def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
227
+ use_linear = Boolean(out > 20.0)
228
+ # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
229
+ dx = dout - dout * cute.math.exp(-out, fastmath=True)
230
+ return dx if not use_linear else dout
231
+
232
+
233
+ @dsl_user_op
234
+ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
235
+ """
236
+ silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
237
+ This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
238
+ """
239
+ if const_expr(not isinstance(x, tuple)):
240
+ x_half = 0.5 * x if const_expr(not already_halved) else x
241
+ # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
242
+ return x_half * tanh(x_half) + x_half
243
+ else:
244
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
245
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
246
+ return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
247
+
248
+
249
+ @dsl_user_op
250
+ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
251
+ if const_expr(not isinstance(x, tuple)):
252
+ return silu(x) * y
253
+ else:
254
+ return utils.mul_packed_f32x2(silu(x), y)
255
+
256
+
257
+ @dsl_user_op
258
+ def dswiglu(
259
+ x: F32_or_F32x2,
260
+ y: F32_or_F32x2,
261
+ dout: F32_or_F32x2,
262
+ *,
263
+ already_halved: bool = False,
264
+ loc=None,
265
+ ip=None,
266
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
267
+ """
268
+ SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
269
+ Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
270
+ Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
271
+
272
+ d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
273
+
274
+ This has been optimized to use fewer instructions (i.e. we expand things out
275
+ to use FFMA instead of FADD and FMUL).
276
+ """
277
+ if const_expr(not isinstance(x, tuple)):
278
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
279
+ # FMUL, MUFU.TANH, then FFMA
280
+ if const_expr(not already_halved):
281
+ sigmoid_x = sigmoid(x)
282
+ silu_x = x * sigmoid_x # FMUL
283
+ else:
284
+ tanh_x = tanh(x) # MUFU.TANH
285
+ sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
286
+ silu_x = x * tanh_x + x # FFMA
287
+ silu_x_dout = silu_x * dout # FMUL
288
+ # d_silu(x) * dout
289
+ # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
290
+ # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
291
+ # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
292
+ # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
293
+ # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
294
+ d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
295
+ dx = d_silu_x_dout * y # FMUL
296
+ dy = silu_x_dout
297
+ swiglu_out = silu_x * y # FMUL
298
+ # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
299
+ return dx, dy, swiglu_out
300
+ else:
301
+ # Compute sigmoid(x) and silu(x)
302
+ if const_expr(not already_halved):
303
+ sigmoid_x = sigmoid(x)
304
+ silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
305
+ else:
306
+ tanh_x = (tanh(x[0]), tanh(x[1]))
307
+ sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
308
+ silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
309
+ silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
310
+ # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
311
+ sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
312
+ sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
313
+ )
314
+ d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
315
+ dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
316
+ dy = silu_x_dout
317
+ swiglu_out = utils.mul_packed_f32x2(silu_x, y)
318
+ return dx, dy, swiglu_out
319
+
320
+
321
+ @dsl_user_op
322
+ def swiglu_oai(
323
+ x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
324
+ ) -> F32_or_F32x2:
325
+ """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
326
+ https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
327
+ x * sigmoid(alpha * x) * (y + 1)
328
+ Compile down to FMUL, FMUL, TANH, FFMA, FFMA
329
+ """
330
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
331
+ if const_expr(not isinstance(x, tuple)):
332
+ x_half = 0.5 * x
333
+ # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
334
+ silu_x = x_half * tanh(alpha * x_half) + x_half
335
+ return silu_x * y + silu_x
336
+ else:
337
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
338
+ alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
339
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
340
+ silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
341
+ return utils.fma_packed_f32x2(silu_x, y, silu_x)
342
+
343
+
344
+ @dsl_user_op
345
+ def dswiglu_oai(
346
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
347
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
348
+ """
349
+ Swiglu OAI backward pass: computes gradients w.r.t. x and y
350
+ Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
351
+ Returns: (dx, dy, swiglu_oai_out)
352
+
353
+ Derivative of x * sigmoid(alpha * x) w.r.t. x:
354
+ d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
355
+ """
356
+ if const_expr(not isinstance(x, tuple)):
357
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
358
+ alpha_x_half = (0.5 * alpha) * x # FMUL
359
+ # MUFU.TANH, then FFMA
360
+ # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
361
+ sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
362
+ silu_x = x * sigmoid_alpha_x # FMUL
363
+ silu_x_dout = silu_x * dout # FMUL
364
+ # FFMA, FFMA, FMUL
365
+ d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
366
+ dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
367
+ dy = silu_x_dout
368
+ swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
369
+ # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
370
+ return dx, dy, swiglu_out
371
+ else:
372
+ # Compute sigmoid(alpha * x)
373
+ alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
374
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
375
+ sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
376
+ silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
377
+ silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
378
+ # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
379
+ silu_x_minus_product = utils.fma_packed_f32x2(
380
+ silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
381
+ )
382
+ sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
383
+ (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
384
+ )
385
+ d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
386
+ dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
387
+ dy = silu_x_dout
388
+ swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
389
+ return dx, dy, swiglu_out
390
+
391
+
392
+ @dsl_user_op
393
+ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
394
+ """GLU: Gated Linear Unit
395
+ glu(x, y) = sigmoid(x) * y
396
+ Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
397
+ """
398
+ if const_expr(not isinstance(x, tuple)):
399
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
400
+ return sigmoid_x * y # FMUL
401
+ else:
402
+ sigmoid_x = sigmoid(x)
403
+ return utils.mul_packed_f32x2(sigmoid_x, y)
404
+
405
+
406
+ @dsl_user_op
407
+ def dglu(
408
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
409
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
410
+ """
411
+ GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
412
+ Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
413
+ Returns: (dx, dy, glu_out) where:
414
+ - dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
415
+ - dy = dout * sigmoid(x)
416
+ - glu_out = sigmoid(x) * y
417
+ """
418
+ if const_expr(not isinstance(x, tuple)):
419
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
420
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
421
+ sigmoid_x_dout = sigmoid_x * dout # FMUL
422
+ glu_out = sigmoid_x * y # FMUL
423
+ # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
424
+ # = y * (1 - sigmoid(x)) * sigmoid_x_dout
425
+ # = (y - y * sigmoid(x)) * sigmoid_x_dout
426
+ # = (y - glu_out) * sigmoid_x_dout
427
+ dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
428
+ dy = sigmoid_x_dout
429
+ # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
430
+ return dx, dy, glu_out
431
+ else:
432
+ sigmoid_x = sigmoid(x)
433
+ sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
434
+ glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
435
+ # dx = (y - glu_out) * sigmoid_x_dout
436
+ y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
437
+ dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
438
+ dy = sigmoid_x_dout
439
+ return dx, dy, glu_out
440
+
441
+
442
+ @dsl_user_op
443
+ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
444
+ """ReGLU: ReLU Gated Linear Unit
445
+ reglu(x, y) = relu(x) * y = max(x, 0) * y
446
+ """
447
+ if const_expr(not isinstance(x, tuple)):
448
+ return cute.arch.fmax(x, Float32(0.0)) * y
449
+ else:
450
+ relu_x = relu(x)
451
+ return utils.mul_packed_f32x2(relu_x, y)
452
+
453
+
454
+ @dsl_user_op
455
+ @cute.jit
456
+ def dreglu(
457
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
458
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
459
+ """
460
+ ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
461
+ Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
462
+ Returns: (dx, dy, reglu_out) where:
463
+ - dx = dout * y if x > 0, else 0
464
+ - dy = dout * relu(x)
465
+ - reglu_out = relu(x) * y
466
+ """
467
+ if const_expr(not isinstance(x, tuple)):
468
+ x_pos = Boolean(x > 0)
469
+ relu_x = cute.arch.fmax(x, Float32(0.0))
470
+ dx = (dout * y) if x_pos else Float32(0.0)
471
+ dy = dout * relu_x
472
+ reglu_out = relu_x * y
473
+ return dx, dy, reglu_out
474
+ else:
475
+ x0_pos = Boolean(x[0] > 0)
476
+ x1_pos = Boolean(x[1] > 0)
477
+ relu_x = relu(x)
478
+ dout_y = utils.mul_packed_f32x2(dout, y)
479
+ dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
480
+ dy = utils.mul_packed_f32x2(dout, relu_x)
481
+ reglu_out = utils.mul_packed_f32x2(relu_x, y)
482
+ return dx, dy, reglu_out
483
+
484
+
485
+ @dsl_user_op
486
+ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
487
+ """GeGLU: GELU Gated Linear Unit
488
+ geglu(x, y) = gelu(x) * y
489
+ Uses the tanh approximation of GELU
490
+ """
491
+ if const_expr(not isinstance(x, tuple)):
492
+ return gelu_tanh_approx(x) * y
493
+ else:
494
+ return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
495
+
496
+
497
+ @dsl_user_op
498
+ def dgeglu(
499
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
500
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
501
+ """
502
+ GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
503
+ Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
504
+ Returns: (dx, dy, geglu_out) where:
505
+ - dx = dout * y * d_gelu(x)
506
+ - dy = dout * gelu(x)
507
+ - geglu_out = gelu(x) * y
508
+ """
509
+ if const_expr(not isinstance(x, tuple)):
510
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
511
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
512
+ # Compute gradients for geglu
513
+ dx = dgelu_x_dout * y
514
+ dy = gelu_x * dout
515
+ geglu_out = gelu_x * y
516
+ return dx, dy, geglu_out
517
+ else:
518
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
+ # Compute gradients for geglu
521
+ dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
522
+ dy = utils.mul_packed_f32x2(gelu_x, dout)
523
+ geglu_out = utils.mul_packed_f32x2(gelu_x, y)
524
+ return dx, dy, geglu_out
build/torch-cuda/quack/autotuner.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py
2
+ # Copyright (C) 2025, Tri Dao.
3
+ from __future__ import annotations
4
+
5
+ import builtins
6
+ import os
7
+ import time
8
+ import inspect
9
+ import base64
10
+ import hashlib
11
+ import json
12
+ from pathlib import Path
13
+ from functools import cached_property, partial
14
+ from typing import Dict, Tuple, List, Optional, Any
15
+
16
+ import torch
17
+ from torch import Tensor
18
+
19
+ import triton
20
+
21
+ from . import __version__
22
+
23
+
24
+ PACKAGE_NAME = "quack"
25
+ VERSION = __version__
26
+
27
+
28
+ def get_home_dir():
29
+ return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
30
+
31
+
32
+ def default_cache_dir():
33
+ return os.path.join(get_home_dir(), f".{PACKAGE_NAME}", "cache")
34
+
35
+
36
+ class FileCacheManager(triton.runtime.cache.FileCacheManager):
37
+ def __init__(self, key):
38
+ super().__init__(key)
39
+ self.cache_dir = (
40
+ os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_DIR", "").strip() or default_cache_dir()
41
+ )
42
+ if self.cache_dir:
43
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
44
+ self.lock_path = os.path.join(self.cache_dir, "lock")
45
+ os.makedirs(self.cache_dir, exist_ok=True)
46
+ else:
47
+ raise RuntimeError("Could not create or locate cache dir")
48
+
49
+
50
+ def _base32(key):
51
+ # Assume key is a hex string.
52
+ return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
53
+
54
+
55
+ class Autotuner:
56
+ def __init__(
57
+ self,
58
+ fn,
59
+ key,
60
+ configs,
61
+ restore_value=None,
62
+ prune_configs_by: Optional[Dict] = None,
63
+ do_bench=None,
64
+ cache_results=False,
65
+ ):
66
+ """
67
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
68
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
69
+ 'top_k': number of configs to bench
70
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
71
+ """
72
+ if not configs:
73
+ self.configs = [AutotuneConfig()]
74
+ else:
75
+ self.configs = configs
76
+ signature = inspect.signature(fn)
77
+ self.keys = key
78
+ self.cache: Dict[Tuple, AutotuneConfig] = {}
79
+ self.arg_names = list(signature.parameters.keys())
80
+ self.cache_results = (
81
+ cache_results or os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_AUTOTUNING", None) == "1"
82
+ )
83
+
84
+ self.restore_value = []
85
+ if restore_value is not None:
86
+ self.restore_value = list(restore_value)
87
+
88
+ if len(self.restore_value) > 0:
89
+
90
+ def _pre_hook(kwargs):
91
+ self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value}
92
+
93
+ self.pre_hook = _pre_hook
94
+ else:
95
+ self.pre_hook = None
96
+
97
+ if len(self.restore_value) > 0:
98
+
99
+ def _post_hook(kwargs, exception):
100
+ for name in self.restore_value:
101
+ kwargs[name].copy_(self.restore_copies[name])
102
+ self.restore_copies = {}
103
+
104
+ self.post_hook = _post_hook
105
+ else:
106
+ self.post_hook = None
107
+
108
+ self.perf_model = None
109
+ self.configs_top_k = 1.0
110
+ self.early_config_prune = None
111
+ if prune_configs_by:
112
+ self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
113
+ self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
114
+ self.early_config_prune = prune_configs_by.get(
115
+ "early_config_prune", self.early_config_prune
116
+ )
117
+
118
+ self.fn = fn
119
+ self._do_bench = do_bench
120
+
121
+ @cached_property
122
+ def do_bench(self):
123
+ if self._do_bench is None:
124
+ return partial(triton.testing.do_bench, warmup=5, rep=25)
125
+ return self._do_bench
126
+
127
+ def _bench(self, *args, config, **meta):
128
+ verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
129
+ if verbose:
130
+ print(f"Autotuning kernel {self.fn.__name__} with config {config}")
131
+
132
+ # check for conflicts, i.e. meta-parameters both provided
133
+ # as kwargs and by the autotuner
134
+ conflicts = meta.keys() & config.kwargs.keys()
135
+ if conflicts:
136
+ raise ValueError(
137
+ f"Conflicting meta-parameters: {', '.join(conflicts)}."
138
+ " Make sure that you don't re-define auto-tuned symbols."
139
+ )
140
+ # augment meta-parameters with tunable ones
141
+ current = dict(meta, **config.all_kwargs())
142
+ full_nargs = {**self.nargs, **current}
143
+
144
+ def kernel_call():
145
+ if self.pre_hook is not None:
146
+ self.pre_hook(full_nargs)
147
+ try:
148
+ self.fn.__call__(
149
+ *args,
150
+ **current,
151
+ )
152
+ except Exception as e:
153
+ try:
154
+ if self.post_hook is not None:
155
+ self.post_hook(full_nargs, exception=e)
156
+ finally:
157
+ # Throw exception raised by `self.fn.run`
158
+ raise
159
+
160
+ if self.post_hook is not None:
161
+ self.post_hook(full_nargs, exception=None)
162
+
163
+ try:
164
+ return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
165
+ except Exception as e:
166
+ if verbose:
167
+ print(f"Autotuning failed with {e}")
168
+ return [float("inf"), float("inf"), float("inf")]
169
+
170
+ @torch.compiler.disable
171
+ def check_disk_cache(self, tuning_key, configs, bench_fn):
172
+ if not tuning_key:
173
+ bench_fn()
174
+ return
175
+
176
+ fn = self.fn
177
+ config_str_list = [str(c) for c in configs]
178
+ assert len(config_str_list) == len(set(config_str_list)), "Config strings must be unique"
179
+ cache_key = [VERSION, str(tuning_key)] + config_str_list
180
+ cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
181
+ cache = FileCacheManager(_base32(cache_key))
182
+ file_name = f"{fn.__name__[:150]}.autotune.json"
183
+ path = cache.get_file(file_name)
184
+ # There's an environment variable to force cache update
185
+ if path and not os.environ.get(f"{PACKAGE_NAME.upper()}_FORCE_CACHE_UPDATE", False):
186
+ str2config = {s: c for s, c in zip(config_str_list, configs)}
187
+ with open(path, "r") as cached_configs:
188
+ timings = json.load(cached_configs)["configs_timings"]
189
+ timings = {str2config[config]: timing for config, timing in timings}
190
+ self.cache[tuning_key] = builtins.min(timings, key=timings.get)
191
+ self.configs_timings = timings
192
+ self.bench_time = 0
193
+ return
194
+
195
+ bench_fn()
196
+ cache.put(
197
+ json.dumps(
198
+ {
199
+ "key": tuning_key,
200
+ "configs_timings": [
201
+ (str(config), timings) for config, timings in self.configs_timings.items()
202
+ ],
203
+ }
204
+ ),
205
+ file_name,
206
+ binary=False,
207
+ )
208
+
209
+ def __call__(self, *args, **kwargs):
210
+ self.nargs = dict(zip(self.arg_names, args))
211
+ used_cached_result = True
212
+ if len(self.configs) > 1:
213
+ all_args = {**self.nargs, **kwargs}
214
+ _args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
215
+ # Need "str" to make it json-serializable
216
+ key = [str(_args[key]) for key in self.keys if key in _args]
217
+ for _, arg in _args.items():
218
+ if isinstance(arg, Tensor):
219
+ key.append(str(arg.shape))
220
+ # If stride != 0, 1, we just cache it as 2
221
+ key.append(str([s if s in {0, 1} else 2 for s in arg.stride()]))
222
+ key.append(str(arg.dtype))
223
+ key = tuple(key)
224
+ if key not in self.cache:
225
+ used_cached_result = False
226
+ pruned_configs = self.prune_configs(kwargs)
227
+
228
+ @torch.compiler.disable # Don't want any tracing here
229
+ def benchmark():
230
+ bench_start = time.time()
231
+ timings = {
232
+ config: self._bench(*args, config=config, **kwargs)
233
+ for config in pruned_configs
234
+ }
235
+ bench_end = time.time()
236
+ if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
237
+ for config, time_ in timings.items():
238
+ print(f"[{config}] -> {time_[0]:.3f}ms")
239
+ self.bench_time = bench_end - bench_start
240
+ self.cache[key] = builtins.min(timings, key=timings.get)
241
+ self.configs_timings = timings
242
+
243
+ if self.cache_results:
244
+ self.check_disk_cache(key, pruned_configs, benchmark)
245
+ else:
246
+ benchmark()
247
+
248
+ config = self.cache[key]
249
+ else:
250
+ config = self.configs[0]
251
+ self.best_config = config
252
+ if (
253
+ os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
254
+ and not used_cached_result
255
+ ):
256
+ print(
257
+ f"{PACKAGE_NAME} autotuning for function {self.fn.__name__} finished after "
258
+ f"{self.bench_time:.2f}s; best config selected: {self.best_config};"
259
+ )
260
+ ret = self.fn.__call__(
261
+ *args,
262
+ **kwargs,
263
+ **config.all_kwargs(),
264
+ )
265
+ self.nargs = None
266
+ return ret
267
+
268
+ def prune_configs(self, kwargs: Dict) -> List[Any]:
269
+ pruned_configs = self.configs
270
+ if self.early_config_prune:
271
+ pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
272
+ if self.perf_model:
273
+ top_k = self.configs_top_k
274
+ if isinstance(top_k, float) and top_k <= 1.0:
275
+ top_k = int(len(self.configs) * top_k)
276
+ elif not isinstance(top_k, int):
277
+ # Slice index must be an integer
278
+ raise TypeError(
279
+ "Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
280
+ )
281
+
282
+ if len(pruned_configs) > top_k:
283
+ est_timing = {
284
+ config: self.perf_model(
285
+ **self.nargs,
286
+ **kwargs,
287
+ **config.all_kwargs(),
288
+ )
289
+ for config in pruned_configs
290
+ }
291
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
292
+ return pruned_configs
293
+
294
+
295
+ class AutotuneConfig:
296
+ """
297
+ An object that represents a possible kernel configuration for the auto-tuner to try.
298
+
299
+ :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
300
+ :type kwargs: dict[Str, Any]
301
+ """
302
+
303
+ def __init__(self, **kwargs):
304
+ self.kwargs = kwargs
305
+
306
+ def __setstate__(self, state):
307
+ self.kwargs = state.get("kwargs", {})
308
+
309
+ def all_kwargs(self):
310
+ return self.kwargs
311
+
312
+ def __str__(self):
313
+ res = []
314
+ for k, v in self.kwargs.items():
315
+ res.append(f"{k}: {v}")
316
+ return ", ".join(res)
317
+
318
+ def __hash__(self):
319
+ return hash(tuple(*self.all_kwargs().items()))
320
+
321
+ def __eq__(self, other):
322
+ self_tuple = tuple(*self.all_kwargs().items())
323
+ other_tuple = tuple(*other.all_kwargs().items())
324
+ return self_tuple == other_tuple
325
+
326
+
327
+ def autotune(
328
+ configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
329
+ ):
330
+ f"""
331
+ Decorator for auto-tuning a function function.
332
+
333
+ .. highlight:: python
334
+
335
+ If the environment variable :code:`{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING` is set to
336
+ :code:`"1"`, we will print a message to stdout after autotuning each
337
+ kernel, including the time spent autotuning and the best configuration.
338
+
339
+ :param configs: a list of :code:`AutotuneConfig` objects
340
+ :type configs: list[AutotuneConfig]
341
+ :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
342
+ :type key: list[str]
343
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
344
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
345
+ 'top_k': number of configs to bench
346
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
347
+ :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
348
+ :type restore_value: list[str]
349
+ :param do_bench: a benchmark function to measure the time of each run.
350
+ :type do_bench: lambda fn, quantiles
351
+ :param cache_results: whether to cache autotune timings to disk. Defaults to False.
352
+ "type cache_results: bool
353
+ """
354
+
355
+ if key is None:
356
+ key = []
357
+
358
+ def decorator(fn):
359
+ return Autotuner(
360
+ fn,
361
+ key,
362
+ configs,
363
+ restore_value=restore_value,
364
+ prune_configs_by=prune_configs_by,
365
+ do_bench=do_bench,
366
+ cache_results=cache_results,
367
+ )
368
+
369
+ return decorator
build/torch-cuda/quack/broadcast_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Callable
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Float32, const_expr
7
+
8
+ from .layout_utils import make_acc_tensor_mn_view
9
+
10
+
11
+ @cute.jit
12
+ def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13
+ if const_expr(tCrC.element_type != Float32): # Convert to f32
14
+ tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
15
+ tCrC_f32.store(tCrC.load().to(Float32))
16
+ else:
17
+ tCrC_f32 = tCrC
18
+ # this happens to work for frgA layout too, not just acc layout
19
+ tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
20
+ if const_expr(is_colvec):
21
+ assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
22
+ for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
23
+ tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
24
+ else:
25
+ assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
26
+ for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
27
+ tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
28
+ if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
29
+ tCrC.store(tCrC_f32.load().to(tCrC.element_type))
build/torch-cuda/quack/compile_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ from typing import Optional
4
+
5
+ import cutlass.cute as cute
6
+
7
+
8
+ def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
9
+ if leading_dim < 0:
10
+ leading_dim = len(shape) + leading_dim
11
+ if dtype is None:
12
+ return None
13
+ stride = tuple(
14
+ cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
15
+ for i in range(len(shape))
16
+ )
17
+ return cute.runtime.make_fake_tensor(
18
+ dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
19
+ )
build/torch-cuda/quack/copy_utils.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import re
4
+ from typing import Optional, Type, Tuple, Callable
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ from cutlass import Int32, Boolean, const_expr
10
+ from cutlass.cute.nvgpu import cpasync, warpgroup
11
+ from cutlass.cutlass_dsl import dsl_user_op
12
+ import cutlass.pipeline
13
+
14
+
15
+ @dsl_user_op
16
+ def cvt_copy(
17
+ tiled_copy: cute.TiledCopy,
18
+ src: cute.Tensor,
19
+ dst: cute.Tensor,
20
+ *,
21
+ pred: Optional[cute.Tensor] = None,
22
+ retile: bool = False,
23
+ loc=None,
24
+ ip=None,
25
+ **kwargs,
26
+ ) -> None:
27
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
28
+ if const_expr(src.element_type != dst.element_type):
29
+ src_cvt = cute.make_fragment_like(src, dst.element_type)
30
+ src_cvt.store(src.load().to(dst.element_type))
31
+ src = src_cvt
32
+ if const_expr(retile):
33
+ src = tiled_copy.retile(src)
34
+ cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
35
+
36
+
37
+ @dsl_user_op
38
+ def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
39
+ dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
40
+ cute.autovec_copy(src, dst, loc=loc, ip=ip)
41
+ return dst
42
+
43
+
44
+ @dsl_user_op
45
+ def load_s2r_retile(
46
+ tiled_copy: cute.TiledCopy,
47
+ src: cute.Tensor,
48
+ dst_shape: cute.Tensor | cute.Shape,
49
+ *,
50
+ loc=None,
51
+ ip=None,
52
+ ) -> cute.Tensor:
53
+ # Will also accept dst_shape being a tensor, in which case we write into that tensor
54
+ if const_expr(not isinstance(dst_shape, cute.Tensor)):
55
+ dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
56
+ else:
57
+ dst = dst_shape
58
+ cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
59
+ return dst
60
+
61
+
62
+ @dsl_user_op
63
+ def get_copy_atom(
64
+ dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
65
+ ) -> cute.CopyAtom:
66
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
67
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
68
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
69
+
70
+
71
+ @dsl_user_op
72
+ def copy(
73
+ src: cute.Tensor,
74
+ dst: cute.Tensor,
75
+ *,
76
+ pred: Optional[cute.Tensor] = None,
77
+ is_async: bool = False,
78
+ loc=None,
79
+ ip=None,
80
+ **kwargs,
81
+ ) -> None:
82
+ num_copy_elems = src.shape[0][0]
83
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
84
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
85
+
86
+
87
+ def tiled_copy_1d(
88
+ dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
89
+ ) -> cute.TiledCopy:
90
+ num_copy_bits = num_copy_elems * dtype.width
91
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
92
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
93
+ thr_layout = cute.make_layout(num_threads)
94
+ val_layout = cute.make_layout(num_copy_elems)
95
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
96
+
97
+
98
+ def tiled_copy_2d(
99
+ dtype: Type[cutlass.Numeric],
100
+ threads_per_row: int,
101
+ num_threads: int,
102
+ num_copy_elems: int = 1,
103
+ is_async: bool = False,
104
+ ) -> cute.TiledCopy:
105
+ num_copy_bits = num_copy_elems * dtype.width
106
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
107
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
108
+ assert num_threads % threads_per_row == 0
109
+ thr_layout = cute.make_ordered_layout(
110
+ (num_threads // threads_per_row, threads_per_row),
111
+ order=(1, 0),
112
+ )
113
+ val_layout = cute.make_layout((1, num_copy_elems))
114
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
115
+
116
+
117
+ @cute.jit
118
+ def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
119
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
120
+ tApA = cute.make_fragment(
121
+ cute.make_layout(
122
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
123
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
124
+ ),
125
+ Boolean,
126
+ )
127
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
128
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
129
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
130
+ return tApA
131
+
132
+
133
+ # def tiled_copy_2d(
134
+ # dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
135
+ # ) -> cute.TiledCopy:
136
+ # num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
137
+ # copy_elems = num_copy_bits // dtype.width
138
+ # copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
139
+ # copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
140
+ # gmem_threads_per_row = major_mode_size // copy_elems
141
+ # assert num_threads % gmem_threads_per_row == 0
142
+ # thr_layout = cute.make_ordered_layout(
143
+ # (num_threads // gmem_threads_per_row, gmem_threads_per_row),
144
+ # order=(1, 0),
145
+ # )
146
+ # val_layout = cute.make_layout((1, copy_elems))
147
+ # return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
148
+
149
+
150
+ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
151
+ """Extract swizzle parameters from a pointer's swizzle_type.
152
+
153
+ The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
154
+ b, m, s are the swizzle parameters (bits, base, shift).
155
+
156
+ Returns:
157
+ A cute.Swizzle object constructed from the extracted parameters
158
+
159
+ Raises:
160
+ ValueError: If the swizzle_type string cannot be parsed
161
+ """
162
+ # Ideally there should be a better API to get swizzle parameters, but we'll just parse
163
+ # the string here.
164
+ swizzle_str = str(ptr.type.swizzle_type)
165
+ # Extract the inner part "S<b,m,s>"
166
+ match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
167
+ if match:
168
+ b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
169
+ return b, m, s
170
+ else:
171
+ raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
172
+
173
+
174
+ def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
175
+ bit_msk = (1 << b) - 1
176
+ yyy_msk = bit_msk << (m + s)
177
+ return ptr_int ^ ((ptr_int & yyy_msk) >> s)
178
+
179
+
180
+ def swizzle_ptr(ptr: cute.Pointer):
181
+ b, m, s = parse_swizzle_from_pointer(ptr)
182
+ ptr_int = swizzle_int(ptr.toint(), b, m, s)
183
+ return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
184
+
185
+
186
+ def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
187
+ outer = tensor.layout
188
+ width = tensor.element_type.width
189
+ inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
190
+ # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
191
+ # for 16 bits and <3, 2, 3> for 32 bits)
192
+ new_layout = cute.recast_layout(
193
+ width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
194
+ )
195
+ # recast_ptr to remove the pointer swizzle
196
+ return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
197
+
198
+
199
+ def partition_D_position_independent(
200
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
201
+ ) -> cute.Tensor:
202
+ return cute.make_tensor(
203
+ swizzle_ptr(thr_copy.partition_D(tensor).iterator),
204
+ thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
205
+ )
206
+
207
+
208
+ def partition_S_position_independent(
209
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
210
+ ) -> cute.Tensor:
211
+ return cute.make_tensor(
212
+ swizzle_ptr(thr_copy.partition_S(tensor).iterator),
213
+ thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
214
+ )
215
+
216
+
217
+ @dsl_user_op
218
+ def sm90_get_smem_load_op(
219
+ layout_c: cutlass.utils.LayoutEnum,
220
+ elem_ty_c: Type[cutlass.Numeric],
221
+ *,
222
+ loc=None,
223
+ ip=None,
224
+ ) -> cute.CopyAtom:
225
+ """
226
+ Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
227
+
228
+ Parameters:
229
+ -----------
230
+ layout_c : LayoutEnum
231
+ The layout enum of the output tensor D.
232
+
233
+ elem_ty_c : Type[Numeric]
234
+ The element type for output tensor D.
235
+
236
+ Returns:
237
+ --------
238
+ Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
239
+ """
240
+
241
+ if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
242
+ raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
243
+ is_m_major = layout_c.is_m_major_c()
244
+ if elem_ty_c.width == 16:
245
+ return cute.make_copy_atom(
246
+ cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
247
+ )
248
+ else:
249
+ return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
250
+
251
+
252
+ def get_smem_store_atom(
253
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
254
+ ) -> cute.CopyAtom:
255
+ if const_expr(arch < 90 or element_type.width != 16):
256
+ return cute.make_copy_atom(
257
+ cute.nvgpu.CopyUniversalOp(),
258
+ element_type,
259
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
260
+ )
261
+ else:
262
+ return cute.make_copy_atom(
263
+ cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
264
+ element_type,
265
+ )
266
+
267
+
268
+ def get_smem_load_atom(
269
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
270
+ ) -> cute.CopyAtom:
271
+ if const_expr(arch < 90 or element_type.width != 16):
272
+ return cute.make_copy_atom(
273
+ cute.nvgpu.CopyUniversalOp(),
274
+ element_type,
275
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
276
+ )
277
+ else:
278
+ return cute.make_copy_atom(
279
+ cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
280
+ element_type,
281
+ )
282
+
283
+
284
+ def get_smem_store_C(
285
+ tiled_mma: cute.TiledMma,
286
+ sC: cute.Tensor,
287
+ tidx: Int32,
288
+ arch: int,
289
+ transpose: bool = False,
290
+ position_independent=False,
291
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
292
+ dtype = sC.element_type
293
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
294
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
295
+ thr_copy = tiled_copy.get_slice(tidx)
296
+ if const_expr(not position_independent):
297
+ tRS_sC = thr_copy.partition_D(sC)
298
+ else:
299
+ tRS_sC = partition_D_position_independent(thr_copy, sC)
300
+
301
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
302
+ cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs)
303
+
304
+ return copy_fn, thr_copy, tRS_sC
305
+
306
+
307
+ def get_smem_load_C(
308
+ tiled_mma: cute.TiledMma,
309
+ sC: cute.Tensor,
310
+ tidx: Int32,
311
+ arch: int,
312
+ transpose: bool = False,
313
+ position_independent=False,
314
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
315
+ dtype = sC.element_type
316
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
317
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
318
+ thr_copy = tiled_copy.get_slice(tidx)
319
+ if const_expr(not position_independent):
320
+ tSR_sC = thr_copy.partition_S(sC)
321
+ else:
322
+ tSR_sC = partition_S_position_independent(thr_copy, sC)
323
+ copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
324
+ thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
325
+ tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
326
+
327
+ def copy_fn(src_idx: Int32, **new_kwargs):
328
+ return load_s2r_retile(
329
+ tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
330
+ )
331
+
332
+ return copy_fn, thr_copy, tSR_sC
333
+
334
+
335
+ def get_smem_store_A(
336
+ tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
337
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
338
+ dtype = sA.element_type
339
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
340
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
341
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
342
+ thr_copy = tiled_copy.get_slice(tidx)
343
+ if const_expr(not position_independent):
344
+ tRS_sA = thr_copy.partition_D(sA)
345
+ else:
346
+ tRS_sA = partition_D_position_independent(thr_copy, sA)
347
+
348
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
349
+ cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
350
+
351
+ return copy_fn, thr_copy, tRS_sA
352
+
353
+
354
+ def get_smem_load_A(
355
+ tiled_mma: cute.TiledMma,
356
+ sA: cute.Tensor,
357
+ tidx: Int32,
358
+ arch: int,
359
+ with_dst_tensor: bool = False,
360
+ position_independent=False,
361
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
362
+ dtype = sA.element_type
363
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
364
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
365
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
366
+ thr_copy = tiled_copy.get_slice(tidx)
367
+ if const_expr(not position_independent):
368
+ tSR_sA = thr_copy.partition_S(sA)
369
+ else:
370
+ tSR_sA = partition_S_position_independent(thr_copy, sA)
371
+ copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
372
+ thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
373
+ tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
374
+
375
+ def copy_fn(src_idx: Int32, **new_kwargs):
376
+ return load_s2r_retile(
377
+ tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
378
+ )
379
+
380
+ def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
381
+ return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
382
+
383
+ return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
384
+
385
+
386
+ def tma_get_copy_fn(
387
+ atom: cute.CopyAtom,
388
+ cta_coord: cute.Coord,
389
+ cta_layout: cute.Layout,
390
+ src_tensor: cute.Tensor,
391
+ dst_tensor: cute.Tensor,
392
+ filter_zeros: bool = False,
393
+ single_stage: bool = False,
394
+ **kwargs,
395
+ ) -> Callable:
396
+ src_is_smem = const_expr(
397
+ isinstance(src_tensor.iterator, cute.Pointer)
398
+ and src_tensor.memspace == cute.AddressSpace.smem
399
+ )
400
+ smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
401
+ group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
402
+ group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
403
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
404
+ s, g = cpasync.tma_partition(
405
+ atom,
406
+ cta_coord,
407
+ cta_layout,
408
+ cute.group_modes(smem_tensor, 0, group_rank_smem),
409
+ cute.group_modes(gmem_tensor, 0, group_rank_gmem),
410
+ )
411
+ if const_expr(filter_zeros):
412
+ s = cute.filter_zeros(s)
413
+ g = cute.filter_zeros(g)
414
+ src, dst = (s, g) if src_is_smem else (g, s)
415
+
416
+ def copy_tma(src_idx, dst_idx, **new_kwargs):
417
+ cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
418
+
419
+ def copy_tma_single_stage(**new_kwargs):
420
+ cute.copy(atom, src, dst, **new_kwargs, **kwargs)
421
+
422
+ return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
423
+
424
+
425
+ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
426
+ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
427
+ copy(
428
+ src_idx=src_idx,
429
+ dst_idx=producer_state.index,
430
+ tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
431
+ **new_kwargs,
432
+ )
433
+
434
+ return copy_fn
435
+
436
+
437
+ @cute.jit
438
+ def gather_m_get_copy_fn(
439
+ thr_copy_A: cute.ThrCopy,
440
+ mA: cute.Tensor, # (whatever, K)
441
+ sA: cute.Tensor, # (tile_M, tile_N, STAGE)
442
+ gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
443
+ limit_m: Int32,
444
+ limit_k: Int32,
445
+ ) -> Callable:
446
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
447
+ tAsA = thr_copy_A.partition_D(sA)
448
+ # k-major
449
+ assert tAsA.shape[2] == 1
450
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
451
+
452
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
453
+ if const_expr(not is_even_m_smem):
454
+ limit_m = min(limit_m, tile_shape_mk[0])
455
+ elems_per_load = cute.size(tAsA.shape[0][0])
456
+ cA = cute.make_identity_tensor(tile_shape_mk)
457
+ tAcA = thr_copy_A.partition_S(cA)
458
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
459
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
460
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
461
+ # This is so that when we do the comparison, t0AcA is known at compile time.
462
+ limit_m = limit_m - tAcA[0][0]
463
+ limit_k = limit_k - tAcA[0][1]
464
+ # Read and cache indices for A
465
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
466
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
467
+ tApA_m = cute.make_fragment(rows_per_thread, Boolean)
468
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
469
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
470
+ m_idx = cute.make_fragment(rows_per_thread, Int32)
471
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
472
+ row_idx = tAcA[0, m, 0][0]
473
+ if tApA_m[m]:
474
+ m_idx[m] = gsAIdx[row_idx]
475
+ else:
476
+ m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
477
+
478
+ mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
479
+
480
+ def copy_fn(src_idx, dst_idx, pred: bool = False):
481
+ tApA_k = None
482
+ if const_expr(pred):
483
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
484
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
485
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
486
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
487
+ mA_cur = mA_k[None, (None, src_idx)]
488
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
489
+ # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
490
+ # ((elems_per_load), thread_per_row)
491
+ # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
492
+ # So we append 1s to the last dimension and then do tiled_divide, then slice.
493
+ mA_row = cute.tiled_divide(
494
+ cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
495
+ )[None, None, 0]
496
+ if const_expr(is_even_m_smem) or tApA_m[m]:
497
+ # There's only 1 load per row
498
+ assert cute.size(tAcA.shape, mode=[2]) == 1
499
+ ki = tAcA[0, 0, 0][1] // elems_per_load
500
+ cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
501
+
502
+ return copy_fn
503
+
504
+
505
+ @cute.jit
506
+ def gather_k_get_copy_fn(
507
+ thr_copy_A: cute.ThrCopy,
508
+ mA: cute.Tensor, # (tile_M, whatever)
509
+ sA: cute.Tensor, # (tile_M, tile_N, STAGE)
510
+ gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
511
+ limit_m: Int32,
512
+ limit_k: Int32,
513
+ ) -> Callable:
514
+ gAIdx, sAIdx = None, None
515
+ if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
516
+ gAIdx = gsAIdx
517
+ else:
518
+ assert gsAIdx.memspace == cute.AddressSpace.smem
519
+ sAIdx = gsAIdx
520
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
521
+ # (atom_v, CPY_M, 1, STAGE)
522
+ tAsA = thr_copy_A.partition_D(sA)
523
+ # m-major
524
+ tAsA = cute.group_modes(tAsA, 0, 3)
525
+
526
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
527
+ if const_expr(not is_even_m_smem):
528
+ limit_m = min(limit_m, tile_shape_mk[0])
529
+ elems_per_load = cute.size(tAsA.shape[0][0])
530
+ cA = cute.make_identity_tensor(tile_shape_mk)
531
+ tAcA = thr_copy_A.partition_S(cA)
532
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
533
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
534
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
535
+ # This is so that when we do the comparison, t0AcA is known at compile time.
536
+ limit_m = limit_m - tAcA[0][0]
537
+ limit_k = limit_k - tAcA[0][1]
538
+ # Read and cache indices for A
539
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
540
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
541
+ tApA_m = cute.make_fragment(rows_per_thread, Boolean)
542
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
543
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
544
+ threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
545
+ # This is very convoluted but idk a better way
546
+ # for tile_M=128, flat_divide gives (8, 16, K),
547
+ # then logical_divide gives ((8, 1), (8, 2), K).
548
+ tidx = thr_copy_A.thr_idx
549
+ tAmA = cute.logical_divide(
550
+ cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
551
+ )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
552
+
553
+ def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
554
+ # Prefetch mAIdx early, even before smem is free
555
+ tApA_k = None
556
+ if const_expr(pred):
557
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
558
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
559
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
560
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
561
+ gAIdx_cur = gAIdx[None, src_idx]
562
+ k_idx = cute.make_fragment(cols_per_thread, Int32)
563
+ for k in cutlass.range(cols_per_thread):
564
+ col_idx = tAcA[0, 0, k][1]
565
+ if const_expr(not pred):
566
+ k_idx[k] = gAIdx_cur[col_idx]
567
+ else:
568
+ if tApA_k[k]:
569
+ k_idx[k] = gAIdx_cur[col_idx]
570
+ else:
571
+ k_idx[k] = -1
572
+ return k_idx, tApA_k
573
+
574
+ def prefetch_from_smem_fn(
575
+ a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
576
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
577
+ tApA_k = None
578
+ if const_expr(pred):
579
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
580
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
581
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
582
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
583
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
584
+ sAIdx_cur = sAIdx[None, dst_idx]
585
+ k_idx = cute.make_fragment(cols_per_thread, Int32)
586
+ for k in cutlass.range(cols_per_thread):
587
+ col_idx = tAcA[0, 0, k][1]
588
+ k_idx[k] = sAIdx_cur[col_idx]
589
+ cute.arch.sync_warp()
590
+ with cute.arch.elect_one():
591
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
592
+ return k_idx, tApA_k
593
+
594
+ def copy_fn(
595
+ src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
596
+ ):
597
+ k_idx, tApA_k = k_idx_tApA_k
598
+ tApA_k_pred = None
599
+ if const_expr(pred):
600
+ tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
601
+ for k in cutlass.range_constexpr(tAcA.shape[2]):
602
+ # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
603
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
604
+ if tApA_m[m]:
605
+ cute.copy(
606
+ thr_copy_A,
607
+ tAmA[None, m, k_idx[k]],
608
+ tAsA[(None, m, k), dst_idx],
609
+ pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
610
+ )
611
+
612
+ return copy_fn, prefetch_from_gmem_fn if const_expr(
613
+ gAIdx is not None
614
+ ) else prefetch_from_smem_fn
build/torch-cuda/quack/cute_dsl_ptxas.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System ptxas replacement for CUTLASS DSL.
3
+ Environment variables:
4
+ CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
5
+ CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import re
11
+ import ctypes
12
+ import subprocess
13
+ from pathlib import Path
14
+
15
+ import cutlass
16
+
17
+
18
+ CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
19
+ VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
20
+
21
+ _original_load_cuda_library = None
22
+ _user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
23
+
24
+
25
+ def _log(msg):
26
+ if VERBOSE:
27
+ print(f"[ptxas] {msg}", file=sys.stderr)
28
+
29
+
30
+ def _get_ptx(compiled_func) -> tuple[str, Path] | None:
31
+ """Find and read PTX file, stripping null bytes."""
32
+ func_name = getattr(compiled_func, "function_name", None)
33
+ if not func_name:
34
+ return None
35
+
36
+ dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
37
+ for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
38
+ content = ptx_path.read_text().rstrip("\x00")
39
+ if ".entry " in content and content.rstrip().endswith("}"):
40
+ _log(f"Found PTX: {ptx_path}")
41
+ return content, ptx_path
42
+ return None
43
+
44
+
45
+ def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:
46
+ """Compile PTX to cubin using system ptxas."""
47
+ # Extract arch from PTX
48
+ match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content)
49
+ arch = match.group(1) if match else "sm_90a"
50
+
51
+ # Write stripped content back if needed
52
+ if ptx_path.read_text() != ptx_content:
53
+ ptx_path.write_text(ptx_content)
54
+
55
+ # Compile
56
+ cubin_tmp = ptx_path.with_suffix(".cubin.tmp")
57
+ try:
58
+ assert CUTE_DSL_PTXAS_PATH is not None
59
+ result = subprocess.run(
60
+ [CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)],
61
+ capture_output=True,
62
+ text=True,
63
+ )
64
+ if result.returncode != 0:
65
+ raise RuntimeError(f"ptxas failed: {result.stderr}")
66
+
67
+ cubin_data = cubin_tmp.read_bytes()
68
+ _log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})")
69
+
70
+ # Save cubin if CUTE_DSL_KEEP_CUBIN is set
71
+ if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1":
72
+ cubin_out = ptx_path.with_suffix(".cubin")
73
+ cubin_out.write_bytes(cubin_data)
74
+ _log(f"Saved: {cubin_out}")
75
+
76
+ return cubin_data
77
+ finally:
78
+ cubin_tmp.unlink(missing_ok=True)
79
+
80
+
81
+ def _patched_load_cuda_library(self):
82
+ """Replacement for _load_cuda_library that uses system ptxas."""
83
+
84
+ result = _get_ptx(self)
85
+ if not result:
86
+ _log("PTX not found, falling back to embedded ptxas")
87
+ return _original_load_cuda_library(self)
88
+
89
+ ptx_content, ptx_path = result
90
+
91
+ try:
92
+ cubin = _compile_ptx(ptx_path, ptx_content)
93
+ except Exception as e:
94
+ _log(f"Compilation failed ({e}), falling back to embedded ptxas")
95
+ return _original_load_cuda_library(self)
96
+
97
+ # Load cubin
98
+ import cuda.bindings.runtime as cuda_runtime
99
+
100
+ err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0)
101
+ if err != cuda_runtime.cudaError_t.cudaSuccess:
102
+ _log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
103
+ return _original_load_cuda_library(self)
104
+
105
+ # Register kernels on all devices
106
+ _, cuda_load_to_device = self._get_cuda_init_and_load()
107
+ lib_ptr = ctypes.c_void_p(int(library))
108
+ dev_id = ctypes.c_int32(0)
109
+ err_val = ctypes.c_int32(0)
110
+ args = (ctypes.c_void_p * 3)(
111
+ ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
112
+ ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
113
+ ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
114
+ )
115
+
116
+ for dev in range(self.num_devices):
117
+ dev_id.value = dev
118
+ cuda_load_to_device(args)
119
+ if err_val.value != 0:
120
+ _log("cuda_load_to_device failed, falling back to embedded ptxas")
121
+ return _original_load_cuda_library(self)
122
+
123
+ _log(f"Loaded kernel from {ptx_path.name}")
124
+
125
+ # Delete PTX if user didn't originally want it kept
126
+ if not _user_wanted_ptx:
127
+ ptx_path.unlink(missing_ok=True)
128
+
129
+ return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]
130
+
131
+
132
+ def patch():
133
+ """Install system ptxas hook. Call before importing cutlass."""
134
+ global _original_load_cuda_library, _user_wanted_ptx
135
+
136
+ assert CUTE_DSL_PTXAS_PATH is not None
137
+ if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
138
+ raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
139
+
140
+ # Track if user originally wanted PTX kept
141
+ _user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
142
+ # os.environ['CUTE_DSL_KEEP_PTX'] = '1'
143
+ assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
144
+ "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
145
+ )
146
+
147
+ cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
148
+ _original_load_cuda_library = cls._load_cuda_library
149
+ cls._load_cuda_library = _patched_load_cuda_library
150
+ _log("Patch applied")
151
+ return
build/torch-cuda/quack/cute_dsl_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple
4
+ from functools import lru_cache
5
+ from dataclasses import dataclass, fields
6
+
7
+ import torch
8
+
9
+ try:
10
+ from triton.tools.disasm import extract
11
+ except ImportError:
12
+ extract = None
13
+
14
+ import cutlass
15
+ import cutlass.cute as cute
16
+ from cutlass import Int32, Int64, Float16, BFloat16, Float32
17
+ from cutlass.base_dsl.typing import JitArgument
18
+ from cutlass.cutlass_dsl import NumericMeta
19
+
20
+
21
+ StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
22
+
23
+
24
+ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
25
+ cute_compile_og = cute.compile
26
+
27
+
28
+ torch2cute_dtype_map = {
29
+ torch.float16: Float16,
30
+ torch.bfloat16: BFloat16,
31
+ torch.float32: Float32,
32
+ torch.int32: Int32,
33
+ torch.int64: Int64,
34
+ }
35
+
36
+
37
+ @lru_cache
38
+ def get_max_active_clusters(cluster_size):
39
+ return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
40
+
41
+
42
+ @lru_cache
43
+ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
44
+ return torch.cuda.get_device_capability(device)
45
+
46
+
47
+ @dataclass
48
+ class ParamsBase:
49
+ def __extract_mlir_values__(self):
50
+ all_fields = [getattr(self, field.name) for field in fields(self)]
51
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
52
+ values, self._values_pos = [], []
53
+ for obj in non_constexpr_fields:
54
+ obj_values = cutlass.extract_mlir_values(obj)
55
+ values += obj_values
56
+ self._values_pos.append(len(obj_values))
57
+ return values
58
+
59
+ def __new_from_mlir_values__(self, values):
60
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
61
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
62
+ non_constexpr_fields = {
63
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
64
+ }
65
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
66
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
67
+ values = values[n_items:]
68
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
69
+
70
+
71
+ @dataclass
72
+ class ArgumentsBase(JitArgument):
73
+ def __c_pointers__(self):
74
+ all_fields = [getattr(self, field.name) for field in fields(self)]
75
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
76
+ c_ptrs = []
77
+ for obj in non_constexpr_fields:
78
+ if hasattr(obj, "__c_pointers__"):
79
+ c_ptrs.extend(obj.__c_pointers__())
80
+ return c_ptrs
81
+
82
+ def __get_mlir_types__(self):
83
+ all_fields = [getattr(self, field.name) for field in fields(self)]
84
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
85
+ types, self._values_pos = [], []
86
+ for obj in non_constexpr_fields:
87
+ if hasattr(obj, "__get_mlir_types__"):
88
+ obj_types = obj.__get_mlir_types__()
89
+ types.extend(obj_types)
90
+ self._values_pos.append(len(obj_types))
91
+ else:
92
+ self._values_pos.append(0)
93
+ return types
94
+
95
+ def __new_from_mlir_values__(self, values):
96
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
97
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
98
+ non_constexpr_fields = {
99
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
100
+ }
101
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
102
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
103
+ values = values[n_items:]
104
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
build/torch-cuda/quack/fast_math.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32, Uint32
9
+ from cutlass.cutlass_dsl import T, dsl_user_op
10
+ from cutlass._mlir.dialects import llvm
11
+
12
+ from .cute_dsl_utils import ParamsBase
13
+
14
+
15
+ @cute.jit
16
+ def clz(x: Int32) -> Int32:
17
+ # for i in cutlass.range_constexpr(32):
18
+ # if (1 << (31 - i)) & x:
19
+ # return Int32(i)
20
+ # return Int32(32)
21
+ # Early exit is not supported yet
22
+ res = Int32(32)
23
+ done = False
24
+ for i in cutlass.range(32):
25
+ if ((1 << (31 - i)) & x) and not done:
26
+ res = Int32(i)
27
+ done = True
28
+ return res
29
+
30
+
31
+ def find_log2(x: Int32) -> Int32:
32
+ a: Int32 = Int32(31 - clz(x))
33
+ return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.
34
+
35
+
36
+ @dsl_user_op
37
+ def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
38
+ return Uint32(
39
+ llvm.inline_asm(
40
+ T.i32(),
41
+ [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
42
+ "mul.hi.u32 $0, $1, $2;",
43
+ "=r,r,r",
44
+ has_side_effects=False,
45
+ is_align_stack=False,
46
+ asm_dialect=llvm.AsmDialect.AD_ATT,
47
+ )
48
+ )
49
+
50
+
51
+ @dataclass
52
+ class FastDivmod(ParamsBase):
53
+ divisor: Int32
54
+ multiplier: Uint32
55
+ shift_right: Uint32
56
+
57
+ # called by host
58
+ @staticmethod
59
+ def create(divisor: Int32) -> "FastDivmod":
60
+ """Construct the FastDivmod object, in host code.
61
+ This precomputes some values based on the divisor and is computationally expensive.
62
+ """
63
+ p = Uint32(31 + find_log2(divisor))
64
+ divisor_u32 = Uint32(divisor)
65
+ multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
66
+ shift_right = Uint32(p - 32)
67
+ return FastDivmod(divisor, multiplier, shift_right)
68
+
69
+ @cute.jit
70
+ def div(self, dividend: Int32) -> Int32:
71
+ return (
72
+ Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
73
+ if self.divisor != 1
74
+ else dividend
75
+ )
76
+
77
+ def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
78
+ quotient = self.div(dividend)
79
+ remainder = dividend - quotient * self.divisor
80
+ return quotient, remainder
build/torch-cuda/quack/gemm.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from functools import partial
3
+
4
+ from torch import Tensor
5
+
6
+ import cutlass.cute as cute
7
+ import cutlass.torch as cutlass_torch
8
+ from cutlass import Float32
9
+ from cutlass.cute.runtime import from_dlpack, make_ptr
10
+
11
+ from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
12
+ from .gemm_wrapper_utils import GemmWrapperBase
13
+ from .gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100
14
+
15
+
16
+ def gemm(
17
+ # (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
18
+ A: Tensor,
19
+ B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
20
+ D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
21
+ C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
22
+ tile_count_semaphore: Optional[Tensor], # (1,)
23
+ tile_M: int,
24
+ tile_N: int,
25
+ cluster_M: int,
26
+ cluster_N: int,
27
+ pingpong: bool = False,
28
+ persistent: bool = True,
29
+ max_swizzle_size: int = 8,
30
+ rowvec_bias: Optional[Tensor] = None, # (l, n)
31
+ colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
32
+ alpha: float | Tensor = 1.0,
33
+ beta: float | Tensor = 1.0,
34
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
35
+ cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
36
+ A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
37
+ batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
38
+ add_to_output: bool = False,
39
+ ) -> None:
40
+ varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
41
+ assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
42
+ "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
43
+ )
44
+ gather_A = A_idx is not None
45
+ if gather_A:
46
+ assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
47
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
48
+ if varlen:
49
+ assert persistent, "varlen requires persistent=True"
50
+ if add_to_output:
51
+ assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
52
+ if cu_seqlens_m is not None:
53
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
54
+ assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
55
+ if cu_seqlens_k is not None:
56
+ assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
57
+ assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
58
+
59
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
60
+ A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
61
+ )
62
+ GemmWrapperBase.permute_tensors(
63
+ tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
64
+ )
65
+ GemmWrapperBase.extract_dtypes(tensor_infos)
66
+ major_configs = {
67
+ "A": ("m", "k", "l"),
68
+ "B": ("n", "k", "l"),
69
+ "D": ("m", "n", "l"),
70
+ "C": ("m", "n", "l"),
71
+ }
72
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
73
+
74
+ device_capacity = get_device_capacity(A.device)
75
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
76
+ GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90
77
+
78
+ acc_dtype = Float32
79
+ tile_shape_mn = (tile_M, tile_N)
80
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
81
+ if not GemmCls.is_valid_dtypes(
82
+ tensor_infos["A"].dtype,
83
+ tensor_infos["B"].dtype,
84
+ acc_dtype,
85
+ tensor_infos["D"].dtype,
86
+ tensor_infos["A"].major,
87
+ tensor_infos["B"].major,
88
+ ):
89
+ raise TypeError("Skipping due to unsupported combination of types and majors")
90
+
91
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
92
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
93
+
94
+ def scalar_arg(scalar: float | Tensor):
95
+ if isinstance(scalar, float):
96
+ return Float32(scalar) if scalar != 1.0 else None
97
+ else:
98
+ assert isinstance(scalar, Tensor)
99
+ return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
100
+
101
+ epi_args = GemmCls.EpilogueArguments(
102
+ scalar_arg(alpha),
103
+ scalar_arg(beta),
104
+ mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
105
+ leading_dim=1
106
+ )
107
+ if rowvec_bias is not None
108
+ else None,
109
+ mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
110
+ leading_dim=1 if cu_seqlens_m is None else 0
111
+ )
112
+ if colvec_bias is not None
113
+ else None,
114
+ add_to_output=add_to_output,
115
+ )
116
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
117
+ max_active_clusters,
118
+ tile_count_semaphore,
119
+ batch_idx_permute,
120
+ max_swizzle_size,
121
+ )
122
+
123
+ # Create varlen arguments if needed (assumes persistent=True when varlen)
124
+ varlen_args = GemmWrapperBase.create_varlen_args(
125
+ cu_seqlens_m,
126
+ cu_seqlens_k,
127
+ A_idx,
128
+ max_active_clusters,
129
+ cluster_shape_mnk,
130
+ tensor_infos,
131
+ GemmCls.num_epi_tensormaps,
132
+ pingpong,
133
+ )
134
+
135
+ current_stream = cutlass_torch.current_stream()
136
+ compile_key = GemmWrapperBase.get_compile_key(
137
+ tensor_infos,
138
+ None, # activation
139
+ tile_shape_mn,
140
+ cluster_shape_mnk,
141
+ pingpong,
142
+ persistent,
143
+ tile_count_semaphore is not None,
144
+ device_capacity,
145
+ # Technically we don't need to recompile for different max_swizzle_size, but currently
146
+ # not recompiling will skew the autotuning results due to power throttling.
147
+ # Effectively we're recompiling as a way to pause between benchmarks during autotuning.
148
+ max_swizzle_size,
149
+ rowvec_bias.dtype if rowvec_bias is not None else None,
150
+ colvec_bias.dtype if colvec_bias is not None else None,
151
+ 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
152
+ 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
153
+ add_to_output,
154
+ cu_seqlens_m is not None,
155
+ cu_seqlens_k is not None,
156
+ gather_A,
157
+ batch_idx_permute is not None,
158
+ key_tensor_names=("A", "B", "D", "C"),
159
+ )
160
+ cache = gemm.compile_cache
161
+ if compile_key not in cache:
162
+ if device_capacity[0] == 9:
163
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
164
+ gemm_obj = GemmCls(
165
+ acc_dtype,
166
+ tensor_infos["A"].dtype,
167
+ tile_shape_mn,
168
+ cluster_shape_mnk,
169
+ gather_A=gather_A,
170
+ )
171
+ cache[compile_key] = cute.compile(
172
+ gemm_obj,
173
+ tensor_infos["A"].cute_tensor,
174
+ tensor_infos["B"].cute_tensor,
175
+ tensor_infos["D"].cute_tensor,
176
+ tensor_infos["C"].cute_tensor,
177
+ epi_args,
178
+ scheduler_args,
179
+ varlen_args,
180
+ current_stream,
181
+ )
182
+ cache[compile_key](
183
+ tensor_infos["A"].cute_tensor,
184
+ tensor_infos["B"].cute_tensor,
185
+ tensor_infos["D"].cute_tensor,
186
+ tensor_infos["C"].cute_tensor,
187
+ epi_args,
188
+ scheduler_args,
189
+ varlen_args,
190
+ current_stream,
191
+ )
192
+
193
+
194
+ gemm.compile_cache = {}
build/torch-cuda/quack/gemm_act.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
+ from typing import Tuple, Optional, Callable
3
+ from functools import partial
4
+ from dataclasses import dataclass
5
+
6
+ from torch import Tensor
7
+
8
+ import cutlass
9
+ import cutlass.cute as cute
10
+ import cutlass.utils.hopper_helpers as sm90_utils_og
11
+ import cutlass.utils.blackwell_helpers as sm100_utils
12
+ from cutlass import Int32, Float32, Boolean, const_expr
13
+ from cutlass.cutlass_dsl import if_generate
14
+ import cutlass.torch as cutlass_torch
15
+ from cutlass.cute.runtime import from_dlpack
16
+
17
+ from .cute_dsl_utils import ArgumentsBase, ParamsBase
18
+ from .varlen_utils import VarlenManager
19
+ from .gemm_sm90 import GemmSm90
20
+ from .gemm_sm100 import GemmSm100
21
+ from .gemm_default_epi import GemmDefaultEpiMixin
22
+ from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
23
+ from .gemm_wrapper_utils import GemmWrapperBase
24
+ from . import sm90_utils as sm90_utils
25
+ from . import copy_utils as copy_utils
26
+ from . import activation
27
+
28
+
29
+ class GemmActMixin(GemmDefaultEpiMixin):
30
+ num_epi_tensormaps: int = 1
31
+
32
+ @dataclass
33
+ class EpilogueArguments(ArgumentsBase):
34
+ mPostAct: cute.Tensor
35
+ act_fn: cutlass.Constexpr[Optional[Callable]] = None
36
+ alpha: Optional[Float32 | cute.Tensor] = None
37
+ beta: Optional[Float32 | cute.Tensor] = None
38
+ mRowVecBroadcast: Optional[cute.Tensor] = None
39
+ mColVecBroadcast: Optional[cute.Tensor] = None
40
+
41
+ @dataclass
42
+ class EpilogueParams(ParamsBase):
43
+ tma_atom_postact: cute.CopyAtom
44
+ mPostAct_mnl: cute.Tensor
45
+ epi_postact_smem_layout_staged: cute.ComposedLayout
46
+ epi_tile_postact: cute.Tile
47
+ act_fn: cutlass.Constexpr[Optional[Callable]] = None
48
+ alpha: Optional[Float32 | cute.Tensor] = None
49
+ beta: Optional[Float32 | cute.Tensor] = None
50
+ mRowVecBroadcast: Optional[cute.Tensor] = None
51
+ mColVecBroadcast: Optional[cute.Tensor] = None
52
+
53
+ def epi_to_underlying_arguments(
54
+ self, args: EpilogueArguments, *, loc=None, ip=None
55
+ ) -> EpilogueParams:
56
+ self.postact_dtype = args.mPostAct.element_type
57
+ self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
58
+
59
+ self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
60
+ epi_tile_postact = self.epi_tile
61
+ utils_cls = sm100_utils if self.arch == 100 else sm90_utils
62
+ epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
63
+ self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
64
+ )
65
+ tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
66
+ args.mPostAct,
67
+ epi_postact_smem_layout_staged,
68
+ epi_tile_postact,
69
+ op_type="store",
70
+ )
71
+ # Assume all strides are divisible by 32 bits except the last stride
72
+ new_stride = lambda t: tuple(
73
+ cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
74
+ for s in t.stride
75
+ )
76
+ mRowVecBroadcast, mColVecBroadcast = [
77
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
78
+ if t is not None
79
+ else None
80
+ for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
81
+ ]
82
+ return self.EpilogueParams(
83
+ tma_atom_postact,
84
+ tma_tensor_postact,
85
+ epi_postact_smem_layout_staged,
86
+ epi_tile_postact,
87
+ args.act_fn,
88
+ alpha=args.alpha,
89
+ beta=args.beta,
90
+ mRowVecBroadcast=mRowVecBroadcast,
91
+ mColVecBroadcast=mColVecBroadcast,
92
+ )
93
+
94
+ def epi_get_tma_atoms(
95
+ self, params: EpilogueParams, *, loc=None, ip=None
96
+ ) -> list[cute.CopyAtom]:
97
+ return [params.tma_atom_postact]
98
+
99
+ def epi_get_tensormap_update_shapes_orders(
100
+ self,
101
+ params: EpilogueParams,
102
+ cu_seqlens_m: Optional[cute.Tensor],
103
+ batch_idx: Int32,
104
+ *,
105
+ loc=None,
106
+ ip=None,
107
+ ) -> tuple[list[Int32], list[int]]:
108
+ shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None]
109
+ orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
110
+ return shapes, orders
111
+
112
+ @staticmethod
113
+ def epi_smem_bytes_per_stage(
114
+ args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile
115
+ ) -> int:
116
+ postact_dtype = args.mPostAct.element_type
117
+ postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8)
118
+ rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
119
+ args, cta_tile_shape_mnk, epi_tile
120
+ )
121
+ return postact_bytes_per_stage + rowvec_colvec_bytes
122
+
123
+ def epi_get_smem_struct(self, params: EpilogueParams):
124
+ row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
125
+ col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
126
+ row_vec_dtype = (
127
+ params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
128
+ )
129
+ col_vec_dtype = (
130
+ params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
131
+ )
132
+
133
+ @cute.struct
134
+ class EpiSharedStorage:
135
+ sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
136
+ sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
137
+ sPostAct: cute.struct.Align[
138
+ cute.struct.MemRange[
139
+ self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
140
+ ],
141
+ self.buffer_align_bytes,
142
+ ]
143
+
144
+ return EpiSharedStorage
145
+
146
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
147
+ sRowVec, sColVec = super().epi_get_smem_tensors(params, storage)
148
+ sPostAct = storage.epi.sPostAct.get_tensor(
149
+ params.epi_postact_smem_layout_staged.outer,
150
+ swizzle=params.epi_postact_smem_layout_staged.inner,
151
+ )
152
+ return (sRowVec, sColVec, sPostAct)
153
+
154
+ @cute.jit
155
+ def epilogue(
156
+ self,
157
+ params: EpilogueParams,
158
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
159
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
160
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
161
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
162
+ epi_read_state: cutlass.pipeline.PipelineState,
163
+ epi_producer_state: cutlass.pipeline.PipelineState,
164
+ epi_tile: cute.Tile,
165
+ load_acc_subtile: Callable,
166
+ tRS_rD: cute.Tensor,
167
+ tRS_rC: Optional[cute.Tensor],
168
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
169
+ tiled_copy_r2s: cute.TiledCopy,
170
+ tRS_sD: cute.Tensor,
171
+ tiled_copy_s2r: Optional[cute.TiledCopy],
172
+ tSR_rC: Optional[cute.Tensor],
173
+ tSR_sC: Optional[cute.Tensor],
174
+ copy_D: Optional[Callable],
175
+ copy_C: Optional[Callable],
176
+ tile_coord_mnkl: cute.Coord,
177
+ varlen_manager: VarlenManager,
178
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
179
+ tile_scheduler,
180
+ tidx: Int32,
181
+ is_tma_warp: Boolean,
182
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
183
+ has_C = const_expr(tRS_rC is not None)
184
+ has_D = const_expr(copy_D is not None)
185
+
186
+ tma_atom_postact = params.tma_atom_postact
187
+ mPostAct_mnl = params.mPostAct_mnl
188
+ sRowVec, sColVec, sPostAct = epi_smem_tensors
189
+ get_smem_store_op = (
190
+ partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
191
+ if self.arch == 100
192
+ else sm90_utils_og.sm90_get_smem_store_op
193
+ )
194
+ copy_atom_postact_r2s = get_smem_store_op(
195
+ self.postact_layout, self.postact_dtype, self.acc_dtype
196
+ )
197
+ # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
198
+ # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
199
+ tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
200
+ tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
201
+ (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
202
+ batch_idx = tile_coord_mnkl[3]
203
+ copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
204
+ tma_atom_postact,
205
+ varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
206
+ self.cta_tile_shape_postact_mn,
207
+ params.epi_tile_postact,
208
+ sPostAct,
209
+ tile_coord_mnkl,
210
+ tma_desc_ptr=tma_desc_postact_ptr,
211
+ )
212
+
213
+ # We iterate over epi tiles in the N dimension first before the M dimension
214
+ epi_tile_shape = cute.zipped_divide(
215
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
216
+ ).shape[1]
217
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
218
+ epi_tile_num = cute.size(epi_tile_shape)
219
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
220
+
221
+ epi_tensors = self.epi_begin(
222
+ params,
223
+ epi_smem_tensors,
224
+ epi_tile,
225
+ tiled_copy_t2r,
226
+ tiled_copy_r2s,
227
+ tile_coord_mnkl,
228
+ varlen_manager,
229
+ epilogue_barrier,
230
+ tidx,
231
+ )
232
+
233
+ if const_expr(copy_C is not None):
234
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
235
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
236
+ if is_tma_warp:
237
+ epi_pipeline.producer_acquire(epi_producer_state)
238
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
239
+ epi_pipeline.producer_commit(epi_producer_state)
240
+ epi_producer_state.advance()
241
+
242
+ def tma_store_fn(src_idx, dst_idx):
243
+ # Fence and barrier to make sure shared memory store is visible to TMA store
244
+ cute.arch.fence_proxy(
245
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
246
+ )
247
+ epilogue_barrier.arrive_and_wait()
248
+ # Copy from shared memory to global memory
249
+ if is_tma_warp:
250
+ if const_expr(has_D):
251
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
252
+ copy_postact(src_idx=src_idx, dst_idx=dst_idx)
253
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
254
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
255
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
256
+ epilogue_barrier.arrive_and_wait()
257
+
258
+ delay_tma_store = True
259
+
260
+ src_idx_prev, dst_idx_prev = None, None
261
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
262
+ # The global memory coordinate for the current epi tile
263
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
264
+ # Copy from acc to D registers
265
+ load_acc_subtile(tRS_rD, epi_idx)
266
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
267
+ if const_expr(has_C):
268
+ epi_pipeline.consumer_wait(epi_read_state)
269
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
270
+ # Fence to make sure shared memory read is visible to TMA load
271
+ cute.arch.fence_proxy(
272
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
273
+ )
274
+ cute.arch.sync_warp()
275
+ with cute.arch.elect_one():
276
+ epi_pipeline.consumer_release(epi_read_state)
277
+ epi_read_state.advance()
278
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
279
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
280
+ if is_tma_warp:
281
+ epi_pipeline.producer_acquire(epi_producer_state)
282
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
283
+ epi_pipeline.producer_commit(epi_producer_state)
284
+ epi_producer_state.advance()
285
+ tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
286
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
287
+ if const_expr(delay_tma_store):
288
+ if const_expr(epi_idx > 0):
289
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
290
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
291
+ # Copy from D registers to shared memory
292
+ if const_expr(has_D):
293
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
294
+ cute.copy(
295
+ tiled_copy_postact_r2s,
296
+ tiled_copy_postact_r2s.retile(tRS_rPostAct),
297
+ tRS_sPostAct[None, None, None, epi_buffer],
298
+ )
299
+ if const_expr(not delay_tma_store):
300
+ tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
301
+
302
+ if const_expr(delay_tma_store):
303
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
304
+
305
+ self.epi_end(
306
+ params,
307
+ epi_tensors,
308
+ epi_tile,
309
+ tiled_copy_t2r,
310
+ tiled_copy_r2s,
311
+ tile_coord_mnkl,
312
+ varlen_manager,
313
+ tidx,
314
+ )
315
+
316
+ return epi_read_state, epi_producer_state
317
+
318
+ @cute.jit
319
+ def epi_visit_subtile(
320
+ self,
321
+ params: EpilogueParams,
322
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
323
+ tRS_rD: cute.Tensor,
324
+ tRS_rC: Optional[cute.Tensor] = None,
325
+ ) -> Optional[cute.Tensor]:
326
+ GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
327
+ # Apply activation function if provided
328
+ # If we don't have .shape here, the compiler generates local stores and loads
329
+ if const_expr(params.act_fn is not None):
330
+ tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
331
+ if const_expr(self.arch < 100):
332
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
333
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
334
+ else:
335
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
336
+ tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
337
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1])
338
+ )
339
+ else:
340
+ tRS_rPostAct = tRS_rD
341
+ # Type conversion
342
+ tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
343
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
344
+ return tRS_rPostAct_out
345
+
346
+
347
+ class GemmActSm90(GemmActMixin, GemmSm90):
348
+ pass
349
+
350
+
351
+ class GemmActSm100(GemmActMixin, GemmSm100):
352
+ pass
353
+
354
+
355
+ act_fn_map = {
356
+ None: None,
357
+ "relu": activation.relu,
358
+ "relu_sq": activation.relu_sq,
359
+ "gelu_tanh_approx": activation.gelu_tanh_approx,
360
+ }
361
+
362
+
363
+ def gemm_act(
364
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
365
+ B: Tensor, # (l, n, k)
366
+ D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
367
+ C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
368
+ PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
369
+ tile_count_semaphore: Optional[Tensor], # (1,)
370
+ activation: Optional[str],
371
+ tile_M: int,
372
+ tile_N: int,
373
+ cluster_M: int,
374
+ cluster_N: int,
375
+ pingpong: bool = False,
376
+ persistent: bool = True,
377
+ max_swizzle_size: int = 8,
378
+ rowvec_bias: Optional[Tensor] = None, # (l, n)
379
+ colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
380
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
381
+ A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
382
+ ) -> None:
383
+ if cu_seqlens_m is not None:
384
+ assert persistent, "varlen_m requires persistent=True"
385
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
386
+ if D is not None:
387
+ assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
388
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
389
+ gather_A = A_idx is not None
390
+ if gather_A:
391
+ assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
392
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
393
+ assert activation in act_fn_map, f"Unsupported activation {activation}"
394
+
395
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
396
+ A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
397
+ )
398
+ GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
399
+ GemmWrapperBase.extract_dtypes(tensor_infos)
400
+ major_configs = {
401
+ "A": ("m", "k", "l"),
402
+ "B": ("n", "k", "l"),
403
+ "D": ("m", "n", "l"),
404
+ "C": ("m", "n", "l"),
405
+ "PostAct": ("m", "n", "l"),
406
+ }
407
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
408
+
409
+ device_capacity = get_device_capacity(A.device)
410
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
411
+ GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90
412
+
413
+ acc_dtype = Float32
414
+ tile_shape_mn = (tile_M, tile_N)
415
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
416
+ if not GemmCls.is_valid_dtypes(
417
+ tensor_infos["A"].dtype,
418
+ tensor_infos["B"].dtype,
419
+ acc_dtype,
420
+ tensor_infos["D"].dtype,
421
+ tensor_infos["A"].major,
422
+ tensor_infos["B"].major,
423
+ ):
424
+ raise TypeError("Skipping due to unsupported combination of types and majors")
425
+
426
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
427
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
428
+ act_fn = act_fn_map[activation]
429
+ epi_args = GemmCls.EpilogueArguments(
430
+ tensor_infos["PostAct"].cute_tensor,
431
+ act_fn,
432
+ mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
433
+ leading_dim=1
434
+ )
435
+ if rowvec_bias is not None
436
+ else None,
437
+ mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
438
+ leading_dim=1 if cu_seqlens_m is None else 0
439
+ )
440
+ if colvec_bias is not None
441
+ else None,
442
+ )
443
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
444
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
445
+ )
446
+
447
+ # Create varlen arguments if needed (assumes persistent=True when varlen_m)
448
+ varlen_args = GemmWrapperBase.create_varlen_args(
449
+ cu_seqlens_m,
450
+ None, # cu_seqlens_k
451
+ A_idx,
452
+ max_active_clusters,
453
+ cluster_shape_mnk,
454
+ tensor_infos,
455
+ GemmCls.num_epi_tensormaps,
456
+ pingpong,
457
+ )
458
+
459
+ current_stream = cutlass_torch.current_stream()
460
+ compile_key = GemmWrapperBase.get_compile_key(
461
+ tensor_infos,
462
+ activation,
463
+ tile_shape_mn,
464
+ cluster_shape_mnk,
465
+ pingpong,
466
+ persistent,
467
+ tile_count_semaphore is not None,
468
+ device_capacity,
469
+ max_swizzle_size,
470
+ rowvec_bias.dtype if rowvec_bias is not None else None,
471
+ colvec_bias.dtype if colvec_bias is not None else None,
472
+ cu_seqlens_m is not None,
473
+ A_idx is not None,
474
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
475
+ )
476
+ cache = gemm_act.compile_cache
477
+ if compile_key not in cache:
478
+ if device_capacity[0] == 9:
479
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
480
+ gemm_obj = GemmCls(
481
+ acc_dtype,
482
+ tensor_infos["A"].dtype,
483
+ tile_shape_mn,
484
+ cluster_shape_mnk,
485
+ gather_A=gather_A,
486
+ )
487
+ cache[compile_key] = cute.compile(
488
+ gemm_obj,
489
+ tensor_infos["A"].cute_tensor,
490
+ tensor_infos["B"].cute_tensor,
491
+ tensor_infos["D"].cute_tensor,
492
+ tensor_infos["C"].cute_tensor,
493
+ epi_args,
494
+ scheduler_args,
495
+ varlen_args,
496
+ current_stream,
497
+ )
498
+ cache[compile_key](
499
+ tensor_infos["A"].cute_tensor,
500
+ tensor_infos["B"].cute_tensor,
501
+ tensor_infos["D"].cute_tensor,
502
+ tensor_infos["C"].cute_tensor,
503
+ epi_args,
504
+ scheduler_args,
505
+ varlen_args,
506
+ current_stream,
507
+ )
508
+
509
+
510
+ gemm_act.compile_cache = {}
build/torch-cuda/quack/gemm_config.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, Fri Dao.
2
+ import itertools
3
+ from typing import Optional, List, Literal
4
+ from functools import partial
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class GemmConfig:
10
+ tile_m: int = 128
11
+ tile_n: int = 192
12
+ pingpong: bool = True
13
+ cluster_m: int = 2
14
+ cluster_n: int = 1
15
+ swap_ab: bool = False
16
+ # raster_order: int = 1
17
+ max_swizzle_size: int = 8
18
+
19
+
20
+ def get_all_configs(
21
+ device_capacity: Literal[9, 10] = 9,
22
+ epilogue: Optional[str] = None,
23
+ tune_coop: bool = True,
24
+ # tune_raster_order=True,
25
+ ) -> List[GemmConfig]:
26
+ assert device_capacity in [9, 10]
27
+ if device_capacity == 9:
28
+ tile_n_vals = [128, 144, 160, 176, 192, 208]
29
+ tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
30
+ (128, 224),
31
+ (128, 256),
32
+ # (192, 256), # Getting IOT instruction (core dumped) in the bwd
33
+ ]
34
+ tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
35
+ if epilogue in ["gated"]:
36
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
37
+ tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
38
+ elif epilogue in ["lse"]:
39
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
40
+ tile_mn_vals = []
41
+ if tune_coop:
42
+ tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
43
+ tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
44
+ cluster = [(1, 2), (2, 1)]
45
+ # cluster = [(1, 1), (1, 2), (2, 1)]
46
+ if epilogue in ["lse"]:
47
+ cluster = [(1, 2), (2, 1)]
48
+ swap_ab_vals = [False, True]
49
+ if epilogue in ["lse", "gated"]:
50
+ swap_ab_vals = [False]
51
+ # raster_swizzle = (
52
+ # [(0, 1)]
53
+ # if not tune_raster_order
54
+ # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
55
+ # )
56
+ return [
57
+ GemmConfig(
58
+ tile_m=tile_m,
59
+ tile_n=tile_n,
60
+ pingpong=pingpong,
61
+ cluster_m=cluster_m,
62
+ cluster_n=cluster_n,
63
+ swap_ab=swap_ab,
64
+ # raster_order=raster_order,
65
+ # max_swizzle_size=max_swizzle_size,
66
+ )
67
+ for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
68
+ tile_mn_vals,
69
+ cluster,
70
+ swap_ab_vals,
71
+ # raster_swizzle,
72
+ )
73
+ ]
74
+ elif device_capacity == 10:
75
+ tile_n_vals = [128, 160, 192, 224, 256]
76
+ tile_n_64_vals = [128, 192, 256]
77
+ tile_mn_cluster_vals = (
78
+ [(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
79
+ # + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals]
80
+ + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
81
+ + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
82
+ )
83
+ swap_ab_vals = [False, True]
84
+ if epilogue in ["lse", "gated"]:
85
+ swap_ab_vals = [False]
86
+ max_swizzle_size_vals = [4, 8, 16]
87
+ GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100
88
+ return [
89
+ GemmConfigCls(
90
+ tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms
91
+ )
92
+ for (m, n, (cm, cn)), sab, ms in itertools.product(
93
+ tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals
94
+ )
95
+ ]
build/torch-cuda/quack/gemm_dact.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Optional, Tuple
3
+ from functools import partial
4
+
5
+ from torch import Tensor
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Float32, const_expr
10
+ import cutlass.torch as cutlass_torch
11
+
12
+ from .gemm_sm90 import GemmSm90
13
+ from .gemm_sm100 import GemmSm100
14
+ from .gemm_default_epi import GemmDefaultEpiMixin
15
+ from .gemm_act import GemmActMixin
16
+ from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
17
+ from .gemm_wrapper_utils import GemmWrapperBase
18
+ from . import activation
19
+
20
+
21
+ class GemmDActMixin(GemmActMixin):
22
+ # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
23
+ # and return 2 arguments (dx, out)
24
+ EpilogueArguments = GemmActMixin.EpilogueArguments
25
+ EpilogueParams = GemmActMixin.EpilogueParams
26
+
27
+ @cute.jit
28
+ def epi_visit_subtile(
29
+ self,
30
+ params: EpilogueParams,
31
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
32
+ tRS_rD: cute.Tensor,
33
+ tRS_rC: Optional[cute.Tensor] = None,
34
+ ) -> Optional[cute.Tensor]:
35
+ assert tRS_rC is not None
36
+ # We don't add C to the accumulator
37
+ GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
38
+ tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
39
+ tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
40
+ # If we don't have .shape here, the compiler generates local stores and loads
41
+ if const_expr(params.act_fn is not None):
42
+ tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
43
+ if const_expr(self.arch < 100):
44
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
45
+ tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
46
+ else:
47
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
48
+ (
49
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
50
+ (tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]),
51
+ ) = params.act_fn(
52
+ (tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]),
53
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
54
+ )
55
+ else:
56
+ tRS_rPostAct = tRS_rC_acc
57
+ # Type conversion
58
+ tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
59
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
60
+ return tRS_rPostAct_out
61
+
62
+
63
+ class GemmDActSm90(GemmDActMixin, GemmSm90):
64
+ pass
65
+
66
+
67
+ class GemmDActSm100(GemmDActMixin, GemmSm100):
68
+ pass
69
+
70
+
71
+ dact_fn_map = {
72
+ None: None,
73
+ "relu": activation.drelu,
74
+ "relu_sq": activation.drelu_sq,
75
+ "gelu_tanh_approx": activation.dgelu_tanh_approx,
76
+ }
77
+
78
+
79
+ def gemm_dact(
80
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
81
+ B: Tensor, # (l, n, k)
82
+ Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
83
+ PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
84
+ PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
85
+ tile_count_semaphore: Optional[Tensor], # (1,)
86
+ activation: Optional[str],
87
+ tile_M: int,
88
+ tile_N: int,
89
+ cluster_M: int,
90
+ cluster_N: int,
91
+ pingpong: bool = True,
92
+ persistent: bool = True,
93
+ max_swizzle_size: int = 8,
94
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
95
+ A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
96
+ ) -> None:
97
+ if cu_seqlens_m is not None:
98
+ assert persistent, "varlen_m requires persistent=True"
99
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
100
+ assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
101
+ assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
102
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
103
+ gather_A = A_idx is not None
104
+ if gather_A:
105
+ assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
106
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
107
+ assert activation in dact_fn_map, f"Unsupported activation {activation}"
108
+
109
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
110
+ A,
111
+ B,
112
+ Out,
113
+ PreAct,
114
+ additional_tensors={"PostAct": PostAct},
115
+ cu_seqlens_m=cu_seqlens_m,
116
+ A_idx=A_idx,
117
+ )
118
+ GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
119
+ GemmWrapperBase.extract_dtypes(tensor_infos)
120
+ major_configs = {
121
+ "A": ("m", "k", "l"),
122
+ "B": ("n", "k", "l"),
123
+ "D": ("m", "n", "l"),
124
+ "C": ("m", "n", "l"),
125
+ "PostAct": ("m", "n", "l"),
126
+ }
127
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
128
+
129
+ device_capacity = get_device_capacity(A.device)
130
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
131
+ GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90
132
+
133
+ acc_dtype = Float32
134
+ tile_shape_mn = (tile_M, tile_N)
135
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
136
+ if not GemmCls.is_valid_dtypes(
137
+ tensor_infos["A"].dtype,
138
+ tensor_infos["B"].dtype,
139
+ acc_dtype,
140
+ tensor_infos["D"].dtype,
141
+ tensor_infos["A"].major,
142
+ tensor_infos["B"].major,
143
+ ):
144
+ raise TypeError("Skipping due to unsupported combination of types and majors")
145
+
146
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
147
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
148
+ act_fn = dact_fn_map[activation]
149
+ epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
150
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
151
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
152
+ )
153
+
154
+ # Create varlen arguments if needed (assumes persistent=True when varlen_m)
155
+ varlen_args = GemmWrapperBase.create_varlen_args(
156
+ cu_seqlens_m,
157
+ None, # cu_seqlens_k
158
+ A_idx,
159
+ max_active_clusters,
160
+ cluster_shape_mnk,
161
+ tensor_infos,
162
+ GemmCls.num_epi_tensormaps,
163
+ pingpong,
164
+ )
165
+
166
+ current_stream = cutlass_torch.current_stream()
167
+ compile_key = GemmWrapperBase.get_compile_key(
168
+ tensor_infos,
169
+ activation,
170
+ tile_shape_mn,
171
+ cluster_shape_mnk,
172
+ pingpong,
173
+ persistent,
174
+ tile_count_semaphore is not None,
175
+ device_capacity,
176
+ max_swizzle_size,
177
+ cu_seqlens_m is not None,
178
+ A_idx is not None,
179
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
180
+ )
181
+ cache = gemm_dact.compile_cache
182
+ if compile_key not in cache:
183
+ if device_capacity[0] == 9:
184
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
185
+ gemm = GemmCls(
186
+ acc_dtype,
187
+ tensor_infos["A"].dtype,
188
+ tile_shape_mn,
189
+ cluster_shape_mnk,
190
+ gather_A=gather_A,
191
+ )
192
+ cache[compile_key] = cute.compile(
193
+ gemm,
194
+ tensor_infos["A"].cute_tensor,
195
+ tensor_infos["B"].cute_tensor,
196
+ tensor_infos["D"].cute_tensor,
197
+ tensor_infos["C"].cute_tensor,
198
+ epi_args,
199
+ scheduler_args,
200
+ varlen_args,
201
+ current_stream,
202
+ )
203
+ cache[compile_key](
204
+ tensor_infos["A"].cute_tensor,
205
+ tensor_infos["B"].cute_tensor,
206
+ tensor_infos["D"].cute_tensor,
207
+ tensor_infos["C"].cute_tensor,
208
+ epi_args,
209
+ scheduler_args,
210
+ varlen_args,
211
+ current_stream,
212
+ )
213
+
214
+
215
+ gemm_dact.compile_cache = {}
build/torch-cuda/quack/gemm_default_epi.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
+ from typing import Optional, Tuple
3
+ from functools import partial
4
+ from dataclasses import dataclass
5
+
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32, Float32, Boolean, const_expr
10
+
11
+ from .cute_dsl_utils import ArgumentsBase, ParamsBase
12
+ from .gemm_sm90 import GemmSm90
13
+ from .gemm_sm100 import GemmSm100
14
+ from .sm90_utils import partition_for_epilogue
15
+ from . import utils as utils
16
+ from . import copy_utils as copy_utils
17
+ from .varlen_utils import VarlenManager
18
+
19
+
20
+ class GemmDefaultEpiMixin:
21
+ num_epi_tensormaps: int = 0
22
+
23
+ @dataclass
24
+ class EpilogueArguments(ArgumentsBase):
25
+ alpha: Optional[Float32 | cute.Tensor] = None
26
+ beta: Optional[Float32 | cute.Tensor] = None
27
+ mRowVecBroadcast: Optional[cute.Tensor] = None
28
+ mColVecBroadcast: Optional[cute.Tensor] = None
29
+ add_to_output: bool = False
30
+
31
+ @dataclass
32
+ class EpilogueParams(ParamsBase):
33
+ alpha: Optional[Float32 | cute.Tensor] = None
34
+ beta: Optional[Float32 | cute.Tensor] = None
35
+ mRowVecBroadcast: Optional[cute.Tensor] = None
36
+ mColVecBroadcast: Optional[cute.Tensor] = None
37
+
38
+ def epi_to_underlying_arguments(
39
+ self, args: EpilogueArguments, *, loc=None, ip=None
40
+ ) -> EpilogueParams:
41
+ # Assume all strides are divisible by 32 bits except the last stride
42
+ new_stride = lambda t: tuple(
43
+ cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
44
+ for s in t.stride
45
+ )
46
+ mRowVecBroadcast, mColVecBroadcast = [
47
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
48
+ if t is not None
49
+ else None
50
+ for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
51
+ ]
52
+ return self.EpilogueParams(
53
+ alpha=args.alpha,
54
+ beta=args.beta,
55
+ mRowVecBroadcast=mRowVecBroadcast,
56
+ mColVecBroadcast=mColVecBroadcast,
57
+ )
58
+
59
+ @cute.jit
60
+ def epi_begin(
61
+ self,
62
+ params: EpilogueParams,
63
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
64
+ epi_tile: cute.Tile,
65
+ tiled_copy_t2r: Optional[cute.TiledCopy],
66
+ tiled_copy_r2s: cute.TiledCopy,
67
+ tile_coord_mnkl: cute.Coord,
68
+ varlen_manager: VarlenManager,
69
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
70
+ tidx: Int32,
71
+ ):
72
+ alpha, beta = None, None
73
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
74
+ alpha = utils.load_scalar_or_pointer(params.alpha)
75
+ if const_expr(hasattr(params, "beta") and params.beta is not None):
76
+ beta = utils.load_scalar_or_pointer(params.beta)
77
+ sRowVec, sColVec, *rest = epi_smem_tensors
78
+ tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
79
+ batch_idx = tile_coord_mnkl[3]
80
+ num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
81
+ # Don't need sync as we assume the previous epilogue has finished
82
+
83
+ partition_for_epilogue_fn = partial(
84
+ partition_for_epilogue,
85
+ epi_tile=epi_tile,
86
+ tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
87
+ tidx=tidx,
88
+ reference_src=tiled_copy_t2r is None,
89
+ )
90
+
91
+ tDsRowVec = None
92
+ if const_expr(params.mRowVecBroadcast is not None):
93
+ rowvec_dtype = params.mRowVecBroadcast.element_type
94
+ num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width
95
+ thr_copy_RV = copy_utils.tiled_copy_1d(
96
+ params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
97
+ ).get_slice(tidx)
98
+ mRowVec = params.mRowVecBroadcast[batch_idx, None]
99
+ gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],))
100
+ tRVgRV = thr_copy_RV.partition_S(gRowVec)
101
+ tRVsRV = thr_copy_RV.partition_D(sRowVec)
102
+ tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
103
+ limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
104
+ tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
105
+ for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
106
+ tRVpRV[0, m] = tRVcRV[0, m] < limit_n
107
+ cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
108
+ # (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
109
+ tDsRowVec = partition_for_epilogue_fn(
110
+ cute.make_tensor(
111
+ sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1))
112
+ )
113
+ )
114
+ if const_expr(tiled_copy_t2r is not None):
115
+ tDsRowVec = tiled_copy_r2s.retile(tDsRowVec)
116
+
117
+ tDsColVec = None
118
+ if const_expr(params.mColVecBroadcast is not None):
119
+ colvec_dtype = params.mColVecBroadcast.element_type
120
+ num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width
121
+ thr_copy_CV = copy_utils.tiled_copy_1d(
122
+ params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
123
+ ).get_slice(tidx)
124
+ if const_expr(not varlen_manager.varlen_m):
125
+ mColVec = params.mColVecBroadcast[batch_idx, None]
126
+ else:
127
+ mColVec = cute.domain_offset(
128
+ (varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast
129
+ )
130
+ gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
131
+ tCVgCV = thr_copy_CV.partition_S(gColVec)
132
+ tCVsCV = thr_copy_CV.partition_D(sColVec)
133
+ tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
134
+ limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
135
+ tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
136
+ for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
137
+ tCVpCV[0, m] = tCVcCV[0, m] < limit_m
138
+ cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
139
+ tDsColVec = partition_for_epilogue_fn(
140
+ cute.make_tensor(
141
+ sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0))
142
+ )
143
+ )
144
+ if const_expr(tiled_copy_t2r is not None):
145
+ tDsColVec = tiled_copy_r2s.retile(tDsColVec)
146
+
147
+ if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None):
148
+ cute.arch.cp_async_commit_group()
149
+ cute.arch.cp_async_wait_group(0)
150
+ epilogue_barrier.arrive_and_wait()
151
+ return alpha, beta, tDsRowVec, tDsColVec
152
+
153
+ def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
154
+ alpha, beta, tDsRowVec, tDsColVec = epi_tensors
155
+ tDrRowVec_cvt = None
156
+ if const_expr(tDsRowVec is not None):
157
+ tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[
158
+ None, None, None, epi_coord
159
+ ]
160
+ # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
161
+ tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
162
+ cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
163
+ tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
164
+ tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
165
+ tDrColVec_cvt = None
166
+ if const_expr(tDsColVec is not None):
167
+ tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[
168
+ None, None, None, epi_coord
169
+ ]
170
+ # This somehow doesn't work, some dim with stride 0 turns to non-zero stride
171
+ # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
172
+ tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
173
+ cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
174
+ tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
175
+ tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
176
+ return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt
177
+
178
+ @cute.jit
179
+ def epi_visit_subtile(
180
+ self,
181
+ params: EpilogueParams,
182
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
183
+ tRS_rD: cute.Tensor,
184
+ tRS_rC: Optional[cute.Tensor] = None,
185
+ ) -> Optional[cute.Tensor]:
186
+ alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors
187
+ rD = tRS_rD.load()
188
+ # Apply alpha scaling to accumulator if alpha is provided (not None)
189
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
190
+ alpha = utils.load_scalar_or_pointer(params.alpha)
191
+ rD *= alpha
192
+ # Apply C with beta scaling
193
+ if const_expr(tRS_rC is not None):
194
+ if const_expr(not hasattr(params, "beta") or params.beta is None):
195
+ # beta is None, default behavior: add C (beta=1.0)
196
+ rD += tRS_rC.load().to(tRS_rD.element_type)
197
+ else:
198
+ beta = utils.load_scalar_or_pointer(params.beta)
199
+ rD += beta * tRS_rC.load().to(tRS_rD.element_type)
200
+ tRS_rD.store(rD)
201
+ if const_expr(tDrRowVec is not None):
202
+ for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True):
203
+ tRS_rD[i] += tDrRowVec[i]
204
+ if const_expr(tDrColVec is not None):
205
+ for i in cutlass.range(cute.size(tDrColVec), unroll_full=True):
206
+ tRS_rD[i] += tDrColVec[i]
207
+ return None
208
+
209
+ @staticmethod
210
+ def epi_smem_bytes_per_stage(
211
+ args: Optional[EpilogueArguments],
212
+ cta_tile_shape_mnk: Tuple[int, int, int],
213
+ epi_tile: cute.Tile,
214
+ ) -> int:
215
+ row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1]
216
+ col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0]
217
+ row_vec_dtype = (
218
+ args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32
219
+ )
220
+ col_vec_dtype = (
221
+ args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32
222
+ )
223
+ return (
224
+ row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width
225
+ ) // 8
226
+
227
+ def epi_get_smem_struct(self, params: EpilogueParams):
228
+ row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
229
+ col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
230
+ row_vec_dtype = (
231
+ params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
232
+ )
233
+ col_vec_dtype = (
234
+ params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
235
+ )
236
+
237
+ @cute.struct
238
+ class EpiSharedStorage:
239
+ sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
240
+ sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
241
+
242
+ return EpiSharedStorage
243
+
244
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
245
+ sRowVec = None
246
+ if const_expr(params.mRowVecBroadcast is not None):
247
+ sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1]))
248
+ sColVec = None
249
+ if const_expr(params.mColVecBroadcast is not None):
250
+ sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0]))
251
+ return (sRowVec, sColVec)
252
+
253
+
254
+ class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
255
+ pass
256
+
257
+
258
+ class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
259
+ pass
build/torch-cuda/quack/gemm_interface.py ADDED
@@ -0,0 +1,1058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao
2
+ from typing import Optional, Tuple, Literal
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from ._ops_compat import add_quack_op_namespace_prefix
9
+
10
+ from .gemm_config import GemmConfig, get_all_configs
11
+
12
+ from .autotuner import autotune, AutotuneConfig
13
+ from .cute_dsl_utils import get_device_capacity
14
+ from .gemm import gemm as gemm_sm90_sm100
15
+ from .gemm_act import gemm_act as gemm_act_sm90_sm100
16
+ from .gemm_dact import gemm_dact as gemm_dact_sm90_sm100
17
+ from .gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
18
+
19
+
20
+ # Dictionary mapping activation names to PyTorch functions
21
+ act_to_pytorch_fn_map = {
22
+ None: lambda x: x,
23
+ "relu": F.relu,
24
+ "relu_sq": lambda x: F.relu(x).square(),
25
+ "gelu_tanh_approx": partial(F.gelu, approximate="tanh"),
26
+ }
27
+
28
+
29
+ # Dictionary mapping gated activation names to their forward functions
30
+ # Each function takes (gate, up) and returns postact
31
+ gated_to_pytorch_fn_map = {
32
+ "swiglu": lambda gate, up: F.silu(gate) * up,
33
+ "swiglu_oai": lambda gate, up: gate * torch.sigmoid(1.702 * gate) * (up + 1),
34
+ "reglu": lambda gate, up: F.relu(gate) * up,
35
+ "geglu": lambda gate, up: F.gelu(gate, approximate="tanh") * up,
36
+ "glu": lambda gate, up: torch.sigmoid(gate) * up,
37
+ }
38
+
39
+
40
+ def _get_default_device_capacity():
41
+ if not torch.cuda.is_available():
42
+ return (9, 0)
43
+ cap = get_device_capacity(torch.device("cuda"))
44
+ if cap[0] not in (9, 10):
45
+ return (9, 0)
46
+ return cap
47
+
48
+
49
+ class _LazyDeviceCapacity:
50
+ """Defer torch.cuda.get_device_capability until first access so the
51
+ module can be imported in environments without a GPU (e.g. nix build)."""
52
+ _value = None
53
+ def __getitem__(self, idx):
54
+ if self._value is None:
55
+ self._value = _get_default_device_capacity()
56
+ return self._value[idx]
57
+
58
+
59
+ default_device_capacity = _LazyDeviceCapacity()
60
+
61
+
62
+ def default_config(device):
63
+ if get_device_capacity(device)[0] != 10:
64
+ return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
65
+ else:
66
+ return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False)
67
+
68
+
69
+ def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
70
+ kwargs = named_args | kwargs
71
+ gather_A = kwargs.get("A_idx", None) is not None
72
+ varlen_m = kwargs.get("cu_seqlens_m", None) is not None
73
+ if varlen_m or gather_A: # Doesn't support swap_ab
74
+ configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
75
+ if gather_A:
76
+ if get_device_capacity(kwargs["A"].device)[0] == 9:
77
+ # tile_n == 208 causes register spills, as gather_A requires more registers for the producer
78
+ configs = [
79
+ conf
80
+ for conf in configs
81
+ if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
82
+ ]
83
+ return configs
84
+
85
+
86
+ @autotune(
87
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
88
+ key=["dynamic_scheduler"],
89
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
90
+ )
91
+ def gemm_tuned(
92
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
93
+ A: Tensor,
94
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
95
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
96
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
97
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
98
+ alpha: float | Tensor = 1.0, # (1,)
99
+ beta: float | Tensor = 1.0, # (1,)
100
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
101
+ cu_seqlens_k: Optional[Tensor] = None, # (L+1), int32
102
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
103
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
104
+ add_to_output: bool = False,
105
+ dynamic_scheduler: bool = False,
106
+ config: Optional[GemmConfig] = None,
107
+ ) -> None:
108
+ if config is None:
109
+ config = default_config(A.device)
110
+ varlen_m = cu_seqlens_m is not None
111
+ varlen_k = cu_seqlens_k is not None
112
+ varlen = varlen_m or varlen_k
113
+ gather_A = A_idx is not None
114
+ if gather_A:
115
+ assert varlen, "gather_A requires either varlen_m or varlen_k"
116
+ assert config.cluster_n == 1, "gather_A requires cluster_n=1"
117
+ if varlen_m:
118
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
119
+ if A.ndim == 2 and not varlen:
120
+ A = A.unsqueeze(0) # (1, M, K)
121
+ B = B.mT # (N, K) or (L, N, K) or (N, total_K)
122
+ if B.ndim == 2 and not varlen_k:
123
+ B = B.unsqueeze(0) # (1, N, K)
124
+ if C is not None and C.ndim == 2 and not varlen_m:
125
+ C = C.unsqueeze(0) # (1, M, N)
126
+ if out.ndim == 2 and not varlen_m:
127
+ out = out.unsqueeze(0)
128
+ if bias is not None and bias.ndim == 1:
129
+ bias = bias.unsqueeze(0) # (L, N)
130
+ batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1
131
+ if varlen_m:
132
+ # If gather_A (A_idx provided), use its length; otherwise use A.shape[0]
133
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
134
+ out_shape = (total_m, B.shape[-2])
135
+ else:
136
+ out_shape = (batch_size, A.shape[-2], B.shape[-2])
137
+ assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
138
+ tile_count_semaphore = (
139
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
140
+ )
141
+ gemm_sm90_sm100(
142
+ A if not config.swap_ab else B,
143
+ B if not config.swap_ab else A,
144
+ out if not config.swap_ab else out.mT,
145
+ (C if not config.swap_ab else C.mT) if C is not None else None,
146
+ tile_count_semaphore,
147
+ config.tile_m,
148
+ config.tile_n,
149
+ config.cluster_m,
150
+ config.cluster_n,
151
+ config.pingpong,
152
+ persistent=True,
153
+ max_swizzle_size=config.max_swizzle_size,
154
+ rowvec_bias=bias if not config.swap_ab else None,
155
+ colvec_bias=bias if config.swap_ab else None,
156
+ alpha=alpha,
157
+ beta=beta,
158
+ cu_seqlens_m=cu_seqlens_m,
159
+ cu_seqlens_k=cu_seqlens_k,
160
+ A_idx=A_idx,
161
+ batch_idx_permute=batch_idx_permute,
162
+ add_to_output=add_to_output,
163
+ )
164
+
165
+
166
+ @autotune(
167
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
168
+ key=["activation", "dynamic_scheduler"],
169
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
170
+ )
171
+ def gemm_act_tuned(
172
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
173
+ A: Tensor,
174
+ B: Tensor, # (K, N) or (L, K, N)
175
+ # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
176
+ preact_out: Optional[Tensor],
177
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
178
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
179
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
180
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
181
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
182
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
183
+ dynamic_scheduler: bool = False,
184
+ config: Optional[GemmConfig] = None,
185
+ ) -> None:
186
+ if config is None:
187
+ config = default_config(A.device)
188
+ varlen_m = cu_seqlens_m is not None
189
+ if varlen_m:
190
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
191
+ if A.ndim == 2 and not varlen_m:
192
+ A = A.unsqueeze(0) # (1, M, K)
193
+ B = B.mT # (N, K) or (L, N, K)
194
+ if B.ndim == 2:
195
+ B = B.unsqueeze(0) # (1, N, K)
196
+ if C is not None and C.ndim == 2 and not varlen_m:
197
+ C = C.unsqueeze(0) # (1, M, N)
198
+ if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
199
+ D = preact_out.unsqueeze(0)
200
+ else:
201
+ D = preact_out
202
+ if postact_out.ndim == 2 and not varlen_m:
203
+ PostAct = postact_out.unsqueeze(0)
204
+ else:
205
+ PostAct = postact_out
206
+ if bias is not None and bias.ndim == 1:
207
+ bias = bias.unsqueeze(0) # (L, N)
208
+ tile_count_semaphore = (
209
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
210
+ )
211
+ gemm_act_sm90_sm100(
212
+ A if not config.swap_ab else B,
213
+ B if not config.swap_ab else A,
214
+ (D if not config.swap_ab else D.mT) if D is not None else None,
215
+ (C if not config.swap_ab else C.mT) if C is not None else None,
216
+ PostAct if not config.swap_ab else PostAct.mT,
217
+ tile_count_semaphore,
218
+ activation,
219
+ config.tile_m,
220
+ config.tile_n,
221
+ config.cluster_m,
222
+ config.cluster_n,
223
+ config.pingpong,
224
+ persistent=True,
225
+ max_swizzle_size=config.max_swizzle_size,
226
+ rowvec_bias=bias if not config.swap_ab else None,
227
+ colvec_bias=bias if config.swap_ab else None,
228
+ cu_seqlens_m=cu_seqlens_m,
229
+ A_idx=A_idx,
230
+ )
231
+
232
+
233
+ @autotune(
234
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
235
+ key=["activation", "dynamic_scheduler"],
236
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
237
+ )
238
+ def gemm_dact_tuned(
239
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
240
+ A: Tensor,
241
+ B: Tensor, # (K, N) or (L, K, N)
242
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
243
+ dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
244
+ postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
245
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
246
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
247
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
248
+ dynamic_scheduler: bool = True,
249
+ config: Optional[GemmConfig] = None,
250
+ ) -> None:
251
+ if config is None:
252
+ config = default_config(A.device)
253
+ varlen_m = cu_seqlens_m is not None
254
+ if varlen_m:
255
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
256
+ if A.ndim == 2 and not varlen_m:
257
+ A = A.unsqueeze(0) # (1, M, K)
258
+ B = B.mT # (N, K) or (L, N, K)
259
+ if B.ndim == 2:
260
+ B = B.unsqueeze(0) # (1, N, K)
261
+ if PreAct.ndim == 2 and not varlen_m:
262
+ PreAct = PreAct.unsqueeze(0) # (1, M, N)
263
+ if dx_out.ndim == 2 and not varlen_m:
264
+ D = dx_out.unsqueeze(0)
265
+ else:
266
+ D = dx_out
267
+ if postact_out.ndim == 2 and not varlen_m:
268
+ PostAct = postact_out.unsqueeze(0)
269
+ else:
270
+ PostAct = postact_out
271
+ tile_count_semaphore = (
272
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
273
+ )
274
+ gemm_dact_sm90_sm100(
275
+ A if not config.swap_ab else B,
276
+ B if not config.swap_ab else A,
277
+ D if not config.swap_ab else D.mT,
278
+ PreAct if not config.swap_ab else PreAct.mT,
279
+ PostAct if not config.swap_ab else PostAct.mT,
280
+ tile_count_semaphore,
281
+ activation,
282
+ config.tile_m,
283
+ config.tile_n,
284
+ config.cluster_m,
285
+ config.cluster_n,
286
+ config.pingpong,
287
+ persistent=True,
288
+ max_swizzle_size=config.max_swizzle_size,
289
+ cu_seqlens_m=cu_seqlens_m,
290
+ A_idx=A_idx,
291
+ )
292
+
293
+
294
+ def gemm(
295
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
296
+ A: Tensor,
297
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
298
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
299
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
300
+ alpha: float | Tensor = 1.0,
301
+ out_dtype: Optional[torch.dtype] = None,
302
+ cu_seqlens_m: Optional[Tensor] = None,
303
+ cu_seqlens_k: Optional[Tensor] = None,
304
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
305
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
306
+ dynamic_scheduler: bool = False,
307
+ tuned: bool = True,
308
+ ) -> Tensor:
309
+ """GEMM with optional output tensor and tuning control."""
310
+ if out is None:
311
+ out_dtype = A.dtype if out_dtype is None else out_dtype
312
+ varlen_m = cu_seqlens_m is not None
313
+ varlen_k = cu_seqlens_k is not None
314
+ if varlen_m:
315
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
316
+ out_shape = (total_m, B.shape[-1])
317
+ elif varlen_k:
318
+ L = cu_seqlens_k.shape[0] - 1
319
+ # For varlen_k, the first dimension is always A.shape[0] (M dimension)
320
+ out_shape = (L, A.shape[0], B.shape[-1])
321
+ else:
322
+ out_shape = (
323
+ (A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
324
+ )
325
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
326
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
327
+ alpha = alpha if isinstance(alpha, float) else 1.0
328
+ gemm_out(
329
+ A,
330
+ B,
331
+ out,
332
+ bias=bias,
333
+ alpha=alpha,
334
+ alpha_tensor=alpha_tensor,
335
+ cu_seqlens_m=cu_seqlens_m,
336
+ cu_seqlens_k=cu_seqlens_k,
337
+ A_idx=A_idx,
338
+ batch_idx_permute=batch_idx_permute,
339
+ dynamic_scheduler=dynamic_scheduler,
340
+ tuned=tuned,
341
+ )
342
+ return out
343
+
344
+
345
+ @torch.library.custom_op(
346
+ add_quack_op_namespace_prefix("gemm_out"),
347
+ mutates_args=("out",),
348
+ device_types="cuda",
349
+ # We have to split out alpha and alpha_tensor since torch.library requires
350
+ # each argument to have a fixed type
351
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? bias, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
352
+ )
353
+ def gemm_out(
354
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
355
+ A: Tensor,
356
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
357
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
358
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
359
+ alpha: float = 1.0,
360
+ alpha_tensor: Optional[Tensor] = None,
361
+ cu_seqlens_m: Optional[Tensor] = None,
362
+ cu_seqlens_k: Optional[Tensor] = None,
363
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
364
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
365
+ dynamic_scheduler: bool = False,
366
+ tuned: bool = True,
367
+ ) -> None:
368
+ """GEMM with pre-allocated output tensor."""
369
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
370
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
371
+ fn(
372
+ A,
373
+ B,
374
+ out,
375
+ C=None,
376
+ bias=bias,
377
+ alpha=alpha,
378
+ cu_seqlens_m=cu_seqlens_m,
379
+ cu_seqlens_k=cu_seqlens_k,
380
+ A_idx=A_idx,
381
+ batch_idx_permute=batch_idx_permute,
382
+ dynamic_scheduler=dynamic_scheduler,
383
+ )
384
+
385
+
386
+ def gemm_ref(
387
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
388
+ A: Tensor,
389
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
390
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
391
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
392
+ alpha: float | Tensor = 1.0,
393
+ cu_seqlens_m: Optional[Tensor] = None,
394
+ cu_seqlens_k: Optional[Tensor] = None,
395
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
396
+ out_dtype: Optional[torch.dtype] = None,
397
+ ) -> Tensor:
398
+ """Reference implementation for GEMM with pre-allocated output."""
399
+ # The out_dtype argument requires torch >= 2.8
400
+ out_dtype = A.dtype if out_dtype is None else out_dtype
401
+ if cu_seqlens_m is None and cu_seqlens_k is None:
402
+ fn = torch.bmm if A.ndim == 3 else torch.mm
403
+ out = fn(A, B, out_dtype=out_dtype, out=out)
404
+ if not isinstance(alpha, float) or alpha != 1.0:
405
+ out *= alpha
406
+ if bias is not None:
407
+ bias = bias if A.ndim == 2 else bias.unsqueeze(1)
408
+ out += bias
409
+ elif cu_seqlens_m is not None:
410
+ # Handle varlen_m case
411
+ if out is None:
412
+ # When gather_A (A_idx provided), output size is determined by A_idx length
413
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
414
+ out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
415
+ for i in range(cu_seqlens_m.shape[0] - 1):
416
+ A_slice = (
417
+ A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
418
+ if A_idx is not None
419
+ else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
420
+ )
421
+ torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]])
422
+ if not isinstance(alpha, float) or alpha != 1.0:
423
+ out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] *= alpha
424
+ if bias is not None:
425
+ out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] += bias[i]
426
+ else: # cu_seqlens_k is not None
427
+ L = cu_seqlens_k.shape[0] - 1
428
+ if out is None:
429
+ out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
430
+ for i in range(L):
431
+ A_slice = (
432
+ A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
433
+ if A_idx is not None
434
+ else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
435
+ )
436
+ torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i])
437
+ if not isinstance(alpha, float) or alpha != 1.0:
438
+ out *= alpha
439
+ if bias is not None:
440
+ out += bias
441
+ return out
442
+
443
+
444
+ def gemm_add(
445
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
446
+ A: Tensor,
447
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
448
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
449
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
450
+ alpha: float | Tensor = 1.0,
451
+ beta: float | Tensor = 1.0,
452
+ out_dtype: Optional[torch.dtype] = None,
453
+ cu_seqlens_m: Optional[Tensor] = None,
454
+ cu_seqlens_k: Optional[Tensor] = None,
455
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
456
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
457
+ dynamic_scheduler: bool = False,
458
+ tuned: bool = True,
459
+ ) -> Tensor:
460
+ """GEMM with addition and optional output tensor."""
461
+ if out is None:
462
+ out_dtype = A.dtype if out_dtype is None else out_dtype
463
+ varlen_m = cu_seqlens_m is not None
464
+ varlen_k = cu_seqlens_k is not None
465
+ if varlen_m:
466
+ # If A_idx is provided (gather_A), use its length; otherwise use A.shape[0]
467
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
468
+ out_shape = (total_m, B.shape[-1])
469
+ elif varlen_k:
470
+ L = cu_seqlens_k.shape[0] - 1
471
+ # For varlen_k, the first dimension is always A.shape[0] (M dimension)
472
+ out_shape = (L, A.shape[0], B.shape[-1])
473
+ else:
474
+ out_shape = (
475
+ (A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
476
+ )
477
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
478
+ add_to_output = C is out and isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
479
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
480
+ alpha = alpha if isinstance(alpha, float) else 1.0
481
+ beta_tensor = beta if not isinstance(beta, float) else None
482
+ beta = beta if isinstance(beta, float) else 1.0
483
+ gemm_add_out(
484
+ A,
485
+ B,
486
+ C if not add_to_output else None,
487
+ out,
488
+ alpha,
489
+ beta,
490
+ alpha_tensor,
491
+ beta_tensor,
492
+ cu_seqlens_m=cu_seqlens_m,
493
+ cu_seqlens_k=cu_seqlens_k,
494
+ A_idx=A_idx,
495
+ batch_idx_permute=batch_idx_permute,
496
+ add_to_output=add_to_output,
497
+ dynamic_scheduler=dynamic_scheduler,
498
+ tuned=tuned,
499
+ )
500
+ return out
501
+
502
+
503
+ @torch.library.custom_op(
504
+ add_quack_op_namespace_prefix("gemm_add_out"),
505
+ mutates_args=("out",),
506
+ device_types="cuda",
507
+ # We have to split out alpha and alpha_tensor since torch.library requires
508
+ # each argument to have a fixed type
509
+ # schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
510
+ )
511
+ def gemm_add_out(
512
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
513
+ A: Tensor,
514
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
515
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
516
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
517
+ alpha: float = 1.0,
518
+ beta: float = 1.0,
519
+ alpha_tensor: Optional[Tensor] = None,
520
+ beta_tensor: Optional[Tensor] = None,
521
+ cu_seqlens_m: Optional[Tensor] = None,
522
+ cu_seqlens_k: Optional[Tensor] = None,
523
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
524
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
525
+ add_to_output: bool = False,
526
+ dynamic_scheduler: bool = False,
527
+ tuned: bool = True,
528
+ ) -> None:
529
+ """GEMM with addition and pre-allocated output tensor."""
530
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
531
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
532
+ beta = beta_tensor if beta_tensor is not None else beta
533
+ fn(
534
+ A,
535
+ B,
536
+ out,
537
+ C,
538
+ alpha=alpha,
539
+ beta=beta,
540
+ cu_seqlens_m=cu_seqlens_m,
541
+ cu_seqlens_k=cu_seqlens_k,
542
+ A_idx=A_idx,
543
+ batch_idx_permute=batch_idx_permute,
544
+ add_to_output=add_to_output,
545
+ dynamic_scheduler=dynamic_scheduler,
546
+ )
547
+
548
+
549
+ def gemm_add_ref(
550
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
551
+ A: Tensor,
552
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
553
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
554
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
555
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
556
+ alpha: float | Tensor = 1.0,
557
+ beta: float | Tensor = 1.0,
558
+ cu_seqlens_m: Optional[Tensor] = None,
559
+ cu_seqlens_k: Optional[Tensor] = None,
560
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
561
+ out_dtype: Optional[torch.dtype] = None,
562
+ ) -> Tensor:
563
+ """Reference implementation for GEMM with addition and pre-allocated output."""
564
+ if cu_seqlens_m is None and cu_seqlens_k is None:
565
+ if isinstance(alpha, float) and isinstance(beta, float):
566
+ out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
567
+ else:
568
+ out_dtype = (
569
+ out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
570
+ )
571
+ result = (alpha * (A @ B) + beta * C).to(out_dtype)
572
+ if out is not None:
573
+ out.copy_(result)
574
+ if bias is not None:
575
+ bias = bias if A.ndim == 2 else bias.unsqueeze(1)
576
+ out += bias
577
+ elif cu_seqlens_m is not None:
578
+ # Handle varlen_m case
579
+ if out is None:
580
+ # When gather_A (A_idx provided), output size is determined by A_idx length
581
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
582
+ out_dtype = out_dtype if out_dtype is not None else A.dtype
583
+ out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
584
+ for i in range(cu_seqlens_m.shape[0] - 1):
585
+ A_slice = (
586
+ A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
587
+ if A_idx is not None
588
+ else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
589
+ )
590
+ C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
591
+ out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
592
+ result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice
593
+ if bias is not None:
594
+ result += bias[i]
595
+ out_slice.copy_(result)
596
+ else: # cu_seqlens_k is not None
597
+ # Handle varlen_k case
598
+ L = cu_seqlens_k.shape[0] - 1
599
+ out_dtype = out_dtype if out_dtype is not None else A.dtype
600
+ if out is None:
601
+ out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
602
+ for i in range(L):
603
+ A_slice = (
604
+ A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
605
+ if A_idx is not None
606
+ else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
607
+ )
608
+ B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :]
609
+ result = alpha * torch.mm(A_slice, B_slice) + beta * C[i]
610
+ out[i].copy_(result)
611
+ if bias is not None:
612
+ out += bias
613
+ return out
614
+
615
+
616
+ def gemm_add_inplace(
617
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
618
+ A: Tensor,
619
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
620
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
621
+ alpha: float | Tensor = 1.0,
622
+ beta: float | Tensor = 1.0,
623
+ cu_seqlens_m: Optional[Tensor] = None,
624
+ cu_seqlens_k: Optional[Tensor] = None,
625
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
626
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
627
+ dynamic_scheduler: bool = False,
628
+ tuned: bool = True,
629
+ ) -> None:
630
+ """In-place GEMM with addition: out = alpha * A @ B + beta * out.
631
+ Args:
632
+ A: (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k - input tensor
633
+ B: (K, N) or (L, K, N) or (total_K, N) if varlen_k - input tensor
634
+ out: (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k - tensor to accumulate into (modified in-place)
635
+ alpha: Scalar multiplier for A @ B
636
+ beta: Scalar multiplier for out
637
+ cu_seqlens_m: Optional cumulative sequence lengths for variable M
638
+ cu_seqlens_k: Optional cumulative sequence lengths for variable K
639
+ dynamic_scheduler: Whether to use dynamic scheduler
640
+ tuned: Whether to use autotuned configuration
641
+ """
642
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
643
+ alpha = alpha if isinstance(alpha, float) else 1.0
644
+ beta_tensor = beta if not isinstance(beta, float) else None
645
+ beta = beta if isinstance(beta, float) else 1.0
646
+ gemm_add_inplace_op(
647
+ A,
648
+ B,
649
+ out,
650
+ alpha,
651
+ beta,
652
+ alpha_tensor,
653
+ beta_tensor,
654
+ cu_seqlens_m,
655
+ cu_seqlens_k,
656
+ A_idx=A_idx,
657
+ batch_idx_permute=batch_idx_permute,
658
+ dynamic_scheduler=dynamic_scheduler,
659
+ tuned=tuned,
660
+ )
661
+
662
+
663
+ @torch.library.custom_op(
664
+ add_quack_op_namespace_prefix("gemm_add_inplace"),
665
+ mutates_args=("out",),
666
+ device_types="cuda",
667
+ # We have to split out alpha and alpha_tensor since torch.library requires
668
+ # each argument to have a fixed type
669
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
670
+ )
671
+ def gemm_add_inplace_op(
672
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
673
+ A: Tensor,
674
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
675
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
676
+ alpha: float = 1.0,
677
+ beta: float = 1.0,
678
+ alpha_tensor: Optional[Tensor] = None,
679
+ beta_tensor: Optional[Tensor] = None,
680
+ cu_seqlens_m: Optional[Tensor] = None,
681
+ cu_seqlens_k: Optional[Tensor] = None,
682
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
683
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
684
+ dynamic_scheduler: bool = False,
685
+ tuned: bool = True,
686
+ ) -> None:
687
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
688
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
689
+ beta = beta_tensor if beta_tensor is not None else beta
690
+ add_to_output = isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
691
+ # Use out as both input bias and output
692
+ fn(
693
+ A,
694
+ B,
695
+ out,
696
+ out if not add_to_output else None,
697
+ alpha=alpha,
698
+ beta=beta,
699
+ cu_seqlens_m=cu_seqlens_m,
700
+ cu_seqlens_k=cu_seqlens_k,
701
+ A_idx=A_idx,
702
+ batch_idx_permute=batch_idx_permute,
703
+ add_to_output=add_to_output,
704
+ dynamic_scheduler=dynamic_scheduler,
705
+ )
706
+
707
+
708
+ def gemm_act(
709
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
710
+ B: Tensor, # (K, N) or (L, K, N)
711
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
712
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
713
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
714
+ preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
715
+ postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
716
+ out_dtype: Optional[torch.dtype] = None,
717
+ postact_dtype: Optional[torch.dtype] = None,
718
+ cu_seqlens_m: Optional[Tensor] = None,
719
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
720
+ store_preact: bool = True,
721
+ dynamic_scheduler: bool = False,
722
+ tuned: bool = True,
723
+ ) -> Tuple[Optional[Tensor], Tensor]:
724
+ """GEMM with activation and optional output tensors."""
725
+ out_dtype = A.dtype if out_dtype is None else out_dtype
726
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
727
+ varlen_m = cu_seqlens_m is not None
728
+ # Determine output shape based on gather_A
729
+ if varlen_m:
730
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
731
+ out_shape = (total_m, B.shape[-1])
732
+ elif A.ndim == 2:
733
+ out_shape = (A.shape[0], B.shape[-1])
734
+ else:
735
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
736
+ if preact_out is None and store_preact:
737
+ preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
738
+ if postact_out is None:
739
+ postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
740
+ gemm_act_out(
741
+ A,
742
+ B,
743
+ preact_out,
744
+ postact_out,
745
+ C,
746
+ bias,
747
+ activation,
748
+ cu_seqlens_m,
749
+ A_idx,
750
+ dynamic_scheduler,
751
+ tuned,
752
+ )
753
+ return preact_out, postact_out
754
+
755
+
756
+ @torch.library.custom_op(
757
+ add_quack_op_namespace_prefix("gemm_act_out"),
758
+ mutates_args=("preact_out", "postact_out"),
759
+ device_types="cuda",
760
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
761
+ )
762
+ def gemm_act_out(
763
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
764
+ B: Tensor, # (K, N) or (L, K, N)
765
+ preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
766
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
767
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
768
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
769
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
770
+ cu_seqlens_m: Optional[Tensor] = None,
771
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
772
+ dynamic_scheduler: bool = False,
773
+ tuned: bool = True,
774
+ ) -> None:
775
+ """GEMM with activation and pre-allocated output tensors."""
776
+ fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
777
+ fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
778
+
779
+
780
+ def gemm_act_ref(
781
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
782
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
783
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
784
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
785
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
786
+ cu_seqlens_m: Optional[Tensor] = None,
787
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
788
+ out_dtype: Optional[torch.dtype] = None,
789
+ postact_dtype: Optional[torch.dtype] = None,
790
+ store_preact: bool = True,
791
+ ) -> Tuple[Optional[Tensor], Tensor]:
792
+ out_dtype = A.dtype if out_dtype is None else out_dtype
793
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
794
+ if C is None:
795
+ out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
796
+ else:
797
+ out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
798
+ postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
799
+ return out.to(out_dtype) if store_preact else None, postact
800
+
801
+
802
+ def gemm_dact(
803
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
804
+ B: Tensor, # (K, N) or (L, K, N)
805
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
806
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
807
+ dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
808
+ postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
809
+ out_dtype: Optional[torch.dtype] = None,
810
+ postact_dtype: Optional[torch.dtype] = None,
811
+ cu_seqlens_m: Optional[Tensor] = None,
812
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
813
+ dynamic_scheduler: bool = True,
814
+ tuned: bool = True,
815
+ ) -> Tuple[Tensor, Tensor]:
816
+ """GEMM with activation gradient and optional output tensors."""
817
+ out_dtype = A.dtype if out_dtype is None else out_dtype
818
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
819
+ varlen_m = cu_seqlens_m is not None
820
+ # Determine output shape based on gather_A
821
+ if varlen_m:
822
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
823
+ out_shape = (total_m, B.shape[-1])
824
+ elif A.ndim == 2:
825
+ out_shape = (A.shape[0], B.shape[-1])
826
+ else:
827
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
828
+ if dx_out is None:
829
+ dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
830
+ if postact_out is None:
831
+ postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
832
+ gemm_dact_out(
833
+ A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
834
+ )
835
+ return dx_out, postact_out
836
+
837
+
838
+ @torch.library.custom_op(
839
+ add_quack_op_namespace_prefix("gemm_dact_out"),
840
+ mutates_args=("dx_out", "postact_out"),
841
+ device_types="cuda",
842
+ schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
843
+ )
844
+ def gemm_dact_out(
845
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
846
+ B: Tensor, # (K, N) or (L, K, N)
847
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
848
+ dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
849
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
850
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
851
+ cu_seqlens_m: Optional[Tensor] = None,
852
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
853
+ dynamic_scheduler: bool = True,
854
+ tuned: bool = True,
855
+ ) -> None:
856
+ """GEMM with activation gradient and pre-allocated output tensors."""
857
+ fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
858
+ fn(A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
859
+
860
+
861
+ def gemm_dact_ref(
862
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
863
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
864
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
865
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
866
+ cu_seqlens_m: Optional[Tensor] = None,
867
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
868
+ out_dtype: Optional[torch.dtype] = None,
869
+ postact_dtype: Optional[torch.dtype] = None,
870
+ ) -> Tuple[Tensor, Tensor]:
871
+ """Reference implementation for GEMM with activation gradient."""
872
+ out_dtype = A.dtype if out_dtype is None else out_dtype
873
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
874
+ dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
875
+ postact = act_to_pytorch_fn_map[activation](PreAct)
876
+ # Compute gradient using autograd
877
+ if activation is None:
878
+ dx = dout
879
+ else:
880
+ PreAct_requires_grad = PreAct.requires_grad
881
+ PreAct.requires_grad_(True)
882
+ postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
883
+ dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
884
+ PreAct.requires_grad_(PreAct_requires_grad)
885
+ return dx.to(out_dtype), postact.to(postact_dtype)
886
+
887
+
888
+ def gemm_gated_ref(
889
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
890
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
891
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
892
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
893
+ activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
894
+ cu_seqlens_m: Optional[Tensor] = None,
895
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
896
+ out_dtype: Optional[torch.dtype] = None,
897
+ postact_dtype: Optional[torch.dtype] = None,
898
+ store_preact: bool = True,
899
+ ) -> Tuple[Optional[Tensor], Tensor]:
900
+ """Reference implementation for GEMM with gated activation forward.
901
+
902
+ Args:
903
+ A: (M, K) - input tensor
904
+ B: (K, N) - weight tensor with gate and up projections
905
+ C: (M, N) - optional bias tensor
906
+ activation: Type of gated activation
907
+ out_dtype: Output dtype for preact
908
+ postact_dtype: Output dtype for postact
909
+ store_preact: Whether to return the pre-activation
910
+
911
+ Returns:
912
+ (preact, postact) where:
913
+ - preact: (M, N) pre-activation (if store_preact=True, else None)
914
+ - postact: (M, N // 2) post-activation output
915
+ """
916
+ out_dtype = A.dtype if out_dtype is None else out_dtype
917
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
918
+ if C is None:
919
+ preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
920
+ else:
921
+ preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
922
+ # Split preact into gate and up projections
923
+ gate = preact[..., ::2] # (M, N//2)
924
+ up = preact[..., 1::2] # (M, N//2)
925
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
926
+ return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
927
+
928
+
929
+ def gemm_dgated_ref(
930
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
931
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
932
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
933
+ activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
934
+ cu_seqlens_m: Optional[Tensor] = None,
935
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
936
+ out_dtype: Optional[torch.dtype] = None,
937
+ postact_dtype: Optional[torch.dtype] = None,
938
+ ) -> Tuple[Tensor, Tensor]:
939
+ """Reference implementation for GEMM with gated activation gradient.
940
+
941
+ Args:
942
+ A: (M, K) - dout input tensor
943
+ B: (K, N) - weight tensor
944
+ PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved
945
+ activation: Type of gated activation
946
+ out_dtype: Output dtype for dx
947
+ postact_dtype: Output dtype for postact
948
+
949
+ Returns:
950
+ (dx, postact) where:
951
+ - dx: (M, 2*N) gradient w.r.t. PreAct
952
+ - postact: (M, N) post-activation output
953
+ """
954
+ out_dtype = A.dtype if out_dtype is None else out_dtype
955
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
956
+ dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
957
+ # Split PreAct into gate and up projections
958
+ gate = PreAct[..., ::2] # (M, N)
959
+ up = PreAct[..., 1::2] # (M, N)
960
+ # Use autograd to compute gradients w.r.t. gate and up
961
+ gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
962
+ gate.requires_grad_(True)
963
+ up.requires_grad_(True)
964
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
965
+ dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
966
+ gate.requires_grad_(gate_requires_grad)
967
+ up.requires_grad_(up_requires_grad)
968
+ # Interleave gradients back
969
+ dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
970
+ return dx.to(out_dtype), postact.to(postact_dtype)
971
+
972
+
973
+ @torch.library.custom_op(
974
+ add_quack_op_namespace_prefix("gemm_symmetric_out"),
975
+ mutates_args=("out",),
976
+ device_types="cuda",
977
+ schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? C=None, bool dynamic_scheduler=False, float alpha=1.0, float beta=1.0) -> ()",
978
+ )
979
+ def gemm_symmetric_out(
980
+ A: Tensor, # (M, K) or (L, M, K)
981
+ B: Tensor, # (K, M) or (L, K, M)
982
+ out: Tensor, # (M, M) or (L, M, M)
983
+ C: Optional[Tensor] = None, # (M, M) or (L, M, M)
984
+ dynamic_scheduler: bool = False,
985
+ alpha: float = 1.0,
986
+ beta: float = 1.0,
987
+ ) -> None:
988
+ """GEMM with guaranteed symmetric output."""
989
+ if A.ndim == 2:
990
+ A = A.unsqueeze(0) # (1, M, K)
991
+ B = B.mT # (M, K) or (L, M, K)
992
+ if B.ndim == 2:
993
+ B = B.unsqueeze(0) # (1, M, K)
994
+ if C is not None and C.ndim == 2:
995
+ C = C.unsqueeze(0) # (1, M, M)
996
+ if out.ndim == 2:
997
+ out = out.unsqueeze(0)
998
+ else:
999
+ out = out
1000
+ tile_count_semaphore = (
1001
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
1002
+ )
1003
+ gemm_symmetric_sm90_sm100(
1004
+ A,
1005
+ B,
1006
+ out if out is not None else None,
1007
+ C if C is not None else None,
1008
+ tile_count_semaphore,
1009
+ tile_M=128,
1010
+ tile_N=256,
1011
+ cluster_M=2,
1012
+ cluster_N=1,
1013
+ pingpong=False,
1014
+ persistent=True,
1015
+ max_swizzle_size=8,
1016
+ alpha=alpha,
1017
+ beta=beta,
1018
+ )
1019
+
1020
+
1021
+ def gemm_symmetric(
1022
+ A: Tensor, # (M, K) or (L, M, K)
1023
+ B: Tensor, # (K, M) or (L, K, M)
1024
+ C: Optional[Tensor] = None, # (M, M) or (L, M, M)
1025
+ out: Optional[Tensor] = None, # (M, M) or (L, M, M)
1026
+ out_dtype: Optional[torch.dtype] = None,
1027
+ dynamic_scheduler: bool = False,
1028
+ alpha: float | Tensor = 1.0,
1029
+ beta: float | Tensor = 1.0,
1030
+ ) -> Tuple[Optional[Tensor], Tensor]:
1031
+ """GEMM with symmetric output."""
1032
+ out_dtype = A.dtype if out_dtype is None else out_dtype
1033
+ # Determine output shape based on gather_A
1034
+ if A.ndim == 2:
1035
+ out_shape = (A.shape[0], B.shape[-1])
1036
+ else:
1037
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
1038
+ if out is None:
1039
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
1040
+
1041
+ alpha_val = alpha if isinstance(alpha, float) else 1.0
1042
+ beta_val = beta if isinstance(beta, float) else 1.0
1043
+
1044
+ gemm_symmetric_out(
1045
+ A, B, out, C, dynamic_scheduler=dynamic_scheduler, alpha=alpha_val, beta=beta_val
1046
+ )
1047
+ return out
1048
+
1049
+
1050
+ # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
1051
+ # try:
1052
+ # from torch._inductor.fx_passes.reinplace import InplaceableOp
1053
+ # torch._inductor.fx_passes.reinplace.inplaceable_ops.update({
1054
+ # torch.ops.quack.gemm_add_out.default:
1055
+ # InplaceableOp(torch.ops.quack.gemm_add_inplace.default, mutated_arg=2)
1056
+ # })
1057
+ # except ImportError:
1058
+ # pass
build/torch-cuda/quack/gemm_sm100.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch-cuda/quack/gemm_sm90.py ADDED
@@ -0,0 +1,2070 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on the cute-dsl example:
2
+ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
3
+
4
+ import enum
5
+ from typing import Tuple, Type, Callable, Optional, Union, Literal
6
+ from functools import partial
7
+ import math
8
+
9
+
10
+ import cuda.bindings.driver as cuda
11
+
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+ import cutlass.pipeline as pipeline
15
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
16
+ import cutlass.utils.hopper_helpers as sm90_utils
17
+ from cutlass import Int32, Float32, Float16, Boolean, const_expr
18
+ from cutlass.cutlass_dsl import if_generate
19
+ from cutlass.utils import LayoutEnum
20
+
21
+
22
+ from .cute_dsl_utils import ParamsBase, ArgumentsBase
23
+ from .tile_scheduler import (
24
+ TileSchedulerOptions,
25
+ TileSchedulerArguments,
26
+ TileScheduler,
27
+ VarlenMTileSchedulerArguments,
28
+ VarlenMTileScheduler,
29
+ )
30
+ from .varlen_utils import VarlenArguments, VarlenManager
31
+
32
+ # return PipelineStateWAdvance instead of PipelineState
33
+ from .pipeline import make_pipeline_state, PipelineTmaCpAsync
34
+ from . import copy_utils as copy_utils
35
+ from . import sm90_utils as quack_sm90_utils
36
+
37
+ """
38
+ A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
39
+ using CUTE DSL.
40
+ - Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
41
+ - Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
42
+ - Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
43
+
44
+ This GEMM kernel supports the following features:
45
+ - Utilizes Tensor Memory Access (TMA) for efficient memory operations
46
+ - Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations
47
+ - Implements TMA multicast with cluster to reduce L2 memory traffic
48
+ - Supports multi-stage pipeline to overlap computation and memory access
49
+
50
+ This GEMM works as follows:
51
+ 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
52
+ 2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction.
53
+ 3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations.
54
+
55
+ Hopper WGMMA instructions operate as follows:
56
+ - Read matrix A from SMEM
57
+ - Read matrix B from SMEM
58
+ - Perform MMA operation and store the result in Accumulator(register)
59
+
60
+ Constraints:
61
+ * Supported input data types: fp16, fp8 (e4m3fn, e5m2)
62
+ * For fp16 types, A and B must have the same data type
63
+ * For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
64
+ * Fp8 types only support k-major layout
65
+ * Only fp32 accumulation is supported in this example
66
+ * CTA tile shape M must be 64/128
67
+ * CTA tile shape N must be 64/128/256
68
+ * CTA tile shape K must be 64
69
+ * Cluster shape M/N must be positive and power of 2, total cluster size <= 4
70
+ * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
71
+ i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
72
+ """
73
+
74
+
75
+ class NamedBarrierGemm(enum.IntEnum):
76
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
77
+ # For mainloop load warps to signal that the epilogue load warp can start.
78
+ # This is to avoid loading C too early, interfering with loading A and B.
79
+ EpilogueLoad = enum.auto()
80
+ MmaWG0 = enum.auto()
81
+ MmaWG1 = enum.auto()
82
+ EpiWG0 = enum.auto()
83
+ EpiWG1 = enum.auto()
84
+ TmemPtr = enum.auto()
85
+
86
+
87
+ class GemmSm90:
88
+ """
89
+ This class implements batched matrix multiplication (C = A x B) with support for various data types
90
+ and architectural features specific to Hopper GPUs with persistent tile scheduling and warp specialization.
91
+
92
+ :param acc_dtype: Data type for accumulation during computation
93
+ :type acc_dtype: type[cutlass.Numeric]
94
+ :param tile_shape_mn: Shape of the CTA tile (M,N)
95
+ :type tile_shape_mn: Tuple[int, int, int]
96
+ :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
97
+ :type cluster_shape_mnk: Tuple[int, int, int]
98
+
99
+ :note: Data type requirements:
100
+ - For 16-bit types: A and B must have the same data type
101
+ - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit
102
+ - Float8 types only support k-major layout
103
+
104
+ :note: Supported data types:
105
+ - Float16
106
+ - BFloat16
107
+ - Float8E4M3FN/Float8E5M2
108
+
109
+ :note: Supported accumulation types:
110
+ - Float32 (for all floating point inputs)
111
+
112
+ :note: Constraints:
113
+ - Cluster shape M/N must be positive and power of 2, total cluster size <= 4
114
+
115
+ Example:
116
+ >>> gemm = GemmSm90(
117
+ ... acc_dtype=Float32,
118
+ ... tile_shape_mn=(128, 256),
119
+ ... cluster_shape_mnk=(1, 1, 1)
120
+ ... )
121
+ >>> gemm(a_tensor, b_tensor, c_tensor, stream)
122
+ """
123
+
124
+ arch = 90
125
+ num_epi_tensormaps: int = 0
126
+
127
+ EpilogueArguments = ArgumentsBase
128
+ EpilogueParams = ParamsBase
129
+
130
+ def __init__(
131
+ self,
132
+ acc_dtype: Type[cutlass.Numeric],
133
+ a_dtype: Type[cutlass.Numeric],
134
+ tile_shape_mn: Tuple[int, int],
135
+ cluster_shape_mnk: Tuple[int, int, int],
136
+ pingpong: bool = False,
137
+ is_persistent: bool = True,
138
+ fp8_fast_accum: bool = False,
139
+ gather_A: bool = False,
140
+ ):
141
+ """
142
+ Initializes the configuration for a Hopper dense GEMM kernel.
143
+
144
+ This configuration includes data types for operands, tile shape, cluster configuration,
145
+ and thread layout.
146
+
147
+ :param acc_dtype: Data type for accumulation during computation
148
+ :type acc_dtype: type[cutlass.Numeric]
149
+ :param tile_shape_mn: Shape of the CTA tile (M,N)
150
+ :type tile_shape_mn: Tuple[int, int]
151
+ :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
152
+ :type cluster_shape_mnk: Tuple[int, int, int]
153
+ """
154
+
155
+ self.acc_dtype = acc_dtype
156
+ self.pingpong = pingpong
157
+ self.is_persistent = is_persistent
158
+ if self.pingpong:
159
+ assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
160
+ self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
161
+ self.gather_A = gather_A
162
+ if gather_A:
163
+ assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
164
+
165
+ self.cluster_shape_mnk = cluster_shape_mnk
166
+ # K dimension is deferred in _setup_attributes
167
+ self.cta_tile_shape_mnk = (*tile_shape_mn, 1)
168
+ tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
169
+ # check the cta tile shape
170
+ if not self.pingpong:
171
+ if tile_M not in [64, 128, 192, 256, 320]:
172
+ raise ValueError("CTA tile shape M must be 64/128/192/256/320")
173
+ if tile_M in [192, 320]: # special case
174
+ tile_N_max = 256 if tile_M == 192 else 160
175
+ if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
176
+ raise ValueError(
177
+ f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
178
+ )
179
+ else:
180
+ if not (
181
+ (tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
182
+ ):
183
+ raise ValueError(
184
+ "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
185
+ )
186
+ else:
187
+ if tile_M not in [64, 128, 192]:
188
+ raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
189
+ tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
190
+ if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
191
+ raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
192
+
193
+ if not self.pingpong:
194
+ if tile_M == 320: # tile_M / 64 is not even so we have to split along N
195
+ atom_layout_m, atom_layout_n = 1, 2
196
+ elif tile_M == 192:
197
+ if tile_N <= 128:
198
+ atom_layout_m, atom_layout_n = 3, 1
199
+ else:
200
+ atom_layout_m, atom_layout_n = 1, 2
201
+ else:
202
+ atom_layout_m = (
203
+ self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2
204
+ )
205
+ atom_layout_n = 1
206
+ assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
207
+ else:
208
+ atom_layout_m, atom_layout_n = 1, 1
209
+ self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
210
+
211
+ self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
212
+ if self.gather_A:
213
+ assert self.num_mcast_ctas_a == 1
214
+ self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
215
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
216
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
217
+
218
+ self.occupancy = 1
219
+ self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
220
+ if self.pingpong:
221
+ assert self.mma_warp_groups == 2
222
+ assert self.mma_warp_groups in [1, 2, 3]
223
+ self.num_threads_per_warp_group = 128
224
+ self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
225
+ self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
226
+ self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
227
+ self.num_ab_load_warps = 1 if not self.gather_A else 4
228
+ self.ab_load_warp_id = self.mma_warp_groups * 4
229
+ # self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
230
+ # self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
231
+
232
+ regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
233
+ math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
234
+ )
235
+ if self.fp8_slow_accum:
236
+ regs_per_thread *= 2
237
+ if not self.gather_A:
238
+ if self.mma_warp_groups == 3:
239
+ self.num_regs_load, self.num_regs_mma = 32, 160
240
+ else:
241
+ heavy_register_pressure = regs_per_thread >= 208
242
+ self.num_regs_load, self.num_regs_mma = (
243
+ (40, 232) if not heavy_register_pressure else (24, 240)
244
+ )
245
+ else:
246
+ if self.mma_warp_groups == 3:
247
+ self.num_regs_load, self.num_regs_mma = 56, 152
248
+ else:
249
+ self.num_regs_load, self.num_regs_mma = (56, 224)
250
+
251
+ self.ab_stage = None
252
+ self.epi_stage = None
253
+
254
+ self.a_smem_layout_staged = None
255
+ self.b_smem_layout_staged = None
256
+ self.epi_smem_layout_staged = None
257
+ self.epi_tile = None
258
+
259
+ self.shared_storage = None
260
+ self.buffer_align_bytes = 1024
261
+
262
+ def _setup_attributes(self, epilogue_args: EpilogueArguments):
263
+ """Set up configurations that are dependent on GEMM inputs
264
+
265
+ This method configures various attributes based on the input tensor properties
266
+ (data types, leading dimensions) and kernel settings:
267
+ - Configuring tiled MMA
268
+ - Computing MMA/cluster/tile shapes
269
+ - Computing cluster layout
270
+ - Computing multicast CTAs for A/B
271
+ - Computing epilogue subtile
272
+ - Setting up A/B/C stage counts in shared memory
273
+ - Computing A/B/C shared memory layout
274
+ """
275
+
276
+ self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
277
+ self.a_dtype,
278
+ self.b_dtype,
279
+ self.a_layout.sm90_mma_major_mode(),
280
+ self.b_layout.sm90_mma_major_mode(),
281
+ self.acc_dtype,
282
+ self.atom_layout_mnk,
283
+ tiler_mn=(64, self.cta_tile_shape_mnk[1] // self.atom_layout_mnk[1]),
284
+ )
285
+ if const_expr(self.atom_layout_mnk[1] > 1):
286
+ # If N dimension is split among 2 WGs, we need to permute the N dimension so
287
+ # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
288
+ # containing accumulators that are next to each other in the N dimension.
289
+ # Without permutation WG0 would write to epi smem of size (64, 16) and
290
+ # WG1 would write to a separate epi smem of size (64, 16) that's far away.
291
+ atom_n = self.atom_layout_mnk[1]
292
+ permutation_n = cute.make_ordered_layout(
293
+ (8, self.cta_tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
294
+ )
295
+ self.tiled_mma = cute.make_tiled_mma(
296
+ cute.make_mma_atom(self.tiled_mma.op),
297
+ self.atom_layout_mnk,
298
+ permutation_mnk=(None, permutation_n, None),
299
+ )
300
+ mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
301
+ mma_inst_tile_k = 4
302
+ self.cta_tile_shape_mnk = (
303
+ self.cta_tile_shape_mnk[0],
304
+ self.cta_tile_shape_mnk[1],
305
+ mma_inst_shape_k * mma_inst_tile_k,
306
+ )
307
+
308
+ self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
309
+
310
+ self.epi_tile = self._sm90_compute_tile_shape_or_override(
311
+ self.cta_tile_shape_mnk,
312
+ self.atom_layout_mnk,
313
+ self.d_dtype,
314
+ )
315
+
316
+ # Compute stage before compute smem layout
317
+ self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
318
+ self.cta_tile_shape_mnk,
319
+ self.epi_tile,
320
+ self.a_dtype,
321
+ self.b_dtype,
322
+ self.d_dtype,
323
+ self.c_dtype,
324
+ epilogue_args,
325
+ cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
326
+ self.occupancy,
327
+ # epi_smem will reuse smem ab if not persistent.
328
+ overlap_sD_sA=not self.is_persistent,
329
+ )
330
+ self.sched_stage = 2 if self.pingpong else 1
331
+
332
+ (
333
+ self.a_smem_layout_staged,
334
+ self.b_smem_layout_staged,
335
+ self.epi_smem_layout_staged,
336
+ self.epi_c_smem_layout_staged,
337
+ ) = self._make_smem_layouts(
338
+ self.cta_tile_shape_mnk,
339
+ self.epi_tile,
340
+ self.a_dtype,
341
+ self.a_layout,
342
+ self.b_dtype,
343
+ self.b_layout,
344
+ self.ab_stage,
345
+ self.d_dtype,
346
+ self.d_layout,
347
+ self.epi_stage,
348
+ self.c_dtype,
349
+ self.c_layout,
350
+ self.epi_c_stage,
351
+ )
352
+
353
+ @cute.jit
354
+ def __call__(
355
+ self,
356
+ mA: cute.Tensor,
357
+ mB: cute.Tensor,
358
+ mD: Optional[cute.Tensor],
359
+ mC: Optional[cute.Tensor],
360
+ epilogue_args: ArgumentsBase,
361
+ scheduler_args: TileSchedulerOptions,
362
+ varlen_args: Optional[VarlenArguments],
363
+ stream: cuda.CUstream,
364
+ ):
365
+ """Execute the GEMM operation in steps:
366
+ - Setup static attributes
367
+ - Setup TMA load/store atoms and tensors
368
+ - Compute grid size
369
+ - Define shared storage for kernel
370
+ - Launch the kernel synchronously
371
+
372
+ :param mA: Input tensor A
373
+ :type mA: cute.Tensor
374
+ :param mB: Input tensor B
375
+ :type mB: cute.Tensor
376
+ :param mD: Output tensor D
377
+ :type mD: cute.Tensor
378
+ :param stream: CUDA stream for asynchronous execution
379
+ :type stream: cuda.CUstream
380
+ """
381
+
382
+ # setup static attributes before smem/grid/tma computation
383
+ self.a_dtype = mA.element_type
384
+ self.b_dtype = mB.element_type
385
+ self.d_dtype = mD.element_type if mD is not None else None
386
+ self.c_dtype = mC.element_type if mC is not None else None
387
+ self.a_layout = LayoutEnum.from_tensor(mA)
388
+ self.b_layout = LayoutEnum.from_tensor(mB)
389
+ self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
390
+ self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
391
+
392
+ if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
393
+ raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
394
+ if const_expr(self.a_dtype.width != self.b_dtype.width):
395
+ raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
396
+ if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
397
+ raise TypeError("a_dtype should be float16 or float8")
398
+
399
+ if const_expr(varlen_args is None):
400
+ varlen_args = VarlenArguments()
401
+ assert (varlen_args.mAIdx is not None) == self.gather_A
402
+
403
+ # Assume all strides are divisible by 128 bits except the last stride
404
+ new_stride = lambda t: tuple(
405
+ cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
406
+ for s in t.stride
407
+ )
408
+ mA, mD = [
409
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
410
+ if t is not None
411
+ else None
412
+ for t in (mA, mD)
413
+ ]
414
+
415
+ self._setup_attributes(epilogue_args)
416
+
417
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0))
418
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0))
419
+ tma_atom_a, tma_tensor_a = None, None
420
+ if const_expr(not self.gather_A):
421
+ tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
422
+ mA,
423
+ a_smem_layout,
424
+ (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
425
+ self.cluster_shape_mnk[1],
426
+ )
427
+ tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
428
+ mB,
429
+ b_smem_layout,
430
+ (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
431
+ self.cluster_shape_mnk[0],
432
+ )
433
+
434
+ self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
435
+ if const_expr(not self.gather_A):
436
+ self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
437
+
438
+ tma_atom_d, tma_tensor_d = None, None
439
+ if const_expr(mD is not None):
440
+ tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
441
+ mD,
442
+ self.epi_smem_layout_staged,
443
+ self.epi_tile,
444
+ op_type="store"
445
+ if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
446
+ else "add",
447
+ )
448
+ tma_atom_c, tma_tensor_c = None, None
449
+ if const_expr(mC is not None):
450
+ tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
451
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
452
+ )
453
+
454
+ epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
455
+ varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
456
+
457
+ TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
458
+ tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
459
+ tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
460
+ grid = TileSchedulerCls.get_grid_shape(
461
+ tile_sched_params, scheduler_args.max_active_clusters
462
+ )
463
+
464
+ epi_smem_size = (
465
+ cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
466
+ )
467
+ epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
468
+
469
+ @cute.struct
470
+ class SharedStorage:
471
+ ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
472
+ epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
473
+ sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
474
+ tile_count: cute.struct.MemRange[Int32, self.sched_stage]
475
+ sD: cute.struct.Align[
476
+ cute.struct.MemRange[
477
+ self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
478
+ ],
479
+ self.buffer_align_bytes,
480
+ ]
481
+ sC: cute.struct.Align[
482
+ cute.struct.MemRange[
483
+ self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
484
+ ],
485
+ self.buffer_align_bytes,
486
+ ]
487
+ epi: self.epi_get_smem_struct(epilogue_params)
488
+ sA: cute.struct.Align[
489
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
490
+ self.buffer_align_bytes,
491
+ ]
492
+ sB: cute.struct.Align[
493
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
494
+ self.buffer_align_bytes,
495
+ ]
496
+
497
+ self.shared_storage = SharedStorage
498
+
499
+ # Launch the kernel synchronously
500
+ self.kernel(
501
+ self.tiled_mma,
502
+ tma_atom_a,
503
+ tma_tensor_a if const_expr(not self.gather_A) else mA,
504
+ tma_atom_b,
505
+ tma_tensor_b,
506
+ tma_atom_d,
507
+ tma_tensor_d,
508
+ tma_atom_c,
509
+ tma_tensor_c,
510
+ epilogue_params,
511
+ varlen_params,
512
+ self.cluster_layout_mnk,
513
+ self.a_smem_layout_staged,
514
+ self.b_smem_layout_staged,
515
+ self.epi_smem_layout_staged,
516
+ self.epi_c_smem_layout_staged,
517
+ tile_sched_params,
518
+ TileSchedulerCls,
519
+ ).launch(
520
+ grid=grid,
521
+ block=[self.threads_per_cta, 1, 1],
522
+ cluster=self.cluster_shape_mnk,
523
+ stream=stream,
524
+ min_blocks_per_mp=1,
525
+ )
526
+ return
527
+
528
+ # GPU device kernel
529
+ @cute.kernel
530
+ def kernel(
531
+ self,
532
+ tiled_mma: cute.TiledMma,
533
+ tma_atom_a: Optional[cute.CopyAtom],
534
+ mA_mkl: cute.Tensor,
535
+ tma_atom_b: cute.CopyAtom,
536
+ mB_nkl: cute.Tensor,
537
+ tma_atom_d: Optional[cute.CopyAtom],
538
+ mD_mnl: Optional[cute.Tensor],
539
+ tma_atom_c: Optional[cute.CopyAtom],
540
+ mC_mnl: Optional[cute.Tensor],
541
+ epilogue_params: ParamsBase,
542
+ varlen_params: VarlenManager.Params,
543
+ cluster_layout_mnk: cute.Layout,
544
+ a_smem_layout: cute.ComposedLayout,
545
+ b_smem_layout: cute.ComposedLayout,
546
+ epi_smem_layout: cute.ComposedLayout,
547
+ epi_c_smem_layout: cute.ComposedLayout,
548
+ tile_sched_params: ParamsBase,
549
+ TileSchedulerCls: cutlass.Constexpr[Callable],
550
+ ):
551
+ """
552
+ GPU device kernel performing the batched GEMM computation.
553
+
554
+ :param tma_atom_a: TMA copy atom for A tensor
555
+ :type tma_atom_a: cute.CopyAtom
556
+ :param mA_mkl: Input tensor A
557
+ :type mA_mkl: cute.Tensor
558
+ :param tma_atom_b: TMA copy atom for B tensor
559
+ :type tma_atom_b: cute.CopyAtom
560
+ :param mB_nkl: Input tensor B
561
+ :type mB_nkl: cute.Tensor
562
+ :param tma_atom_d: TMA copy atom for D tensor
563
+ :type tma_atom_d: cute.CopyAtom
564
+ :param mD_mnl: Output tensor D
565
+ :type mD_mnl: cute.Tensor
566
+ :param tiled_mma: Tiled MMA object
567
+ :type tiled_mma: cute.TiledMma
568
+ :param cluster_layout_mnk: CTA layout
569
+ :type cluster_layout_mnk: cute.Layout
570
+ :param a_smem_layout: Shared memory layout for A
571
+ :type a_smem_layout: cute.ComposedLayout
572
+ :param b_smem_layout: Shared memory layout for B
573
+ :type b_smem_layout: cute.ComposedLayout
574
+ :param epi_smem_layout: Shared memory layout for epilogue
575
+ :type epi_smem_layout: cute.ComposedLayout
576
+ """
577
+
578
+ varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
579
+ varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
580
+ assert not (varlen_m and varlen_k)
581
+ if const_expr(self.gather_A):
582
+ assert varlen_m or varlen_k
583
+ has_D = const_expr(mD_mnl is not None)
584
+ has_C = const_expr(mC_mnl is not None)
585
+
586
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
587
+
588
+ # /////////////////////////////////////////////////////////////////////////////
589
+ # Prefetch Tma desc
590
+ # /////////////////////////////////////////////////////////////////////////////
591
+ if warp_idx == self.ab_load_warp_id:
592
+ for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
593
+ if const_expr(tma_atom is not None):
594
+ cpasync.prefetch_descriptor(tma_atom)
595
+
596
+ # /////////////////////////////////////////////////////////////////////////////
597
+ # Alloc and init AB full/empty + ACC full mbar (pipeline)
598
+ # /////////////////////////////////////////////////////////////////////////////
599
+ smem = cutlass.utils.SmemAllocator()
600
+ storage = smem.allocate(self.shared_storage)
601
+
602
+ ab_pipeline = self.make_ab_pipeline(
603
+ tiled_mma=tiled_mma,
604
+ cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
605
+ ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
606
+ )
607
+ epi_pipeline = None
608
+ if const_expr(has_C):
609
+ epi_pipeline = self.make_epi_pipeline(
610
+ c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
611
+ epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
612
+ )
613
+ sched_pipeline = None
614
+ tile_count = None
615
+ if const_expr(tile_sched_params.tile_count_semaphore is not None):
616
+ # Dynamic persistent scheduler
617
+ sched_pipeline = self.make_sched_pipeline(
618
+ cluster_layout_mnk,
619
+ sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
620
+ varlen_k=varlen_k,
621
+ )
622
+ tile_count = storage.tile_count.get_tensor((self.sched_stage,))
623
+
624
+ # ///////////////////////////////////////////////////////////////////////////////
625
+ # Generate smem tensor A/B
626
+ # ///////////////////////////////////////////////////////////////////////////////
627
+ sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
628
+ sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
629
+ sD = None
630
+ if const_expr(has_D):
631
+ if const_expr(not self.is_persistent):
632
+ sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
633
+ sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
634
+ else:
635
+ sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
636
+ sC = None
637
+ if const_expr(has_C):
638
+ sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
639
+ epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
640
+
641
+ varlen_manager = VarlenManager.create(
642
+ varlen_params,
643
+ has_D,
644
+ self.num_epi_tensormaps,
645
+ # Only used if not varlen_m
646
+ len_m_static=Int32(
647
+ mA_mkl.shape[0]
648
+ if varlen_k or varlen_params.mAIdx is None
649
+ else varlen_params.mAIdx.shape[0]
650
+ ),
651
+ len_k_static=Int32(mA_mkl.shape[1]),
652
+ pingpong=self.pingpong,
653
+ warp_idx=warp_idx,
654
+ )
655
+
656
+ TileSchedulerCls = partial(
657
+ TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
658
+ )
659
+
660
+ if warp_idx >= self.ab_load_warp_id:
661
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
662
+ if (
663
+ warp_idx >= self.ab_load_warp_id
664
+ and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
665
+ ):
666
+ is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
667
+ # initialize tensormap for A & B
668
+ varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
669
+ tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
670
+ tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
671
+ # ///////////////////////////////////////////////////////////////////////////////
672
+ # Get mcast mask
673
+ # ///////////////////////////////////////////////////////////////////////////////
674
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
675
+ block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
676
+ a_mcast_mask = cute.make_layout_image_mask(
677
+ cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
678
+ )
679
+ b_mcast_mask = cute.make_layout_image_mask(
680
+ cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0
681
+ )
682
+ a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
683
+ b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
684
+
685
+ # Persistent tile scheduling loop
686
+ is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
687
+ if const_expr(cute.size(cluster_layout_mnk) > 1):
688
+ is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
689
+ tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
690
+ work_tile = tile_scheduler.initial_work_tile_info()
691
+ ab_producer_state = make_pipeline_state(
692
+ pipeline.PipelineUserType.Producer, self.ab_stage
693
+ )
694
+ if const_expr(varlen_k):
695
+ # wait tensormap initialization complete before update
696
+ varlen_manager.fence_tensormap_init()
697
+ while work_tile.is_valid_tile:
698
+ tile_coord_mnkl = work_tile.tile_idx
699
+ batch_idx = tile_coord_mnkl[3]
700
+ varlen_manager.update_tensormap_AB(
701
+ batch_idx,
702
+ self.a_layout,
703
+ self.b_layout,
704
+ is_tma_warp,
705
+ )
706
+ # ///////////////////////////////////////////////////////////////////////////
707
+ # Local_tile partition global tensors
708
+ # ///////////////////////////////////////////////////////////////////////////
709
+ if const_expr(not self.gather_A):
710
+ mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
711
+ # (bM, bK, RestK)
712
+ gA_mk = cute.local_tile(
713
+ mA_mk,
714
+ cute.select(self.cta_tile_shape_mnk, [0, 2]),
715
+ (tile_coord_mnkl[0], None),
716
+ )
717
+ else:
718
+ mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
719
+ if const_expr(varlen_m):
720
+ gAIdx = cute.local_tile(
721
+ mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
722
+ )
723
+ # (M, K)
724
+ mA_mk = mA_mkl
725
+ else:
726
+ assert varlen_k
727
+ # (tile_K, RestK)
728
+ gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
729
+ # (tile_M, K)
730
+ mA_mk = cute.local_tile(
731
+ mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
732
+ )
733
+ # (bN, bK, RestK)
734
+ gB_nk = cute.local_tile(
735
+ varlen_manager.offset_batch_B(mB_nkl, batch_idx),
736
+ cute.select(self.cta_tile_shape_mnk, [1, 2]),
737
+ (tile_coord_mnkl[1], None),
738
+ )
739
+ # //////////////////////////////////////////////////////////////////////////
740
+ # Partition shared tensor for TMA load A/B
741
+ # //////////////////////////////////////////////////////////////////////////
742
+ varlen_manager.fence_tensormap_update_AB(is_tma_warp)
743
+ len_m = varlen_manager.len_m(batch_idx)
744
+ len_k = varlen_manager.len_k(batch_idx)
745
+ # TMA load A partition_S/D
746
+ copy_A = None
747
+ if const_expr(not self.gather_A):
748
+ copy_A, _, _ = copy_utils.tma_get_copy_fn(
749
+ tma_atom_a,
750
+ cta_coord=block_in_cluster_coord_mnk[1],
751
+ cta_layout=cute.make_layout(
752
+ cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
753
+ ),
754
+ src_tensor=gA_mk,
755
+ dst_tensor=sA,
756
+ mcast_mask=a_mcast_mask,
757
+ tma_desc_ptr=tma_desc_a_ptr,
758
+ )
759
+ else:
760
+ tiled_copy_A = self._make_gmem_tiled_copy_A(
761
+ mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
762
+ )
763
+ tidx = (
764
+ cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
765
+ )
766
+ thr_copy_A = tiled_copy_A.get_slice(tidx)
767
+ copy_A, prefetch_A = None, None
768
+ if const_expr(varlen_m):
769
+ copy_A = copy_utils.gather_m_get_copy_fn(
770
+ thr_copy_A,
771
+ mA_mk,
772
+ sA,
773
+ gAIdx,
774
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
775
+ limit_k=len_k,
776
+ )
777
+ else:
778
+ copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
779
+ thr_copy_A,
780
+ mA_mk,
781
+ sA,
782
+ gAIdx,
783
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
784
+ limit_k=len_k,
785
+ )
786
+ # TMA load B partition_S/D
787
+ copy_B, _, _ = copy_utils.tma_get_copy_fn(
788
+ tma_atom_b,
789
+ cta_coord=block_in_cluster_coord_mnk[0],
790
+ cta_layout=cute.make_layout(
791
+ cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
792
+ ),
793
+ src_tensor=gB_nk,
794
+ dst_tensor=sB,
795
+ mcast_mask=b_mcast_mask,
796
+ tma_desc_ptr=tma_desc_b_ptr,
797
+ )
798
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
799
+ if const_expr(not self.gather_A):
800
+ ab_producer_state = self.load_AB(
801
+ ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
802
+ )
803
+ else:
804
+ ab_producer_state = self.load_AB_gather_A(
805
+ ab_pipeline,
806
+ ab_producer_state,
807
+ copy_A,
808
+ prefetch_A,
809
+ copy_B,
810
+ k_tile_cnt,
811
+ varlen_m=varlen_m,
812
+ )
813
+ tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
814
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
815
+ work_tile = tile_scheduler.get_current_work()
816
+ # End of persistent scheduler loop
817
+ if const_expr(self.pingpong and not varlen_k):
818
+ # Need to write the tile_idx to smem for the next WG in the pingpong mode
819
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
820
+ ab_pipeline.producer_tail(ab_producer_state)
821
+ if is_scheduler_warp:
822
+ tile_scheduler.producer_tail()
823
+
824
+ if warp_idx < self.ab_load_warp_id:
825
+ cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
826
+ is_tma_warp = Boolean(
827
+ (not self.pingpong and warp_idx == 0)
828
+ or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
829
+ )
830
+ varlen_manager.init_tensormap_epi(
831
+ tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
832
+ )
833
+ tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
834
+ tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
835
+ # //////////////////////////////////////////////////////////////////////////////
836
+ # Partition global tensor for TiledMMA_A/B/C
837
+ # //////////////////////////////////////////////////////////////////////////////
838
+ tidx, _, _ = cute.arch.thread_idx()
839
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
840
+ if const_expr(self.pingpong):
841
+ tidx = tidx % self.num_threads_per_warp_group
842
+ warp_group_thread_layout = cute.make_layout(
843
+ self.mma_warp_groups if not self.pingpong else 1,
844
+ stride=self.num_threads_per_warp_group,
845
+ )
846
+ thr_mma = tiled_mma.get_slice(
847
+ warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
848
+ )
849
+
850
+ # //////////////////////////////////////////////////////////////////////////////
851
+ # Make fragments
852
+ # //////////////////////////////////////////////////////////////////////////////
853
+ tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
854
+ tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
855
+
856
+ acc_shape = tiled_mma.partition_shape_C(
857
+ cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
858
+ )
859
+ acc = cute.make_fragment(acc_shape, self.acc_dtype)
860
+ acc_slow = None
861
+ if const_expr(self.fp8_slow_accum):
862
+ acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
863
+
864
+ if const_expr(self.pingpong):
865
+ if warp_group_idx == 0:
866
+ # WG0 needs a start signal at the very beginning
867
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
868
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
869
+
870
+ k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2])
871
+ c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
872
+
873
+ ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
874
+ epi_store_pipeline = self.make_epi_store_pipeline()
875
+ epi_read_state = make_pipeline_state(
876
+ pipeline.PipelineUserType.Consumer, self.epi_c_stage
877
+ )
878
+ epi_producer_state = make_pipeline_state(
879
+ pipeline.PipelineUserType.Producer, self.epi_c_stage
880
+ )
881
+ tile_scheduler = TileSchedulerCls()
882
+ work_tile = None
883
+ if const_expr(self.pingpong):
884
+ if const_expr(varlen_k):
885
+ work_tile = tile_scheduler.initial_work_tile_info()
886
+ if warp_idx >= 4:
887
+ # Advance 2nd Math WG pipeline states to the end of 1st Math WG
888
+ epi_read_state.advance_iters(c_tile_cnt)
889
+ epi_producer_state.advance_iters(c_tile_cnt)
890
+ if const_expr(not varlen_k):
891
+ ab_read_state.advance_iters(k_tile_cnt_static)
892
+ else:
893
+ len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
894
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
895
+ ab_read_state.advance_iters(k_tile_cnt)
896
+ tile_scheduler.advance_to_next_work()
897
+ if const_expr(varlen_k):
898
+ work_tile = tile_scheduler.get_current_work()
899
+ if const_expr(not varlen_k):
900
+ work_tile = tile_scheduler.initial_work_tile_info()
901
+ else:
902
+ work_tile = tile_scheduler.initial_work_tile_info()
903
+ if const_expr(varlen_m):
904
+ # wait tensormap initialization complete before update
905
+ varlen_manager.fence_tensormap_init()
906
+ while work_tile.is_valid_tile:
907
+ tile_coord_mnkl = work_tile.tile_idx
908
+ batch_idx = tile_coord_mnkl[3]
909
+ epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
910
+ epilogue_params, varlen_params.cu_seqlens_m, batch_idx
911
+ )
912
+ varlen_manager.update_tensormap_epi(
913
+ batch_idx,
914
+ self.d_layout,
915
+ epi_shapes,
916
+ epi_orders,
917
+ is_tma_warp,
918
+ )
919
+ len_k = varlen_manager.len_k(batch_idx)
920
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
921
+ ab_read_state, tiled_mma = self.mma(
922
+ ab_pipeline,
923
+ ab_read_state,
924
+ tiled_mma,
925
+ tCrA,
926
+ tCrB,
927
+ acc,
928
+ acc_slow,
929
+ k_tile_cnt,
930
+ warp_group_idx,
931
+ )
932
+ if const_expr(varlen_k):
933
+ if k_tile_cnt == 0:
934
+ acc.fill(0.0)
935
+
936
+ # /////////////////////////////////////////////////////////////////////////////
937
+ # EPILOGUE
938
+ # /////////////////////////////////////////////////////////////////////////////
939
+ if const_expr(self.pingpong):
940
+ self.pingpong_barrier_sync(warp_group_idx, "epi")
941
+
942
+ epilogue_barrier = pipeline.NamedBarrier(
943
+ barrier_id=int(NamedBarrierGemm.Epilogue),
944
+ num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
945
+ )
946
+
947
+ varlen_manager.fence_tensormap_update_epi(is_tma_warp)
948
+
949
+ copy_D = None
950
+ if const_expr(has_D):
951
+ copy_D, _, _ = self.epilog_gmem_copy_and_partition(
952
+ tma_atom_d,
953
+ varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
954
+ self.cta_tile_shape_mnk[:2],
955
+ self.epi_tile,
956
+ sD,
957
+ tile_coord_mnkl,
958
+ tma_desc_ptr=tma_desc_d_ptr,
959
+ )
960
+ copy_C = None
961
+ if const_expr(has_C):
962
+ copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
963
+ tma_atom_c,
964
+ varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
965
+ self.cta_tile_shape_mnk[:2],
966
+ self.epi_tile,
967
+ sC,
968
+ tile_coord_mnkl,
969
+ )
970
+ copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
971
+
972
+ d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
973
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
974
+ tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
975
+ )
976
+ # (R2S, R2S_M, R2S_N)
977
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
978
+ load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
979
+ if const_expr(has_C):
980
+ tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
981
+ tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
982
+ )
983
+ else:
984
+ tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
985
+
986
+ # Wait for all warp groups in the thread block to finish, because smem for tensor
987
+ # A in the mainloop is reused in the epilogue if not persistent.
988
+ if const_expr(not self.is_persistent):
989
+ epilogue_barrier.arrive_and_wait()
990
+
991
+ self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
992
+
993
+ epi_read_state, epi_producer_state = self.epilogue(
994
+ epilogue_params,
995
+ epi_smem_tensors,
996
+ tma_desc_epi_ptrs,
997
+ epi_pipeline,
998
+ epi_store_pipeline,
999
+ epi_read_state,
1000
+ epi_producer_state,
1001
+ self.epi_tile,
1002
+ load_acc_subtile,
1003
+ tRS_rD,
1004
+ tRS_rC,
1005
+ None, # tiled_copy_t2r, for Sm100 only
1006
+ tiled_copy_r2s,
1007
+ tRS_sD,
1008
+ tiled_copy_s2r,
1009
+ tSR_rC,
1010
+ tSR_sC,
1011
+ copy_D,
1012
+ copy_C,
1013
+ tile_coord_mnkl,
1014
+ varlen_manager,
1015
+ epilogue_barrier,
1016
+ tile_scheduler,
1017
+ tidx,
1018
+ is_tma_warp,
1019
+ )
1020
+
1021
+ if const_expr(self.pingpong):
1022
+ # With pingpong, 2 WGs write two different output tiles to the same smem,
1023
+ # so we have to make sure the smem content is done reading before signaling
1024
+ # the next WG's epilogue.
1025
+ if is_tma_warp:
1026
+ epi_store_pipeline.producer_tail()
1027
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
1028
+
1029
+ if const_expr(not self.pingpong):
1030
+ tile_scheduler.advance_to_next_work()
1031
+ work_tile = tile_scheduler.get_current_work()
1032
+ else: # Skip a tile for pingpong
1033
+ # Update starting load/store pipeline states for the next tile
1034
+ epi_read_state.advance_iters(c_tile_cnt)
1035
+ epi_producer_state.advance_iters(c_tile_cnt)
1036
+ # Update starting mainloop pipeline state for the next tile
1037
+ if const_expr(not varlen_k):
1038
+ ab_read_state.advance_iters(k_tile_cnt_static)
1039
+ tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
1040
+ work_tile = tile_scheduler.get_current_work()
1041
+ else:
1042
+ tile_scheduler.advance_to_next_work()
1043
+ work_tile = tile_scheduler.get_current_work()
1044
+ if work_tile.is_valid_tile:
1045
+ len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
1046
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1047
+ ab_read_state.advance_iters(k_tile_cnt)
1048
+ tile_scheduler.advance_to_next_work()
1049
+ work_tile = tile_scheduler.get_current_work()
1050
+ # End of persistent scheduler loop
1051
+
1052
+ # Wait for D store complete
1053
+ if const_expr(not self.pingpong):
1054
+ if is_tma_warp:
1055
+ epi_store_pipeline.producer_tail()
1056
+
1057
+ @cute.jit
1058
+ def load_AB(
1059
+ self,
1060
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1061
+ ab_producer_state: cutlass.pipeline.PipelineState,
1062
+ copy_A: Optional[Callable],
1063
+ copy_B: Callable,
1064
+ k_tile_cnt: Int32,
1065
+ # These are for Sm100 blockscaled gemm
1066
+ copy_SFA: Optional[Callable] = None,
1067
+ copy_SFB: Optional[Callable] = None,
1068
+ ) -> cutlass.pipeline.PipelineState:
1069
+ blockscaled = const_expr(copy_SFA is not None)
1070
+ if const_expr(blockscaled):
1071
+ assert copy_SFB is not None
1072
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1073
+ peek_ab_empty_status = Boolean(True)
1074
+ if 0 < k_tile_cnt:
1075
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1076
+ # /////////////////////////////////////////////////////////////////////////
1077
+ # TMA load
1078
+ # /////////////////////////////////////////////////////////////////////////
1079
+ for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1080
+ # Wait for A/B buffers to be empty before loading into them
1081
+ # Also sets the transaction barrier for the A/B buffers
1082
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1083
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1084
+ smem_idx = ab_producer_state.index
1085
+ if const_expr(copy_A is not None):
1086
+ copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1087
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1088
+ if const_expr(blockscaled):
1089
+ copy_SFA(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1090
+ copy_SFB(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1091
+ # Mainloop pipeline's producer commit is a NOP
1092
+ ab_pipeline.producer_commit(ab_producer_state)
1093
+ ab_producer_state.advance()
1094
+ peek_ab_empty_status = Boolean(True)
1095
+ if k_tile + 1 < k_tile_cnt:
1096
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1097
+ return ab_producer_state
1098
+
1099
+ @cute.jit
1100
+ def load_AB_gather_A(
1101
+ self,
1102
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1103
+ ab_producer_state: cutlass.pipeline.PipelineState,
1104
+ copy_A: Callable,
1105
+ prefetch_A: Optional[Callable],
1106
+ copy_B: Callable,
1107
+ k_tile_cnt: Int32,
1108
+ varlen_m: bool = True,
1109
+ ) -> cutlass.pipeline.PipelineState:
1110
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1111
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1112
+ peek_ab_empty_status = Boolean(True)
1113
+ if 0 < k_tile_cnt:
1114
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1115
+ # /////////////////////////////////////////////////////////////////////////
1116
+ # TMA load on B and cp.async on A
1117
+ # /////////////////////////////////////////////////////////////////////////
1118
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1119
+ prefetch_out = ()
1120
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1121
+ prefetch_out = (prefetch_A(k_tile),)
1122
+ # Wait for A/B buffers to be empty before loading into them
1123
+ # Also sets the transaction barrier for the A/B buffers
1124
+ # A tiny bit faster to rotate the warp that does TMA
1125
+ # However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
1126
+ # since that's the warp that does the tensormap update.
1127
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
1128
+ (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1129
+ )
1130
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1131
+ smem_idx = ab_producer_state.index
1132
+ # A bit faster to load B first while we calculate the indices for A
1133
+ if is_tma_warp:
1134
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1135
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1136
+ copy_A(k_tile, smem_idx, *prefetch_out)
1137
+ # This tells mbarrier to track the completion of cp.async
1138
+ ab_pipeline.producer_cpasync_commit(ab_producer_state)
1139
+ ab_producer_state.advance()
1140
+ peek_ab_empty_status = Boolean(True)
1141
+ if k_tile + 1 < k_tile_cnt:
1142
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1143
+ # bound checking in the K dimension on the last k_tile
1144
+ if 0 < k_tile_cnt:
1145
+ k_tile = k_tile_cnt - 1
1146
+ prefetch_out = ()
1147
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1148
+ prefetch_out = (prefetch_A(k_tile, pred=True),)
1149
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
1150
+ (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1151
+ )
1152
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1153
+ smem_idx = ab_producer_state.index
1154
+ if is_tma_warp:
1155
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1156
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1157
+ copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
1158
+ ab_pipeline.producer_cpasync_commit(ab_producer_state)
1159
+ ab_producer_state.advance()
1160
+ return ab_producer_state
1161
+
1162
+ @cute.jit
1163
+ def mma(
1164
+ self,
1165
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1166
+ ab_read_state: cutlass.pipeline.PipelineState,
1167
+ tiled_mma: cute.TiledMma,
1168
+ tCrA: cute.Tensor,
1169
+ tCrB: cute.Tensor,
1170
+ acc: cute.Tensor,
1171
+ acc_slow: Optional[cute.Tensor],
1172
+ k_tile_cnt: Int32,
1173
+ warp_group_idx: Int32,
1174
+ ) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
1175
+ # /////////////////////////////////////////////////////////////////////////////
1176
+ # Prologue MMAs
1177
+ # /////////////////////////////////////////////////////////////////////////////
1178
+ k_pipe_mmas = 1
1179
+ ab_release_state = ab_read_state.clone()
1180
+ num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
1181
+ if const_expr(self.pingpong):
1182
+ self.pingpong_barrier_sync(warp_group_idx, stage="mma")
1183
+ peek_ab_full_status = Boolean(True)
1184
+ if 0 < k_tile_cnt:
1185
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1186
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1187
+ num_k_blocks = cute.size(tCrA, mode=[2])
1188
+ for k_tile in cutlass.range(num_prologue_mma):
1189
+ # Wait for A/B buffer to be ready
1190
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1191
+ warpgroup.fence()
1192
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1193
+ k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1194
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1195
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1196
+ warpgroup.commit_group()
1197
+ ab_read_state.advance()
1198
+ peek_ab_full_status = Boolean(True)
1199
+ if k_tile + 1 < k_tile_cnt:
1200
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1201
+ # If k_tile_cnt == 0, this is not correct. But we will set acc to 0 in the mainloop
1202
+ # in that case.
1203
+ if const_expr(self.fp8_slow_accum):
1204
+ warpgroup.wait_group(0)
1205
+ acc_slow.store(acc.load())
1206
+
1207
+ # /////////////////////////////////////////////////////////////////////////////
1208
+ # MAINLOOP
1209
+ # /////////////////////////////////////////////////////////////////////////////
1210
+ for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
1211
+ # Wait for TMA copies to complete
1212
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1213
+ # WGMMA
1214
+ warpgroup.fence()
1215
+ if const_expr(self.fp8_slow_accum):
1216
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1217
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1218
+ k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1219
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1220
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1221
+ warpgroup.commit_group()
1222
+ # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
1223
+ if const_expr(not self.fp8_slow_accum):
1224
+ warpgroup.wait_group(k_pipe_mmas)
1225
+ else:
1226
+ warpgroup.wait_group(0)
1227
+ acc_slow.store(acc_slow.load() + acc.load())
1228
+ ab_pipeline.consumer_release(ab_release_state)
1229
+ ab_read_state.advance()
1230
+ ab_release_state.advance()
1231
+ peek_ab_full_status = Boolean(True)
1232
+ if k_tile + 1 < k_tile_cnt:
1233
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1234
+ if const_expr(self.pingpong):
1235
+ # Cue for next WG's MMA to start
1236
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
1237
+ if const_expr(not self.fp8_slow_accum):
1238
+ # fp8_slow_accum would already called wait_group(0) inside the loop
1239
+ warpgroup.wait_group(0)
1240
+ for k_tile in cutlass.range(num_prologue_mma, unroll=1):
1241
+ ab_pipeline.consumer_release(ab_release_state)
1242
+ ab_release_state.advance()
1243
+ if const_expr(self.fp8_slow_accum):
1244
+ acc.store(acc_slow.load())
1245
+ # If we don't return the tiled_mma, we get compiler error
1246
+ # "operand #0 does not dominate this use"
1247
+ return ab_read_state, tiled_mma
1248
+
1249
+ @cute.jit
1250
+ def epilogue(
1251
+ self,
1252
+ params: EpilogueParams,
1253
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
1254
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
1255
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
1256
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
1257
+ epi_read_state: cutlass.pipeline.PipelineState,
1258
+ epi_producer_state: Optional[cutlass.pipeline.PipelineState],
1259
+ epi_tile: cute.Tile,
1260
+ load_acc_subtile: Callable,
1261
+ tRS_rD: cute.Tensor,
1262
+ tRS_rC: Optional[cute.Tensor],
1263
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
1264
+ tiled_copy_r2s: cute.TiledCopy,
1265
+ tRS_sD: cute.Tensor,
1266
+ tiled_copy_s2r: Optional[cute.ThrCopy],
1267
+ tSR_rC: Optional[cute.Tensor],
1268
+ tSR_sC: Optional[cute.Tensor],
1269
+ copy_D: Optional[Callable],
1270
+ copy_C: Optional[Callable],
1271
+ tile_coord_mnkl: cute.Coord,
1272
+ varlen_manager: VarlenManager,
1273
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
1274
+ tile_scheduler,
1275
+ tidx: Int32,
1276
+ is_tma_warp: Boolean,
1277
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
1278
+ has_C = const_expr(tRS_rC is not None)
1279
+ has_D = const_expr(copy_D is not None)
1280
+ epi_tile_shape = cute.zipped_divide(
1281
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
1282
+ ).shape[1]
1283
+ # We iterate over epi tiles in the N dimension first before the M dimension
1284
+ epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0))
1285
+ epi_tile_num = cute.size(epi_tile_shape)
1286
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1287
+
1288
+ epi_tensors = self.epi_begin(
1289
+ params,
1290
+ epi_smem_tensors,
1291
+ epi_tile,
1292
+ tiled_copy_t2r,
1293
+ tiled_copy_r2s,
1294
+ tile_coord_mnkl,
1295
+ varlen_manager,
1296
+ epilogue_barrier,
1297
+ tidx,
1298
+ )
1299
+
1300
+ if const_expr(copy_C is not None):
1301
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
1302
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
1303
+ if is_tma_warp:
1304
+ epi_pipeline.producer_acquire(epi_producer_state)
1305
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
1306
+ epi_pipeline.producer_commit(epi_producer_state)
1307
+ epi_producer_state.advance()
1308
+
1309
+ def tma_store_fn(src_idx, dst_idx):
1310
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1311
+ cute.arch.fence_proxy(
1312
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1313
+ )
1314
+ epilogue_barrier.arrive_and_wait()
1315
+ # Copy from shared memory to global memory
1316
+ if is_tma_warp:
1317
+ if const_expr(has_D):
1318
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
1319
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
1320
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
1321
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
1322
+ epilogue_barrier.arrive_and_wait()
1323
+
1324
+ # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops
1325
+ # with the TMA store. However, currently this doesn't seem to improve perf.
1326
+ delay_tma_store = False
1327
+
1328
+ src_idx_prev, dst_idx_prev = None, None
1329
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
1330
+ # The global memory coordinate for the current epi tile
1331
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1332
+ # Copy from acc to D registers
1333
+ load_acc_subtile(tRS_rD, epi_idx)
1334
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
1335
+ if const_expr(has_C):
1336
+ epi_pipeline.consumer_wait(epi_read_state)
1337
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
1338
+ # Fence to make sure shared memory read is visible to TMA load
1339
+ cute.arch.fence_proxy(
1340
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1341
+ )
1342
+ cute.arch.sync_warp()
1343
+ with cute.arch.elect_one():
1344
+ epi_pipeline.consumer_release(epi_read_state)
1345
+ epi_read_state.advance()
1346
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
1347
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
1348
+ if is_tma_warp:
1349
+ epi_pipeline.producer_acquire(epi_producer_state)
1350
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
1351
+ epi_pipeline.producer_commit(epi_producer_state)
1352
+ epi_producer_state.advance()
1353
+ tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
1354
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
1355
+ if const_expr(delay_tma_store):
1356
+ if const_expr(epi_idx > 0):
1357
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
1358
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
1359
+ # Copy from D registers to shared memory
1360
+ if const_expr(has_D):
1361
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
1362
+ if const_expr(not delay_tma_store):
1363
+ tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
1364
+
1365
+ if const_expr(delay_tma_store):
1366
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
1367
+
1368
+ self.epi_end(
1369
+ params,
1370
+ epi_tensors,
1371
+ epi_tile,
1372
+ tiled_copy_t2r,
1373
+ tiled_copy_r2s,
1374
+ tile_coord_mnkl,
1375
+ varlen_manager,
1376
+ tidx,
1377
+ )
1378
+
1379
+ return epi_read_state, epi_producer_state
1380
+
1381
+ def get_scheduler_class(self, varlen_m: bool = False):
1382
+ """Return the scheduler class to use. Override in subclasses for custom schedulers."""
1383
+ return TileScheduler if not varlen_m else VarlenMTileScheduler
1384
+
1385
+ def get_scheduler_arguments(
1386
+ self,
1387
+ mA: cute.Tensor,
1388
+ mB: cute.Tensor,
1389
+ mD: Optional[cute.Tensor],
1390
+ scheduler_args,
1391
+ varlen_args,
1392
+ ):
1393
+ """Create scheduler arguments. Override in subclasses for custom schedulers."""
1394
+ if const_expr(varlen_args.mCuSeqlensM is None):
1395
+ num_problems = (
1396
+ mD.shape[2]
1397
+ if mD is not None
1398
+ else (
1399
+ mB.shape[2]
1400
+ if varlen_args.mCuSeqlensK is None
1401
+ else varlen_args.mCuSeqlensK.shape[0] - 1
1402
+ )
1403
+ )
1404
+ problem_shape_ntile_mnl = (
1405
+ cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]),
1406
+ cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
1407
+ num_problems,
1408
+ )
1409
+ tile_sched_args = TileSchedulerArguments(
1410
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1411
+ raster_order=scheduler_args.raster_order,
1412
+ group_size=scheduler_args.max_swizzle_size,
1413
+ cluster_shape_mnk=self.cluster_shape_mnk,
1414
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
1415
+ batch_idx_permute=scheduler_args.batch_idx_permute,
1416
+ is_persistent=self.is_persistent,
1417
+ )
1418
+ else:
1419
+ assert mD is not None or not self.gather_A
1420
+ problem_shape_ntile_mnl = (
1421
+ None,
1422
+ cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
1423
+ varlen_args.mCuSeqlensM.shape[0] - 1,
1424
+ )
1425
+ tile_sched_args = VarlenMTileSchedulerArguments(
1426
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1427
+ total_m=mD.shape[0] if mD is not None else varlen_args.mAIdx.shape[0],
1428
+ cu_seqlens_m=varlen_args.mCuSeqlensM,
1429
+ raster_order=scheduler_args.raster_order,
1430
+ group_size=scheduler_args.max_swizzle_size,
1431
+ tile_shape_mn=self.cta_tile_shape_mnk[:2],
1432
+ cluster_shape_mnk=self.cluster_shape_mnk,
1433
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
1434
+ is_persistent=self.is_persistent,
1435
+ )
1436
+ return tile_sched_args
1437
+
1438
+ @cute.jit
1439
+ def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
1440
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1441
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1442
+
1443
+ @cute.jit
1444
+ def epi_begin(
1445
+ self,
1446
+ params: EpilogueParams,
1447
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
1448
+ epi_tile: cute.Tile,
1449
+ tiled_copy_t2r: Optional[cute.TiledCopy],
1450
+ tiled_copy_r2s: cute.TiledCopy,
1451
+ tile_coord_mnkl: cute.Coord,
1452
+ varlen_manager: VarlenManager,
1453
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
1454
+ tidx: Int32,
1455
+ ) -> Tuple[cute.Tensor, ...]:
1456
+ return ()
1457
+
1458
+ def epi_begin_loop(
1459
+ self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord
1460
+ ) -> Tuple[cute.Tensor, ...]:
1461
+ return ()
1462
+
1463
+ def epi_visit_subtile(
1464
+ self,
1465
+ params: EpilogueParams,
1466
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
1467
+ tRS_rD: cute.Tensor,
1468
+ tRS_rC: Optional[cute.Tensor] = None,
1469
+ ) -> Optional[cute.Tensor]:
1470
+ return None
1471
+
1472
+ def epi_visit_acc(
1473
+ self,
1474
+ params: EpilogueParams,
1475
+ acc: cute.Tensor,
1476
+ tiled_mma: cute.TiledMma,
1477
+ tile_coord_mnkl: cute.Coord,
1478
+ tidx: Int32,
1479
+ ) -> None:
1480
+ pass
1481
+
1482
+ @cute.jit
1483
+ def epi_end(
1484
+ self,
1485
+ params: EpilogueParams,
1486
+ epi_tensors: Tuple[cute.Tensor, ...],
1487
+ epi_tile: cute.Tile,
1488
+ tiled_copy_t2r: Optional[cute.TiledCopy],
1489
+ tiled_copy_r2s: cute.TiledCopy,
1490
+ tile_coord_mnkl: cute.Coord,
1491
+ varlen_manager,
1492
+ tidx,
1493
+ ) -> None:
1494
+ pass
1495
+
1496
+ def epi_to_underlying_arguments(
1497
+ self, args: EpilogueArguments, *, loc=None, ip=None
1498
+ ) -> EpilogueParams:
1499
+ return self.EpilogueParams()
1500
+
1501
+ def epi_get_tma_atoms(
1502
+ self, params: EpilogueParams, *, loc=None, ip=None
1503
+ ) -> list[cute.CopyAtom]:
1504
+ """Subclasses can override this"""
1505
+ return []
1506
+
1507
+ def epi_get_tensormap_update_shapes_orders(
1508
+ self,
1509
+ params: EpilogueParams,
1510
+ cu_seqlens_m: cute.Tensor,
1511
+ batch_idx: Int32,
1512
+ *,
1513
+ loc=None,
1514
+ ip=None,
1515
+ ) -> tuple[list[Int32], list[int]]:
1516
+ """Subclasses can override this"""
1517
+ return [], []
1518
+
1519
+ @staticmethod
1520
+ def epi_smem_bytes_per_stage(
1521
+ args: Optional[EpilogueArguments],
1522
+ cta_tile_shape_mnk: Tuple[int, int, int],
1523
+ epi_tile: cute.Tile,
1524
+ ) -> int:
1525
+ return 0
1526
+
1527
+ def epi_get_smem_struct(self, params: EpilogueParams):
1528
+ return cute.struct.MemRange[Int32, 0] # Dummy struct
1529
+
1530
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
1531
+ return tuple()
1532
+
1533
+ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
1534
+ assert stage in ["mma", "epi"]
1535
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1536
+ cute.arch.barrier(
1537
+ barrier_id=int(barrier) + warp_group_idx,
1538
+ number_of_threads=2 * self.num_threads_per_warp_group,
1539
+ )
1540
+
1541
+ def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
1542
+ assert stage in ["mma", "epi"]
1543
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1544
+ cute.arch.barrier_arrive(
1545
+ barrier_id=int(barrier) + warp_group_idx,
1546
+ number_of_threads=2 * self.num_threads_per_warp_group,
1547
+ )
1548
+
1549
+ def epilog_smem_copy_atom(self, tiled_mma: cute.TiledMma) -> cute.TiledCopy:
1550
+ copy_atom_C = cute.make_copy_atom(
1551
+ warp.StMatrix8x8x16bOp(
1552
+ self.d_layout.is_m_major_c() if self.d_layout is not None else False,
1553
+ num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
1554
+ ),
1555
+ Float16, # this is just to get the right source layout
1556
+ )
1557
+ tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
1558
+ return tiled_copy_C_atom
1559
+
1560
+ def epilog_smem_store_and_partition(
1561
+ self,
1562
+ tiled_mma: cute.TiledMma,
1563
+ d_layout: Optional[LayoutEnum],
1564
+ dtype: Type[cutlass.Numeric],
1565
+ sD: Optional[cute.Tensor],
1566
+ tidx: Int32,
1567
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1568
+ if d_layout is None:
1569
+ d_layout = LayoutEnum.ROW_MAJOR
1570
+ tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
1571
+ # Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
1572
+ # get st.matrix with num_matrices=4
1573
+ copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
1574
+ d_layout, elem_ty_d=dtype, elem_ty_acc=self.acc_dtype
1575
+ )
1576
+ tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
1577
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1578
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1579
+ tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1580
+ sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
1581
+ tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
1582
+ tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
1583
+ return tiled_copy_r2s, tRS_rD, tRS_sD
1584
+
1585
+ def epilog_smem_load_and_partition(
1586
+ self,
1587
+ tiled_mma: cute.TiledMma,
1588
+ c_layout: LayoutEnum,
1589
+ dtype: Type[cutlass.Numeric],
1590
+ sC: cute.Tensor,
1591
+ tRS_rD_layout: cutlass.Layout,
1592
+ tidx: Int32,
1593
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1594
+ tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
1595
+ copy_atom_s2r = copy_utils.sm90_get_smem_load_op(c_layout, dtype)
1596
+ tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1597
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1598
+ tSR_sC = thr_copy_s2r.partition_S(sC)
1599
+ tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
1600
+ tSR_rC = thr_copy_s2r.retile(tRS_rC)
1601
+ return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
1602
+
1603
+ def epilog_gmem_copy_and_partition(
1604
+ self,
1605
+ atom: Union[cute.CopyAtom, cute.TiledCopy],
1606
+ mD_mn: cute.Tensor,
1607
+ tile_shape_mn: cute.Tile,
1608
+ epi_tile: cute.Tile,
1609
+ sD: cute.Tensor,
1610
+ tile_coord_mnkl: cute.Coord,
1611
+ tma_desc_ptr: Optional[cute.Pointer] = None,
1612
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
1613
+ # (bM, bN)
1614
+ gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
1615
+ tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
1616
+ is_s2g = isinstance(
1617
+ atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp)
1618
+ )
1619
+ src_tensor, dst_tensor = (
1620
+ (sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD)
1621
+ )
1622
+ return copy_utils.tma_get_copy_fn(
1623
+ atom,
1624
+ cta_coord=0,
1625
+ cta_layout=cute.make_layout(1),
1626
+ src_tensor=src_tensor,
1627
+ dst_tensor=dst_tensor,
1628
+ tma_desc_ptr=tma_desc_ptr,
1629
+ )
1630
+
1631
+ def make_ab_pipeline(
1632
+ self,
1633
+ tiled_mma: cute.TiledMma,
1634
+ cluster_layout_vmnk: cute.Layout,
1635
+ ab_pipeline_mbar_ptr: cute.Pointer,
1636
+ ):
1637
+ # Threads/warps participating in this pipeline
1638
+ producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32
1639
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1640
+ # Each warp will contribute to the arrive count with the number of mcast size
1641
+ mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
1642
+ consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
1643
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(
1644
+ pipeline.Agent.Thread, consumer_arrive_cnt
1645
+ )
1646
+ pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
1647
+ return pipeline_cls.create(
1648
+ barrier_storage=ab_pipeline_mbar_ptr,
1649
+ num_stages=self.ab_stage,
1650
+ producer_group=ab_pipeline_producer_group,
1651
+ consumer_group=ab_pipeline_consumer_group,
1652
+ tx_count=self.num_tma_load_bytes,
1653
+ cta_layout_vmnk=cluster_layout_vmnk,
1654
+ )
1655
+
1656
+ def make_epi_pipeline(
1657
+ self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer
1658
+ ):
1659
+ # Threads/warps participating in this pipeline
1660
+ epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1661
+ # Each warp will contribute 1 to the arrive count
1662
+ consumer_arrive_cnt = self.num_epi_warps
1663
+ epi_pipeline_consumer_group = pipeline.CooperativeGroup(
1664
+ pipeline.Agent.Thread, consumer_arrive_cnt
1665
+ )
1666
+ tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
1667
+ return pipeline.PipelineTmaAsync.create(
1668
+ barrier_storage=epi_pipeline_mbar_ptr,
1669
+ num_stages=self.epi_c_stage,
1670
+ producer_group=epi_pipeline_producer_group,
1671
+ consumer_group=epi_pipeline_consumer_group,
1672
+ tx_count=tma_copy_c_bytes,
1673
+ )
1674
+
1675
+ def make_epi_store_pipeline(self):
1676
+ # Threads/warps participating in tma store pipeline
1677
+ num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
1678
+ epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads)
1679
+ return pipeline.PipelineTmaStore.create(
1680
+ num_stages=self.epi_stage, producer_group=epi_store_producer_group
1681
+ )
1682
+
1683
+ def make_sched_pipeline(
1684
+ self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
1685
+ ):
1686
+ # Threads/warps participating in this pipeline
1687
+ sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1688
+ cluster_size = cute.size(cluster_layout_mnk)
1689
+ # Each warp that are not the scheduler warp will contribute 1 to the arrive count
1690
+ # If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
1691
+ # at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
1692
+ consumer_arrive_cnt = (
1693
+ (self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
1694
+ + self.num_ab_load_warps
1695
+ ) * cluster_size - 1
1696
+ sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1697
+ pipeline.Agent.Thread, consumer_arrive_cnt
1698
+ )
1699
+ return pipeline.PipelineAsync.create(
1700
+ barrier_storage=sched_pipeline_mbar_ptr,
1701
+ num_stages=self.sched_stage,
1702
+ producer_group=sched_pipeline_producer_group,
1703
+ consumer_group=sched_pipeline_consumer_group,
1704
+ # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1705
+ consumer_mask=None if const_expr(cluster_size == 1) else 0,
1706
+ )
1707
+
1708
+ @classmethod
1709
+ def _compute_stages(
1710
+ cls,
1711
+ cta_tile_shape_mnk: Tuple[int, int, int],
1712
+ epi_tile: Tuple[int, int],
1713
+ a_dtype: Type[cutlass.Numeric],
1714
+ b_dtype: Type[cutlass.Numeric],
1715
+ d_dtype: Optional[Type[cutlass.Numeric]],
1716
+ c_dtype: Optional[Type[cutlass.Numeric]],
1717
+ epilogue_args: EpilogueArguments,
1718
+ smem_capacity: int,
1719
+ occupancy: int,
1720
+ overlap_sD_sA: bool = False,
1721
+ ) -> Tuple[int, int]:
1722
+ """Computes the number of stages for A/B/C operands based on heuristics.
1723
+
1724
+ :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1725
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1726
+ :param a_dtype: Data type of operand A.
1727
+ :type a_dtype: type[cutlass.Numeric]
1728
+ :param b_dtype: Data type of operand B.
1729
+ :type b_dtype: type[cutlass.Numeric]
1730
+ :param smem_capacity: Total available shared memory capacity in bytes.
1731
+ :type smem_capacity: int
1732
+ :param occupancy: Target number of CTAs per SM (occupancy).
1733
+ :type occupancy: int
1734
+
1735
+ :return: A tuple containing the computed number of stages for:
1736
+ (A/B operand stages, epilogue stages)
1737
+ :rtype: Tuple[int, int]
1738
+ """
1739
+
1740
+ epi_stage = 4 if epi_tile[1] <= 16 else 2
1741
+ if overlap_sD_sA:
1742
+ epi_bytes = 0
1743
+ else:
1744
+ d_bytes_per_stage = (
1745
+ cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1746
+ )
1747
+ epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1748
+ epilogue_args, cta_tile_shape_mnk, epi_tile
1749
+ )
1750
+ epi_bytes = epi_bytes_per_stage * epi_stage
1751
+ epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
1752
+ if c_dtype is not None:
1753
+ epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
1754
+
1755
+ a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
1756
+ b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
1757
+ ab_bytes_per_stage = (
1758
+ cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
1759
+ )
1760
+ mbar_helpers_bytes = 1024
1761
+
1762
+ remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
1763
+ ab_stage = remaining_bytes // ab_bytes_per_stage
1764
+
1765
+ # Refine epilogue stages:
1766
+ # Calculate remaining smem after allocating for A/B stages and reserved bytes
1767
+ # Add remaining unused smem to epilogue
1768
+ if not overlap_sD_sA and epi_bytes_per_stage > 0:
1769
+ epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
1770
+ return ab_stage, epi_stage, epi_c_stage
1771
+
1772
+ @staticmethod
1773
+ def _sm90_compute_tile_shape_or_override(
1774
+ cta_tile_shape_mnk: Tuple[int, int, int],
1775
+ atom_layout_mnk: Tuple[int, int, int],
1776
+ element_type: Optional[Type[cutlass.Numeric]] = None,
1777
+ epi_tile_override: Tuple[int, int] | None = None,
1778
+ ) -> Tuple[int, int]:
1779
+ """Compute the epilogue tile shape or use override if provided.
1780
+
1781
+ :param cta_tile_shape_mnk: CTA tile shape (M,N,K)
1782
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1783
+ :param element_type: Data type of elements
1784
+ :type element_type: type[cutlass.Numeric]
1785
+ :param is_cooperative: Whether to use cooperative approach
1786
+ :type is_cooperative: bool
1787
+ :param epi_tile_override: Optional override for epilogue tile shape
1788
+ :type epi_tile_override: Tuple[int, int] or None
1789
+
1790
+ :return: Computed epilogue tile shape
1791
+ :rtype: Tuple[int, int]
1792
+ """
1793
+ if epi_tile_override is not None:
1794
+ return epi_tile_override
1795
+ if cta_tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
1796
+ tile_m = math.gcd(128, cute.size(cta_tile_shape_mnk, mode=[0]))
1797
+ tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
1798
+ elif cta_tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
1799
+ tile_m = math.gcd(192, cute.size(cta_tile_shape_mnk, mode=[0]))
1800
+ tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
1801
+ else:
1802
+ # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
1803
+ # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
1804
+ # M dimension first, then move to the N dimension. But the accumulator in registers
1805
+ # iterate along the N dimension first, then move to the M dimension.
1806
+ # We could change the epilogue to accommodate this,
1807
+ # but it's easier to just set epi_tile_m = 64.
1808
+ n_perf = 64 if element_type is not None and element_type.width == 8 else 32
1809
+ tile_m = math.gcd(64, cute.size(cta_tile_shape_mnk, mode=[0]))
1810
+ tile_n = math.gcd(n_perf, cute.size(cta_tile_shape_mnk, mode=[1]))
1811
+ return (tile_m, tile_n)
1812
+
1813
+ @staticmethod
1814
+ def _make_smem_layouts(
1815
+ cta_tile_shape_mnk: Tuple[int, int, int],
1816
+ epi_tile: Tuple[int, int],
1817
+ a_dtype: Type[cutlass.Numeric],
1818
+ a_layout: LayoutEnum,
1819
+ b_dtype: Type[cutlass.Numeric],
1820
+ b_layout: LayoutEnum,
1821
+ ab_stage: int,
1822
+ d_dtype: Optional[Type[cutlass.Numeric]],
1823
+ d_layout: LayoutEnum,
1824
+ epi_stage: int,
1825
+ c_dtype: Optional[Type[cutlass.Numeric]],
1826
+ c_layout: Optional[LayoutEnum],
1827
+ epi_c_stage: int,
1828
+ ) -> Tuple[
1829
+ cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
1830
+ ]:
1831
+ """Create shared memory layouts for A, B, and C tensors.
1832
+
1833
+ :param cta_tile_shape_mnk: CTA tile shape (M,N,K)
1834
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1835
+ :param epi_tile: Epilogue tile shape
1836
+ :type epi_tile: Tuple[int, int]
1837
+ :param a_dtype: Data type for matrix A
1838
+ :type a_dtype: type[cutlass.Numeric]
1839
+ :param a_layout: Layout enum for matrix A
1840
+ :type a_layout: LayoutEnum
1841
+ :param b_dtype: Data type for matrix B
1842
+ :type b_dtype: type[cutlass.Numeric]
1843
+ :param b_layout: Layout enum for matrix B
1844
+ :type b_layout: LayoutEnum
1845
+ :param ab_stage: Number of stages for A/B tensors
1846
+ :type ab_stage: int
1847
+ :param d_dtype: Data type for output matrix D
1848
+ :type d_dtype: type[cutlass.Numeric]
1849
+ :param d_layout: Layout enum for the output matrix C
1850
+ :type d_layout: LayoutEnum
1851
+ :param epi_stage: Number of epilogue stages
1852
+ :type epi_stage: int
1853
+
1854
+ :return: Tuple of shared memory layouts for A, B, and C
1855
+ :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
1856
+ """
1857
+ a_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
1858
+
1859
+ a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1860
+ b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1861
+ a_major_mode_size = cta_tile_shape_mnk[2 if a_is_k_major else 0]
1862
+ a_smem_layout_atom = warpgroup.make_smem_layout_atom(
1863
+ sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
1864
+ a_dtype,
1865
+ )
1866
+ a_smem_layout_staged = cute.tile_to_shape(
1867
+ a_smem_layout_atom,
1868
+ cute.append(a_smem_shape, ab_stage),
1869
+ order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
1870
+ )
1871
+
1872
+ b_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
1873
+
1874
+ b_major_mode_size = cta_tile_shape_mnk[2 if b_is_k_major else 1]
1875
+ b_smem_layout_atom = warpgroup.make_smem_layout_atom(
1876
+ sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
1877
+ b_dtype,
1878
+ )
1879
+ b_smem_layout_staged = cute.tile_to_shape(
1880
+ b_smem_layout_atom,
1881
+ cute.append(b_smem_shape, ab_stage),
1882
+ order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
1883
+ )
1884
+
1885
+ epi_smem_layout_staged = None
1886
+ if d_dtype is not None:
1887
+ epi_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
1888
+ d_dtype, d_layout, epi_tile, epi_stage
1889
+ )
1890
+
1891
+ epi_c_smem_layout_staged = None
1892
+ if c_dtype is not None:
1893
+ assert c_layout is not None
1894
+ epi_c_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
1895
+ c_dtype, c_layout, epi_tile, epi_c_stage
1896
+ )
1897
+
1898
+ return (
1899
+ a_smem_layout_staged,
1900
+ b_smem_layout_staged,
1901
+ epi_smem_layout_staged,
1902
+ epi_c_smem_layout_staged,
1903
+ )
1904
+
1905
+ @staticmethod
1906
+ def _make_tma_epi_atoms_and_tensors(
1907
+ tensor_d: cute.Tensor,
1908
+ epi_smem_layout_staged: cute.ComposedLayout,
1909
+ epi_tile: Tuple[int, int],
1910
+ op_type: Literal["store", "load", "add"],
1911
+ ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1912
+ """Create TMA atoms and tensors for storing D or loading C.
1913
+
1914
+ :param tensor_d: Output tensor D
1915
+ :type tensor_d: cute.Tensor
1916
+ :param epi_smem_layout_staged: Shared memory layout for epilogue
1917
+ :type epi_smem_layout_staged: cute.ComposedLayout
1918
+ :param epi_tile: Epilogue tile shape
1919
+ :type epi_tile: Tuple[int, int]
1920
+
1921
+ :return: TMA atom and tensor for C
1922
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1923
+ """
1924
+ assert op_type in ["load", "store", "add"]
1925
+ epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
1926
+ d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
1927
+ op = (
1928
+ cpasync.CopyBulkTensorTileG2SOp()
1929
+ if op_type == "load"
1930
+ else cpasync.CopyBulkTensorTileS2GOp()
1931
+ if op_type == "store"
1932
+ else cpasync.CopyReduceBulkTensorTileS2GOp(cute.ReductionOp.ADD)
1933
+ )
1934
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
1935
+ op, tensor_d, epi_smem_layout, d_cta_v_layout
1936
+ )
1937
+ return tma_atom_d, tma_tensor_d
1938
+
1939
+ @staticmethod
1940
+ def _make_tma_atoms_and_tensors(
1941
+ tensor: cute.Tensor,
1942
+ smem_layout: cute.ComposedLayout,
1943
+ smem_tile: Tuple[int, int],
1944
+ mcast_dim: int,
1945
+ ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1946
+ """Create TMA atoms and tensors for input tensors.
1947
+
1948
+ :param tensor: Input tensor (A or B)
1949
+ :type tensor: cute.Tensor
1950
+ :param smem_layout: Shared memory layout for the tensor
1951
+ :type smem_layout: cute.ComposedLayout
1952
+ :param smem_tile: Shared memory tile shape
1953
+ :type smem_tile: Tuple[int, int]
1954
+ :param mcast_dim: Multicast dimension
1955
+ :type mcast_dim: int
1956
+
1957
+ :return: TMA atom and tensor
1958
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1959
+ """
1960
+ op = (
1961
+ cpasync.CopyBulkTensorTileG2SOp()
1962
+ if mcast_dim == 1
1963
+ else cpasync.CopyBulkTensorTileG2SMulticastOp()
1964
+ )
1965
+ tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
1966
+ op,
1967
+ tensor,
1968
+ smem_layout,
1969
+ smem_tile,
1970
+ num_multicast=mcast_dim,
1971
+ )
1972
+ return tma_atom, tma_tensor
1973
+
1974
+ def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128):
1975
+ atom_async_copy = cute.make_copy_atom(
1976
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
1977
+ dtype,
1978
+ num_bits_per_copy=copy_bits,
1979
+ )
1980
+ copy_elems = copy_bits // dtype.width
1981
+ loads_per_cache_line = 128 * 8 // copy_bits # 128 bytes per cache line
1982
+ shape_dim_1 = cute.size(self.cta_tile_shape_mnk[2]) // copy_elems
1983
+ if shape_dim_1 > loads_per_cache_line:
1984
+ shape_dim_1 = math.gcd(shape_dim_1, loads_per_cache_line)
1985
+ # thread layout for copy
1986
+ thread_layout = cute.make_layout(
1987
+ (num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
1988
+ )
1989
+ if major_mode != LayoutEnum.ROW_MAJOR:
1990
+ shape_dim_0 = cute.size(self.cta_tile_shape_mnk[0]) // copy_elems
1991
+ if shape_dim_0 > loads_per_cache_line:
1992
+ shape_dim_0 = math.gcd(shape_dim_0, loads_per_cache_line)
1993
+ thread_layout = cute.make_layout(
1994
+ (shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
1995
+ )
1996
+ # Value layout for copy
1997
+ value_layout = (
1998
+ cute.make_layout((1, copy_elems))
1999
+ if major_mode == LayoutEnum.ROW_MAJOR
2000
+ else cute.make_layout((copy_elems, 1))
2001
+ )
2002
+ return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
2003
+
2004
+ @staticmethod
2005
+ def is_valid_dtypes(
2006
+ a_dtype: Type[cutlass.Numeric],
2007
+ b_dtype: Type[cutlass.Numeric],
2008
+ acc_dtype: Type[cutlass.Numeric],
2009
+ d_dtype: Optional[Type[cutlass.Numeric]],
2010
+ a_major: str,
2011
+ b_major: str,
2012
+ ) -> bool:
2013
+ """
2014
+ Check if the dtypes are valid
2015
+
2016
+ :param a_dtype: The data type of tensor A
2017
+ :type a_dtype: Type[cutlass.Numeric]
2018
+ :param b_dtype: The data type of tensor B
2019
+ :type b_dtype: Type[cutlass.Numeric]
2020
+ :param acc_dtype: The data type of the accumulator
2021
+ :type acc_dtype: Type[cutlass.Numeric]
2022
+ :param d_dtype: The data type of the output tensor
2023
+ :type d_dtype: Type[cutlass.Numeric]
2024
+ :param a_major: major mode of tensor A
2025
+ :type a_major: str
2026
+ :param b_major: major mode of tensor B
2027
+ :type b_major: str
2028
+
2029
+ :return: True if the dtypes are valid, False otherwise
2030
+ :rtype: bool
2031
+ """
2032
+ is_valid = True
2033
+ if a_dtype not in {
2034
+ Float16,
2035
+ cutlass.BFloat16,
2036
+ cutlass.Float8E4M3FN,
2037
+ cutlass.Float8E5M2,
2038
+ }:
2039
+ is_valid = False
2040
+ # tested b_dtype
2041
+ if b_dtype not in {
2042
+ Float16,
2043
+ cutlass.BFloat16,
2044
+ cutlass.Float8E4M3FN,
2045
+ cutlass.Float8E5M2,
2046
+ }:
2047
+ is_valid = False
2048
+ if acc_dtype not in {Float32, Float16}:
2049
+ is_valid = False
2050
+ # tested d_dtype
2051
+ if d_dtype not in {
2052
+ None,
2053
+ Float32,
2054
+ Float16,
2055
+ cutlass.BFloat16,
2056
+ cutlass.Float8E4M3FN,
2057
+ cutlass.Float8E5M2,
2058
+ }:
2059
+ is_valid = False
2060
+ # make sure a_dtype == b_dtype for Float16
2061
+ if a_dtype.width == 16 and a_dtype != b_dtype:
2062
+ is_valid = False
2063
+ # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2)
2064
+ if a_dtype.width != b_dtype.width:
2065
+ is_valid = False
2066
+
2067
+ # for Float8 types, this implementation only supports k-major layout
2068
+ if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
2069
+ is_valid = False
2070
+ return is_valid
build/torch-cuda/quack/gemm_symmetric.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Callable
2
+ from functools import partial
3
+ from torch import Tensor
4
+ from .gemm_act import GemmActMixin, act_fn_map, gemm_act
5
+ from .gemm_sm90 import GemmSm90
6
+ from .gemm_sm100 import GemmSm100
7
+ from .tile_scheduler import TriangularTileScheduler
8
+ from .gemm_wrapper_utils import GemmWrapperBase
9
+ from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
10
+ from .varlen_utils import VarlenManager
11
+ from . import copy_utils as copy_utils
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+ import cutlass.torch as cutlass_torch
15
+ from cutlass.cute.runtime import make_ptr
16
+ from cutlass import Int32, Float32, Boolean, const_expr
17
+ import cutlass.utils.hopper_helpers as sm90_utils_og
18
+ import cutlass.utils.blackwell_helpers as sm100_utils
19
+ from cutlass.cutlass_dsl import if_generate
20
+
21
+
22
+ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
23
+ def get_scheduler_class(self, varlen_m: bool = False):
24
+ return TriangularTileScheduler
25
+
26
+ @cute.jit
27
+ def epilogue(
28
+ self,
29
+ params: GemmActMixin.EpilogueParams,
30
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
31
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
32
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
33
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
34
+ epi_read_state: cutlass.pipeline.PipelineState,
35
+ epi_producer_state: cutlass.pipeline.PipelineState,
36
+ epi_tile: cute.Tile,
37
+ load_acc_subtile: Callable,
38
+ tRS_rD: cute.Tensor,
39
+ tRS_rC: Optional[cute.Tensor],
40
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
41
+ tiled_copy_r2s: cute.TiledCopy,
42
+ tRS_sD: cute.Tensor,
43
+ tiled_copy_s2r: Optional[cute.TiledCopy],
44
+ tSR_rC: Optional[cute.Tensor],
45
+ tSR_sC: Optional[cute.Tensor],
46
+ copy_D: Optional[Callable],
47
+ copy_C: Optional[Callable],
48
+ tile_coord_mnkl: cute.Coord,
49
+ varlen_manager: VarlenManager,
50
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
51
+ tile_scheduler,
52
+ tidx: Int32,
53
+ is_tma_warp: Boolean,
54
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
55
+ has_C = const_expr(tRS_rC is not None)
56
+ has_D = const_expr(copy_D is not None)
57
+
58
+ tma_atom_postact = params.tma_atom_postact
59
+ mPostAct_mnl = params.mPostAct_mnl
60
+ sRowVec, sColVec, sPostAct = epi_smem_tensors
61
+ get_smem_store_op = (
62
+ partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
63
+ if self.arch == 100
64
+ else sm90_utils_og.sm90_get_smem_store_op
65
+ )
66
+ copy_atom_postact_r2s = get_smem_store_op(
67
+ self.postact_layout, self.postact_dtype, self.acc_dtype
68
+ )
69
+ # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
70
+ # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
71
+ tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
72
+ tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
73
+ (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
74
+ batch_idx = tile_coord_mnkl[3]
75
+ copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
76
+ tma_atom_postact,
77
+ varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
78
+ self.cta_tile_shape_postact_mn,
79
+ params.epi_tile_postact,
80
+ sPostAct,
81
+ tile_coord_mnkl,
82
+ tma_desc_ptr=tma_desc_postact_ptr,
83
+ )
84
+
85
+ # We iterate over epi tiles in the N dimension first before the M dimension
86
+ epi_tile_shape = cute.zipped_divide(
87
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
88
+ ).shape[1]
89
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
90
+ epi_tile_num = cute.size(epi_tile_shape)
91
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
92
+
93
+ epi_tensors = self.epi_begin(
94
+ params,
95
+ epi_smem_tensors,
96
+ epi_tile,
97
+ tiled_copy_t2r,
98
+ tiled_copy_r2s,
99
+ tile_coord_mnkl,
100
+ varlen_manager,
101
+ epilogue_barrier,
102
+ tidx,
103
+ )
104
+
105
+ if const_expr(copy_C is not None):
106
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
107
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
108
+ if is_tma_warp:
109
+ epi_pipeline.producer_acquire(epi_producer_state)
110
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
111
+ epi_pipeline.producer_commit(epi_producer_state)
112
+ epi_producer_state.advance()
113
+
114
+ def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl):
115
+ pid_m = tile_coord_mnkl[0]
116
+ pid_n = tile_coord_mnkl[1]
117
+ # Fence and barrier to make sure shared memory store is visible to TMA store
118
+ cute.arch.fence_proxy(
119
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
120
+ )
121
+ epilogue_barrier.arrive_and_wait()
122
+ # Copy from shared memory to global memory
123
+ if is_tma_warp:
124
+ square_tile_m = pid_m // self.cluster_shape_mnk[0]
125
+ square_tile_n = pid_n // self.cluster_shape_mnk[1]
126
+ if const_expr(has_D):
127
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
128
+ if square_tile_m != square_tile_n: # don't write twice to the same tile
129
+ copy_postact(src_idx=src_idx, dst_idx=dst_idx)
130
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
131
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
132
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
133
+ epilogue_barrier.arrive_and_wait()
134
+
135
+ delay_tma_store = True
136
+
137
+ src_idx_prev, dst_idx_prev = None, None
138
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
139
+ # The global memory coordinate for the current epi tile
140
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
141
+ # Copy from acc to D registers
142
+ load_acc_subtile(tRS_rD, epi_idx)
143
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
144
+ if const_expr(has_C):
145
+ epi_pipeline.consumer_wait(epi_read_state)
146
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
147
+ # Fence to make sure shared memory read is visible to TMA load
148
+ cute.arch.fence_proxy(
149
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
150
+ )
151
+ cute.arch.sync_warp()
152
+ with cute.arch.elect_one():
153
+ epi_pipeline.consumer_release(epi_read_state)
154
+ epi_read_state.advance()
155
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
156
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
157
+ if is_tma_warp:
158
+ epi_pipeline.producer_acquire(epi_producer_state)
159
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
160
+ epi_pipeline.producer_commit(epi_producer_state)
161
+ epi_producer_state.advance()
162
+ tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
163
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
164
+ if const_expr(delay_tma_store):
165
+ if const_expr(epi_idx > 0):
166
+ tma_store_fn(
167
+ src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
168
+ )
169
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
170
+ # Copy from D registers to shared memory
171
+ if const_expr(has_D):
172
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
173
+ cute.copy(
174
+ tiled_copy_postact_r2s,
175
+ tiled_copy_postact_r2s.retile(tRS_rPostAct),
176
+ tRS_sPostAct[None, None, None, epi_buffer],
177
+ )
178
+ if const_expr(not delay_tma_store):
179
+ tma_store_fn(
180
+ src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl
181
+ )
182
+
183
+ if const_expr(delay_tma_store):
184
+ tma_store_fn(
185
+ src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
186
+ )
187
+
188
+ self.epi_end(
189
+ params,
190
+ epi_tensors,
191
+ epi_tile,
192
+ tiled_copy_t2r,
193
+ tiled_copy_r2s,
194
+ tile_coord_mnkl,
195
+ varlen_manager,
196
+ tidx,
197
+ )
198
+
199
+ return epi_read_state, epi_producer_state
200
+
201
+
202
+ class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90):
203
+ pass
204
+
205
+
206
+ class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
207
+ pass
208
+
209
+
210
+ def gemm_symmetric(
211
+ A: Tensor, # (l, m, k)
212
+ B: Tensor, # (l, m, k)
213
+ D: Optional[Tensor], # (l, m, m)
214
+ C: Optional[Tensor], # (l, m, m)
215
+ tile_count_semaphore: Optional[Tensor], # (1,)
216
+ tile_M: int,
217
+ tile_N: int,
218
+ cluster_M: int,
219
+ cluster_N: int,
220
+ pingpong: bool = False,
221
+ persistent: bool = True,
222
+ max_swizzle_size: int = 8,
223
+ alpha: float | Tensor = 1.0,
224
+ beta: float | Tensor = 1.0,
225
+ ) -> None:
226
+ # Tranpose D so the "activation" is a write to the mirrored tile
227
+ PostAct = D.mT
228
+
229
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
230
+ A, B, D, C, additional_tensors={"PostAct": PostAct}
231
+ )
232
+ assert M == N, "M and N must be the same; symmetric gemm only supports square matrices"
233
+ GemmWrapperBase.permute_tensors(tensor_infos)
234
+ GemmWrapperBase.extract_dtypes(tensor_infos)
235
+ major_configs = {
236
+ "A": ("m", "k", "l"),
237
+ "B": ("n", "k", "l"),
238
+ "D": ("m", "n", "l"),
239
+ "C": ("m", "n", "l"),
240
+ "PostAct": ("m", "n", "l"),
241
+ }
242
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
243
+
244
+ device_capacity = get_device_capacity(A.device)
245
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
246
+ GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100
247
+
248
+ acc_dtype = Float32
249
+ tile_shape_mn = (tile_M, tile_N)
250
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
251
+ if not GemmCls.is_valid_dtypes(
252
+ tensor_infos["A"].dtype,
253
+ tensor_infos["B"].dtype,
254
+ acc_dtype,
255
+ tensor_infos["D"].dtype,
256
+ tensor_infos["A"].major,
257
+ tensor_infos["B"].major,
258
+ ):
259
+ raise TypeError("Skipping due to unsupported combination of types and majors")
260
+
261
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
262
+ GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs)
263
+
264
+ def scalar_arg(scalar: float | Tensor):
265
+ if isinstance(scalar, float):
266
+ return Float32(scalar) if scalar != 1.0 else None
267
+ else:
268
+ assert isinstance(scalar, Tensor)
269
+ return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
270
+
271
+ activation = None # Equivalent to identity
272
+ act_fn = act_fn_map[activation]
273
+ epi_args = GemmCls.EpilogueArguments(
274
+ tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta)
275
+ )
276
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
277
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
278
+ )
279
+ varlen_args = None
280
+
281
+ current_stream = cutlass_torch.current_stream()
282
+ compile_key = GemmWrapperBase.get_compile_key(
283
+ tensor_infos,
284
+ activation,
285
+ tile_shape_mn,
286
+ cluster_shape_mnk,
287
+ pingpong,
288
+ persistent,
289
+ tile_count_semaphore is not None,
290
+ device_capacity,
291
+ max_swizzle_size,
292
+ 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
293
+ 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
294
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
295
+ )
296
+ cache = gemm_act.compile_cache
297
+ if compile_key not in cache:
298
+ if device_capacity[0] == 9:
299
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
300
+ gemm_obj = GemmCls(
301
+ acc_dtype,
302
+ tensor_infos["A"].dtype,
303
+ tile_shape_mn,
304
+ cluster_shape_mnk,
305
+ gather_A=False,
306
+ )
307
+ cache[compile_key] = cute.compile(
308
+ gemm_obj,
309
+ tensor_infos["A"].cute_tensor,
310
+ tensor_infos["B"].cute_tensor,
311
+ tensor_infos["D"].cute_tensor,
312
+ tensor_infos["C"].cute_tensor,
313
+ epi_args,
314
+ scheduler_args,
315
+ varlen_args,
316
+ current_stream,
317
+ )
318
+ cache[compile_key](
319
+ tensor_infos["A"].cute_tensor,
320
+ tensor_infos["B"].cute_tensor,
321
+ tensor_infos["D"].cute_tensor,
322
+ tensor_infos["C"].cute_tensor,
323
+ epi_args,
324
+ scheduler_args,
325
+ varlen_args,
326
+ current_stream,
327
+ )
328
+
329
+
330
+ gemm_act.compile_cache = {}
build/torch-cuda/quack/gemm_wrapper_utils.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Optional, Tuple, Dict, Any
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32
10
+ from cutlass.cute.runtime import from_dlpack, make_ptr
11
+
12
+ from .cute_dsl_utils import torch2cute_dtype_map
13
+ from .varlen_utils import VarlenArguments
14
+ from .tile_scheduler import TileSchedulerOptions
15
+
16
+
17
+ @dataclass
18
+ class GemmTensorInfo:
19
+ tensor: Optional[Tensor]
20
+ dtype: Optional[Any] = None
21
+ major: Optional[str] = None
22
+ cute_tensor: Optional[cute.Tensor] = None
23
+
24
+
25
+ class GemmWrapperBase:
26
+ @staticmethod
27
+ def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None:
28
+ assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor"
29
+ assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
30
+
31
+ @staticmethod
32
+ def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None:
33
+ assert tensor.shape == expected_shape, (
34
+ f"{name} must have shape {expected_shape}, got {tensor.shape}"
35
+ )
36
+
37
+ @staticmethod
38
+ def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str:
39
+ # Tensor is already permuted to (dims[0], dims[1], dims[2])
40
+ # stride(1) == 1 means dims[1] is contiguous (innermost)
41
+ return dims[1] if tensor.stride(1) == 1 else dims[0]
42
+
43
+ @staticmethod
44
+ def create_cute_tensor(
45
+ tensor: Optional[Tensor],
46
+ major: Optional[str],
47
+ dims: Tuple[str, str, str],
48
+ assumed_align: int = 16,
49
+ ) -> Optional[cute.Tensor]:
50
+ if tensor is None:
51
+ return None
52
+ # Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1])
53
+ # If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
54
+ leading_dim = 1 if major == dims[1] else 0
55
+ return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
56
+ leading_dim=leading_dim
57
+ )
58
+
59
+ @staticmethod
60
+ def validate_and_prepare_tensors(
61
+ A: Tensor,
62
+ B: Tensor,
63
+ D: Optional[Tensor] = None,
64
+ C: Optional[Tensor] = None,
65
+ additional_tensors: Optional[Dict[str, Tensor]] = None,
66
+ cu_seqlens_m: Optional[Tensor] = None,
67
+ cu_seqlens_k: Optional[Tensor] = None,
68
+ A_idx: Optional[Tensor] = None,
69
+ ) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
70
+ assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
71
+ "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
72
+ )
73
+ assert B.dtype == A.dtype, "A and B must have the same dtype"
74
+
75
+ # Validate A_idx if provided (for gather_A case)
76
+ gather_A = A_idx is not None
77
+ if gather_A:
78
+ assert cu_seqlens_m is not None or cu_seqlens_k is not None, (
79
+ "gather_A requires either varlen_m or varlen_k"
80
+ )
81
+ assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}"
82
+ assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D"
83
+
84
+ # Determine mode and extract dimensions
85
+ if cu_seqlens_m is not None:
86
+ # varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n)
87
+ assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D"
88
+ assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D"
89
+
90
+ if gather_A:
91
+ # When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M
92
+ total_M = A_idx.shape[0]
93
+ _, K = A.shape
94
+ else:
95
+ total_M, K = A.shape
96
+
97
+ L, N, K_B = B.shape
98
+ assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
99
+ assert cu_seqlens_m.shape == (L + 1,), (
100
+ f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}"
101
+ )
102
+ M = total_M
103
+ dc_shape = (total_M, N)
104
+ dc_ndim = 2
105
+ elif cu_seqlens_k is not None:
106
+ # varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n)
107
+ assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D"
108
+ assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D"
109
+
110
+ if gather_A:
111
+ # When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K
112
+ M, _ = A.shape
113
+ total_K = A_idx.shape[0]
114
+ else:
115
+ M, total_K = A.shape
116
+
117
+ N, K_B = B.shape
118
+ assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}"
119
+ L = cu_seqlens_k.shape[0] - 1
120
+ assert cu_seqlens_k.shape == (L + 1,), (
121
+ f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}"
122
+ )
123
+ K = total_K
124
+ dc_shape = (L, M, N)
125
+ dc_ndim = 3
126
+ else:
127
+ # Normal case - all tensors must be 3D
128
+ GemmWrapperBase.validate_tensor(A, "A", 3)
129
+ GemmWrapperBase.validate_tensor(B, "B", 3)
130
+ L, M, K = A.shape
131
+ _, N, K_B = B.shape
132
+ assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
133
+ GemmWrapperBase.validate_shape(B, (L, N, K), "B")
134
+ dc_shape = (L, M, N)
135
+ dc_ndim = 3
136
+
137
+ # Validate D and C shapes uniformly
138
+ for tensor, name in [(D, "D"), (C, "C")]:
139
+ if tensor is not None:
140
+ assert tensor.dim() == dc_ndim, (
141
+ f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
142
+ )
143
+ assert tensor.shape == dc_shape, (
144
+ f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
145
+ )
146
+
147
+ tensors = {
148
+ "A": GemmTensorInfo(A),
149
+ "B": GemmTensorInfo(B),
150
+ "D": GemmTensorInfo(D),
151
+ "C": GemmTensorInfo(C),
152
+ }
153
+
154
+ if additional_tensors:
155
+ for name, tensor in additional_tensors.items():
156
+ if tensor is not None:
157
+ assert tensor.dim() == dc_ndim, (
158
+ f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
159
+ )
160
+ assert tensor.shape == dc_shape, (
161
+ f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
162
+ )
163
+ tensors[name] = GemmTensorInfo(tensor)
164
+
165
+ return L, M, K, N, tensors
166
+
167
+ @staticmethod
168
+ def permute_tensors(
169
+ tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False
170
+ ) -> None:
171
+ # Determine which tensors need permutation
172
+ if varlen_m:
173
+ # Only B needs permutation (3D tensor)
174
+ tensors_to_permute = ["B"]
175
+ elif varlen_k:
176
+ # Only D and C need permutation (3D tensors)
177
+ tensors_to_permute = ["D", "C"]
178
+ else:
179
+ # All tensors need permutation
180
+ tensors_to_permute = None
181
+
182
+ # Apply permutation from (L, *, *) -> (*, *, L) for selected tensors
183
+ for name, info in tensors.items():
184
+ if info.tensor is not None and info.tensor.ndim == 3:
185
+ if tensors_to_permute is None or name in tensors_to_permute:
186
+ info.tensor = info.tensor.permute(1, 2, 0)
187
+
188
+ @staticmethod
189
+ def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
190
+ for name, info in tensors.items():
191
+ if info.tensor is not None:
192
+ info.dtype = torch2cute_dtype_map[info.tensor.dtype]
193
+
194
+ @staticmethod
195
+ def determine_major_orders(
196
+ tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
197
+ ) -> None:
198
+ for name, dims in major_configs.items():
199
+ if name in tensors and tensors[name].tensor is not None:
200
+ tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims)
201
+
202
+ @staticmethod
203
+ def create_cute_tensors(
204
+ tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
205
+ ) -> None:
206
+ for name, info in tensors.items():
207
+ if info.tensor is not None and name in major_configs:
208
+ info.cute_tensor = GemmWrapperBase.create_cute_tensor(
209
+ info.tensor, info.major, major_configs[name]
210
+ )
211
+
212
+ @staticmethod
213
+ def create_scheduler_args(
214
+ max_active_clusters: int,
215
+ tile_count_semaphore: Optional[Tensor] = None,
216
+ batch_idx_permute: Optional[Tensor] = None,
217
+ max_swizzle_size: int = 8,
218
+ ) -> TileSchedulerOptions:
219
+ return TileSchedulerOptions(
220
+ Int32(max_active_clusters),
221
+ tile_count_semaphore=make_ptr(
222
+ Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
223
+ )
224
+ if tile_count_semaphore is not None
225
+ else None,
226
+ batch_idx_permute=(
227
+ from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0)
228
+ )
229
+ if batch_idx_permute is not None
230
+ else None,
231
+ max_swizzle_size=Int32(max_swizzle_size),
232
+ )
233
+
234
+ @staticmethod
235
+ def create_varlen_args(
236
+ cu_seqlens_m: Optional[Tensor],
237
+ cu_seqlens_k: Optional[Tensor],
238
+ A_idx: Optional[Tensor],
239
+ max_active_clusters: int,
240
+ cluster_shape_mnk: Tuple[int, int, int],
241
+ tensors: Dict[str, GemmTensorInfo],
242
+ num_epi_tensormaps: int = 0,
243
+ pingpong: bool = False,
244
+ ) -> Optional[Any]:
245
+ if cu_seqlens_m is None and cu_seqlens_k is None:
246
+ return None
247
+ # When varlen_m, we assume persistent=True
248
+ # Grid size depends on num_active_clusters and cluster size
249
+ cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1]
250
+ num_blocks = max_active_clusters * cluster_size
251
+ # Calculate number of tensormaps needed
252
+ if cu_seqlens_m is not None:
253
+ # For varlen_m: need tensormaps for D and epilogue tensors
254
+ num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2)
255
+ if tensors["D"].tensor is not None:
256
+ num_tensormaps += 1 if not pingpong else 2 # D tensormap
257
+ else:
258
+ # For varlen_k: need tensormaps for A & B
259
+ num_tensormaps = 2 if A_idx is None else 1
260
+ # Create tensormap buffer (each tensormap is 128 bytes = 16 int64s)
261
+ tensormap_size = 128 // 8 # 16 int64s
262
+ if num_tensormaps > 0:
263
+ device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device
264
+ tensormaps = torch.empty(
265
+ (num_blocks, num_tensormaps, tensormap_size),
266
+ dtype=torch.int64,
267
+ device=device,
268
+ )
269
+ tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic(
270
+ mode=0, stride_order=(0, 1, 2)
271
+ )
272
+ else:
273
+ tensormaps_cute = None
274
+
275
+ return VarlenArguments(
276
+ mCuSeqlensM=(
277
+ from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0)
278
+ if cu_seqlens_m is not None
279
+ else None
280
+ ),
281
+ mCuSeqlensK=(
282
+ from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0)
283
+ if cu_seqlens_k is not None
284
+ else None
285
+ ),
286
+ mTensormaps=tensormaps_cute,
287
+ mAIdx=(
288
+ from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0)
289
+ if A_idx is not None
290
+ else None
291
+ ),
292
+ )
293
+
294
+ @staticmethod
295
+ def get_compile_key(
296
+ tensors: Dict[str, GemmTensorInfo],
297
+ activation: Optional[str],
298
+ tile_shape_mn: Tuple[int, int],
299
+ cluster_shape_mnk: Tuple[int, int, int],
300
+ pingpong: bool,
301
+ persistent: bool,
302
+ has_semaphore: bool,
303
+ *args,
304
+ key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"),
305
+ ) -> Tuple:
306
+ key_parts = []
307
+ for name in key_tensor_names:
308
+ if name in tensors:
309
+ key_parts.append(tensors[name].dtype)
310
+ key_parts.append(activation)
311
+ key_parts.extend([tile_shape_mn, cluster_shape_mnk])
312
+ for name in key_tensor_names:
313
+ if name in tensors:
314
+ key_parts.append(tensors[name].major)
315
+ key_parts.extend([pingpong, persistent, has_semaphore])
316
+ key_parts.extend(args)
317
+ return tuple(key_parts)
build/torch-cuda/quack/layout_utils.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+
7
+ from cutlass import Int32, const_expr
8
+
9
+ from .utils import prmt
10
+
11
+
12
+ def transpose_view(a: cute.Tensor) -> cute.Tensor:
13
+ """Transpose the first two dimensions of a tensor on smem."""
14
+ shape = (a.shape[1], a.shape[0], *a.shape[2:])
15
+ order = (1, 0, *range(2, cute.rank(a)))
16
+ return cute.composition(a, cute.make_ordered_layout(shape, order=order))
17
+
18
+
19
+ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
20
+ return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
21
+
22
+
23
+ def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
24
+ shape = (*a.shape[:dim], size, *a.shape[dim:])
25
+ stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
26
+ return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
27
+
28
+
29
+ @cute.jit
30
+ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
31
+ assert t.element_type.width == 16
32
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
33
+ t_u32 = cute.recast_tensor(t, Int32)
34
+
35
+ quad_idx = cute.arch.lane_idx() % 4
36
+ lane_03 = quad_idx == 0 or quad_idx == 3
37
+ selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
38
+ selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
39
+ # upper_map = [0, 3, 1, 2]
40
+ # lower_map = [1, 2, 0, 3]
41
+ # upper_idx = upper_map[quad_idx]
42
+ # indexing isn't supported so we have to do arithmetic
43
+ upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
44
+ lower_idx = upper_idx ^ 1
45
+
46
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
47
+ width = 4
48
+ mask = cute.arch.WARP_SIZE - width
49
+ clamp = cute.arch.WARP_SIZE - 1
50
+ mask_and_clamp = mask << 8 | clamp
51
+
52
+ for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
53
+ upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
54
+ upper0 = upper if lane_03 else lower
55
+ lower0 = lower if lane_03 else upper
56
+ upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
57
+ lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
58
+ t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
59
+ t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
60
+
61
+
62
+ @cute.jit
63
+ def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
64
+ """Permute and shuffle within 4 threads to change the layout from
65
+ T0 | T1 | T2 | T3
66
+ a b | c d | e f | g h
67
+ to
68
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
69
+ a | b | c | d | e | f | g | h
70
+ This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
71
+ """
72
+
73
+ assert t.element_type.width == 32
74
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
75
+
76
+ quad_idx = cute.arch.lane_idx() % 4
77
+ # left_map = [0, 2, 1, 3]
78
+ # right_map = [2, 0, 3, 1]
79
+ # indexing isn't supported so we have to do arithmetic
80
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
81
+ right_idx = left_idx ^ 0b10
82
+
83
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
84
+ width = 4
85
+ mask = cute.arch.WARP_SIZE - width
86
+ clamp = cute.arch.WARP_SIZE - 1
87
+ mask_and_clamp = mask << 8 | clamp
88
+
89
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
90
+ for r in cutlass.range(2, unroll_full=True):
91
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
92
+ # a b | c d | e f | g h -> a b | c d | f e | h g
93
+ left0 = left if quad_idx < 2 else right
94
+ right0 = right if quad_idx < 2 else left
95
+ # a b | c d | f e | h g -> a b | f d | c e | h g
96
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
97
+ # a b | f d | c e | h g -> a e | f b | c g | h d
98
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
99
+ # a e | f b | c g | h d -> a e | b f | c g | d h
100
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
101
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
102
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
103
+
104
+
105
+ @cute.jit
106
+ def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
107
+ """Permute and shuffle within 4 threads to change the layout from
108
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
109
+ a | b | c | d | e | f | g | h
110
+ to
111
+ T0 | T1 | T2 | T3
112
+ a b | c d | e f | g h
113
+ This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
114
+ """
115
+
116
+ assert t.element_type.width == 32
117
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
118
+
119
+ quad_idx = cute.arch.lane_idx() % 4
120
+ # left_map = [0, 2, 1, 3]
121
+ # right_map = [1, 3, 0, 2]
122
+ # indexing isn't supported so we have to do arithmetic
123
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
124
+ right_idx = left_idx ^ 0b01
125
+
126
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
127
+ width = 4
128
+ mask = cute.arch.WARP_SIZE - width
129
+ clamp = cute.arch.WARP_SIZE - 1
130
+ mask_and_clamp = mask << 8 | clamp
131
+
132
+ # This is just the inverse of permute_Cregs_b32_for_stsm
133
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
134
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
135
+ for r in cutlass.range(2, unroll_full=True):
136
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
137
+ # a e | b f | c g | d h -> a e | f b | c g | h d
138
+ left0 = left if quad_idx % 2 == 0 else right
139
+ right0 = right if quad_idx % 2 == 0 else left
140
+ # a e | f b | c g | h d -> a b | f d | c e | h g
141
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
142
+ # a b | f d | c e | h g -> a b | c d | f e | h g
143
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
144
+ # a b | c d | f e | h g -> a b | c d | e f | g h
145
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
146
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
147
+
148
+
149
+ @cute.jit
150
+ def concat_layout(*layouts: cute.Layout) -> cute.Layout:
151
+ return cute.make_layout(
152
+ tuple(l.shape for l in layouts),
153
+ stride=tuple(l.stride for l in layouts),
154
+ )
155
+
156
+
157
+ def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
158
+ """
159
+ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
160
+ For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
161
+ """
162
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
163
+ acc_layout_mn = cute.make_layout(
164
+ (
165
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
166
+ (
167
+ acc_layout_col_major.shape[0][0],
168
+ *acc_layout_col_major.shape[0][2:],
169
+ acc_layout_col_major.shape[2],
170
+ ), # MMA_N
171
+ *acc_layout_col_major.shape[3:],
172
+ ),
173
+ stride=(
174
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
175
+ (
176
+ acc_layout_col_major.stride[0][0],
177
+ *acc_layout_col_major.stride[0][2:],
178
+ acc_layout_col_major.stride[2],
179
+ ), # MMA_N
180
+ *acc_layout_col_major.stride[3:],
181
+ ),
182
+ )
183
+ return cute.composition(acc_layout, acc_layout_mn)
184
+
185
+
186
+ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
187
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
188
+
189
+
190
+ def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor:
191
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
192
+
193
+
194
+ @cute.jit
195
+ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
196
+ # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
197
+ # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
198
+ # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
199
+ # TODO: Sm90 FP8
200
+ if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
201
+ l = cute.logical_divide(
202
+ acc_layout, ((None, None, 2), None, None)
203
+ ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
204
+ rA_mma_view = cute.make_layout(
205
+ (
206
+ (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
207
+ l.shape[1],
208
+ (l.shape[0][2][1], l.shape[2]),
209
+ ),
210
+ stride=(
211
+ (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
212
+ l.stride[1],
213
+ (l.stride[0][2][1], l.stride[2]),
214
+ ),
215
+ )
216
+ else: # Sm80
217
+ # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
218
+ l = cute.logical_divide(acc_layout, (None, None, 2))
219
+ rA_mma_view = cute.make_layout(
220
+ (
221
+ (l.shape[0], l.shape[2][0]),
222
+ l.shape[1],
223
+ l.shape[2][1],
224
+ ),
225
+ stride=(
226
+ (l.stride[0], l.stride[2][0]),
227
+ l.stride[1],
228
+ l.stride[2][1],
229
+ ),
230
+ )
231
+ return rA_mma_view
232
+
233
+
234
+ def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
235
+ return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
236
+
237
+
238
+ def convert_layout_zero_stride(
239
+ input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
240
+ ) -> cute.Layout:
241
+ layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
242
+ # Group the modes with non-zero stride in the ref_layout together,
243
+ # and the modes with zero stride together
244
+ layout_flat = cute.flatten(layout)
245
+ ref_layout_flat = cute.flatten(ref_layout)
246
+ nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
247
+ zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
248
+ # There's an edge case when all modes are zero stride
249
+ new_shape = (
250
+ tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
251
+ tuple(layout_flat[i].shape for i in zero_modes),
252
+ )
253
+ new_stride = (
254
+ tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
255
+ tuple(layout_flat[i].stride for i in zero_modes),
256
+ )
257
+ out_layout = cute.make_layout(new_shape, stride=new_stride)
258
+ if const_expr(isinstance(input, cute.Tensor)):
259
+ return cute.make_tensor(input.iterator, out_layout)
260
+ else:
261
+ return out_layout
262
+
263
+
264
+ def mma_partition_C_vec(
265
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
266
+ ) -> cute.Tensor:
267
+ assert cute.rank(sVec) == 2
268
+ assert sVec.stride[0] == 1
269
+ stage = sVec.shape[1]
270
+ shape = (
271
+ (sVec.shape[0], expand_shape, stage)
272
+ if const_expr(is_colvec)
273
+ else (expand_shape, sVec.shape[0], stage)
274
+ )
275
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
276
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
277
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
278
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
279
+
280
+
281
+ def mma_partition_A_vec(
282
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
283
+ ) -> cute.Tensor:
284
+ assert cute.rank(sVec) == 2
285
+ assert sVec.stride[0] == 1
286
+ stage = sVec.shape[1]
287
+ shape = (
288
+ (sVec.shape[0], expand_shape, stage)
289
+ if const_expr(is_colvec)
290
+ else (expand_shape, sVec.shape[0], stage)
291
+ )
292
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
293
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
294
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
295
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
build/torch-cuda/quack/pipeline.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Optional
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass.cute as cute
7
+ from cutlass import Boolean, Int32, const_expr
8
+ from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
9
+ from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
10
+ from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
11
+ from cutlass.pipeline import PipelineTmaUmma
12
+
13
+
14
+ class PipelineStateWAdvance(PipelineState):
15
+ @dsl_user_op
16
+ def advance_iters(self, num_iterations: Int32, *, loc=None, ip=None):
17
+ self._count += Int32(num_iterations)
18
+ new_index = self._index + Int32(num_iterations)
19
+ # How many times did we cross the stages boundary
20
+ num_crossings = new_index // self.stages
21
+ self._phase ^= num_crossings
22
+ self._index = new_index % self.stages
23
+
24
+ # This can be overridden by derived classes
25
+ def __new_from_mlir_values__(self, values):
26
+ return PipelineStateWAdvance(
27
+ self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
28
+ )
29
+
30
+
31
+ def make_pipeline_state(type: PipelineUserType, stages: int):
32
+ """
33
+ Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
34
+ """
35
+ if type is PipelineUserType.Producer:
36
+ return PipelineStateWAdvance(
37
+ stages,
38
+ Int32(0),
39
+ Int32(0),
40
+ Int32(1),
41
+ )
42
+ elif type is PipelineUserType.Consumer:
43
+ return PipelineStateWAdvance(
44
+ stages,
45
+ Int32(0),
46
+ Int32(0),
47
+ Int32(0),
48
+ )
49
+ else:
50
+ assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class PipelineTmaCpAsync(PipelineTmaAsync):
55
+ """
56
+ PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers
57
+ """
58
+
59
+ @staticmethod
60
+ def create(
61
+ *,
62
+ num_stages: int,
63
+ producer_group: CooperativeGroup,
64
+ consumer_group: CooperativeGroup,
65
+ tx_count: int,
66
+ barrier_storage: cute.Pointer = None,
67
+ cta_layout_vmnk: Optional[cute.Layout] = None,
68
+ tidx: Optional[Int32] = None,
69
+ ):
70
+ """
71
+ This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
72
+ :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
73
+ :type barrier_storage: cute.Pointer
74
+ :param num_stages: Number of buffer stages for this pipeline
75
+ :type num_stages: Int32
76
+ :param producer_group: CooperativeGroup for the producer agent
77
+ :type producer_group: CooperativeGroup
78
+ :param consumer_group: CooperativeGroup for the consumer agent
79
+ :type consumer_group: CooperativeGroup
80
+ :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
81
+ :type tx_count: int
82
+ :param cta_layout_vmnk: Layout of the cluster shape
83
+ :type cta_layout_vmnk: cute.Layout | None
84
+ :param tidx: thread index to consumer async threads
85
+ :type tidx: Int32 | None
86
+ """
87
+ if not isinstance(barrier_storage, cute.Pointer):
88
+ raise ValueError(
89
+ f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
90
+ )
91
+
92
+ producer_type = PipelineOp.TmaLoad
93
+ consumer_type = PipelineOp.AsyncThread
94
+
95
+ producer = (producer_type, producer_group)
96
+ consumer = (consumer_type, consumer_group)
97
+
98
+ sync_object_full = PipelineAsync._make_sync_object(
99
+ barrier_storage.align(min_align=8), num_stages, producer, tx_count
100
+ )
101
+ sync_object_empty = PipelineAsync._make_sync_object(
102
+ barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
103
+ )
104
+ if tidx is None:
105
+ tidx, _, _ = cute.arch.thread_idx()
106
+ if cta_layout_vmnk is None:
107
+ cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
108
+ (
109
+ dst_rank,
110
+ is_signalling_thread,
111
+ ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
112
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
113
+ dst_rank = None
114
+ else:
115
+ dst_rank = dst_rank
116
+
117
+ producer_mask = None
118
+
119
+ pipeline_init_wait(cta_layout_vmnk)
120
+
121
+ return PipelineTmaCpAsync(
122
+ sync_object_full,
123
+ sync_object_empty,
124
+ num_stages,
125
+ producer_mask,
126
+ dst_rank,
127
+ is_signalling_thread,
128
+ )
129
+
130
+ @dsl_user_op
131
+ def producer_acquire(
132
+ self,
133
+ state: PipelineState,
134
+ try_acquire_token: Optional[Boolean] = None,
135
+ is_tma_warp: Optional[Boolean] = True,
136
+ *,
137
+ loc=None,
138
+ ip=None,
139
+ ):
140
+ """
141
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
142
+ """
143
+ if_generate(
144
+ try_acquire_token is None or try_acquire_token == 0,
145
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
146
+ )
147
+ # This is the difference between this and PipelineTmaAsync: we could have multiple
148
+ # warps calling this, but only 1 warp should do the arrive on the full barrier
149
+ if_generate(
150
+ is_tma_warp,
151
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
152
+ )
153
+
154
+ @dsl_user_op
155
+ def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
156
+ """
157
+ We need the mbarrier to track the completion of cp.async
158
+ """
159
+ cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
160
+
161
+
162
+ class MbarrierArrayWDropCount(MbarrierArray):
163
+ @dsl_user_op
164
+ def __init__(
165
+ self,
166
+ barrier_storage: cute.Pointer,
167
+ num_stages: int,
168
+ agent: tuple[PipelineOp, CooperativeGroup],
169
+ tx_count: int = 0,
170
+ drop_count: Optional[Int32] = None,
171
+ *,
172
+ loc=None,
173
+ ip=None,
174
+ ) -> None:
175
+ self.barrier_storage = barrier_storage
176
+ self.tx_count = tx_count
177
+ self.num_stages = num_stages
178
+ self.op_type, self.cg = agent
179
+ self.arrive_count = self.cg.size
180
+ self.drop_count = drop_count
181
+
182
+ if self.num_stages <= 0:
183
+ raise ValueError("Error: Mbarrier stage count must be greater than 0.")
184
+ if self.arrive_count <= 0:
185
+ raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
186
+ if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0:
187
+ raise ValueError("Error: Mbarrier tx count must not be less than 0 for TMA ops.")
188
+
189
+ if const_expr(drop_count is not None):
190
+ self.arrive_count = self.arrive_count - drop_count
191
+
192
+ # Store mbarrier base pointer
193
+ self.mbarrier_base = self.barrier_storage
194
+
195
+ # Mbarrier initialization in constructor
196
+ self.mbarrier_init(loc=loc, ip=ip)
197
+
198
+ def __extract_mlir_values__(self):
199
+ return [self.barrier_storage, self.drop_count]
200
+
201
+ def __new_from_mlir_values__(self, values):
202
+ return MbarrierArrayWDropCount(
203
+ values[0], self.num_stages, (self.op_type, self.cg), self.tx_count, values[1]
204
+ )
205
+
206
+
207
+ @dataclass(frozen=True)
208
+ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
209
+ """
210
+ PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
211
+ (e.g. Blackwell mainloops)
212
+ """
213
+
214
+ @staticmethod
215
+ def create(
216
+ *,
217
+ num_stages: int,
218
+ producer_group: CooperativeGroup,
219
+ consumer_group: CooperativeGroup,
220
+ tx_count: int,
221
+ barrier_storage: cute.Pointer = None,
222
+ cta_layout_vmnk: Optional[cute.Layout] = None,
223
+ producer_drop_count: Optional[Int32] = None,
224
+ mcast_mode_mn: tuple[int, int] = (1, 1),
225
+ ):
226
+ """
227
+ This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
228
+ :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
229
+ :type barrier_storage: cute.Pointer
230
+ :param num_stages: Number of buffer stages for this pipeline
231
+ :type num_stages: Int32
232
+ :param producer_group: `CooperativeGroup` for the producer agent
233
+ :type producer_group: CooperativeGroup
234
+ :param consumer_group: `CooperativeGroup` for the consumer agent
235
+ :type consumer_group: CooperativeGroup
236
+ :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
237
+ :type tx_count: int
238
+ :param cta_layout_vmnk: Layout of the cluster shape
239
+ :type cta_layout_vmnk: cute.Layout | None
240
+ :param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
241
+ :type mcast_mode_mn: tuple[int, int], optional
242
+ """
243
+ if not isinstance(barrier_storage, cute.Pointer):
244
+ raise ValueError(
245
+ f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
246
+ )
247
+
248
+ producer_type = PipelineOp.TmaLoad
249
+ consumer_type = PipelineOp.TCGen05Mma
250
+
251
+ producer = (producer_type, producer_group)
252
+ consumer = (consumer_type, consumer_group)
253
+
254
+ sync_object_full = MbarrierArrayWDropCount(
255
+ barrier_storage.align(min_align=8),
256
+ num_stages,
257
+ producer,
258
+ tx_count,
259
+ drop_count=producer_drop_count,
260
+ )
261
+ sync_object_empty = PipelineTmaUmma._make_sync_object(
262
+ barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
263
+ )
264
+
265
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
266
+ # No mcast mask if not using clusters
267
+ producer_mask = None
268
+ # All threadblocks are leaders if not using clusters
269
+ is_leader_cta = True
270
+ else:
271
+ producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk, mcast_mode_mn)
272
+ is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
273
+
274
+ cta_group = (
275
+ cute.nvgpu.tcgen05.CtaGroup.ONE
276
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
277
+ else cute.nvgpu.tcgen05.CtaGroup.TWO
278
+ )
279
+
280
+ consumer_mask = producer_mask
281
+
282
+ pipeline_init_wait(cta_layout_vmnk)
283
+
284
+ return PipelineTmaCpAsyncUmma(
285
+ sync_object_full,
286
+ sync_object_empty,
287
+ num_stages,
288
+ producer_mask,
289
+ consumer_mask,
290
+ is_leader_cta,
291
+ cta_group,
292
+ )
293
+
294
+ @dsl_user_op
295
+ def producer_acquire(
296
+ self,
297
+ state: PipelineState,
298
+ try_acquire_token: Optional[Boolean] = None,
299
+ is_tma_warp: Optional[Boolean] = True,
300
+ *,
301
+ loc=None,
302
+ ip=None,
303
+ ):
304
+ """
305
+ TMA producer commit conditionally waits on buffer empty and sets the
306
+ transaction barrier for leader threadblocks.
307
+ """
308
+ if_generate(
309
+ try_acquire_token is None or try_acquire_token == 0,
310
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
311
+ )
312
+ # This is the difference between this and PipelineTmaAsync: we could have multiple
313
+ # warps calling this, but only 1 warp should do the arrive on the full barrier
314
+ if_generate(
315
+ and_(self.is_leader_cta, is_tma_warp),
316
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
317
+ )
318
+
319
+ @dsl_user_op
320
+ def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
321
+ """
322
+ We need the mbarrier to track the completion of cp.async
323
+ """
324
+ cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
build/torch-cuda/quack/reduce.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ import operator
5
+ from typing import Callable, Optional
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32, Int64, Float32, Boolean, const_expr
10
+
11
+ from . import utils as utils
12
+
13
+
14
+ @cute.jit
15
+ def block_reduce(
16
+ val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
17
+ ) -> cute.Numeric:
18
+ """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
19
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
20
+ warps_per_row = cute.size(reduction_buffer.shape[1])
21
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
22
+ if lane_idx == 0:
23
+ reduction_buffer[row_idx, col_idx] = val
24
+ cute.arch.barrier()
25
+ block_reduce_val = init_val
26
+ if lane_idx < warps_per_row:
27
+ block_reduce_val = reduction_buffer[row_idx, lane_idx]
28
+ return cute.arch.warp_reduction(block_reduce_val, op)
29
+
30
+
31
+ @cute.jit
32
+ def cluster_reduce(
33
+ val: cute.Numeric,
34
+ op: Callable,
35
+ reduction_buffer: cute.Tensor,
36
+ mbar_ptr: cute.Pointer,
37
+ init_val: cute.Numeric = 0.0,
38
+ phase: Optional[Int32] = None,
39
+ ) -> cute.Numeric:
40
+ """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
41
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
42
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
43
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
44
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
45
+ if warp_idx == 0:
46
+ with cute.arch.elect_one():
47
+ num_warps = rows_per_block * warps_per_row
48
+ cute.arch.mbarrier_arrive_and_expect_tx(
49
+ mbar_ptr,
50
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
51
+ )
52
+ if lane_idx < cluster_n:
53
+ utils.store_shared_remote(
54
+ val,
55
+ utils.elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
56
+ mbar_ptr,
57
+ peer_cta_rank_in_cluster=lane_idx,
58
+ )
59
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
60
+ block_reduce_val = init_val
61
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
62
+ for i in cutlass.range_constexpr(num_iter):
63
+ idx = lane_idx + i * cute.arch.WARP_SIZE
64
+ if idx < cute.size(reduction_buffer, mode=[1]):
65
+ block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
66
+ return cute.arch.warp_reduction(block_reduce_val, op)
67
+
68
+
69
+ @cute.jit
70
+ def block_or_cluster_reduce(
71
+ val: cute.Numeric,
72
+ op: Callable,
73
+ reduction_buffer: cute.Tensor,
74
+ mbar_ptr: Optional[cute.Pointer],
75
+ phase: Optional[Int32] = None,
76
+ init_val: cute.Numeric = 0.0,
77
+ ) -> cute.Numeric:
78
+ """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
79
+ if const_expr(mbar_ptr is None):
80
+ return block_reduce(val, op, reduction_buffer, init_val=init_val)
81
+ else:
82
+ return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
83
+
84
+
85
+ @cute.jit
86
+ def row_reduce(
87
+ x: cute.TensorSSA | cute.Numeric,
88
+ op: cute.ReductionOp,
89
+ threads_per_row: cutlass.Constexpr[int],
90
+ reduction_buffer: Optional[cute.Tensor] = None,
91
+ mbar_ptr: Optional[cute.Pointer] = None,
92
+ phase: Optional[Int32] = None,
93
+ init_val: cute.Numeric = 0.0,
94
+ hook_fn: Optional[Callable] = None,
95
+ ) -> cute.Numeric:
96
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
97
+ if const_expr(isinstance(x, cute.TensorSSA)):
98
+ val = x.reduce(op, init_val=init_val, reduction_profile=0)
99
+ else:
100
+ val = x
101
+ warp_op = {
102
+ cute.ReductionOp.ADD: operator.add,
103
+ cute.ReductionOp.MAX: cute.arch.fmax if const_expr(x.dtype == Float32) else max,
104
+ cute.ReductionOp.MIN: min,
105
+ cute.ReductionOp.MUL: operator.mul,
106
+ }[op]
107
+ val = cute.arch.warp_reduction(
108
+ val,
109
+ warp_op,
110
+ threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
111
+ )
112
+ if const_expr(hook_fn is not None):
113
+ hook_fn()
114
+ if const_expr(reduction_buffer is not None):
115
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
116
+ assert cluster_n == 1 or mbar_ptr is not None, (
117
+ "mbar_ptr must be provided for cluster reduction"
118
+ )
119
+ if const_expr(warps_per_row > 1 or cluster_n > 1):
120
+ val = block_or_cluster_reduce(
121
+ val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
122
+ )
123
+ return val
124
+
125
+
126
+ @cute.jit
127
+ def online_softmax_reduce(
128
+ x: cute.TensorSSA,
129
+ threads_per_row: cutlass.Constexpr[int],
130
+ reduction_buffer: Optional[cute.Tensor] = None,
131
+ mbar_ptr: Optional[cute.Pointer] = None,
132
+ hook_fn: Optional[Callable] = None,
133
+ phase: Optional[Int32] = None,
134
+ return_exp_x: bool = False,
135
+ ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
136
+ assert x.dtype == Float32, "x must be of type Float32"
137
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
138
+ max_x = cute.arch.warp_reduction(
139
+ x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
140
+ cute.arch.fmax,
141
+ threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
142
+ )
143
+ log2_e = math.log2(math.e)
144
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
145
+ sum_exp_x = cute.arch.warp_reduction(
146
+ exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
147
+ operator.add,
148
+ threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
149
+ )
150
+ if const_expr(hook_fn is not None):
151
+ hook_fn()
152
+ if const_expr(reduction_buffer is not None):
153
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
154
+ assert cluster_n == 1 or mbar_ptr is not None, (
155
+ "mbar_ptr must be provided for cluster reduction"
156
+ )
157
+ if const_expr(warps_per_row > 1 or cluster_n > 1):
158
+ assert reduction_buffer.element_type == Int64, (
159
+ "reduction_buffer must be of type cute.Int64"
160
+ )
161
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
162
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
163
+ if const_expr(mbar_ptr is None):
164
+ if lane_idx == 0:
165
+ reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x)
166
+ cute.arch.barrier()
167
+ max_x_single_warp = -Float32.inf
168
+ sum_exp_x = 0.0
169
+ if lane_idx < warps_per_row:
170
+ max_x_single_warp, sum_exp_x = utils.i64_to_f32x2(
171
+ reduction_buffer[row_idx, lane_idx]
172
+ )
173
+ max_x_final = cute.arch.warp_reduction(max_x_single_warp, cute.arch.fmax)
174
+ sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
175
+ sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add)
176
+ if const_expr(return_exp_x):
177
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
178
+ max_x = max_x_final
179
+ else:
180
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
181
+ if warp_idx == 0:
182
+ with cute.arch.elect_one():
183
+ num_warps = rows_per_block * warps_per_row
184
+ cute.arch.mbarrier_arrive_and_expect_tx(
185
+ mbar_ptr,
186
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
187
+ )
188
+ if lane_idx < cluster_n:
189
+ utils.store_shared_remote(
190
+ utils.f32x2_to_i64(max_x, sum_exp_x),
191
+ utils.elem_pointer(
192
+ reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))
193
+ ),
194
+ mbar_ptr,
195
+ peer_cta_rank_in_cluster=lane_idx,
196
+ )
197
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
198
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
199
+ max_x_single_warp = cute.make_fragment(num_iter, Float32)
200
+ max_x_single_warp.fill(-Float32.inf)
201
+ sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
202
+ sum_exp_x_single_warp.fill(0.0)
203
+ for i in cutlass.range_constexpr(num_iter):
204
+ idx = lane_idx + i * cute.arch.WARP_SIZE
205
+ if idx < cute.size(reduction_buffer, mode=[1]):
206
+ max_x_single_warp[i], sum_exp_x_single_warp[i] = utils.i64_to_f32x2(
207
+ reduction_buffer[row_idx, idx]
208
+ )
209
+ max_x_final = max_x_single_warp.load().reduce(
210
+ cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
211
+ )
212
+ max_x_final = cute.arch.warp_reduction(max_x_final, cute.arch.fmax)
213
+ sum_exp_x = 0.0
214
+ for i in cutlass.range_constexpr(num_iter):
215
+ sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
216
+ max_x_single_warp[i] - max_x_final, fastmath=True
217
+ )
218
+ sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add)
219
+ if const_expr(return_exp_x):
220
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
221
+ max_x = max_x_final
222
+ return max_x, sum_exp_x, (exp_x if const_expr(return_exp_x) else None)
223
+
224
+
225
+ @cute.jit
226
+ def sum_swap_shuffle(
227
+ X: cute.Tensor, elem_per_lane: int = 1, subwarp_size: int = 1, warp_size: int = 32
228
+ ) -> cute.Tensor:
229
+ """
230
+ For warp reduction, we use Swap Shuffle
231
+ The normal way to reduction among threads:
232
+ use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads.
233
+ After each step of reduction, a half of threads won't work in the following steps.
234
+ That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case).
235
+ To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors,
236
+ we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads.
237
+ After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step.
238
+ We can recursively do this until the problem size is 1.
239
+ """
240
+ assert (
241
+ subwarp_size >= 1
242
+ and subwarp_size <= 32
243
+ and subwarp_size == 1 << int(math.log2(subwarp_size))
244
+ )
245
+ assert (
246
+ warp_size <= 32
247
+ and warp_size % subwarp_size == 0
248
+ and warp_size == 1 << int(math.log2(warp_size))
249
+ )
250
+ lane_idx = cute.arch.lane_idx() // subwarp_size
251
+ X = cute.logical_divide(X, cute.make_layout(elem_per_lane)) # (elem_per_lane, M)
252
+ numvec = cute.size(X, mode=[1])
253
+ assert numvec <= 32 // subwarp_size
254
+ # If X has more values than warp_size // subwarp_size, we first do a normal warp reduction
255
+ # to sum up values held by lanes further than size(X) away
256
+ for i in cutlass.range(
257
+ int(math.log2(numvec)), int(math.log2(warp_size // subwarp_size)), unroll_full=True
258
+ ):
259
+ for v in cutlass.range(cute.size(X), unroll_full=True):
260
+ shfl_val = cute.arch.shuffle_sync_bfly(X[v], offset=(1 << i) * subwarp_size)
261
+ X[v] = X[v] + shfl_val
262
+ for logm in cutlass.range_constexpr(int(math.log2(cute.size(X, mode=[1]))) - 1, -1, -1):
263
+ m = 1 << logm
264
+ for r in cutlass.range(m, unroll_full=True):
265
+ frg_A = X[None, r]
266
+ frg_B = X[None, r + m]
267
+ # First half of threads swap fragments from the first half of data to the second
268
+ should_swap = not Boolean(lane_idx & m)
269
+ for v in cutlass.range(cute.size(frg_A), unroll_full=True):
270
+ # Step 1: swap
271
+ lower, upper = frg_A[v], frg_B[v]
272
+ frg_A[v] = upper if should_swap else lower
273
+ frg_B[v] = lower if should_swap else upper
274
+ # Step 2: shuffle
275
+ # each half of threads get a half of data from the other half of threads
276
+ shfl_val = cute.arch.shuffle_sync_bfly(frg_A[v], offset=m * subwarp_size)
277
+ # Step 3: reduction
278
+ frg_A[v] = frg_B[v] + shfl_val
279
+ return X[None, 0]
build/torch-cuda/quack/reduction_base.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ from typing import Type, Tuple, Optional
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ from cutlass import Int32, Int64, Float32, const_expr
8
+
9
+ from . import copy_utils as copy_utils
10
+
11
+
12
+ class ReductionBase:
13
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=Float32):
14
+ self.dtype = dtype
15
+ self.N = N
16
+ self.stage = stage
17
+ self.reduction_dtype = reduction_dtype
18
+
19
+ def _threads_per_row(self):
20
+ raise NotImplementedError()
21
+
22
+ def _num_threads(self):
23
+ return 128 if self.N <= 16384 else 256
24
+
25
+ def _set_cluster_n(self):
26
+ self.cluster_n = 1
27
+
28
+ def _get_tiled_copy(self, vecsize: int = 1):
29
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
30
+ threads_per_row = self._threads_per_row()
31
+ num_threads = self._num_threads()
32
+ assert num_threads % cute.arch.WARP_SIZE == 0
33
+ num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
34
+ tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row)
35
+ tiled_copy = copy_utils.tiled_copy_2d(self.dtype, threads_per_row, num_threads, vecsize)
36
+ return tiled_copy, tiler_mn, threads_per_row
37
+
38
+ def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
39
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
40
+ warps_per_row = (
41
+ num_warps
42
+ if cute.rank(tv_layout.shape[0]) == 1
43
+ else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
44
+ )
45
+ return cute.make_ordered_layout(
46
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
47
+ order=(1, 0, 2),
48
+ )
49
+
50
+ def _allocate_reduction_buffer_and_mbar(
51
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
52
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
53
+ reduction_buffer = smem.allocate_tensor(
54
+ self.reduction_dtype,
55
+ self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
56
+ byte_alignment=8,
57
+ )
58
+ if const_expr(self.cluster_n > 1):
59
+ mbar_ptr = smem.allocate_array(
60
+ Int64, num_elems=self.stage if not is_persistent else self.stage * 2
61
+ )
62
+ else:
63
+ mbar_ptr = None
64
+ return reduction_buffer, mbar_ptr
65
+
66
+ @cute.jit
67
+ def _initialize_cluster(
68
+ self,
69
+ tidx: Int32,
70
+ mbar_ptr: cute.Pointer,
71
+ num_warps: int,
72
+ is_persistent: bool = False,
73
+ ):
74
+ if const_expr(self.cluster_n > 1):
75
+ if tidx < self.stage: # Initialize full barrier
76
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
77
+ if const_expr(is_persistent): # Initialize empty barrier
78
+ cute.arch.mbarrier_init(
79
+ mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
80
+ )
81
+ cute.arch.mbarrier_init_fence()
82
+ # Cluster arrive after barrier init
83
+ cute.arch.cluster_arrive_relaxed()
build/torch-cuda/quack/sm100_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Type, Union
4
+
5
+ import cutlass.cute as cute
6
+ import cutlass.utils.blackwell_helpers as sm100_utils_og
7
+ from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
8
+ from cutlass.cutlass_dsl import Numeric, dsl_user_op
9
+
10
+
11
+ @dsl_user_op
12
+ def make_smem_layout_cpasync_a(
13
+ tiled_mma: cute.TiledMma,
14
+ mma_tiler_mnk: cute.Tile,
15
+ a_dtype: Type[Numeric],
16
+ num_stages: int,
17
+ *,
18
+ loc=None,
19
+ ip=None,
20
+ ) -> Union[cute.Layout, cute.ComposedLayout]:
21
+ """
22
+ :param tiled_mma: The tiled MMA used to partition tensor A
23
+ :type tiled_mma: cute.TiledMma
24
+ :param mma_tiler_mnk: The MMA tile shape
25
+ :type mma_tiler_mnk: cute.cute.Tile
26
+ :param a_dtype: The element type for tensor A
27
+ :type a_dtype: Type[Numeric]
28
+ :param num_stages: The number of pipeline stages for tensor A
29
+ :type num_stages: int
30
+
31
+ :return: SMEM layout for tensor A
32
+ :rtype: Union[cute.Layout, cute.ComposedLayout]
33
+ """
34
+
35
+ is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
36
+ a_smem_shape = tiled_mma.partition_shape_A(
37
+ cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip)
38
+ )
39
+ a_smem_shape_mn_k = (
40
+ cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1],
41
+ cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
42
+ )
43
+ a_smem_layout_atom = sm100_utils_og.make_smem_layout_atom(
44
+ sm100_utils_og.get_smem_layout_atom_ab(
45
+ tiled_mma.op.a_major_mode,
46
+ a_dtype,
47
+ a_smem_shape_mn_k,
48
+ loc=loc,
49
+ ip=ip,
50
+ ),
51
+ a_dtype,
52
+ loc=loc,
53
+ ip=ip,
54
+ )
55
+ a_smem_layout_staged = cute.tile_to_shape(
56
+ a_smem_layout_atom,
57
+ cute.append(a_smem_shape_mn_k, num_stages, loc=loc, ip=ip),
58
+ order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
59
+ loc=loc,
60
+ ip=ip,
61
+ )
62
+ return a_smem_layout_staged
build/torch-cuda/quack/sm90_utils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Type, Union, Optional
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ import cutlass.utils.hopper_helpers as sm90_utils_og
8
+ from cutlass.cute.nvgpu import warpgroup
9
+ from cutlass.cutlass_dsl import Numeric, dsl_user_op
10
+ from cutlass import Float32, Int32, Boolean, const_expr
11
+ from cutlass.utils import LayoutEnum
12
+
13
+
14
+ @dsl_user_op
15
+ def make_smem_layout(
16
+ dtype: Type[Numeric],
17
+ layout: LayoutEnum,
18
+ tile: cute.Tile,
19
+ stage: Optional[int] = None,
20
+ *,
21
+ loc=None,
22
+ ip=None,
23
+ ) -> Union[cute.Layout, cute.ComposedLayout]:
24
+ shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip)
25
+ major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
26
+ smem_layout_atom = warpgroup.make_smem_layout_atom(
27
+ sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
28
+ dtype,
29
+ )
30
+ order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
31
+ smem_layout_staged = cute.tile_to_shape(
32
+ smem_layout_atom,
33
+ cute.append(shape, stage) if const_expr(stage is not None) else shape,
34
+ order=order if const_expr(stage is not None) else order[:2],
35
+ )
36
+ return smem_layout_staged
37
+
38
+
39
+ # For compatibility with blackwell_helpers.py
40
+ make_smem_layout_epi = make_smem_layout
41
+
42
+
43
+ @dsl_user_op
44
+ def partition_for_epilogue(
45
+ cT: cute.Tensor,
46
+ epi_tile: cute.Tile,
47
+ tiled_copy: cute.TiledCopy,
48
+ tidx: Int32,
49
+ reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy
50
+ *,
51
+ loc=None,
52
+ ip=None,
53
+ ) -> cute.Tensor:
54
+ thr_copy = tiled_copy.get_slice(tidx)
55
+ cT_epi = cute.flat_divide(cT, epi_tile)
56
+ # (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
57
+ if const_expr(reference_src):
58
+ return thr_copy.partition_S(cT_epi, loc=loc, ip=ip)
59
+ else:
60
+ return thr_copy.partition_D(cT_epi, loc=loc, ip=ip)
61
+
62
+
63
+ @cute.jit
64
+ def gemm(
65
+ tiled_mma: cute.TiledMma,
66
+ acc: cute.Tensor,
67
+ tCrA: cute.Tensor,
68
+ tCrB: cute.Tensor,
69
+ zero_init: cutlass.Constexpr[bool] = False,
70
+ wg_wait: cutlass.Constexpr[int] = 0,
71
+ # A_in_regs: cutlass.Constexpr[bool] = False,
72
+ swap_AB: cutlass.Constexpr[bool] = False,
73
+ ) -> None:
74
+ if const_expr(swap_AB):
75
+ gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
76
+ else:
77
+ warpgroup.fence()
78
+ # We make a new mma_atom since we'll be modifying its attribute (accumulate).
79
+ # Otherwise the compiler complains "operand #0 does not dominate this use"
80
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
81
+ mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
82
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
83
+ cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
84
+ mma_atom.set(warpgroup.Field.ACCUMULATE, True)
85
+ warpgroup.commit_group()
86
+ if const_expr(wg_wait >= 0):
87
+ warpgroup.wait_group(wg_wait)
88
+
89
+
90
+ def gemm_zero_init(
91
+ tiled_mma: cute.TiledMma,
92
+ shape: cute.Shape,
93
+ tCrA: cute.Tensor,
94
+ tCrB: cute.Tensor,
95
+ A_idx: Optional[Int32] = None,
96
+ B_idx: Optional[Int32] = None,
97
+ wg_wait: int = -1,
98
+ swap_AB: bool = False,
99
+ ) -> cute.Tensor:
100
+ if const_expr(swap_AB):
101
+ return gemm_zero_init(
102
+ tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
103
+ )
104
+ else:
105
+ acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
106
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
107
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
108
+ gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
109
+ return acc
110
+
111
+
112
+ def gemm_w_idx(
113
+ tiled_mma: cute.TiledMma,
114
+ acc: cute.Tensor,
115
+ tCrA: cute.Tensor,
116
+ tCrB: cute.Tensor,
117
+ zero_init: Boolean,
118
+ A_idx: Optional[Int32] = None,
119
+ B_idx: Optional[Int32] = None,
120
+ wg_wait: int = -1,
121
+ swap_AB: bool = False,
122
+ ) -> None:
123
+ if const_expr(swap_AB):
124
+ gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
125
+ else:
126
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
127
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
128
+ gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
129
+
130
+
131
+ def partition_fragment_ABC(
132
+ thr_mma: cute.ThrMma,
133
+ shape_mnk: cute.Shape,
134
+ sA: Optional[cute.Tensor],
135
+ sB: Optional[cute.Tensor],
136
+ swap_AB: bool = False,
137
+ ):
138
+ is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
139
+ if const_expr(not swap_AB):
140
+ acc = cute.make_fragment(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
141
+ if const_expr(not is_rs):
142
+ assert sA is not None
143
+ tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
144
+ else:
145
+ tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2])))
146
+ assert sB is not None
147
+ tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
148
+ else:
149
+ acc = cute.make_fragment(thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32)
150
+ if const_expr(not is_rs):
151
+ assert sB is not None
152
+ tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
153
+ else: # B in rmem
154
+ tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2])))
155
+ assert sA is not None
156
+ tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA))
157
+ return acc, tCrA, tCrB
build/torch-cuda/quack/sort/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
build/torch-cuda/quack/sort/bitonic_sort.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32, Float32, const_expr
9
+
10
+ from .. import utils
11
+ from .utils import compare_and_swap
12
+ from .sorting_networks import optimal_sort
13
+
14
+
15
+ @cute.jit
16
+ def bitonic_merge(
17
+ arr: cute.Tensor,
18
+ n: Optional[cutlass.Constexpr[int]] = None,
19
+ start: cutlass.Constexpr[int] = 0,
20
+ ascending: cutlass.Constexpr[bool] = True,
21
+ ) -> None:
22
+ """Merge a bitonic sequence into a sorted sequence using iterative approach."""
23
+ if const_expr(n is None):
24
+ n = cute.size(arr.shape)
25
+ if const_expr(n > 1):
26
+ num_levels = int(math.log2(n))
27
+ assert n == 2**num_levels, "n must be a power of 2"
28
+ # This one must be range_constexpr otherwise it's very slow for n = 128
29
+ for level in cutlass.range_constexpr(num_levels):
30
+ length = n >> level # n // (2^level)
31
+ step = length // 2
32
+ for i in cutlass.range(n // length, unroll_full=True):
33
+ start_i = start + i * length
34
+ for j in cutlass.range(step, unroll_full=True):
35
+ compare_and_swap(arr, start_i + j, start_i + j + step, ascending)
36
+
37
+
38
+ @cute.jit
39
+ def bitonic_sort(
40
+ arr: cute.Tensor,
41
+ n: Optional[cutlass.Constexpr[int]] = None,
42
+ start: cutlass.Constexpr[int] = 0,
43
+ ascending: cutlass.Constexpr[bool] = True,
44
+ ) -> None:
45
+ """
46
+ Bitonic sort for small arrays of size N (power of 2, N <= 128).
47
+
48
+ Args:
49
+ arr: Array to sort
50
+ n: Size of array (must be power of 2 and <= 128)
51
+ start: Starting index (default 0)
52
+ ascending: Sort in ascending order (default True)
53
+ """
54
+ if const_expr(n is None):
55
+ n = cute.size(arr.shape)
56
+ assert n <= 128
57
+ if const_expr(n > 1):
58
+ if const_expr(n in [2, 4, 8, 16, 32, 64]):
59
+ optimal_sort(arr, n, start, ascending)
60
+ else: # Fall back to bitonic sort
61
+ assert n % 2 == 0
62
+ # Sort first half in ascending order
63
+ bitonic_sort(arr, n // 2, start, True)
64
+ # Sort second half in descending order
65
+ bitonic_sort(arr, n // 2, start + n // 2, False)
66
+ # Merge the whole sequence
67
+ bitonic_merge(arr, n, start, ascending)
68
+
69
+
70
+ @cute.jit
71
+ def bitonic_topk_merge(
72
+ arr0: cute.Tensor,
73
+ arr1: cute.Tensor,
74
+ k: Optional[cutlass.Constexpr[int]] = None,
75
+ start0: cutlass.Constexpr[int] = 0,
76
+ start1: cutlass.Constexpr[int] = 0,
77
+ ascending: cutlass.Constexpr[bool] = False,
78
+ ) -> None:
79
+ if const_expr(k is None):
80
+ k = cute.size(arr0.shape)
81
+ if const_expr(arr0.element_type == Float32):
82
+ minmax_fn = utils.fmin if ascending else cute.arch.fmax
83
+ else:
84
+ minmax_fn = min if ascending else max
85
+ # Write the top k elements to the first half of the array
86
+ for i in cutlass.range(k, unroll_full=True):
87
+ arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
88
+ # Now the 1st half is bitonic, we just need to merge it
89
+ bitonic_merge(arr0, k, start0, ascending)
90
+
91
+
92
+ @cute.jit
93
+ def bitonic_topk(
94
+ arr: cute.Tensor,
95
+ k: cutlass.Constexpr[int],
96
+ ascending: cutlass.Constexpr[bool] = False,
97
+ warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
98
+ ) -> cute.Tensor:
99
+ """
100
+ Bitonic top-k for small arrays of size N (power of 2, N <= 128).
101
+
102
+ Args:
103
+ arr: Array to sort
104
+ k: must be power of 2 and <= 128
105
+ ascending: Sort in ascending order (default False)
106
+ """
107
+ assert arr.element_type in [Float32, Int32]
108
+ n = cute.size(arr.shape)
109
+ assert k == 1 << int(math.log2(k)), "k must be a power of 2"
110
+ assert n % k == 0, "n must be divisible by k"
111
+ topk_vals = cute.make_fragment(k, arr.element_type)
112
+ for v in cutlass.range(k, unroll_full=True):
113
+ topk_vals[v] = arr[v]
114
+ bitonic_sort(topk_vals, ascending=ascending)
115
+ for i in cutlass.range(1, n // k, unroll_full=True):
116
+ other_vals = cute.make_fragment(k, arr.element_type)
117
+ for v in cutlass.range(k, unroll_full=True):
118
+ other_vals[v] = arr[i * k + v]
119
+ bitonic_sort(other_vals, ascending=ascending)
120
+ # Merge 2 sorted top-k sequences to get a new top-k sequence
121
+ bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
122
+ # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
123
+ # do duplicate work.
124
+ for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
125
+ other_vals = cute.make_fragment(k, arr.element_type)
126
+ for v in cutlass.range(k, unroll_full=True):
127
+ other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
128
+ bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
129
+ return topk_vals
build/torch-cuda/quack/sort/generate_sorting_networks.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate optimized sorting network code from the optimal sorting network data.
4
+ Based on data from: https://bertdobbelaere.github.io/sorting_networks.html
5
+
6
+ This script generates CUTE DSL functions for optimal sorting networks of various sizes.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import re
12
+ from typing import List, Tuple, Dict
13
+
14
+ # Network strings from bertdobbelaere.github.io/sorting_networks.html
15
+ # Copy-paste network strings here, then run initialize_networks() to parse them
16
+ NETWORK_STRINGS = {
17
+ # Size 2: 1 CE, depth 1
18
+ 2: """
19
+ [(0,1)]
20
+ """,
21
+ # Size 4: 5 CEs, depth 3
22
+ 4: """
23
+ [(0,2),(1,3)]
24
+ [(0,1),(2,3)]
25
+ [(1,2)]
26
+ """,
27
+ # Size 8: 19 CEs, depth 6
28
+ 8: """
29
+ [(0,2),(1,3),(4,6),(5,7)]
30
+ [(0,4),(1,5),(2,6),(3,7)]
31
+ [(0,1),(2,3),(4,5),(6,7)]
32
+ [(2,4),(3,5)]
33
+ [(1,4),(3,6)]
34
+ [(1,2),(3,4),(5,6)]
35
+ """,
36
+ # Size 16: 60 CEs, depth 10
37
+ 16: """
38
+ [(0,13),(1,12),(2,15),(3,14),(4,8),(5,6),(7,11),(9,10)]
39
+ [(0,5),(1,7),(2,9),(3,4),(6,13),(8,14),(10,15),(11,12)]
40
+ [(0,1),(2,3),(4,5),(6,8),(7,9),(10,11),(12,13),(14,15)]
41
+ [(0,2),(1,3),(4,10),(5,11),(6,7),(8,9),(12,14),(13,15)]
42
+ [(1,2),(3,12),(4,6),(5,7),(8,10),(9,11),(13,14)]
43
+ [(1,4),(2,6),(5,8),(7,10),(9,13),(11,14)]
44
+ [(2,4),(3,6),(9,12),(11,13)]
45
+ [(3,5),(6,8),(7,9),(10,12)]
46
+ [(3,4),(5,6),(7,8),(9,10),(11,12)]
47
+ [(6,7),(8,9)]
48
+ """,
49
+ # Size 32: 185 CEs, depth 14
50
+ 32: """
51
+ [(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31)]
52
+ [(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31)]
53
+ [(0,4),(1,5),(2,6),(3,7),(8,12),(9,13),(10,14),(11,15),(16,20),(17,21),(18,22),(19,23),(24,28),(25,29),(26,30),(27,31)]
54
+ [(0,8),(1,9),(2,10),(3,11),(4,12),(5,13),(6,14),(7,15),(16,24),(17,25),(18,26),(19,27),(20,28),(21,29),(22,30),(23,31)]
55
+ [(0,16),(1,8),(2,4),(3,12),(5,10),(6,9),(7,14),(11,13),(15,31),(17,24),(18,20),(19,28),(21,26),(22,25),(23,30),(27,29)]
56
+ [(1,2),(3,5),(4,8),(6,22),(7,11),(9,25),(10,12),(13,14),(17,18),(19,21),(20,24),(23,27),(26,28),(29,30)]
57
+ [(1,17),(2,18),(3,19),(4,20),(5,10),(7,23),(8,24),(11,27),(12,28),(13,29),(14,30),(21,26)]
58
+ [(3,17),(4,16),(5,21),(6,18),(7,9),(8,20),(10,26),(11,23),(13,25),(14,28),(15,27),(22,24)]
59
+ [(1,4),(3,8),(5,16),(7,17),(9,21),(10,22),(11,19),(12,20),(14,24),(15,26),(23,28),(27,30)]
60
+ [(2,5),(7,8),(9,18),(11,17),(12,16),(13,22),(14,20),(15,19),(23,24),(26,29)]
61
+ [(2,4),(6,12),(9,16),(10,11),(13,17),(14,18),(15,22),(19,25),(20,21),(27,29)]
62
+ [(5,6),(8,12),(9,10),(11,13),(14,16),(15,17),(18,20),(19,23),(21,22),(25,26)]
63
+ [(3,5),(6,7),(8,9),(10,12),(11,14),(13,16),(15,18),(17,20),(19,21),(22,23),(24,25),(26,28)]
64
+ [(3,4),(5,6),(7,8),(9,10),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28)]
65
+ """,
66
+ # Size 64: 512 CEs, depth 21
67
+ 64: """
68
+ [(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31),(32,34),(33,35),(36,38),(37,39),(40,42),(41,43),(44,46),(45,47),(48,50),(49,51),(52,54),(53,55),(56,58),(57,59),(60,62),(61,63)]
69
+ [(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31),(32,33),(34,35),(36,37),(38,39),(40,41),(42,43),(44,45),(46,47),(48,49),(50,51),(52,53),(54,55),(56,57),(58,59),(60,61),(62,63)]
70
+ [(0,52),(1,2),(3,55),(4,48),(5,6),(7,51),(8,60),(9,10),(11,63),(12,56),(13,14),(15,59),(16,32),(17,18),(19,35),(20,24),(21,22),(23,27),(25,26),(28,44),(29,30),(31,47),(33,34),(36,40),(37,38),(39,43),(41,42),(45,46),(49,50),(53,54),(57,58),(61,62)]
71
+ [(0,20),(1,53),(2,54),(3,23),(4,28),(5,49),(6,50),(7,31),(8,36),(9,61),(10,62),(11,39),(12,16),(13,57),(14,58),(15,19),(17,33),(18,34),(21,25),(22,26),(24,52),(27,55),(29,45),(30,46),(32,56),(35,59),(37,41),(38,42),(40,60),(43,63),(44,48),(47,51)]
72
+ [(0,4),(1,21),(2,22),(3,7),(5,29),(6,30),(8,12),(9,37),(10,38),(11,15),(13,17),(14,18),(16,20),(19,23),(24,32),(25,53),(26,54),(27,35),(28,36),(31,39),(33,57),(34,58),(40,44),(41,61),(42,62),(43,47),(45,49),(46,50),(48,52),(51,55),(56,60),(59,63)]
73
+ [(0,8),(1,5),(2,6),(3,11),(4,12),(7,15),(9,13),(10,14),(16,40),(17,21),(18,22),(19,43),(20,44),(23,47),(24,28),(25,33),(26,34),(27,31),(29,37),(30,38),(32,36),(35,39),(41,45),(42,46),(48,56),(49,53),(50,54),(51,59),(52,60),(55,63),(57,61),(58,62)]
74
+ [(1,9),(2,10),(4,8),(5,13),(6,14),(7,11),(12,48),(15,51),(16,24),(17,41),(18,42),(19,27),(20,28),(21,45),(22,46),(23,31),(25,29),(26,30),(32,40),(33,37),(34,38),(35,43),(36,44),(39,47),(49,57),(50,58),(52,56),(53,61),(54,62),(55,59)]
75
+ [(4,16),(5,9),(6,10),(7,19),(8,24),(11,27),(13,49),(14,50),(17,25),(18,26),(20,32),(21,29),(22,30),(23,35),(28,40),(31,43),(33,41),(34,42),(36,52),(37,45),(38,46),(39,55),(44,56),(47,59),(53,57),(54,58)]
76
+ [(1,4),(5,17),(6,18),(8,16),(9,25),(10,26),(11,19),(12,24),(15,27),(21,33),(22,34),(29,41),(30,42),(36,48),(37,53),(38,54),(39,51),(44,52),(45,57),(46,58),(47,55),(59,62)]
77
+ [(2,8),(9,17),(10,18),(12,20),(13,25),(14,26),(15,23),(24,32),(27,35),(28,36),(31,39),(37,49),(38,50),(40,48),(43,51),(45,53),(46,54),(55,61)]
78
+ [(2,4),(12,16),(13,21),(14,22),(15,19),(20,24),(23,27),(25,33),(26,34),(28,32),(29,37),(30,38),(31,35),(36,40),(39,43),(41,49),(42,50),(44,48),(47,51),(59,61)]
79
+ [(4,16),(5,20),(10,40),(13,17),(14,18),(21,25),(22,26),(23,53),(24,28),(27,31),(29,33),(30,34),(32,36),(35,39),(37,41),(38,42),(43,58),(45,49),(46,50),(47,59)]
80
+ [(3,17),(6,36),(7,21),(8,32),(9,24),(11,41),(13,28),(14,44),(15,45),(18,48),(19,49),(22,52),(25,29),(26,30),(27,57),(31,55),(33,37),(34,38),(35,50),(39,54),(42,56),(46,60)]
81
+ [(6,20),(8,16),(10,24),(11,25),(14,28),(15,29),(17,33),(18,32),(21,37),(22,36),(26,42),(27,41),(30,46),(31,45),(34,48),(35,49),(38,52),(39,53),(43,57),(47,55)]
82
+ [(3,18),(5,8),(6,12),(7,22),(15,21),(17,32),(19,33),(23,37),(26,40),(30,44),(31,46),(41,56),(42,48),(45,60),(51,57),(55,58)]
83
+ [(3,16),(7,20),(11,26),(18,24),(19,25),(22,28),(23,29),(27,33),(30,36),(34,40),(35,41),(37,52),(38,44),(39,45),(43,56),(47,60)]
84
+ [(3,9),(7,13),(10,16),(11,17),(14,20),(15,30),(19,34),(21,36),(23,38),(25,40),(26,32),(27,42),(29,44),(31,37),(33,48),(43,49),(46,52),(47,53),(50,56),(54,60)]
85
+ [(3,8),(7,10),(9,12),(11,18),(13,14),(15,24),(17,22),(19,28),(21,26),(23,25),(27,34),(29,36),(30,32),(31,33),(35,44),(37,42),(38,40),(39,48),(41,46),(45,52),(49,50),(51,54),(53,56),(55,60)]
86
+ [(3,6),(7,12),(11,16),(15,17),(18,20),(19,24),(21,22),(23,30),(25,32),(26,28),(27,29),(31,38),(33,40),(34,36),(35,37),(39,44),(41,42),(43,45),(46,48),(47,52),(51,56),(57,60)]
87
+ [(3,5),(6,8),(7,9),(10,12),(11,13),(14,16),(15,18),(17,20),(19,21),(22,24),(23,26),(25,28),(27,30),(29,32),(31,34),(33,36),(35,38),(37,40),(39,41),(42,44),(43,46),(45,48),(47,49),(50,52),(51,53),(54,56),(55,57),(58,60)]
88
+ [(3,4),(7,8),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28),(29,30),(31,32),(33,34),(35,36),(37,38),(39,40),(41,42),(43,44),(45,46),(47,48),(49,50),(51,52),(55,56),(59,60)]
89
+ """,
90
+ }
91
+
92
+ # This will be populated by initialize_networks()
93
+ OPTIMAL_NETWORKS: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] = {}
94
+
95
+
96
+ def parse_network_string(network_str: str) -> List[List[Tuple[int, int]]]:
97
+ """
98
+ Parse a sorting network string from bertdobbelaere.github.io format.
99
+
100
+ Examples:
101
+ Input: "[(0,2),(1,3)], [(0,1),(2,3)], [(1,2)]"
102
+ Output: [[(0, 2), (1, 3)], [(0, 1), (2, 3)], [(1, 2)]]
103
+
104
+ Input: "[(0,1)], [(1,2)], [(0,1)]"
105
+ Output: [[(0, 1)], [(1, 2)], [(0, 1)]]
106
+ """
107
+ # Remove whitespace and split by '], ['
108
+ network_str = network_str.strip()
109
+ if not network_str:
110
+ return []
111
+
112
+ # Split into layer strings
113
+ layer_pattern = r"\[((?:\(\d+,\d+\)(?:,\(\d+,\d+\))*)?)\]"
114
+ layers = []
115
+
116
+ for match in re.finditer(layer_pattern, network_str):
117
+ layer_str = match.group(1)
118
+ if not layer_str.strip():
119
+ layers.append([])
120
+ continue
121
+
122
+ # Parse comparisons in this layer: (i,j), (k,l), ...
123
+ comparisons = []
124
+ comp_pattern = r"\((\d+),(\d+)\)"
125
+
126
+ for comp_match in re.finditer(comp_pattern, layer_str):
127
+ i, j = int(comp_match.group(1)), int(comp_match.group(2))
128
+ comparisons.append((i, j))
129
+
130
+ layers.append(comparisons)
131
+
132
+ return layers
133
+
134
+
135
+ def calculate_network_stats(layers: List[List[Tuple[int, int]]]) -> Tuple[int, int, int]:
136
+ """Calculate depth, total comparisons, and max index from network layers."""
137
+ depth = len(layers)
138
+ total_comparisons = sum(len(layer) for layer in layers)
139
+
140
+ # Find maximum index to determine network size
141
+ max_index = 0
142
+ for layer in layers:
143
+ for i, j in layer:
144
+ max_index = max(max_index, i, j)
145
+
146
+ network_size = max_index + 1 # Since indices are 0-based
147
+ return depth, total_comparisons, network_size
148
+
149
+
150
+ def add_network_from_string(size: int, network_str: str, description: str = ""):
151
+ """
152
+ Add a network from a string representation to the OPTIMAL_NETWORKS dictionary.
153
+
154
+ Args:
155
+ size: Size of the network (number of elements)
156
+ network_str: Network string in bertdobbelaere.github.io format
157
+ description: Optional description for debugging
158
+ """
159
+ try:
160
+ layers = parse_network_string(network_str)
161
+ depth, comparisons, detected_size = calculate_network_stats(layers)
162
+
163
+ if detected_size != size:
164
+ print(f"Warning: Network size mismatch! Expected {size}, detected {detected_size}")
165
+ print(f"Network string: {network_str[:100]}...")
166
+ return False
167
+
168
+ OPTIMAL_NETWORKS[size] = (depth, comparisons, layers)
169
+
170
+ if description:
171
+ print(f"Added network for size {size}: {description}")
172
+ print(f" Depth: {depth}, Comparisons: {comparisons}")
173
+ return True
174
+
175
+ except Exception as e:
176
+ print(f"Error parsing network for size {size}: {e}")
177
+ print(f"Network string: {network_str[:100]}...")
178
+ return False
179
+
180
+
181
+ def generate_networks_dict(
182
+ networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]]
183
+ ) -> str:
184
+ """Generate the global networks dictionary."""
185
+ lines = ["networks = {"]
186
+
187
+ for size, (depth, num_comparisons, layers) in sorted(networks_data.items()):
188
+ # Format the network with proper indentation and newlines
189
+ network_lines = []
190
+ for i, layer in enumerate(layers):
191
+ if i == 0:
192
+ network_lines.append(f" {layer}")
193
+ else:
194
+ network_lines.append(f",\n {layer}")
195
+
196
+ if len(layers) == 1:
197
+ network_str = f"[{network_lines[0].strip()}]"
198
+ else:
199
+ network_str = "[\n" + "".join(network_lines) + "\n ]"
200
+
201
+ lines.append(f" # Size {size}: {num_comparisons} CEs, depth {depth}")
202
+ lines.append(f" {size}: {network_str},")
203
+ lines.append("")
204
+
205
+ lines.append("}")
206
+ return "\n".join(lines)
207
+
208
+
209
+ def generate_optimal_sort_function() -> str:
210
+ """Generate the single optimal_sort function that looks up networks by size."""
211
+ return """@cute.jit
212
+ def optimal_sort(
213
+ arr: cute.Tensor,
214
+ n: cutlass.Constexpr[int],
215
+ start: cutlass.Constexpr[int] = 0,
216
+ ascending: cutlass.Constexpr[bool] = True
217
+ ) -> None:
218
+ \"\"\"
219
+ Optimal sorting network dispatcher.
220
+
221
+ Args:
222
+ arr: Array to sort
223
+ n: Size of array (must be power of 2 and available in networks)
224
+ start: Starting index (default 0)
225
+ ascending: Sort in ascending order (default True)
226
+
227
+ Source: https://bertdobbelaere.github.io/sorting_networks.html
228
+ \"\"\"
229
+ assert n in networks
230
+ for level in networks[n]:
231
+ for i, j in level:
232
+ compare_and_swap(arr, start + i, start + j, ascending)
233
+ """
234
+
235
+
236
+ def generate_sorting_networks_file(max_size: int = 64):
237
+ """Generate a complete sorting networks file with optimal networks up to max_size."""
238
+
239
+ output_file = os.path.join(os.path.dirname(__file__), "sorting_networks.py")
240
+
241
+ # Header
242
+ header = '''# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
243
+ """
244
+ Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
245
+
246
+ This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
247
+ """
248
+
249
+ # fmt: off
250
+ # ruff: noqa
251
+ # isort: skip_file
252
+
253
+ import cutlass
254
+ import cutlass.cute as cute
255
+
256
+ from .utils import compare_and_swap
257
+
258
+
259
+ '''
260
+
261
+ # Generate networks dictionary and optimal_sort function
262
+ sizes = [n for n in range(2, max_size + 1) if n in OPTIMAL_NETWORKS]
263
+ networks_dict = generate_networks_dict(OPTIMAL_NETWORKS)
264
+ optimal_sort_func = generate_optimal_sort_function()
265
+
266
+ # Combine everything
267
+ content = header + networks_dict + "\n\n\n" + optimal_sort_func
268
+
269
+ with open(output_file, "w") as f:
270
+ f.write(content)
271
+
272
+ print(f"Generated optimal sorting networks for sizes {sizes}")
273
+ print(f"Output written to: {output_file}")
274
+ return sizes
275
+
276
+
277
+ def initialize_networks():
278
+ """Initialize the OPTIMAL_NETWORKS dictionary by parsing NETWORK_STRINGS."""
279
+ global OPTIMAL_NETWORKS
280
+ OPTIMAL_NETWORKS.clear()
281
+
282
+ for size, network_str in NETWORK_STRINGS.items():
283
+ success = add_network_from_string(size, network_str, f"Size {size} optimal network")
284
+ if not success:
285
+ print(f"Warning: Failed to parse network for size {size}")
286
+
287
+
288
+ def main():
289
+ parser = argparse.ArgumentParser(
290
+ description="Generate optimal sorting network code from bertdobbelaere.github.io data"
291
+ )
292
+ parser.add_argument(
293
+ "--max-size",
294
+ "-m",
295
+ type=int,
296
+ default=64,
297
+ help="Maximum sorting network size to generate (default: 32)",
298
+ )
299
+ parser.add_argument(
300
+ "--stats", "-s", action="store_true", help="Print statistics about the optimal networks"
301
+ )
302
+
303
+ args = parser.parse_args()
304
+
305
+ # Initialize networks from strings
306
+ initialize_networks()
307
+
308
+ if args.stats:
309
+ print("Optimal Sorting Network Statistics:")
310
+ print("Size\tDepth\tComparisons\tLayers")
311
+ print("-" * 35)
312
+ for n in sorted(OPTIMAL_NETWORKS.keys()):
313
+ if n <= args.max_size:
314
+ depth, comparisons, layers = OPTIMAL_NETWORKS[n]
315
+ print(f"{n}\t{depth}\t{comparisons}\t\t{len(layers)}")
316
+
317
+ # Generate the sorting networks file
318
+ sizes = generate_sorting_networks_file(args.max_size)
319
+
320
+ print(f"\nGenerated optimal sorting networks for {len(sizes)} sizes")
321
+ print(f"Total networks: {len(sizes)}")
322
+ print(f"Max network size: {max(sizes)}")
323
+
324
+
325
+ if __name__ == "__main__":
326
+ main()
build/torch-cuda/quack/sort/sorting_networks.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
+ """
3
+ Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
4
+
5
+ This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
6
+ """
7
+
8
+ # fmt: off
9
+ # ruff: noqa
10
+ # isort: skip_file
11
+
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+
15
+ from .utils import compare_and_swap
16
+
17
+
18
+ networks = {
19
+ # Size 2: 1 CEs, depth 1
20
+ 2: [[(0, 1)]],
21
+
22
+ # Size 4: 5 CEs, depth 3
23
+ 4: [
24
+ [(0, 2), (1, 3)],
25
+ [(0, 1), (2, 3)],
26
+ [(1, 2)]
27
+ ],
28
+
29
+ # Size 8: 19 CEs, depth 6
30
+ 8: [
31
+ [(0, 2), (1, 3), (4, 6), (5, 7)],
32
+ [(0, 4), (1, 5), (2, 6), (3, 7)],
33
+ [(0, 1), (2, 3), (4, 5), (6, 7)],
34
+ [(2, 4), (3, 5)],
35
+ [(1, 4), (3, 6)],
36
+ [(1, 2), (3, 4), (5, 6)]
37
+ ],
38
+
39
+ # Size 16: 60 CEs, depth 10
40
+ 16: [
41
+ [(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)],
42
+ [(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)],
43
+ [(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)],
44
+ [(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)],
45
+ [(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)],
46
+ [(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)],
47
+ [(2, 4), (3, 6), (9, 12), (11, 13)],
48
+ [(3, 5), (6, 8), (7, 9), (10, 12)],
49
+ [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)],
50
+ [(6, 7), (8, 9)]
51
+ ],
52
+
53
+ # Size 32: 185 CEs, depth 14
54
+ 32: [
55
+ [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)],
56
+ [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)],
57
+ [(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)],
58
+ [(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)],
59
+ [(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)],
60
+ [(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)],
61
+ [(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)],
62
+ [(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)],
63
+ [(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)],
64
+ [(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)],
65
+ [(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)],
66
+ [(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)],
67
+ [(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)],
68
+ [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)]
69
+ ],
70
+
71
+ # Size 64: 521 CEs, depth 21
72
+ 64: [
73
+ [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)],
74
+ [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)],
75
+ [(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)],
76
+ [(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)],
77
+ [(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)],
78
+ [(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)],
79
+ [(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)],
80
+ [(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)],
81
+ [(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)],
82
+ [(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)],
83
+ [(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)],
84
+ [(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)],
85
+ [(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)],
86
+ [(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)],
87
+ [(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)],
88
+ [(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)],
89
+ [(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)],
90
+ [(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)],
91
+ [(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)],
92
+ [(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)],
93
+ [(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)]
94
+ ],
95
+
96
+ }
97
+
98
+
99
+ @cute.jit
100
+ def optimal_sort(
101
+ arr: cute.Tensor,
102
+ n: cutlass.Constexpr[int],
103
+ start: cutlass.Constexpr[int] = 0,
104
+ ascending: cutlass.Constexpr[bool] = True
105
+ ) -> None:
106
+ """
107
+ Optimal sorting network dispatcher.
108
+
109
+ Args:
110
+ arr: Array to sort
111
+ n: Size of array (must be power of 2 and available in networks)
112
+ start: Starting index (default 0)
113
+ ascending: Sort in ascending order (default True)
114
+
115
+ Source: https://bertdobbelaere.github.io/sorting_networks.html
116
+ """
117
+ assert n in networks
118
+ for level in networks[n]:
119
+ for i, j in level:
120
+ compare_and_swap(arr, start + i, start + j, ascending)
build/torch-cuda/quack/sort/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cutlass.cute as cute
2
+ from cutlass import Float32, const_expr
3
+
4
+ from .. import utils
5
+
6
+
7
+ @cute.jit
8
+ def compare_and_swap(
9
+ arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
10
+ ) -> None:
11
+ """Compare and swap elements at indices i and j in ascending or descending order."""
12
+ if const_expr(use_selection):
13
+ a, b = arr[i], arr[j]
14
+ if (a > b) ^ (not ascending):
15
+ arr[i] = b
16
+ arr[j] = a
17
+ # if const_expr(ascending):
18
+ # if a > b:
19
+ # arr[i] = b
20
+ # arr[j] = a
21
+ # else:
22
+ # if a < b:
23
+ # arr[i] = b
24
+ # arr[j] = a
25
+ else:
26
+ min_fn = min if const_expr(arr.element_type != Float32) else utils.fmin
27
+ max_fn = max if const_expr(arr.element_type != Float32) else cute.arch.fmax
28
+ if const_expr(ascending):
29
+ arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
30
+ else:
31
+ arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])
build/torch-cuda/quack/tensormap_manager.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass.cutlass_dsl import Boolean, const_expr, Int32
9
+ from cutlass.utils import TensorMapUpdateMode, TensorMapManager
10
+ from cutlass._mlir.dialects import llvm
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TensorMapManagerSm90(TensorMapManager):
15
+ """
16
+ We have to subclass cutlass.utils.TensorMapManager bc it takes in warp_id and only
17
+ perform the operation if warp_id matches the current warp.
18
+ But for Hopper pingpong gemm we want to call it with warp_id 0 and 4.
19
+ So we take in a boolean `is_manager_warp` to determine whether to perform the operation or not.
20
+ """
21
+
22
+ @cute.jit
23
+ def init_tensormap_from_atom(
24
+ self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, is_manager_warp: Boolean
25
+ ) -> None:
26
+ if is_manager_warp:
27
+ with cute.arch.elect_one():
28
+ cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr)
29
+ cute.arch.sync_warp()
30
+ return
31
+
32
+ @cute.jit
33
+ def update_tensormap(
34
+ self,
35
+ tensor_gmem: Tuple[cute.Tensor, ...],
36
+ tma_copy_atom: Tuple[cute.CopyAtom, ...],
37
+ tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
38
+ is_manager_warp: Boolean,
39
+ tensormap_smem_ptr: Tuple[cute.Pointer, ...],
40
+ ) -> None:
41
+ # updates before touching tensormap in global memory
42
+ if is_manager_warp:
43
+ if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
44
+ for copy_atom, tensor, smem_ptr in zip(
45
+ tma_copy_atom, tensor_gmem, tensormap_smem_ptr
46
+ ):
47
+ cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, smem_ptr)
48
+ # wait until it's safe to update tensormap in global memory
49
+ with cute.arch.elect_one():
50
+ cute.arch.cp_async_bulk_commit_group()
51
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
52
+ cute.arch.sync_warp()
53
+ # updates to tensormap in global memory
54
+ if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
55
+ for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
56
+ cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
57
+ else:
58
+ for copy_atom, tensor, gmem_ptr in zip(
59
+ tma_copy_atom, tensor_gmem, tensormap_gmem_ptr
60
+ ):
61
+ cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, gmem_ptr)
62
+ cute.arch.sync_warp()
63
+ cute.nvgpu.cpasync.fence_tma_desc_release()
64
+
65
+ @cute.jit
66
+ def update_tensormap_shape(
67
+ self,
68
+ tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
69
+ is_manager_warp: Boolean,
70
+ tensormap_smem_ptr: Tuple[cute.Pointer, ...],
71
+ shapes: Tuple[Int32, ...],
72
+ orders: cutlass.Constexpr[Tuple[int, ...]],
73
+ ) -> None:
74
+ # updates before touching tensormap in global memory
75
+ if is_manager_warp:
76
+ if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
77
+ for smem_ptr, shape, order in zip(tensormap_smem_ptr, shapes, orders):
78
+ smem_ptr_i32 = smem_ptr.toint().ir_value()
79
+ llvm.inline_asm(
80
+ None,
81
+ [smem_ptr_i32, Int32(shape).ir_value(), Int32(order).ir_value()],
82
+ "{\n\t"
83
+ ".reg .b64 smem_ptr_i64;\n\t"
84
+ "cvt.u64.u32 smem_ptr_i64, $0;\n\t"
85
+ f"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [smem_ptr_i64], {order}, $1;\n\t"
86
+ "}\n",
87
+ "r,r",
88
+ has_side_effects=True,
89
+ is_align_stack=False,
90
+ asm_dialect=llvm.AsmDialect.AD_ATT,
91
+ )
92
+ # wait until it's safe to update tensormap in global memory
93
+ with cute.arch.elect_one():
94
+ cute.arch.cp_async_bulk_commit_group()
95
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
96
+ cute.arch.sync_warp()
97
+ # updates to tensormap in global memory
98
+ if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
99
+ for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
100
+ cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
101
+ else:
102
+ assert len(shapes) == len(orders) == len(tensormap_gmem_ptr)
103
+ for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders):
104
+ gmem_ptr_i64 = gmem_ptr.toint().ir_value()
105
+ llvm.inline_asm(
106
+ None,
107
+ [gmem_ptr_i64, Int32(shape).ir_value(), Int32(order).ir_value()],
108
+ f"tensormap.replace.tile.global_dim.global.b1024.b32 [$0], {order}, $1;",
109
+ "l,r",
110
+ has_side_effects=True,
111
+ is_align_stack=False,
112
+ asm_dialect=llvm.AsmDialect.AD_ATT,
113
+ )
114
+ cute.arch.sync_warp()
115
+ cute.nvgpu.cpasync.fence_tma_desc_release()