Uploaded using `kernel-builder`.
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch-cuda/__init__.py +24 -0
- build/torch-cuda/_ops.py +8 -0
- build/torch-cuda/_ops_compat.py +10 -0
- build/torch-cuda/enums.py +30 -0
- build/torch-cuda/functional/__init__.py +554 -0
- build/torch-cuda/functional/backward.py +682 -0
- build/torch-cuda/functional/forward.py +238 -0
- build/torch-cuda/functional/grouped_gemm.py +0 -0
- build/torch-cuda/functional/moe_config.py +581 -0
- build/torch-cuda/functional/reduction_over_k_gather.py +164 -0
- build/torch-cuda/functional/tile_scheduler.py +91 -0
- build/torch-cuda/functional/topk_softmax.py +195 -0
- build/torch-cuda/functional/triton_kernels/__init__.py +351 -0
- build/torch-cuda/functional/triton_kernels/bitmatrix.py +147 -0
- build/torch-cuda/functional/utils.py +25 -0
- build/torch-cuda/jit.py +159 -0
- build/torch-cuda/metadata.json +10 -0
- build/torch-cuda/moe.py +368 -0
- build/torch-cuda/quack/__init__.py +8 -0
- build/torch-cuda/quack/_ops_compat.py +4 -0
- build/torch-cuda/quack/activation.py +524 -0
- build/torch-cuda/quack/autotuner.py +369 -0
- build/torch-cuda/quack/broadcast_utils.py +29 -0
- build/torch-cuda/quack/compile_utils.py +19 -0
- build/torch-cuda/quack/copy_utils.py +614 -0
- build/torch-cuda/quack/cute_dsl_ptxas.py +151 -0
- build/torch-cuda/quack/cute_dsl_utils.py +104 -0
- build/torch-cuda/quack/fast_math.py +80 -0
- build/torch-cuda/quack/gemm.py +194 -0
- build/torch-cuda/quack/gemm_act.py +510 -0
- build/torch-cuda/quack/gemm_config.py +95 -0
- build/torch-cuda/quack/gemm_dact.py +215 -0
- build/torch-cuda/quack/gemm_default_epi.py +259 -0
- build/torch-cuda/quack/gemm_interface.py +1058 -0
- build/torch-cuda/quack/gemm_sm100.py +0 -0
- build/torch-cuda/quack/gemm_sm90.py +2070 -0
- build/torch-cuda/quack/gemm_symmetric.py +330 -0
- build/torch-cuda/quack/gemm_wrapper_utils.py +317 -0
- build/torch-cuda/quack/layout_utils.py +295 -0
- build/torch-cuda/quack/pipeline.py +324 -0
- build/torch-cuda/quack/reduce.py +279 -0
- build/torch-cuda/quack/reduction_base.py +83 -0
- build/torch-cuda/quack/sm100_utils.py +62 -0
- build/torch-cuda/quack/sm90_utils.py +157 -0
- build/torch-cuda/quack/sort/__init__.py +1 -0
- build/torch-cuda/quack/sort/bitonic_sort.py +129 -0
- build/torch-cuda/quack/sort/generate_sorting_networks.py +326 -0
- build/torch-cuda/quack/sort/sorting_networks.py +120 -0
- build/torch-cuda/quack/sort/utils.py +31 -0
- build/torch-cuda/quack/tensormap_manager.py +115 -0
build/torch-cuda/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
|
| 7 |
+
__version__ = "0.1.1"
|
| 8 |
+
|
| 9 |
+
from .enums import KernelBackendMoE
|
| 10 |
+
|
| 11 |
+
from .moe import MoE
|
| 12 |
+
from .functional import (
|
| 13 |
+
enable_quack_gemm,
|
| 14 |
+
moe_general_routing_inputs,
|
| 15 |
+
moe_TC_softmax_topk_layer,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"KernelBackendMoE",
|
| 20 |
+
"MoE",
|
| 21 |
+
"enable_quack_gemm",
|
| 22 |
+
"moe_general_routing_inputs",
|
| 23 |
+
"moe_TC_softmax_topk_layer",
|
| 24 |
+
]
|
build/torch-cuda/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._sonic_moe_57a1b31
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_sonic_moe_57a1b31::{op_name}"
|
build/torch-cuda/_ops_compat.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility helpers for op namespacing in source and built layouts."""
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
from ._ops import add_op_namespace_prefix as _generated_add_op_namespace_prefix
|
| 5 |
+
except ImportError:
|
| 6 |
+
def _generated_add_op_namespace_prefix(name: str) -> str:
|
| 7 |
+
return name if "::" in name else f"sonicmoe::{name}"
|
| 8 |
+
|
| 9 |
+
def add_op_namespace_prefix(name: str) -> str:
|
| 10 |
+
return _generated_add_op_namespace_prefix(name)
|
build/torch-cuda/enums.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
LIBRARY_NAME = "sonicmoe"
|
| 9 |
+
TENSORMAP = "tensormap"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class KernelBackendMoE(Enum):
|
| 13 |
+
scattermoe = "scattermoe"
|
| 14 |
+
torch = "torch"
|
| 15 |
+
sonicmoe = "sonicmoe"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ActivationType(Enum):
|
| 19 |
+
SWIGLU = "swiglu"
|
| 20 |
+
GEGLU = "geglu"
|
| 21 |
+
REGLU = "reglu"
|
| 22 |
+
|
| 23 |
+
RELU_SQ = "relu_sq"
|
| 24 |
+
RELU = "relu"
|
| 25 |
+
GELU = "gelu"
|
| 26 |
+
SILU = "silu"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_glu(activation_type: ActivationType):
|
| 30 |
+
return activation_type in [ActivationType.SWIGLU, ActivationType.REGLU, ActivationType.GEGLU]
|
build/torch-cuda/functional/__init__.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from ..quack.gemm_interface import gemm
|
| 10 |
+
|
| 11 |
+
from ..enums import ActivationType, is_glu
|
| 12 |
+
from ..quack_utils import gemm_dgated, gemm_gated
|
| 13 |
+
from .backward import (
|
| 14 |
+
_down_projection_backward_act,
|
| 15 |
+
_down_projection_backward_weight,
|
| 16 |
+
_softmax_topk_bwd,
|
| 17 |
+
_token_broadcast_backward,
|
| 18 |
+
_up_projection_backward_act,
|
| 19 |
+
_up_projection_backward_weight,
|
| 20 |
+
)
|
| 21 |
+
from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward
|
| 22 |
+
from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton
|
| 23 |
+
from .utils import enable_quack_gemm, is_using_quack_gemm
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TC_Softmax_Topk_Router_Function(torch.autograd.Function):
|
| 27 |
+
@staticmethod
|
| 28 |
+
def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 29 |
+
T = router_logits.size(0)
|
| 30 |
+
|
| 31 |
+
# change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy
|
| 32 |
+
topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device)
|
| 33 |
+
topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device)
|
| 34 |
+
|
| 35 |
+
_softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K)
|
| 36 |
+
|
| 37 |
+
ctx.save_for_backward(topk_router_score, topk_router_indices)
|
| 38 |
+
ctx.E = E
|
| 39 |
+
ctx.dtype = router_logits.dtype
|
| 40 |
+
|
| 41 |
+
return topk_router_score, topk_router_indices
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 45 |
+
T, K = dtopk_score.size()
|
| 46 |
+
|
| 47 |
+
topk_router_score, topk_router_indices = ctx.saved_tensors
|
| 48 |
+
dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device)
|
| 49 |
+
|
| 50 |
+
_softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K)
|
| 51 |
+
|
| 52 |
+
return dlogits, None, None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class _UpProjection(torch.autograd.Function):
|
| 56 |
+
@staticmethod
|
| 57 |
+
def forward(
|
| 58 |
+
ctx,
|
| 59 |
+
x: torch.Tensor,
|
| 60 |
+
w1: torch.Tensor,
|
| 61 |
+
b1: torch.Tensor | None,
|
| 62 |
+
expert_frequency_offset: torch.Tensor,
|
| 63 |
+
total_expert_freq: int,
|
| 64 |
+
K: int,
|
| 65 |
+
stream_id: int,
|
| 66 |
+
x_gather_idx: torch.Tensor,
|
| 67 |
+
s_scatter_idx: torch.Tensor,
|
| 68 |
+
s_reverse_scatter_idx: torch.Tensor,
|
| 69 |
+
num_activated_expert_per_token_offset: torch.Tensor,
|
| 70 |
+
is_varlen_K: bool,
|
| 71 |
+
activation_type: ActivationType,
|
| 72 |
+
is_inference_mode_enabled: bool,
|
| 73 |
+
) -> torch.Tensor:
|
| 74 |
+
T, H = x.shape
|
| 75 |
+
I, H, E = w1.shape
|
| 76 |
+
is_glu_activation = is_glu(activation_type)
|
| 77 |
+
if is_glu_activation:
|
| 78 |
+
I //= 2
|
| 79 |
+
TK = total_expert_freq
|
| 80 |
+
|
| 81 |
+
if is_using_quack_gemm():
|
| 82 |
+
assert not torch.compiler.is_compiling()
|
| 83 |
+
assert is_glu_activation, "QuACK GEMM does not support non GLU activation yet"
|
| 84 |
+
z, y1 = gemm_gated(
|
| 85 |
+
x,
|
| 86 |
+
w1.permute(2, 1, 0),
|
| 87 |
+
activation="swiglu",
|
| 88 |
+
cu_seqlens_m=expert_frequency_offset,
|
| 89 |
+
A_idx=x_gather_idx,
|
| 90 |
+
dynamic_scheduler=False,
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
z = torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device)
|
| 94 |
+
y1 = torch.empty(TK, I, dtype=x.dtype, device=x.device)
|
| 95 |
+
_up_projection_forward(
|
| 96 |
+
x=x,
|
| 97 |
+
w1=w1,
|
| 98 |
+
z=z,
|
| 99 |
+
y1=y1,
|
| 100 |
+
b1=b1,
|
| 101 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 102 |
+
expert_schedule_order=None,
|
| 103 |
+
x_gather_idx=x_gather_idx,
|
| 104 |
+
stream_id=stream_id,
|
| 105 |
+
activation_type=activation_type.value,
|
| 106 |
+
is_glu_activation=is_glu_activation,
|
| 107 |
+
is_inference_mode_enabled=is_inference_mode_enabled,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
ctx.T = T
|
| 111 |
+
ctx.TK = TK
|
| 112 |
+
ctx.E = E
|
| 113 |
+
ctx.K = K
|
| 114 |
+
ctx.H = H
|
| 115 |
+
ctx.I = I
|
| 116 |
+
ctx.is_varlen_K = is_varlen_K
|
| 117 |
+
ctx.is_glu_activation = is_glu_activation
|
| 118 |
+
ctx.stream_id = stream_id
|
| 119 |
+
|
| 120 |
+
ctx.save_for_backward(
|
| 121 |
+
x,
|
| 122 |
+
w1,
|
| 123 |
+
b1,
|
| 124 |
+
expert_frequency_offset,
|
| 125 |
+
x_gather_idx,
|
| 126 |
+
s_scatter_idx,
|
| 127 |
+
s_reverse_scatter_idx,
|
| 128 |
+
num_activated_expert_per_token_offset,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
ctx.mark_non_differentiable(y1)
|
| 132 |
+
ctx.set_materialize_grads(False)
|
| 133 |
+
|
| 134 |
+
return y1, z
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def backward(ctx, _: None, dz: torch.Tensor):
|
| 138 |
+
is_compiling = torch.compiler.is_compiling()
|
| 139 |
+
|
| 140 |
+
if not is_compiling:
|
| 141 |
+
assert _ is None
|
| 142 |
+
|
| 143 |
+
T = ctx.T
|
| 144 |
+
TK = ctx.TK
|
| 145 |
+
E = ctx.E
|
| 146 |
+
K = ctx.K
|
| 147 |
+
H = ctx.H
|
| 148 |
+
is_glu_activation = ctx.is_glu_activation
|
| 149 |
+
is_varlen_K = ctx.is_varlen_K
|
| 150 |
+
stream_id = ctx.stream_id
|
| 151 |
+
|
| 152 |
+
(
|
| 153 |
+
x,
|
| 154 |
+
w1,
|
| 155 |
+
b1,
|
| 156 |
+
expert_frequency_offset,
|
| 157 |
+
x_gather_idx,
|
| 158 |
+
s_scatter_idx,
|
| 159 |
+
s_reverse_scatter_idx,
|
| 160 |
+
num_activated_expert_per_token_offset,
|
| 161 |
+
) = ctx.saved_tensors
|
| 162 |
+
|
| 163 |
+
dw1 = torch.empty_like(w1)
|
| 164 |
+
db1 = None if b1 is None else torch.empty_like(b1)
|
| 165 |
+
|
| 166 |
+
if is_using_quack_gemm():
|
| 167 |
+
assert not is_compiling
|
| 168 |
+
|
| 169 |
+
gemm(
|
| 170 |
+
x.T,
|
| 171 |
+
dz,
|
| 172 |
+
out=dw1.permute(2, 1, 0),
|
| 173 |
+
cu_seqlens_k=expert_frequency_offset,
|
| 174 |
+
A_idx=x_gather_idx,
|
| 175 |
+
batch_idx_permute=None,
|
| 176 |
+
dynamic_scheduler=False,
|
| 177 |
+
)
|
| 178 |
+
dx_expanded = gemm(dz, w1.permute(2, 0, 1), cu_seqlens_m=expert_frequency_offset, dynamic_scheduler=False)
|
| 179 |
+
else:
|
| 180 |
+
dx_expanded = torch.empty(TK, H, dtype=dz.dtype, device=dz.device)
|
| 181 |
+
|
| 182 |
+
_up_projection_backward_act(
|
| 183 |
+
w1=w1,
|
| 184 |
+
dx_expanded=dx_expanded,
|
| 185 |
+
dz=dz,
|
| 186 |
+
db1=db1,
|
| 187 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 188 |
+
expert_schedule_order=None,
|
| 189 |
+
x_gather_idx=x_gather_idx,
|
| 190 |
+
s_scatter_idx=s_scatter_idx,
|
| 191 |
+
is_glu_activation=is_glu_activation,
|
| 192 |
+
stream_id=stream_id,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
_up_projection_backward_weight(
|
| 196 |
+
x=x,
|
| 197 |
+
dw1=dw1,
|
| 198 |
+
dz=dz,
|
| 199 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 200 |
+
expert_schedule_order=None,
|
| 201 |
+
x_gather_idx=x_gather_idx,
|
| 202 |
+
is_glu_activation=is_glu_activation,
|
| 203 |
+
stream_id=stream_id,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device)
|
| 207 |
+
|
| 208 |
+
_token_broadcast_backward(
|
| 209 |
+
dx_reduced=dx_reduced,
|
| 210 |
+
dx_expanded=dx_expanded,
|
| 211 |
+
s_reverse_scatter_idx=s_reverse_scatter_idx,
|
| 212 |
+
num_activated_expert_per_token_offset=num_activated_expert_per_token_offset,
|
| 213 |
+
varlen_K_max=(E if is_varlen_K else K),
|
| 214 |
+
H=H,
|
| 215 |
+
is_varlen_K=is_varlen_K,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return dx_reduced, dw1, db1, *[None] * 12
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class _DownProjection(torch.autograd.Function):
|
| 222 |
+
@staticmethod
|
| 223 |
+
def forward(
|
| 224 |
+
ctx,
|
| 225 |
+
y1: torch.Tensor,
|
| 226 |
+
z: torch.Tensor,
|
| 227 |
+
w2: torch.Tensor,
|
| 228 |
+
b2: torch.Tensor | None,
|
| 229 |
+
topk_scores: torch.Tensor,
|
| 230 |
+
expert_frequency_offset: torch.Tensor,
|
| 231 |
+
T: int,
|
| 232 |
+
K: int,
|
| 233 |
+
stream_id: int,
|
| 234 |
+
x_gather_idx: torch.Tensor,
|
| 235 |
+
s_scatter_idx: torch.Tensor,
|
| 236 |
+
s_reverse_scatter_idx: torch.Tensor,
|
| 237 |
+
num_activated_expert_per_token_offset: torch.Tensor,
|
| 238 |
+
is_varlen_K: bool,
|
| 239 |
+
activation_type: ActivationType,
|
| 240 |
+
) -> torch.Tensor:
|
| 241 |
+
TK = y1.size(0)
|
| 242 |
+
H, I, E = w2.shape
|
| 243 |
+
|
| 244 |
+
if is_using_quack_gemm():
|
| 245 |
+
assert not torch.compiler.is_compiling()
|
| 246 |
+
|
| 247 |
+
assert b2 is None
|
| 248 |
+
y2 = gemm(y1, w2.permute(2, 1, 0), cu_seqlens_m=expert_frequency_offset)
|
| 249 |
+
else:
|
| 250 |
+
y2 = torch.empty(TK, H, dtype=y1.dtype, device=y1.device)
|
| 251 |
+
_down_projection_forward(
|
| 252 |
+
w2=w2,
|
| 253 |
+
y1=y1,
|
| 254 |
+
y2=y2,
|
| 255 |
+
b2=b2,
|
| 256 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 257 |
+
expert_schedule_order=None,
|
| 258 |
+
x_gather_idx=x_gather_idx,
|
| 259 |
+
stream_id=stream_id,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
o = torch.empty(T, H, device=z.device, dtype=z.dtype)
|
| 263 |
+
topk_scores = topk_scores.flatten()
|
| 264 |
+
|
| 265 |
+
_router_forward(
|
| 266 |
+
y2=y2,
|
| 267 |
+
o=o,
|
| 268 |
+
topk_scores=topk_scores,
|
| 269 |
+
s_reverse_scatter_idx=s_reverse_scatter_idx,
|
| 270 |
+
num_activated_expert_per_token_offset=num_activated_expert_per_token_offset,
|
| 271 |
+
varlen_K_max=(E if is_varlen_K else K),
|
| 272 |
+
H=H,
|
| 273 |
+
is_varlen_K=is_varlen_K,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
ctx.T = T
|
| 277 |
+
ctx.K = K
|
| 278 |
+
ctx.is_varlen_K = is_varlen_K
|
| 279 |
+
ctx.activation_type = activation_type
|
| 280 |
+
ctx.stream_id = stream_id
|
| 281 |
+
|
| 282 |
+
ctx.save_for_backward(
|
| 283 |
+
z,
|
| 284 |
+
w2,
|
| 285 |
+
b2,
|
| 286 |
+
topk_scores,
|
| 287 |
+
expert_frequency_offset,
|
| 288 |
+
x_gather_idx,
|
| 289 |
+
s_scatter_idx,
|
| 290 |
+
s_reverse_scatter_idx,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return o
|
| 294 |
+
|
| 295 |
+
@staticmethod
|
| 296 |
+
def backward(ctx, dout: torch.Tensor):
|
| 297 |
+
T = ctx.T
|
| 298 |
+
K = ctx.K
|
| 299 |
+
stream_id = ctx.stream_id
|
| 300 |
+
is_varlen_K = ctx.is_varlen_K
|
| 301 |
+
activation_type = ctx.activation_type
|
| 302 |
+
|
| 303 |
+
(
|
| 304 |
+
z,
|
| 305 |
+
w2,
|
| 306 |
+
b2,
|
| 307 |
+
topk_scores,
|
| 308 |
+
expert_frequency_offset,
|
| 309 |
+
x_gather_idx,
|
| 310 |
+
s_scatter_idx,
|
| 311 |
+
s_reverse_scatter_idx,
|
| 312 |
+
) = ctx.saved_tensors
|
| 313 |
+
|
| 314 |
+
dw2 = torch.empty_like(w2)
|
| 315 |
+
db2 = None if b2 is None else torch.empty_like(b2)
|
| 316 |
+
dz = torch.empty_like(z)
|
| 317 |
+
|
| 318 |
+
if is_using_quack_gemm():
|
| 319 |
+
assert not torch.compiler.is_compiling()
|
| 320 |
+
assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet"
|
| 321 |
+
|
| 322 |
+
s = topk_scores[s_scatter_idx]
|
| 323 |
+
_, y1s, ds = gemm_dgated(
|
| 324 |
+
dout,
|
| 325 |
+
w2.permute(2, 0, 1),
|
| 326 |
+
PreAct=z,
|
| 327 |
+
activation="swiglu",
|
| 328 |
+
dx_out=dz,
|
| 329 |
+
colvec_scale=s,
|
| 330 |
+
colvec_reduce=True,
|
| 331 |
+
cu_seqlens_m=expert_frequency_offset,
|
| 332 |
+
A_idx=x_gather_idx,
|
| 333 |
+
dynamic_scheduler=False,
|
| 334 |
+
)
|
| 335 |
+
gemm(
|
| 336 |
+
dout.T,
|
| 337 |
+
y1s,
|
| 338 |
+
out=dw2.permute(2, 0, 1),
|
| 339 |
+
cu_seqlens_k=expert_frequency_offset,
|
| 340 |
+
A_idx=x_gather_idx,
|
| 341 |
+
batch_idx_permute=None,
|
| 342 |
+
dynamic_scheduler=False,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
ds = ds[s_reverse_scatter_idx]
|
| 346 |
+
else:
|
| 347 |
+
ds = torch.empty_like(topk_scores)
|
| 348 |
+
|
| 349 |
+
I = w2.size(1)
|
| 350 |
+
TK = x_gather_idx.size(0)
|
| 351 |
+
|
| 352 |
+
y1s = torch.empty(TK, I, dtype=z.dtype, device=z.device)
|
| 353 |
+
is_glu_activation = is_glu(activation_type)
|
| 354 |
+
|
| 355 |
+
_down_projection_backward_act(
|
| 356 |
+
dout=dout,
|
| 357 |
+
z=z,
|
| 358 |
+
w2=w2,
|
| 359 |
+
dz=dz,
|
| 360 |
+
ds=ds,
|
| 361 |
+
b2=b2,
|
| 362 |
+
db2=db2,
|
| 363 |
+
y1s=y1s,
|
| 364 |
+
topk_scores=topk_scores,
|
| 365 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 366 |
+
expert_schedule_order=None,
|
| 367 |
+
x_gather_idx=x_gather_idx,
|
| 368 |
+
s_scatter_idx=s_scatter_idx,
|
| 369 |
+
is_glu_activation=is_glu_activation,
|
| 370 |
+
activation_type=activation_type.value,
|
| 371 |
+
stream_id=stream_id,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
_down_projection_backward_weight(
|
| 375 |
+
dout=dout,
|
| 376 |
+
y1s=y1s,
|
| 377 |
+
dw2=dw2,
|
| 378 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 379 |
+
expert_schedule_order=None,
|
| 380 |
+
x_gather_idx=x_gather_idx,
|
| 381 |
+
stream_id=stream_id,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# TC top-K routing
|
| 385 |
+
if not is_varlen_K:
|
| 386 |
+
ds = ds.view(T, K)
|
| 387 |
+
|
| 388 |
+
return None, dz, dw2, db2, ds, *[None] * 10
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def moe_TC_softmax_topk_layer(
|
| 392 |
+
x: torch.Tensor,
|
| 393 |
+
router_w: torch.Tensor,
|
| 394 |
+
w1: torch.Tensor,
|
| 395 |
+
b1: torch.Tensor | None,
|
| 396 |
+
w2: torch.Tensor,
|
| 397 |
+
b2: torch.Tensor | None,
|
| 398 |
+
K: int,
|
| 399 |
+
stream_id: int,
|
| 400 |
+
activation_type: ActivationType | str = ActivationType.SWIGLU,
|
| 401 |
+
is_inference_mode_enabled: bool = False,
|
| 402 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 403 |
+
assert ((b1 is None) and (b2 is None)) or (
|
| 404 |
+
(b1 is not None) and (b2 is not None)
|
| 405 |
+
), "b1 and b2 has to be None or not None at the same time!"
|
| 406 |
+
E = router_w.size(0)
|
| 407 |
+
router_logits = F.linear(x, router_w)
|
| 408 |
+
topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, E, K)
|
| 409 |
+
|
| 410 |
+
T, K = topk_indices.size()
|
| 411 |
+
TK = T * K
|
| 412 |
+
device = topk_indices.device
|
| 413 |
+
|
| 414 |
+
s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 415 |
+
s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 416 |
+
expert_frequency = torch.empty(E, dtype=torch.int32, device=device)
|
| 417 |
+
expert_frequency_offset = torch.empty(E + 1, dtype=torch.int32, device=device)
|
| 418 |
+
x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 419 |
+
|
| 420 |
+
TC_topk_router_metadata_triton(
|
| 421 |
+
topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
T = x.size(0)
|
| 425 |
+
|
| 426 |
+
if type(activation_type) == str:
|
| 427 |
+
activation_type = ActivationType(activation_type)
|
| 428 |
+
|
| 429 |
+
y1, z = _UpProjection.apply(
|
| 430 |
+
x,
|
| 431 |
+
w1,
|
| 432 |
+
b1,
|
| 433 |
+
expert_frequency_offset,
|
| 434 |
+
T * K,
|
| 435 |
+
K,
|
| 436 |
+
stream_id,
|
| 437 |
+
x_gather_idx,
|
| 438 |
+
s_scatter_idx,
|
| 439 |
+
s_reverse_scatter_idx,
|
| 440 |
+
None,
|
| 441 |
+
False, # is_varlen_K
|
| 442 |
+
activation_type,
|
| 443 |
+
is_inference_mode_enabled,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
o = _DownProjection.apply(
|
| 447 |
+
y1,
|
| 448 |
+
z,
|
| 449 |
+
w2,
|
| 450 |
+
b2,
|
| 451 |
+
topk_scores,
|
| 452 |
+
expert_frequency_offset,
|
| 453 |
+
T,
|
| 454 |
+
K,
|
| 455 |
+
stream_id,
|
| 456 |
+
x_gather_idx,
|
| 457 |
+
s_scatter_idx,
|
| 458 |
+
s_reverse_scatter_idx,
|
| 459 |
+
None,
|
| 460 |
+
False, # is_varlen_K
|
| 461 |
+
activation_type,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
return o, router_logits, expert_frequency
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 468 |
+
# Weight format requirements:
|
| 469 |
+
# - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1), must be interleaved [gate_row0, up_row0, gate_row1, up_row1, ...]
|
| 470 |
+
# - w2_weight: Shape (H, I, E), stride order (2, 0, 1)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# We assume token_indices is already SORTED ascendingly !!!
|
| 474 |
+
# and len(token_indices) = len(expert_indices) = len(router_scores)
|
| 475 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 476 |
+
def moe_general_routing_inputs(
|
| 477 |
+
x: torch.Tensor,
|
| 478 |
+
router_scores: torch.Tensor,
|
| 479 |
+
token_indices: torch.Tensor,
|
| 480 |
+
expert_indices: torch.Tensor,
|
| 481 |
+
w1: torch.Tensor,
|
| 482 |
+
b1: torch.Tensor | None,
|
| 483 |
+
w2: torch.Tensor,
|
| 484 |
+
b2: torch.Tensor | None,
|
| 485 |
+
E: int,
|
| 486 |
+
stream_id: int,
|
| 487 |
+
activation_type: ActivationType,
|
| 488 |
+
is_inference_mode_enabled: bool = False,
|
| 489 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 490 |
+
assert ((b1 is None) and (b2 is None)) or (
|
| 491 |
+
(b1 is not None) and (b2 is not None)
|
| 492 |
+
), "b1 and b2 has to be None or not None at the same time!"
|
| 493 |
+
|
| 494 |
+
T = x.size(0)
|
| 495 |
+
TK = router_scores.size(0)
|
| 496 |
+
E = w2.size(-1)
|
| 497 |
+
device = router_scores.device
|
| 498 |
+
|
| 499 |
+
s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 500 |
+
s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 501 |
+
expert_frequency = torch.empty(E, dtype=torch.int32, device=device)
|
| 502 |
+
expert_frequency_offset = torch.empty(E + 1, dtype=torch.int32, device=device)
|
| 503 |
+
x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 504 |
+
num_activated_expert_per_token_offset = torch.empty(T + 1, dtype=torch.int32, device=device)
|
| 505 |
+
|
| 506 |
+
general_routing_router_metadata_triton(
|
| 507 |
+
token_indices,
|
| 508 |
+
expert_indices,
|
| 509 |
+
T,
|
| 510 |
+
E,
|
| 511 |
+
expert_frequency,
|
| 512 |
+
expert_frequency_offset,
|
| 513 |
+
x_gather_idx,
|
| 514 |
+
s_scatter_idx,
|
| 515 |
+
s_reverse_scatter_idx,
|
| 516 |
+
num_activated_expert_per_token_offset,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
y1, z = _UpProjection.apply(
|
| 520 |
+
x,
|
| 521 |
+
w1,
|
| 522 |
+
b1,
|
| 523 |
+
expert_frequency_offset,
|
| 524 |
+
TK,
|
| 525 |
+
None, # K, not needed
|
| 526 |
+
stream_id,
|
| 527 |
+
x_gather_idx,
|
| 528 |
+
s_scatter_idx,
|
| 529 |
+
s_reverse_scatter_idx,
|
| 530 |
+
num_activated_expert_per_token_offset,
|
| 531 |
+
True, # is_varlen_K
|
| 532 |
+
activation_type,
|
| 533 |
+
is_inference_mode_enabled,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
o = _DownProjection.apply(
|
| 537 |
+
y1,
|
| 538 |
+
z,
|
| 539 |
+
w2,
|
| 540 |
+
b2,
|
| 541 |
+
router_scores,
|
| 542 |
+
expert_frequency_offset,
|
| 543 |
+
T,
|
| 544 |
+
None, # K, not needed
|
| 545 |
+
stream_id,
|
| 546 |
+
x_gather_idx,
|
| 547 |
+
s_scatter_idx,
|
| 548 |
+
s_reverse_scatter_idx,
|
| 549 |
+
num_activated_expert_per_token_offset,
|
| 550 |
+
True, # is_varlen_K
|
| 551 |
+
activation_type,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
return o, expert_frequency
|
build/torch-cuda/functional/backward.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import cuda.bindings.driver as cuda
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
import torch
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from .._ops_compat import add_op_namespace_prefix
|
| 14 |
+
from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType
|
| 15 |
+
from ..utils import ceil_divide, convert_torch_tensor_to_cute_tensor, get_powers_of_2
|
| 16 |
+
from .moe_config import (
|
| 17 |
+
HopperWgmma_MoE_Down_proj_ActGrad_Bwd,
|
| 18 |
+
HopperWgmma_MoE_Down_proj_WeightGrad_Bwd,
|
| 19 |
+
HopperWgmma_MoE_Up_proj_ActGrad_Bwd,
|
| 20 |
+
HopperWgmma_MoE_Up_proj_WeightGrad_Bwd,
|
| 21 |
+
)
|
| 22 |
+
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_autotune_configs_for_db2_and_ds() -> list[triton.Config]:
|
| 26 |
+
configs = []
|
| 27 |
+
for BLOCK_TK in get_powers_of_2(4, 32):
|
| 28 |
+
configs.append(triton.Config({"BLOCK_TK": BLOCK_TK}, num_warps=8, num_stages=4))
|
| 29 |
+
return configs
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@triton.autotune(
|
| 33 |
+
configs=_get_autotune_configs_for_db2_and_ds(),
|
| 34 |
+
key=["H", "E"],
|
| 35 |
+
)
|
| 36 |
+
@triton.jit
|
| 37 |
+
def db2_and_ds_kernel(
|
| 38 |
+
dout_ptr, # (T, H)
|
| 39 |
+
s_ptr, # (TK,)
|
| 40 |
+
new_ds_partial_ptr, # (TK, n_h_blocks)
|
| 41 |
+
old_ds_partial_ptr, # (TK, OLD_DS_PARTIAL_N)
|
| 42 |
+
b2_ptr, # (E, H),
|
| 43 |
+
db2_ptr, # (E, H),
|
| 44 |
+
x_gather_idx_ptr, # (TK,), maps grouped -> token index
|
| 45 |
+
s_scatter_idx_ptr, # (TK,), maps grouped -> scatter index
|
| 46 |
+
expert_offset_ptr, # (E+1,), offsets in grouped layout
|
| 47 |
+
H: tl.constexpr,
|
| 48 |
+
E: tl.constexpr,
|
| 49 |
+
OLD_DS_PARTIAL_N: tl.constexpr,
|
| 50 |
+
BLOCK_H: tl.constexpr, # Block size for H dimension
|
| 51 |
+
BLOCK_TK: tl.constexpr, # Block size for token dimension
|
| 52 |
+
BLOCK_OLD_DS_PARTIAL_N: tl.constexpr,
|
| 53 |
+
):
|
| 54 |
+
Eidx = tl.program_id(0) # expert id
|
| 55 |
+
Hidx = tl.program_id(1) # h-block id
|
| 56 |
+
NUM_H_BLOCKS: tl.constexpr = tl.num_programs(1)
|
| 57 |
+
|
| 58 |
+
# Hidden dimension indices for this block
|
| 59 |
+
h_offsets = Hidx * BLOCK_H + tl.arange(0, BLOCK_H)
|
| 60 |
+
h_mask = h_offsets < H
|
| 61 |
+
|
| 62 |
+
E_count_start = tl.load(expert_offset_ptr + Eidx)
|
| 63 |
+
E_count_end = tl.load(expert_offset_ptr + Eidx + 1)
|
| 64 |
+
n_tokens = E_count_end - E_count_start
|
| 65 |
+
|
| 66 |
+
b2 = tl.load(b2_ptr + Eidx * H + h_offsets, mask=h_mask, other=0.0).to(tl.float32)
|
| 67 |
+
|
| 68 |
+
db2_acc = tl.zeros([BLOCK_H], dtype=tl.float32)
|
| 69 |
+
|
| 70 |
+
# Process tokens in blocks of BLOCK_TK
|
| 71 |
+
for block_start in tl.range(0, n_tokens, BLOCK_TK):
|
| 72 |
+
# Token offsets within this block
|
| 73 |
+
tk_offsets = block_start + tl.arange(0, BLOCK_TK)
|
| 74 |
+
tk_mask = tk_offsets < n_tokens
|
| 75 |
+
tk_grouped = E_count_start + tk_offsets
|
| 76 |
+
|
| 77 |
+
# Gather token indices: [BLOCK_TK]
|
| 78 |
+
token_indices = tl.load(x_gather_idx_ptr + tk_grouped, mask=tk_mask, other=0).to(tl.uint32)
|
| 79 |
+
|
| 80 |
+
# Get scatter indices: [BLOCK_TK]
|
| 81 |
+
scatter_indices = tl.load(s_scatter_idx_ptr + tk_grouped, mask=tk_mask, other=0).to(tl.uint32)
|
| 82 |
+
|
| 83 |
+
s = tl.load(s_ptr + scatter_indices, mask=tk_mask, other=0.0).to(tl.float32)
|
| 84 |
+
|
| 85 |
+
# Gather dout: [BLOCK_TK, BLOCK_H]
|
| 86 |
+
dout_offsets = token_indices[:, None] * H + h_offsets[None, :]
|
| 87 |
+
dout_mask = tk_mask[:, None] & h_mask[None, :]
|
| 88 |
+
dout = tl.load(dout_ptr + dout_offsets, mask=dout_mask, other=0.0).to(tl.float32)
|
| 89 |
+
|
| 90 |
+
# Accumulate db2: sum over tokens of (dout * s)
|
| 91 |
+
db2_acc += tl.sum(dout * s[:, None], axis=0) # Sum over BLOCK_TK dimension
|
| 92 |
+
|
| 93 |
+
# Compute ds: dot(dout, b2) for this H-block
|
| 94 |
+
ds_partial = tl.sum(dout * b2[None, :], axis=1) # [BLOCK_TK]
|
| 95 |
+
|
| 96 |
+
# On first H-block, add old_ds_partial.sum(dim=1)
|
| 97 |
+
if Hidx == 0:
|
| 98 |
+
n_offsets = tl.arange(0, BLOCK_OLD_DS_PARTIAL_N)
|
| 99 |
+
old_ds_partial_offsets = scatter_indices[:, None] * OLD_DS_PARTIAL_N + n_offsets[None, :]
|
| 100 |
+
old_ds_partial_mask = tk_mask[:, None] & (n_offsets[None, :] < OLD_DS_PARTIAL_N)
|
| 101 |
+
old_ds_partial_vals = tl.load(
|
| 102 |
+
old_ds_partial_ptr + old_ds_partial_offsets, mask=old_ds_partial_mask, other=0.0
|
| 103 |
+
).to(tl.float32)
|
| 104 |
+
ds_partial += tl.sum(old_ds_partial_vals, axis=1)
|
| 105 |
+
|
| 106 |
+
tl.store(new_ds_partial_ptr + scatter_indices * NUM_H_BLOCKS + Hidx, ds_partial, mask=tk_mask)
|
| 107 |
+
|
| 108 |
+
tl.store(db2_ptr + Eidx * H + h_offsets, db2_acc, mask=h_mask)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _get_autotune_configs_for_db1() -> list[triton.Config]:
|
| 112 |
+
configs = []
|
| 113 |
+
for BLOCK_TK in get_powers_of_2(4, 128):
|
| 114 |
+
for BLOCK_I in get_powers_of_2(64, 4096):
|
| 115 |
+
if 4096 <= BLOCK_I * BLOCK_TK <= 16384:
|
| 116 |
+
configs.append(triton.Config({"BLOCK_I": BLOCK_I, "BLOCK_TK": BLOCK_TK}, num_warps=8, num_stages=4))
|
| 117 |
+
return configs
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _prune_triton_autotune_config(configs, nargs, **kw):
|
| 121 |
+
pruned_configs = []
|
| 122 |
+
for c in configs:
|
| 123 |
+
if c.kwargs["BLOCK_I"] <= triton.next_power_of_2(nargs["I"]):
|
| 124 |
+
pruned_configs.append(c)
|
| 125 |
+
return pruned_configs
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@triton.autotune(
|
| 129 |
+
configs=_get_autotune_configs_for_db1(),
|
| 130 |
+
key=["I", "E"],
|
| 131 |
+
prune_configs_by={"early_config_prune": _prune_triton_autotune_config},
|
| 132 |
+
)
|
| 133 |
+
@triton.jit
|
| 134 |
+
def db1_kernel(
|
| 135 |
+
dz_ptr, # (T, H)
|
| 136 |
+
db1_ptr, # (E, H),
|
| 137 |
+
expert_offset_ptr, # (E+1,), offsets in grouped layout
|
| 138 |
+
I: tl.constexpr,
|
| 139 |
+
E: tl.constexpr,
|
| 140 |
+
BLOCK_I: tl.constexpr, # Block size for H dimension
|
| 141 |
+
BLOCK_TK: tl.constexpr, # Block size for token dimension
|
| 142 |
+
):
|
| 143 |
+
Eidx = tl.program_id(0) # expert id
|
| 144 |
+
|
| 145 |
+
E_count_start = tl.load(expert_offset_ptr + Eidx).to(tl.int64)
|
| 146 |
+
E_count_end = tl.load(expert_offset_ptr + Eidx + 1).to(tl.int64)
|
| 147 |
+
n_tokens = E_count_end - E_count_start
|
| 148 |
+
|
| 149 |
+
NUM_I_BLOCKS: tl.constexpr = triton.cdiv(I, BLOCK_I)
|
| 150 |
+
for Iidx in tl.static_range(0, NUM_I_BLOCKS, 1):
|
| 151 |
+
i_offsets = Iidx * BLOCK_I + tl.arange(0, BLOCK_I)
|
| 152 |
+
i_mask = i_offsets < I
|
| 153 |
+
|
| 154 |
+
db1_acc = tl.zeros([BLOCK_I], dtype=tl.float32)
|
| 155 |
+
|
| 156 |
+
# Process tokens in blocks of BLOCK_TK
|
| 157 |
+
for block_start in tl.range(0, n_tokens, BLOCK_TK):
|
| 158 |
+
# Token offsets within this block
|
| 159 |
+
tk_offsets = block_start + tl.arange(0, BLOCK_TK)
|
| 160 |
+
tk_mask = tk_offsets < n_tokens
|
| 161 |
+
tk_grouped = E_count_start + tk_offsets
|
| 162 |
+
|
| 163 |
+
dz_offsets = tk_grouped[:, None] * I + i_offsets[None, :]
|
| 164 |
+
dz_mask = tk_mask[:, None] & i_mask[None, :]
|
| 165 |
+
dz = tl.load(dz_ptr + dz_offsets, mask=dz_mask, other=0.0).to(tl.float32)
|
| 166 |
+
|
| 167 |
+
db1_acc += tl.sum(dz, axis=0) # Sum over BLOCK_TK dimension
|
| 168 |
+
|
| 169 |
+
db1_offsets = Eidx.to(tl.int64) * I + i_offsets
|
| 170 |
+
tl.store(db1_ptr + db1_offsets, db1_acc, mask=i_mask)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@triton.jit
|
| 174 |
+
def _colsum_smallN_kernel(
|
| 175 |
+
y_ptr, # *mut T, shape [M]
|
| 176 |
+
x_ptr, # *const T, shape [M, N]
|
| 177 |
+
stride_xm: tl.constexpr,
|
| 178 |
+
stride_xn: tl.constexpr, # strides of X
|
| 179 |
+
stride_y: tl.constexpr, # stride of Y (usually 1)
|
| 180 |
+
N: tl.constexpr, # sizes
|
| 181 |
+
BLOCK_N: tl.constexpr, # tile size along N
|
| 182 |
+
):
|
| 183 |
+
row = tl.program_id(0)
|
| 184 |
+
|
| 185 |
+
# assume BLOCK_N >= N
|
| 186 |
+
offs = tl.arange(0, BLOCK_N)
|
| 187 |
+
mask = offs < N
|
| 188 |
+
# Load a tile from the row; cast to fp32 for the reduction
|
| 189 |
+
x = tl.load(x_ptr + row * stride_xm + offs * stride_xn, mask=mask, other=0).to(tl.float32)
|
| 190 |
+
# Reduce this tile to a scalar and add
|
| 191 |
+
acc = tl.sum(x, axis=0)
|
| 192 |
+
|
| 193 |
+
# Store the row-sum (cast back to y dtype)
|
| 194 |
+
tl.store(y_ptr + row * stride_y, acc)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_act"), mutates_args={"dx_expanded", "db1"})
|
| 198 |
+
def _up_projection_backward_act(
|
| 199 |
+
w1: torch.Tensor,
|
| 200 |
+
dx_expanded: torch.Tensor,
|
| 201 |
+
dz: torch.Tensor,
|
| 202 |
+
db1: torch.Tensor | None,
|
| 203 |
+
expert_frequency_offset: torch.Tensor,
|
| 204 |
+
expert_schedule_order: torch.Tensor | None,
|
| 205 |
+
x_gather_idx: torch.Tensor,
|
| 206 |
+
s_scatter_idx: torch.Tensor,
|
| 207 |
+
is_glu_activation: bool,
|
| 208 |
+
stream_id: int,
|
| 209 |
+
) -> None:
|
| 210 |
+
I, H, E = w1.size()
|
| 211 |
+
if is_glu_activation:
|
| 212 |
+
I //= 2
|
| 213 |
+
|
| 214 |
+
# db1 computation
|
| 215 |
+
if db1 is not None:
|
| 216 |
+
db1_kernel[(E,)](dz, db1, expert_frequency_offset, (2 * I if is_glu_activation else I), E)
|
| 217 |
+
|
| 218 |
+
mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
|
| 219 |
+
mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 220 |
+
mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 221 |
+
mDz = convert_torch_tensor_to_cute_tensor(dz, (0, 1), 1, 16, 8, stream=stream_id)
|
| 222 |
+
mDx_expanded = convert_torch_tensor_to_cute_tensor(dx_expanded, (0, 1), 1, 16, 8, stream=stream_id)
|
| 223 |
+
mW1_trans = convert_torch_tensor_to_cute_tensor(w1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
|
| 224 |
+
|
| 225 |
+
if expert_schedule_order is None:
|
| 226 |
+
mE_permute_order = None
|
| 227 |
+
else:
|
| 228 |
+
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 229 |
+
current_stream = cuda.CUstream(stream_id)
|
| 230 |
+
|
| 231 |
+
compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype)
|
| 232 |
+
if compile_dx_key not in _up_projection_backward_act.compile_cache:
|
| 233 |
+
dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation)
|
| 234 |
+
tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)]
|
| 235 |
+
_up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile(
|
| 236 |
+
dx_module,
|
| 237 |
+
mDz,
|
| 238 |
+
mW1_trans,
|
| 239 |
+
mDx_expanded,
|
| 240 |
+
mE_offset,
|
| 241 |
+
mX_gather,
|
| 242 |
+
mS_scatter,
|
| 243 |
+
tensormaps,
|
| 244 |
+
mE_permute_order,
|
| 245 |
+
current_stream,
|
| 246 |
+
)
|
| 247 |
+
_up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] = tensormaps
|
| 248 |
+
|
| 249 |
+
dx_tensormaps = _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"]
|
| 250 |
+
_up_projection_backward_act.compile_cache[compile_dx_key](
|
| 251 |
+
mDz,
|
| 252 |
+
mW1_trans,
|
| 253 |
+
mDx_expanded,
|
| 254 |
+
mE_offset,
|
| 255 |
+
mX_gather,
|
| 256 |
+
mS_scatter,
|
| 257 |
+
dx_tensormaps,
|
| 258 |
+
mE_permute_order,
|
| 259 |
+
current_stream,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
_up_projection_backward_act.compile_cache = {}
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
@torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_weight"), mutates_args={"dw1"})
|
| 267 |
+
def _up_projection_backward_weight(
|
| 268 |
+
x: torch.Tensor,
|
| 269 |
+
dw1: torch.Tensor,
|
| 270 |
+
dz: torch.Tensor,
|
| 271 |
+
expert_frequency_offset: torch.Tensor,
|
| 272 |
+
expert_schedule_order: torch.Tensor | None,
|
| 273 |
+
x_gather_idx: torch.Tensor,
|
| 274 |
+
is_glu_activation: bool,
|
| 275 |
+
stream_id: int,
|
| 276 |
+
) -> None:
|
| 277 |
+
I, H, E = dw1.size()
|
| 278 |
+
if is_glu_activation:
|
| 279 |
+
I //= 2
|
| 280 |
+
|
| 281 |
+
x = x.detach()
|
| 282 |
+
|
| 283 |
+
mDz_trans = convert_torch_tensor_to_cute_tensor(dz.T, (1, 0), 0, 16, 8, stream=stream_id)
|
| 284 |
+
mDw1_trans = convert_torch_tensor_to_cute_tensor(dw1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
|
| 285 |
+
|
| 286 |
+
mX_trans = convert_torch_tensor_to_cute_tensor(x.T, (1, 0), 0, 16, 8, stream=stream_id)
|
| 287 |
+
mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
|
| 288 |
+
mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 289 |
+
|
| 290 |
+
if expert_schedule_order is None:
|
| 291 |
+
mE_permute_order = None
|
| 292 |
+
else:
|
| 293 |
+
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 294 |
+
current_stream = cuda.CUstream(stream_id)
|
| 295 |
+
|
| 296 |
+
compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype)
|
| 297 |
+
if compile_dw1_key not in _up_projection_backward_weight.compile_cache:
|
| 298 |
+
dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation)
|
| 299 |
+
tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)]
|
| 300 |
+
_up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile(
|
| 301 |
+
dw1_module,
|
| 302 |
+
mX_trans,
|
| 303 |
+
mDz_trans,
|
| 304 |
+
mDw1_trans,
|
| 305 |
+
mE_offset,
|
| 306 |
+
mX_gather,
|
| 307 |
+
tensormaps,
|
| 308 |
+
mE_permute_order,
|
| 309 |
+
current_stream,
|
| 310 |
+
)
|
| 311 |
+
_up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] = tensormaps
|
| 312 |
+
|
| 313 |
+
dw1_tensormaps = _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"]
|
| 314 |
+
_up_projection_backward_weight.compile_cache[compile_dw1_key](
|
| 315 |
+
mX_trans,
|
| 316 |
+
mDz_trans,
|
| 317 |
+
mDw1_trans,
|
| 318 |
+
mE_offset,
|
| 319 |
+
mX_gather,
|
| 320 |
+
dw1_tensormaps,
|
| 321 |
+
mE_permute_order,
|
| 322 |
+
current_stream,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
_up_projection_backward_weight.compile_cache = {}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"dz", "ds", "db2", "y1s"})
|
| 330 |
+
def _down_projection_backward_act(
|
| 331 |
+
dout: torch.Tensor,
|
| 332 |
+
z: torch.Tensor,
|
| 333 |
+
w2: torch.Tensor,
|
| 334 |
+
dz: torch.Tensor,
|
| 335 |
+
ds: torch.Tensor,
|
| 336 |
+
b2: torch.Tensor | None,
|
| 337 |
+
db2: torch.Tensor | None,
|
| 338 |
+
y1s: torch.Tensor,
|
| 339 |
+
topk_scores: torch.Tensor,
|
| 340 |
+
expert_frequency_offset: torch.Tensor,
|
| 341 |
+
expert_schedule_order: torch.Tensor | None,
|
| 342 |
+
x_gather_idx: torch.Tensor,
|
| 343 |
+
s_scatter_idx: torch.Tensor,
|
| 344 |
+
is_glu_activation: bool,
|
| 345 |
+
activation_type: str,
|
| 346 |
+
stream_id: int,
|
| 347 |
+
) -> None:
|
| 348 |
+
H, I, E = w2.size()
|
| 349 |
+
TK = x_gather_idx.size(0)
|
| 350 |
+
|
| 351 |
+
dout = dout.detach()
|
| 352 |
+
w2 = w2.detach()
|
| 353 |
+
topk_scores = topk_scores.detach()
|
| 354 |
+
|
| 355 |
+
mDout = convert_torch_tensor_to_cute_tensor(dout, (0, 1), 1, 16, 8, stream=stream_id)
|
| 356 |
+
mW2_trans = convert_torch_tensor_to_cute_tensor(w2.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
|
| 357 |
+
mS = convert_torch_tensor_to_cute_tensor(topk_scores, (0,), 0, 4, 1, stream=stream_id)
|
| 358 |
+
if is_glu_activation:
|
| 359 |
+
mDz_kernel_input = convert_torch_tensor_to_cute_tensor(
|
| 360 |
+
dz.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id
|
| 361 |
+
)
|
| 362 |
+
mZ_kernel_input = convert_torch_tensor_to_cute_tensor(
|
| 363 |
+
z.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
mDz_kernel_input = convert_torch_tensor_to_cute_tensor(dz.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 367 |
+
mZ_kernel_input = convert_torch_tensor_to_cute_tensor(z.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 368 |
+
|
| 369 |
+
mY1S = convert_torch_tensor_to_cute_tensor(y1s, (0, 1), 1, 16, 8, stream=stream_id)
|
| 370 |
+
mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
|
| 371 |
+
mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 372 |
+
mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 373 |
+
|
| 374 |
+
if expert_schedule_order is None:
|
| 375 |
+
mE_permute_order = None
|
| 376 |
+
else:
|
| 377 |
+
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 378 |
+
current_stream = cuda.CUstream(stream_id)
|
| 379 |
+
ds_partial = None
|
| 380 |
+
|
| 381 |
+
compile_dz_key = ("dz", E, H, I, z.dtype, activation_type)
|
| 382 |
+
if compile_dz_key not in _down_projection_backward_act.compile_cache:
|
| 383 |
+
# I don't know why but this sync appears to fix a mysterious initialization bug??
|
| 384 |
+
torch.cuda.synchronize()
|
| 385 |
+
dz_module = HopperWgmma_MoE_Down_proj_ActGrad_Bwd(E, H, I, ActivationType(activation_type))
|
| 386 |
+
tensormaps = [dz_module.module.generate_tensormap(None, None, None) for _ in range(3)]
|
| 387 |
+
|
| 388 |
+
ds_partial_N = max(ceil_divide(I, dz_module.module.tile_shape_mnk[1]), 1)
|
| 389 |
+
ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
|
| 390 |
+
mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
|
| 391 |
+
|
| 392 |
+
_down_projection_backward_act.compile_cache["ds_partial_N"] = ds_partial_N
|
| 393 |
+
_down_projection_backward_act.compile_cache[compile_dz_key] = cute.compile(
|
| 394 |
+
dz_module,
|
| 395 |
+
mDout,
|
| 396 |
+
mW2_trans,
|
| 397 |
+
mZ_kernel_input,
|
| 398 |
+
mDz_kernel_input,
|
| 399 |
+
mY1S,
|
| 400 |
+
mS,
|
| 401 |
+
mDS_partial,
|
| 402 |
+
mE_offset,
|
| 403 |
+
mX_gather,
|
| 404 |
+
mS_scatter,
|
| 405 |
+
tensormaps,
|
| 406 |
+
mE_permute_order,
|
| 407 |
+
current_stream,
|
| 408 |
+
)
|
| 409 |
+
_down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] = tensormaps
|
| 410 |
+
|
| 411 |
+
if ds_partial is None:
|
| 412 |
+
ds_partial_N = _down_projection_backward_act.compile_cache["ds_partial_N"]
|
| 413 |
+
ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
|
| 414 |
+
mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
|
| 415 |
+
|
| 416 |
+
dz_tensormaps = _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"]
|
| 417 |
+
_down_projection_backward_act.compile_cache[compile_dz_key](
|
| 418 |
+
mDout,
|
| 419 |
+
mW2_trans,
|
| 420 |
+
mZ_kernel_input,
|
| 421 |
+
mDz_kernel_input,
|
| 422 |
+
mY1S,
|
| 423 |
+
mS,
|
| 424 |
+
mDS_partial,
|
| 425 |
+
mE_offset,
|
| 426 |
+
mX_gather,
|
| 427 |
+
mS_scatter,
|
| 428 |
+
dz_tensormaps,
|
| 429 |
+
mE_permute_order,
|
| 430 |
+
current_stream,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if db2 is None:
|
| 434 |
+
# we don't need to update ds
|
| 435 |
+
if ds_partial.size(1) == 1:
|
| 436 |
+
ds.copy_(ds_partial.view(-1).to(dtype=ds.dtype))
|
| 437 |
+
elif ds_partial.size(1) <= 32:
|
| 438 |
+
ds.copy_(ds_partial.sum(dim=-1, dtype=ds.dtype))
|
| 439 |
+
else:
|
| 440 |
+
M, N = ds_partial.size()
|
| 441 |
+
|
| 442 |
+
_colsum_smallN_kernel[M,](
|
| 443 |
+
y_ptr=ds,
|
| 444 |
+
x_ptr=ds_partial,
|
| 445 |
+
stride_xm=ds_partial.stride(0),
|
| 446 |
+
stride_xn=ds_partial.stride(1),
|
| 447 |
+
stride_y=1,
|
| 448 |
+
N=N,
|
| 449 |
+
BLOCK_N=triton.next_power_of_2(N),
|
| 450 |
+
)
|
| 451 |
+
else:
|
| 452 |
+
# db2 and ds update
|
| 453 |
+
BLOCK_H = min(triton.next_power_of_2(H), 2048)
|
| 454 |
+
NUM_H_BLOCKS = triton.cdiv(H, BLOCK_H)
|
| 455 |
+
|
| 456 |
+
new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, device=ds.device, dtype=torch.float32)
|
| 457 |
+
|
| 458 |
+
db2_and_ds_kernel[(E, NUM_H_BLOCKS)](
|
| 459 |
+
dout,
|
| 460 |
+
topk_scores,
|
| 461 |
+
new_ds_partial,
|
| 462 |
+
ds_partial,
|
| 463 |
+
b2,
|
| 464 |
+
db2,
|
| 465 |
+
x_gather_idx,
|
| 466 |
+
s_scatter_idx,
|
| 467 |
+
expert_frequency_offset,
|
| 468 |
+
H,
|
| 469 |
+
E,
|
| 470 |
+
ds_partial_N,
|
| 471 |
+
BLOCK_H=BLOCK_H,
|
| 472 |
+
BLOCK_OLD_DS_PARTIAL_N=triton.next_power_of_2(ds_partial_N),
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
if NUM_H_BLOCKS == 1:
|
| 476 |
+
ds.copy_(new_ds_partial.view(-1).to(dtype=ds.dtype))
|
| 477 |
+
else:
|
| 478 |
+
ds.copy_(new_ds_partial.sum(dim=-1, dtype=ds.dtype))
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
_down_projection_backward_act.compile_cache = {}
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_weight"), mutates_args={"dw2"})
|
| 485 |
+
def _down_projection_backward_weight(
|
| 486 |
+
dout: torch.Tensor,
|
| 487 |
+
y1s: torch.Tensor,
|
| 488 |
+
dw2: torch.Tensor,
|
| 489 |
+
expert_frequency_offset: torch.Tensor,
|
| 490 |
+
expert_schedule_order: torch.Tensor | None,
|
| 491 |
+
x_gather_idx: torch.Tensor,
|
| 492 |
+
stream_id: int,
|
| 493 |
+
) -> None:
|
| 494 |
+
H, I, E = dw2.size()
|
| 495 |
+
|
| 496 |
+
mDout_trans = convert_torch_tensor_to_cute_tensor(dout.T, (1, 0), 0, 16, 8, stream=stream_id)
|
| 497 |
+
mDw2 = convert_torch_tensor_to_cute_tensor(dw2, (2, 0, 1), 1, 16, 8, stream=stream_id)
|
| 498 |
+
mY1S_trans = convert_torch_tensor_to_cute_tensor(y1s.T, (1, 0), 0, 16, 8, stream=stream_id)
|
| 499 |
+
mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
|
| 500 |
+
mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 501 |
+
|
| 502 |
+
if expert_schedule_order is None:
|
| 503 |
+
mE_permute_order = None
|
| 504 |
+
else:
|
| 505 |
+
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 506 |
+
current_stream = cuda.CUstream(stream_id)
|
| 507 |
+
|
| 508 |
+
compile_dw2_key = ("dw2", E, H, I, dw2.dtype)
|
| 509 |
+
if compile_dw2_key not in _down_projection_backward_weight.compile_cache:
|
| 510 |
+
dw2_module = HopperWgmma_MoE_Down_proj_WeightGrad_Bwd(E, H, I)
|
| 511 |
+
tensormaps = [dw2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
|
| 512 |
+
_down_projection_backward_weight.compile_cache[compile_dw2_key] = cute.compile(
|
| 513 |
+
dw2_module,
|
| 514 |
+
mDout_trans,
|
| 515 |
+
mY1S_trans,
|
| 516 |
+
mDw2,
|
| 517 |
+
mE_offset,
|
| 518 |
+
mX_gather,
|
| 519 |
+
tensormaps,
|
| 520 |
+
mE_permute_order,
|
| 521 |
+
current_stream,
|
| 522 |
+
)
|
| 523 |
+
_down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] = tensormaps
|
| 524 |
+
|
| 525 |
+
dw2_tensormaps = _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"]
|
| 526 |
+
_down_projection_backward_weight.compile_cache[compile_dw2_key](
|
| 527 |
+
mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, dw2_tensormaps, mE_permute_order, current_stream
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
_down_projection_backward_weight.compile_cache = {}
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@torch.library.custom_op(add_op_namespace_prefix("_token_broadcast_backward"), mutates_args={"dx_reduced"})
|
| 535 |
+
def _token_broadcast_backward(
|
| 536 |
+
dx_reduced: torch.Tensor,
|
| 537 |
+
dx_expanded: torch.Tensor,
|
| 538 |
+
s_reverse_scatter_idx: torch.Tensor,
|
| 539 |
+
num_activated_expert_per_token_offset: Optional[torch.Tensor],
|
| 540 |
+
varlen_K_max: int,
|
| 541 |
+
H: int,
|
| 542 |
+
is_varlen_K: bool,
|
| 543 |
+
) -> None:
|
| 544 |
+
if num_activated_expert_per_token_offset is None:
|
| 545 |
+
assert not is_varlen_K, "`num_activated_expert_per_token_offset` as None requires fixed top-K routing"
|
| 546 |
+
token_gather_and_sum_varlen_K_triton(
|
| 547 |
+
dx_expanded,
|
| 548 |
+
None,
|
| 549 |
+
dx_reduced,
|
| 550 |
+
s_reverse_scatter_idx,
|
| 551 |
+
num_activated_expert_per_token_offset,
|
| 552 |
+
dx_reduced.size(0),
|
| 553 |
+
varlen_K_max,
|
| 554 |
+
H,
|
| 555 |
+
is_varlen_K,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
@triton.jit
|
| 560 |
+
def _softmax_bwd_scatter_small_kernel(
|
| 561 |
+
dlogits_ptr,
|
| 562 |
+
dlogits_full_ptr,
|
| 563 |
+
score_ptr,
|
| 564 |
+
dscore_ptr,
|
| 565 |
+
idx_ptr,
|
| 566 |
+
stride_dm: tl.constexpr,
|
| 567 |
+
stride_dn: tl.constexpr,
|
| 568 |
+
stride_sm: tl.constexpr,
|
| 569 |
+
stride_sn: tl.constexpr,
|
| 570 |
+
stride_gm: tl.constexpr,
|
| 571 |
+
stride_gk: tl.constexpr,
|
| 572 |
+
stride_im: tl.constexpr,
|
| 573 |
+
stride_ik: tl.constexpr,
|
| 574 |
+
K: tl.constexpr,
|
| 575 |
+
BLOCK_K: tl.constexpr,
|
| 576 |
+
dlogits_is_none: tl.constexpr,
|
| 577 |
+
):
|
| 578 |
+
row = tl.program_id(axis=0)
|
| 579 |
+
|
| 580 |
+
# tl.assume(K <= BLOCK_K)
|
| 581 |
+
k_offs = tl.arange(0, BLOCK_K)
|
| 582 |
+
k_mask = k_offs < K
|
| 583 |
+
|
| 584 |
+
idx = tl.load(idx_ptr + row * stride_im + k_offs * stride_ik, mask=k_mask, other=0).to(tl.int32)
|
| 585 |
+
s_sel = tl.load(score_ptr + row * stride_sm + k_offs * stride_sn, mask=k_mask, other=0).to(tl.float32)
|
| 586 |
+
g_sel = tl.load(dscore_ptr + row * stride_gm + k_offs * stride_gk, mask=k_mask, other=0).to(tl.float32)
|
| 587 |
+
|
| 588 |
+
# dot = sum_j g_j * y_j over selected columns
|
| 589 |
+
dot = tl.sum(g_sel * s_sel, axis=0)
|
| 590 |
+
|
| 591 |
+
# scatter-only: dx[idx] += y_sel * (g_sel - dot)
|
| 592 |
+
add_vals = s_sel * (g_sel - dot)
|
| 593 |
+
|
| 594 |
+
indices = row * stride_dm + idx * stride_dn
|
| 595 |
+
if not dlogits_is_none:
|
| 596 |
+
add_vals += tl.load(dlogits_ptr + indices, mask=k_mask)
|
| 597 |
+
tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
@torch.library.custom_op(add_op_namespace_prefix("_softmax_topk_bwd"), mutates_args={"dlogits_full"})
|
| 601 |
+
def _softmax_topk_bwd(
|
| 602 |
+
dlogits_full: torch.Tensor,
|
| 603 |
+
dlogits: Optional[torch.Tensor],
|
| 604 |
+
dtopk_score: torch.Tensor,
|
| 605 |
+
topk_router_score: torch.Tensor,
|
| 606 |
+
topk_router_indices: torch.Tensor,
|
| 607 |
+
K: int,
|
| 608 |
+
) -> None:
|
| 609 |
+
T = dtopk_score.shape[0]
|
| 610 |
+
|
| 611 |
+
_softmax_bwd_scatter_small_kernel[T,](
|
| 612 |
+
dlogits,
|
| 613 |
+
dlogits_full,
|
| 614 |
+
topk_router_score,
|
| 615 |
+
dtopk_score,
|
| 616 |
+
topk_router_indices,
|
| 617 |
+
dlogits_full.stride(0),
|
| 618 |
+
dlogits_full.stride(1),
|
| 619 |
+
topk_router_score.stride(0),
|
| 620 |
+
topk_router_score.stride(1),
|
| 621 |
+
dtopk_score.stride(0),
|
| 622 |
+
dtopk_score.stride(1),
|
| 623 |
+
topk_router_indices.stride(0),
|
| 624 |
+
topk_router_indices.stride(1),
|
| 625 |
+
K,
|
| 626 |
+
triton.next_power_of_2(K),
|
| 627 |
+
(dlogits is None),
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
@triton.jit
|
| 632 |
+
def _topk_bwd_scatter_small_kernel(
|
| 633 |
+
dlogits_full_ptr,
|
| 634 |
+
dscore_ptr,
|
| 635 |
+
idx_ptr,
|
| 636 |
+
stride_dm: tl.constexpr,
|
| 637 |
+
stride_dn: tl.constexpr,
|
| 638 |
+
stride_gm: tl.constexpr,
|
| 639 |
+
stride_gk: tl.constexpr,
|
| 640 |
+
stride_im: tl.constexpr,
|
| 641 |
+
stride_ik: tl.constexpr,
|
| 642 |
+
K: tl.constexpr,
|
| 643 |
+
BLOCK_K: tl.constexpr,
|
| 644 |
+
):
|
| 645 |
+
row = tl.program_id(axis=0)
|
| 646 |
+
|
| 647 |
+
# tl.assume(K <= BLOCK_K)
|
| 648 |
+
k_offs = tl.arange(0, BLOCK_K)
|
| 649 |
+
k_mask = k_offs < K
|
| 650 |
+
|
| 651 |
+
idx = tl.load(idx_ptr + row * stride_im + k_offs * stride_ik, mask=k_mask, other=0).to(tl.int32)
|
| 652 |
+
g_sel = tl.load(dscore_ptr + row * stride_gm + k_offs * stride_gk, mask=k_mask, other=0).to(tl.float32)
|
| 653 |
+
|
| 654 |
+
# scatter-only: dx[idx] += y_sel * (g_sel - dot)
|
| 655 |
+
add_vals = g_sel
|
| 656 |
+
|
| 657 |
+
indices = row * stride_dm + idx * stride_dn
|
| 658 |
+
tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
@torch.library.custom_op(add_op_namespace_prefix("_topk_bwd"), mutates_args={"dlogits_full"})
|
| 662 |
+
def _topk_bwd(
|
| 663 |
+
dlogits_full: torch.Tensor,
|
| 664 |
+
dtopk_values: torch.Tensor,
|
| 665 |
+
topk_indices: torch.Tensor,
|
| 666 |
+
K: int,
|
| 667 |
+
) -> None:
|
| 668 |
+
T = dtopk_values.shape[0]
|
| 669 |
+
|
| 670 |
+
_topk_bwd_scatter_small_kernel[T,](
|
| 671 |
+
dlogits_full,
|
| 672 |
+
dtopk_values,
|
| 673 |
+
topk_indices,
|
| 674 |
+
dlogits_full.stride(0),
|
| 675 |
+
dlogits_full.stride(1),
|
| 676 |
+
dtopk_values.stride(0),
|
| 677 |
+
dtopk_values.stride(1),
|
| 678 |
+
topk_indices.stride(0),
|
| 679 |
+
topk_indices.stride(1),
|
| 680 |
+
K,
|
| 681 |
+
triton.next_power_of_2(K),
|
| 682 |
+
)
|
build/torch-cuda/functional/forward.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
import cuda.bindings.driver as cuda
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
from cutlass.cute.runtime import from_dlpack
|
| 11 |
+
from ..quack.cute_dsl_utils import torch2cute_dtype_map
|
| 12 |
+
|
| 13 |
+
from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType
|
| 14 |
+
from .._ops_compat import add_op_namespace_prefix
|
| 15 |
+
from ..utils import convert_torch_tensor_to_cute_tensor
|
| 16 |
+
from .moe_config import HopperWgmma_MoE_Down_proj_Fwd, HopperWgmma_MoE_Up_proj_Fwd
|
| 17 |
+
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
|
| 18 |
+
from .topk_softmax import TopK_Softmax
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"})
|
| 22 |
+
def _topk_fwd(
|
| 23 |
+
x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor, require_softmax_fusion: bool = True
|
| 24 |
+
) -> None:
|
| 25 |
+
"""Top-k forward pass.
|
| 26 |
+
Args:
|
| 27 |
+
x: Input tensor of shape (M, N)
|
| 28 |
+
k: Number of top elements to return
|
| 29 |
+
Returns:
|
| 30 |
+
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
| 31 |
+
"""
|
| 32 |
+
N = x.size(1)
|
| 33 |
+
|
| 34 |
+
input_dtype = torch2cute_dtype_map[x.dtype]
|
| 35 |
+
output_dtype = torch2cute_dtype_map[values.dtype]
|
| 36 |
+
convert_from_dlpack = lambda tensor: (
|
| 37 |
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
|
| 41 |
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 42 |
+
compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion)
|
| 43 |
+
if compile_key not in _topk_fwd.compile_cache:
|
| 44 |
+
topk_op = TopK_Softmax(input_dtype, output_dtype, N, k, require_softmax_fusion)
|
| 45 |
+
_topk_fwd.compile_cache[compile_key] = cute.compile(
|
| 46 |
+
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
|
| 47 |
+
)
|
| 48 |
+
_topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
_topk_fwd.compile_cache = {}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"z", "y1"})
|
| 55 |
+
def _up_projection_forward(
|
| 56 |
+
x: torch.Tensor,
|
| 57 |
+
w1: torch.Tensor,
|
| 58 |
+
z: torch.Tensor,
|
| 59 |
+
y1: torch.Tensor,
|
| 60 |
+
b1: torch.Tensor | None,
|
| 61 |
+
expert_frequency_offset: torch.Tensor,
|
| 62 |
+
expert_schedule_order: torch.Tensor,
|
| 63 |
+
x_gather_idx: torch.Tensor,
|
| 64 |
+
stream_id: int,
|
| 65 |
+
activation_type: str,
|
| 66 |
+
is_glu_activation: bool,
|
| 67 |
+
is_inference_mode_enabled: bool = False,
|
| 68 |
+
) -> None:
|
| 69 |
+
I, H, E = w1.size()
|
| 70 |
+
if is_glu_activation:
|
| 71 |
+
I //= 2
|
| 72 |
+
|
| 73 |
+
mX = convert_torch_tensor_to_cute_tensor(x.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 74 |
+
mW1 = convert_torch_tensor_to_cute_tensor(w1.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id)
|
| 75 |
+
mZ = convert_torch_tensor_to_cute_tensor(z, (0, 1), 1, 16, 8, stream=stream_id)
|
| 76 |
+
mY1 = convert_torch_tensor_to_cute_tensor(y1, (0, 1), 1, 16, 8, stream=stream_id)
|
| 77 |
+
mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
|
| 78 |
+
mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 79 |
+
|
| 80 |
+
if expert_schedule_order is None:
|
| 81 |
+
mE_permute_order = None
|
| 82 |
+
else:
|
| 83 |
+
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 84 |
+
|
| 85 |
+
if b1 is None:
|
| 86 |
+
mB1 = None
|
| 87 |
+
else:
|
| 88 |
+
mB1 = convert_torch_tensor_to_cute_tensor(b1.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 89 |
+
|
| 90 |
+
current_stream = cuda.CUstream(stream_id)
|
| 91 |
+
|
| 92 |
+
compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled)
|
| 93 |
+
if compile_w1_key not in _up_projection_forward.compile_cache:
|
| 94 |
+
w1_module = HopperWgmma_MoE_Up_proj_Fwd(
|
| 95 |
+
E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled
|
| 96 |
+
)
|
| 97 |
+
tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)]
|
| 98 |
+
_up_projection_forward.compile_cache[compile_w1_key] = cute.compile(
|
| 99 |
+
w1_module,
|
| 100 |
+
mX,
|
| 101 |
+
mW1,
|
| 102 |
+
mZ,
|
| 103 |
+
mY1,
|
| 104 |
+
mB1,
|
| 105 |
+
mE_offset,
|
| 106 |
+
mX_gather,
|
| 107 |
+
tensormaps[0],
|
| 108 |
+
tensormaps[1],
|
| 109 |
+
mE_permute_order,
|
| 110 |
+
current_stream,
|
| 111 |
+
)
|
| 112 |
+
_up_projection_forward.compile_cache[TENSORMAP] = tensormaps
|
| 113 |
+
|
| 114 |
+
w1_tensormaps = _up_projection_forward.compile_cache[TENSORMAP]
|
| 115 |
+
_up_projection_forward.compile_cache[compile_w1_key](
|
| 116 |
+
mX,
|
| 117 |
+
mW1,
|
| 118 |
+
mZ,
|
| 119 |
+
mY1,
|
| 120 |
+
mB1,
|
| 121 |
+
mE_offset,
|
| 122 |
+
mX_gather,
|
| 123 |
+
w1_tensormaps[0],
|
| 124 |
+
w1_tensormaps[1],
|
| 125 |
+
mE_permute_order,
|
| 126 |
+
current_stream,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
_up_projection_forward.compile_cache = {}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y2"})
|
| 134 |
+
def _down_projection_forward(
|
| 135 |
+
w2: torch.Tensor,
|
| 136 |
+
y1: torch.Tensor,
|
| 137 |
+
y2: torch.Tensor,
|
| 138 |
+
b2: torch.Tensor | None,
|
| 139 |
+
expert_frequency_offset: torch.Tensor,
|
| 140 |
+
expert_schedule_order: torch.Tensor,
|
| 141 |
+
x_gather_idx: torch.Tensor,
|
| 142 |
+
stream_id: int,
|
| 143 |
+
) -> None:
|
| 144 |
+
H, I, E = w2.size()
|
| 145 |
+
|
| 146 |
+
mW2 = convert_torch_tensor_to_cute_tensor(w2.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id)
|
| 147 |
+
mY1 = convert_torch_tensor_to_cute_tensor(y1.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 148 |
+
mY2 = convert_torch_tensor_to_cute_tensor(y2, (0, 1), 1, 16, 8, stream=stream_id)
|
| 149 |
+
mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
|
| 150 |
+
mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 151 |
+
|
| 152 |
+
if expert_schedule_order is None:
|
| 153 |
+
mE_permute_order = None
|
| 154 |
+
else:
|
| 155 |
+
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 156 |
+
|
| 157 |
+
if b2 is None:
|
| 158 |
+
mB2 = None
|
| 159 |
+
else:
|
| 160 |
+
mB2 = convert_torch_tensor_to_cute_tensor(b2.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 161 |
+
|
| 162 |
+
current_stream = cuda.CUstream(stream_id)
|
| 163 |
+
|
| 164 |
+
compile_w2_key = (E, H, I, (b2 is None), w2.dtype)
|
| 165 |
+
if compile_w2_key not in _down_projection_forward.compile_cache:
|
| 166 |
+
w2_module = HopperWgmma_MoE_Down_proj_Fwd(E, H, I)
|
| 167 |
+
tensormaps = [w2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
|
| 168 |
+
_down_projection_forward.compile_cache[compile_w2_key] = cute.compile(
|
| 169 |
+
w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream
|
| 170 |
+
)
|
| 171 |
+
_down_projection_forward.compile_cache[TENSORMAP] = tensormaps
|
| 172 |
+
|
| 173 |
+
w2_tensormaps = _down_projection_forward.compile_cache[TENSORMAP]
|
| 174 |
+
_down_projection_forward.compile_cache[compile_w2_key](
|
| 175 |
+
mY1, mW2, mY2, mB2, mE_offset, mX_gather, w2_tensormaps[0], mE_permute_order, current_stream
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
_down_projection_forward.compile_cache = {}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"})
|
| 183 |
+
def _router_forward(
|
| 184 |
+
y2: torch.Tensor,
|
| 185 |
+
o: torch.Tensor,
|
| 186 |
+
topk_scores: torch.Tensor,
|
| 187 |
+
s_reverse_scatter_idx: torch.Tensor,
|
| 188 |
+
num_activated_expert_per_token_offset: torch.Tensor,
|
| 189 |
+
varlen_K_max: int,
|
| 190 |
+
H: int,
|
| 191 |
+
is_varlen_K: bool,
|
| 192 |
+
) -> None:
|
| 193 |
+
token_gather_and_sum_varlen_K_triton(
|
| 194 |
+
y2,
|
| 195 |
+
topk_scores,
|
| 196 |
+
o,
|
| 197 |
+
s_reverse_scatter_idx,
|
| 198 |
+
num_activated_expert_per_token_offset,
|
| 199 |
+
o.size(0),
|
| 200 |
+
varlen_K_max,
|
| 201 |
+
H,
|
| 202 |
+
is_varlen_K,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@triton.jit
|
| 207 |
+
def _softmax_fwd_small_kernel(
|
| 208 |
+
logits_ptr, stride_lm: tl.constexpr, stride_ln: tl.constexpr, K: tl.constexpr, BLOCK_K: tl.constexpr
|
| 209 |
+
):
|
| 210 |
+
row = tl.program_id(axis=0)
|
| 211 |
+
|
| 212 |
+
# tl.assume(K <= BLOCK_K)
|
| 213 |
+
k_offs = tl.arange(0, BLOCK_K)
|
| 214 |
+
k_mask = k_offs < K
|
| 215 |
+
|
| 216 |
+
# load full row (all columns) in one go (N is small)
|
| 217 |
+
x = tl.load(logits_ptr + row * stride_lm + k_offs * stride_ln, mask=k_mask, other=-float("inf")).to(tl.float32)
|
| 218 |
+
x = x - tl.max(x, axis=0)
|
| 219 |
+
ex = tl.exp(x)
|
| 220 |
+
y = ex / tl.sum(ex, axis=0)
|
| 221 |
+
|
| 222 |
+
tl.store(logits_ptr + row * stride_lm + k_offs * stride_ln, y, mask=k_mask)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
@torch.library.custom_op(
|
| 226 |
+
add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"}
|
| 227 |
+
)
|
| 228 |
+
def _softmax_topk_fwd(
|
| 229 |
+
router_logits: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, E: int, K: int
|
| 230 |
+
) -> None:
|
| 231 |
+
# T = router_logits.shape[0]
|
| 232 |
+
if E <= 4096 and K <= 16 and E % 8 == 0:
|
| 233 |
+
# fast topk-softmax fusion that covers most common MoE configs
|
| 234 |
+
_topk_fwd(router_logits, K, topk_router_score, topk_router_indices, require_softmax_fusion=True)
|
| 235 |
+
else:
|
| 236 |
+
topk_results = router_logits.topk(K, dim=-1)
|
| 237 |
+
topk_router_score.copy_(topk_results.values.softmax(dim=-1, dtype=torch.float32).to(topk_router_score.dtype))
|
| 238 |
+
topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
|
build/torch-cuda/functional/grouped_gemm.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch-cuda/functional/moe_config.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
import cuda.bindings.driver as cuda
|
| 9 |
+
import cutlass
|
| 10 |
+
import cutlass.cute as cute
|
| 11 |
+
import torch
|
| 12 |
+
from cutlass import const_expr
|
| 13 |
+
from ..quack.tile_scheduler import RasterOrderOption
|
| 14 |
+
|
| 15 |
+
from ..enums import ActivationType, is_glu
|
| 16 |
+
from .grouped_gemm import HopperWgmma_MoE_kernel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
LIBRARY_NAME = "cutedsl_kernels"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def ceil_div(a: int, b: int):
|
| 23 |
+
return int(math.ceil(a / b))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class HopperGEMMConfig:
|
| 28 |
+
tile_shape_mnk: cutlass.Constexpr[cute.Shape] = (128, 256, 64)
|
| 29 |
+
cluster_shape_mnk: cutlass.Constexpr[cute.Shape] = (2, 1)
|
| 30 |
+
epi_tile_size: cutlass.Constexpr[int] = 32
|
| 31 |
+
## assume we always use persistent kernel
|
| 32 |
+
# is_persistent: cutlass.Constexpr[bool] = True
|
| 33 |
+
is_pingpong: cutlass.Constexpr[bool] = False
|
| 34 |
+
raster_order: RasterOrderOption = RasterOrderOption.Heuristic
|
| 35 |
+
L2_group_size: int = 8
|
| 36 |
+
initial_d_epi_stage: cutlass.Constexpr[int] = 4
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class HopperWgmma_MoE_Up_proj_Fwd:
|
| 40 |
+
def __init__(self, E: int, H: int, I: int, activation_type: ActivationType, inference_mode=False):
|
| 41 |
+
super().__init__()
|
| 42 |
+
is_glu_activation = is_glu(activation_type)
|
| 43 |
+
if is_glu_activation:
|
| 44 |
+
assert (
|
| 45 |
+
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 46 |
+
), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 47 |
+
else:
|
| 48 |
+
assert (
|
| 49 |
+
H % 64 == 0 and H >= 512 and I % 128 == 0
|
| 50 |
+
), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
|
| 51 |
+
# TODO: this assertion does not mean that the MoE impl prohibits such config.
|
| 52 |
+
# Instead, we just do not search for the best configs manually yet for small-shaped MoE
|
| 53 |
+
if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation):
|
| 54 |
+
up_config = HopperGEMMConfig(
|
| 55 |
+
tile_shape_mnk=(128, 256, 64),
|
| 56 |
+
cluster_shape_mnk=(2, 1),
|
| 57 |
+
epi_tile_size=(32 if not inference_mode else 64),
|
| 58 |
+
is_pingpong=False,
|
| 59 |
+
initial_d_epi_stage=2,
|
| 60 |
+
raster_order=RasterOrderOption.AlongM,
|
| 61 |
+
)
|
| 62 |
+
elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation):
|
| 63 |
+
up_config = HopperGEMMConfig(
|
| 64 |
+
tile_shape_mnk=(192, 128, 64),
|
| 65 |
+
cluster_shape_mnk=(1, 1),
|
| 66 |
+
epi_tile_size=(32 if not inference_mode else 64),
|
| 67 |
+
is_pingpong=True,
|
| 68 |
+
initial_d_epi_stage=8,
|
| 69 |
+
raster_order=RasterOrderOption.AlongM,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
raise NotImplementedError()
|
| 73 |
+
|
| 74 |
+
compute_swiglu = False
|
| 75 |
+
compute_geglu = False
|
| 76 |
+
compute_reglu = False
|
| 77 |
+
|
| 78 |
+
compute_relu_sq = False
|
| 79 |
+
compute_silu = False
|
| 80 |
+
compute_relu = False
|
| 81 |
+
compute_gelu = False
|
| 82 |
+
|
| 83 |
+
if activation_type == ActivationType.SWIGLU:
|
| 84 |
+
compute_swiglu = True
|
| 85 |
+
elif activation_type == ActivationType.GEGLU:
|
| 86 |
+
compute_geglu = True
|
| 87 |
+
elif activation_type == ActivationType.REGLU:
|
| 88 |
+
compute_reglu = True
|
| 89 |
+
|
| 90 |
+
elif activation_type == ActivationType.RELU_SQ:
|
| 91 |
+
compute_relu_sq = True
|
| 92 |
+
elif activation_type == ActivationType.RELU:
|
| 93 |
+
compute_relu = True
|
| 94 |
+
elif activation_type == ActivationType.SILU:
|
| 95 |
+
compute_silu = True
|
| 96 |
+
elif activation_type == ActivationType.GELU:
|
| 97 |
+
compute_gelu = True
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
raise NotImplementedError(f"Activation function {activation_type} not supported yet!")
|
| 101 |
+
|
| 102 |
+
self.module = HopperWgmma_MoE_kernel(
|
| 103 |
+
E,
|
| 104 |
+
cutlass.Float32,
|
| 105 |
+
up_config.tile_shape_mnk,
|
| 106 |
+
(*up_config.cluster_shape_mnk, 1),
|
| 107 |
+
pingpong=up_config.is_pingpong,
|
| 108 |
+
is_persistent=True,
|
| 109 |
+
compute_swiglu=compute_swiglu,
|
| 110 |
+
compute_reglu=compute_reglu,
|
| 111 |
+
compute_geglu=compute_geglu,
|
| 112 |
+
compute_relu_sq=compute_relu_sq,
|
| 113 |
+
compute_relu=compute_relu,
|
| 114 |
+
compute_silu=compute_silu,
|
| 115 |
+
compute_gelu=compute_gelu,
|
| 116 |
+
is_A_gather=True,
|
| 117 |
+
epi_tile_size=up_config.epi_tile_size,
|
| 118 |
+
initial_d_epi_stage=up_config.initial_d_epi_stage,
|
| 119 |
+
inference_mode=inference_mode,
|
| 120 |
+
)
|
| 121 |
+
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 122 |
+
up_config.cluster_shape_mnk[0] * up_config.cluster_shape_mnk[1]
|
| 123 |
+
)
|
| 124 |
+
self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 125 |
+
|
| 126 |
+
@cute.jit
|
| 127 |
+
def __call__(
|
| 128 |
+
self, mX, mW1, mZ, mY1, mB1, mE_offset, mX_gather, mD_tensormap, mY1_tensormap, mE_permute_order, stream
|
| 129 |
+
):
|
| 130 |
+
return self.module(
|
| 131 |
+
mX,
|
| 132 |
+
mW1,
|
| 133 |
+
None,
|
| 134 |
+
mB1,
|
| 135 |
+
mZ,
|
| 136 |
+
mY1,
|
| 137 |
+
None,
|
| 138 |
+
None,
|
| 139 |
+
mE_offset,
|
| 140 |
+
mX_gather,
|
| 141 |
+
None,
|
| 142 |
+
None,
|
| 143 |
+
None,
|
| 144 |
+
None,
|
| 145 |
+
None,
|
| 146 |
+
mD_tensormap,
|
| 147 |
+
mY1_tensormap,
|
| 148 |
+
None,
|
| 149 |
+
mE_permute_order,
|
| 150 |
+
const_expr(self.max_active_clusters),
|
| 151 |
+
stream,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class HopperWgmma_MoE_Down_proj_Fwd:
|
| 156 |
+
def __init__(self, E: int, H: int, I: int):
|
| 157 |
+
super().__init__()
|
| 158 |
+
assert (
|
| 159 |
+
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 160 |
+
), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 161 |
+
if I >= 1024:
|
| 162 |
+
down_config = HopperGEMMConfig(
|
| 163 |
+
tile_shape_mnk=(128, 256, 64),
|
| 164 |
+
cluster_shape_mnk=(2, 1),
|
| 165 |
+
epi_tile_size=32,
|
| 166 |
+
is_pingpong=False,
|
| 167 |
+
initial_d_epi_stage=4,
|
| 168 |
+
raster_order=RasterOrderOption.AlongN,
|
| 169 |
+
)
|
| 170 |
+
elif I >= 256:
|
| 171 |
+
down_config = HopperGEMMConfig(
|
| 172 |
+
tile_shape_mnk=(128, 192, 64),
|
| 173 |
+
cluster_shape_mnk=(2, 1),
|
| 174 |
+
epi_tile_size=(96 if H % 96 == 0 else 64),
|
| 175 |
+
is_pingpong=True,
|
| 176 |
+
initial_d_epi_stage=5,
|
| 177 |
+
raster_order=RasterOrderOption.AlongN,
|
| 178 |
+
)
|
| 179 |
+
elif I >= 64:
|
| 180 |
+
down_config = HopperGEMMConfig(
|
| 181 |
+
tile_shape_mnk=(128, 192, 64),
|
| 182 |
+
cluster_shape_mnk=(1, 2),
|
| 183 |
+
epi_tile_size=64,
|
| 184 |
+
is_pingpong=True,
|
| 185 |
+
initial_d_epi_stage=8,
|
| 186 |
+
raster_order=RasterOrderOption.AlongN,
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
raise NotImplementedError()
|
| 190 |
+
|
| 191 |
+
self.module = HopperWgmma_MoE_kernel(
|
| 192 |
+
E,
|
| 193 |
+
cutlass.Float32,
|
| 194 |
+
down_config.tile_shape_mnk,
|
| 195 |
+
(*down_config.cluster_shape_mnk, 1),
|
| 196 |
+
pingpong=down_config.is_pingpong,
|
| 197 |
+
is_persistent=True,
|
| 198 |
+
compute_swiglu=False,
|
| 199 |
+
is_A_gather=False,
|
| 200 |
+
epi_tile_size=down_config.epi_tile_size,
|
| 201 |
+
initial_d_epi_stage=down_config.initial_d_epi_stage,
|
| 202 |
+
)
|
| 203 |
+
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 204 |
+
down_config.cluster_shape_mnk[0] * down_config.cluster_shape_mnk[1]
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
@cute.jit
|
| 208 |
+
def __call__(self, mY1, mW2, mY2, mB2, mE_offset, mX_gather, mD_tensormap, mE_permute_order, stream):
|
| 209 |
+
# we are not really using mX_gather in the Grouped GEMM,
|
| 210 |
+
# but CuTe-DSL compiler disallows dynamic flow so we still need to pass this argument
|
| 211 |
+
return self.module(
|
| 212 |
+
mY1,
|
| 213 |
+
mW2,
|
| 214 |
+
None,
|
| 215 |
+
mB2,
|
| 216 |
+
mY2,
|
| 217 |
+
None,
|
| 218 |
+
None,
|
| 219 |
+
None,
|
| 220 |
+
mE_offset,
|
| 221 |
+
mX_gather,
|
| 222 |
+
None,
|
| 223 |
+
None,
|
| 224 |
+
None,
|
| 225 |
+
None,
|
| 226 |
+
None,
|
| 227 |
+
mD_tensormap,
|
| 228 |
+
None,
|
| 229 |
+
None,
|
| 230 |
+
mE_permute_order,
|
| 231 |
+
const_expr(self.max_active_clusters),
|
| 232 |
+
stream,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class HopperWgmma_MoE_Down_proj_ActGrad_Bwd:
|
| 237 |
+
def __init__(self, E: int, H: int, I: int, activation_type: ActivationType):
|
| 238 |
+
super().__init__()
|
| 239 |
+
is_glu_activation = is_glu(activation_type)
|
| 240 |
+
if is_glu_activation:
|
| 241 |
+
assert (
|
| 242 |
+
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 243 |
+
), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 244 |
+
else:
|
| 245 |
+
assert (
|
| 246 |
+
H % 64 == 0 and H >= 512 and I % 128 == 0
|
| 247 |
+
), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
|
| 248 |
+
|
| 249 |
+
# heavy register pressure due to pingpong + heavy epilogue
|
| 250 |
+
# effectively no alternatives to this config
|
| 251 |
+
dz_partial_ds_config = HopperGEMMConfig(
|
| 252 |
+
tile_shape_mnk=(128, 128, 64),
|
| 253 |
+
cluster_shape_mnk=(2, 1),
|
| 254 |
+
epi_tile_size=32,
|
| 255 |
+
initial_d_epi_stage=4,
|
| 256 |
+
is_pingpong=True,
|
| 257 |
+
raster_order=RasterOrderOption.Heuristic,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
compute_swiglu = False
|
| 261 |
+
compute_geglu = False
|
| 262 |
+
compute_reglu = False
|
| 263 |
+
|
| 264 |
+
compute_relu_sq = False
|
| 265 |
+
compute_silu = False
|
| 266 |
+
compute_relu = False
|
| 267 |
+
compute_gelu = False
|
| 268 |
+
|
| 269 |
+
if activation_type == ActivationType.SWIGLU:
|
| 270 |
+
compute_swiglu = True
|
| 271 |
+
elif activation_type == ActivationType.GEGLU:
|
| 272 |
+
compute_geglu = True
|
| 273 |
+
elif activation_type == ActivationType.REGLU:
|
| 274 |
+
compute_reglu = True
|
| 275 |
+
|
| 276 |
+
elif activation_type == ActivationType.RELU_SQ:
|
| 277 |
+
compute_relu_sq = True
|
| 278 |
+
elif activation_type == ActivationType.RELU:
|
| 279 |
+
compute_relu = True
|
| 280 |
+
elif activation_type == ActivationType.SILU:
|
| 281 |
+
compute_silu = True
|
| 282 |
+
elif activation_type == ActivationType.GELU:
|
| 283 |
+
compute_gelu = True
|
| 284 |
+
|
| 285 |
+
else:
|
| 286 |
+
raise NotImplementedError(f"Activation function {activation_type} not supported yet!")
|
| 287 |
+
|
| 288 |
+
self.module = HopperWgmma_MoE_kernel(
|
| 289 |
+
E,
|
| 290 |
+
cutlass.Float32,
|
| 291 |
+
dz_partial_ds_config.tile_shape_mnk,
|
| 292 |
+
(*dz_partial_ds_config.cluster_shape_mnk, 1),
|
| 293 |
+
pingpong=dz_partial_ds_config.is_pingpong,
|
| 294 |
+
is_persistent=True,
|
| 295 |
+
compute_swiglu=compute_swiglu,
|
| 296 |
+
compute_reglu=compute_reglu,
|
| 297 |
+
compute_geglu=compute_geglu,
|
| 298 |
+
compute_relu_sq=compute_relu_sq,
|
| 299 |
+
compute_relu=compute_relu,
|
| 300 |
+
compute_silu=compute_silu,
|
| 301 |
+
compute_gelu=compute_gelu,
|
| 302 |
+
compute_dz_and_partial_ds_and_y1s=True,
|
| 303 |
+
is_A_gather=True,
|
| 304 |
+
epi_tile_size=dz_partial_ds_config.epi_tile_size,
|
| 305 |
+
initial_d_epi_stage=dz_partial_ds_config.initial_d_epi_stage,
|
| 306 |
+
)
|
| 307 |
+
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 308 |
+
dz_partial_ds_config.cluster_shape_mnk[0] * dz_partial_ds_config.cluster_shape_mnk[1]
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
@cute.jit
|
| 312 |
+
def __call__(
|
| 313 |
+
self,
|
| 314 |
+
mDout,
|
| 315 |
+
mW2_trans,
|
| 316 |
+
mZ_FP32_if_GLU_else_BF16,
|
| 317 |
+
mDz_FP32_if_GLU_else_BF16,
|
| 318 |
+
mY1S,
|
| 319 |
+
mS,
|
| 320 |
+
mDS_partial,
|
| 321 |
+
mE_offset,
|
| 322 |
+
mX_gather,
|
| 323 |
+
mS_scatter,
|
| 324 |
+
tensormaps,
|
| 325 |
+
mE_permute_order,
|
| 326 |
+
stream,
|
| 327 |
+
):
|
| 328 |
+
return self.module(
|
| 329 |
+
mDout,
|
| 330 |
+
mW2_trans,
|
| 331 |
+
mZ_FP32_if_GLU_else_BF16,
|
| 332 |
+
None,
|
| 333 |
+
mDz_FP32_if_GLU_else_BF16,
|
| 334 |
+
mY1S,
|
| 335 |
+
mS,
|
| 336 |
+
mDS_partial,
|
| 337 |
+
mE_offset,
|
| 338 |
+
mX_gather,
|
| 339 |
+
None,
|
| 340 |
+
mS_scatter,
|
| 341 |
+
None,
|
| 342 |
+
None,
|
| 343 |
+
tensormaps[0],
|
| 344 |
+
tensormaps[1],
|
| 345 |
+
tensormaps[2],
|
| 346 |
+
None,
|
| 347 |
+
mE_permute_order,
|
| 348 |
+
const_expr(self.max_active_clusters),
|
| 349 |
+
stream,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class HopperWgmma_MoE_Down_proj_WeightGrad_Bwd:
|
| 354 |
+
def __init__(self, E: int, H: int, I: int):
|
| 355 |
+
super().__init__()
|
| 356 |
+
assert (
|
| 357 |
+
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 358 |
+
), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 359 |
+
|
| 360 |
+
if I >= 128:
|
| 361 |
+
dw2_config = HopperGEMMConfig(
|
| 362 |
+
tile_shape_mnk=(128, 256, 64),
|
| 363 |
+
cluster_shape_mnk=(2, 1),
|
| 364 |
+
epi_tile_size=16,
|
| 365 |
+
is_pingpong=False,
|
| 366 |
+
initial_d_epi_stage=6,
|
| 367 |
+
raster_order=RasterOrderOption.AlongN,
|
| 368 |
+
)
|
| 369 |
+
elif I == 64:
|
| 370 |
+
dw2_config = HopperGEMMConfig(
|
| 371 |
+
tile_shape_mnk=(64, 192, 64),
|
| 372 |
+
cluster_shape_mnk=(2, 1),
|
| 373 |
+
epi_tile_size=32,
|
| 374 |
+
is_pingpong=True,
|
| 375 |
+
initial_d_epi_stage=6,
|
| 376 |
+
raster_order=RasterOrderOption.AlongN,
|
| 377 |
+
)
|
| 378 |
+
else:
|
| 379 |
+
raise NotImplementedError()
|
| 380 |
+
|
| 381 |
+
self.module = HopperWgmma_MoE_kernel(
|
| 382 |
+
E,
|
| 383 |
+
cutlass.Float32,
|
| 384 |
+
dw2_config.tile_shape_mnk,
|
| 385 |
+
(*dw2_config.cluster_shape_mnk, 1),
|
| 386 |
+
pingpong=dw2_config.is_pingpong,
|
| 387 |
+
is_persistent=True,
|
| 388 |
+
compute_swiglu=False,
|
| 389 |
+
compute_weight_gradient=True,
|
| 390 |
+
compute_dz_and_partial_ds_and_y1s=False,
|
| 391 |
+
is_A_gather=True,
|
| 392 |
+
epi_tile_size=dw2_config.epi_tile_size,
|
| 393 |
+
initial_d_epi_stage=dw2_config.initial_d_epi_stage,
|
| 394 |
+
)
|
| 395 |
+
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 396 |
+
dw2_config.cluster_shape_mnk[0] * dw2_config.cluster_shape_mnk[1]
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
@cute.jit
|
| 400 |
+
def __call__(self, mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
|
| 401 |
+
return self.module(
|
| 402 |
+
mDout_trans,
|
| 403 |
+
mY1S_trans,
|
| 404 |
+
None,
|
| 405 |
+
None,
|
| 406 |
+
mDw2,
|
| 407 |
+
None,
|
| 408 |
+
None,
|
| 409 |
+
None,
|
| 410 |
+
mE_offset,
|
| 411 |
+
mX_gather,
|
| 412 |
+
None,
|
| 413 |
+
None,
|
| 414 |
+
None,
|
| 415 |
+
tensormaps[0],
|
| 416 |
+
None,
|
| 417 |
+
None,
|
| 418 |
+
None,
|
| 419 |
+
None,
|
| 420 |
+
mE_permute_order,
|
| 421 |
+
const_expr(self.max_active_clusters),
|
| 422 |
+
stream,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class HopperWgmma_MoE_Up_proj_ActGrad_Bwd:
|
| 427 |
+
def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
|
| 428 |
+
super().__init__()
|
| 429 |
+
if is_glu_activation:
|
| 430 |
+
assert (
|
| 431 |
+
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 432 |
+
), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 433 |
+
else:
|
| 434 |
+
assert (
|
| 435 |
+
H % 64 == 0 and H >= 512 and I % 128 == 0
|
| 436 |
+
), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
|
| 437 |
+
|
| 438 |
+
if (I >= 512 and is_glu_activation) or (I >= 1024 and not is_glu_activation):
|
| 439 |
+
dx_config = HopperGEMMConfig(
|
| 440 |
+
tile_shape_mnk=(128, 256, 64),
|
| 441 |
+
cluster_shape_mnk=(2, 1),
|
| 442 |
+
epi_tile_size=32,
|
| 443 |
+
is_pingpong=False,
|
| 444 |
+
initial_d_epi_stage=4,
|
| 445 |
+
raster_order=RasterOrderOption.AlongN,
|
| 446 |
+
)
|
| 447 |
+
elif (I >= 64 and is_glu_activation) or (I >= 128 and not is_glu_activation):
|
| 448 |
+
dx_config = HopperGEMMConfig(
|
| 449 |
+
tile_shape_mnk=(128, 192, 64),
|
| 450 |
+
cluster_shape_mnk=(2, 1),
|
| 451 |
+
epi_tile_size=64,
|
| 452 |
+
is_pingpong=True,
|
| 453 |
+
initial_d_epi_stage=8,
|
| 454 |
+
raster_order=RasterOrderOption.AlongN,
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
raise NotImplementedError()
|
| 458 |
+
|
| 459 |
+
self.module = HopperWgmma_MoE_kernel(
|
| 460 |
+
E,
|
| 461 |
+
cutlass.Float32,
|
| 462 |
+
dx_config.tile_shape_mnk,
|
| 463 |
+
(*dx_config.cluster_shape_mnk, 1),
|
| 464 |
+
pingpong=dx_config.is_pingpong,
|
| 465 |
+
is_persistent=True,
|
| 466 |
+
compute_swiglu=False,
|
| 467 |
+
compute_dz_and_partial_ds_and_y1s=False,
|
| 468 |
+
is_A_gather=False,
|
| 469 |
+
epi_tile_size=dx_config.epi_tile_size,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 473 |
+
dx_config.cluster_shape_mnk[0] * dx_config.cluster_shape_mnk[1]
|
| 474 |
+
)
|
| 475 |
+
self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 476 |
+
|
| 477 |
+
@cute.jit
|
| 478 |
+
def __call__(
|
| 479 |
+
self, mDz, mW1_trans, mDx_expanded, mE_offset, mX_gather, mS_scatter, tensormaps, mE_permute_order, stream
|
| 480 |
+
):
|
| 481 |
+
return self.module(
|
| 482 |
+
mDz,
|
| 483 |
+
mW1_trans,
|
| 484 |
+
None,
|
| 485 |
+
None,
|
| 486 |
+
mDx_expanded,
|
| 487 |
+
None,
|
| 488 |
+
None,
|
| 489 |
+
None,
|
| 490 |
+
mE_offset,
|
| 491 |
+
mX_gather,
|
| 492 |
+
None,
|
| 493 |
+
mS_scatter,
|
| 494 |
+
None,
|
| 495 |
+
None,
|
| 496 |
+
None,
|
| 497 |
+
tensormaps[0],
|
| 498 |
+
tensormaps[1],
|
| 499 |
+
None,
|
| 500 |
+
mE_permute_order,
|
| 501 |
+
const_expr(self.max_active_clusters),
|
| 502 |
+
stream,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd:
|
| 507 |
+
def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
|
| 508 |
+
super().__init__()
|
| 509 |
+
if is_glu_activation:
|
| 510 |
+
assert (
|
| 511 |
+
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 512 |
+
), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 513 |
+
else:
|
| 514 |
+
assert (
|
| 515 |
+
H % 64 == 0 and H >= 512 and I % 128 == 0
|
| 516 |
+
), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
|
| 517 |
+
|
| 518 |
+
if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation):
|
| 519 |
+
dw1_config = HopperGEMMConfig(
|
| 520 |
+
tile_shape_mnk=(128, 256, 64),
|
| 521 |
+
cluster_shape_mnk=(2, 1),
|
| 522 |
+
epi_tile_size=16,
|
| 523 |
+
is_pingpong=False,
|
| 524 |
+
initial_d_epi_stage=6,
|
| 525 |
+
raster_order=RasterOrderOption.Heuristic,
|
| 526 |
+
)
|
| 527 |
+
elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation):
|
| 528 |
+
dw1_config = HopperGEMMConfig(
|
| 529 |
+
tile_shape_mnk=(256, 128, 64),
|
| 530 |
+
cluster_shape_mnk=(2, 1),
|
| 531 |
+
epi_tile_size=16,
|
| 532 |
+
is_pingpong=False,
|
| 533 |
+
initial_d_epi_stage=6,
|
| 534 |
+
raster_order=RasterOrderOption.AlongN,
|
| 535 |
+
)
|
| 536 |
+
else:
|
| 537 |
+
raise NotImplementedError()
|
| 538 |
+
|
| 539 |
+
self.module = HopperWgmma_MoE_kernel(
|
| 540 |
+
E,
|
| 541 |
+
cutlass.Float32,
|
| 542 |
+
dw1_config.tile_shape_mnk,
|
| 543 |
+
(*dw1_config.cluster_shape_mnk, 1),
|
| 544 |
+
pingpong=dw1_config.is_pingpong,
|
| 545 |
+
is_persistent=True,
|
| 546 |
+
compute_swiglu=False,
|
| 547 |
+
compute_weight_gradient=True,
|
| 548 |
+
compute_dz_and_partial_ds_and_y1s=False,
|
| 549 |
+
is_A_gather=True,
|
| 550 |
+
epi_tile_size=dw1_config.epi_tile_size,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 554 |
+
dw1_config.cluster_shape_mnk[0] * dw1_config.cluster_shape_mnk[1]
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
@cute.jit
|
| 558 |
+
def __call__(self, mX_trans, mDz_trans, mDw1_trans, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
|
| 559 |
+
return self.module(
|
| 560 |
+
mX_trans,
|
| 561 |
+
mDz_trans,
|
| 562 |
+
None,
|
| 563 |
+
None,
|
| 564 |
+
mDw1_trans,
|
| 565 |
+
None,
|
| 566 |
+
None,
|
| 567 |
+
None,
|
| 568 |
+
mE_offset,
|
| 569 |
+
mX_gather,
|
| 570 |
+
None,
|
| 571 |
+
None,
|
| 572 |
+
None,
|
| 573 |
+
tensormaps[0],
|
| 574 |
+
None,
|
| 575 |
+
None,
|
| 576 |
+
None,
|
| 577 |
+
None,
|
| 578 |
+
mE_permute_order,
|
| 579 |
+
const_expr(self.max_active_clusters),
|
| 580 |
+
stream,
|
| 581 |
+
)
|
build/torch-cuda/functional/reduction_over_k_gather.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
from ..utils import get_powers_of_2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
### This triton impl is equivalent as the cute-dsl impl shown above,
|
| 15 |
+
# and also achieves similar memory bandwidth on H100 for large K and H.
|
| 16 |
+
# However, for small K and H, this impl is better by autotuning so we use it as the default.
|
| 17 |
+
def _get_triton_autotune_configs() -> list[triton.Config]:
|
| 18 |
+
configs = []
|
| 19 |
+
for BLOCK_H in get_powers_of_2(256, 4096):
|
| 20 |
+
for BLOCK_K in get_powers_of_2(1, 128):
|
| 21 |
+
for num_warps in [4, 8]:
|
| 22 |
+
if BLOCK_K * BLOCK_H <= 32768:
|
| 23 |
+
configs.append(
|
| 24 |
+
triton.Config({"BLOCK_H": BLOCK_H, "BLOCK_K": BLOCK_K}, num_warps=num_warps, num_stages=4)
|
| 25 |
+
)
|
| 26 |
+
return configs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _prune_triton_autotune_config(configs, nargs, **kw):
|
| 30 |
+
pruned_configs = []
|
| 31 |
+
for c in configs:
|
| 32 |
+
BLOCK_H = c.kwargs["BLOCK_H"]
|
| 33 |
+
BLOCK_K = c.kwargs["BLOCK_K"]
|
| 34 |
+
H = kw["H"]
|
| 35 |
+
MAX_K = kw["MAX_K"]
|
| 36 |
+
if (
|
| 37 |
+
BLOCK_H <= triton.next_power_of_2(H)
|
| 38 |
+
and BLOCK_K <= triton.next_power_of_2(MAX_K)
|
| 39 |
+
and min(H * MAX_K, 1024) <= (BLOCK_H * BLOCK_K)
|
| 40 |
+
):
|
| 41 |
+
pruned_configs.append(c)
|
| 42 |
+
|
| 43 |
+
if len(pruned_configs) == 0:
|
| 44 |
+
return configs
|
| 45 |
+
else:
|
| 46 |
+
return pruned_configs
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@triton.autotune(
|
| 50 |
+
configs=_get_triton_autotune_configs(),
|
| 51 |
+
key=["H", "MAX_K", "w_is_None", "is_varlen_K"],
|
| 52 |
+
prune_configs_by={"early_config_prune": _prune_triton_autotune_config},
|
| 53 |
+
)
|
| 54 |
+
@triton.jit
|
| 55 |
+
def token_gather_sum_kernel(
|
| 56 |
+
x_ptr, # (Mtotal, H)
|
| 57 |
+
w_ptr, # (Mtotal,)
|
| 58 |
+
M_perm_ptr, # (Mtotal,) int32
|
| 59 |
+
M_offset_ptr, # (T+1,) int32
|
| 60 |
+
out_ptr, # (T, H)
|
| 61 |
+
T,
|
| 62 |
+
H: tl.constexpr,
|
| 63 |
+
MAX_K: tl.constexpr,
|
| 64 |
+
# strides
|
| 65 |
+
stride_xM: tl.constexpr,
|
| 66 |
+
stride_xH: tl.constexpr,
|
| 67 |
+
stride_outT: tl.constexpr,
|
| 68 |
+
stride_outH: tl.constexpr,
|
| 69 |
+
# tile sizes
|
| 70 |
+
BLOCK_H: tl.constexpr,
|
| 71 |
+
BLOCK_K: tl.constexpr,
|
| 72 |
+
w_is_None: tl.constexpr,
|
| 73 |
+
is_varlen_K: tl.constexpr,
|
| 74 |
+
):
|
| 75 |
+
# 1D tiling over T only
|
| 76 |
+
pid_t = tl.program_id(axis=0)
|
| 77 |
+
t_idx = pid_t.to(tl.uint32)
|
| 78 |
+
|
| 79 |
+
# Load segment starts and ends for this token
|
| 80 |
+
if is_varlen_K:
|
| 81 |
+
Ms = tl.load(M_offset_ptr + t_idx).to(tl.uint32)
|
| 82 |
+
Me = tl.load(M_offset_ptr + t_idx + 1).to(tl.uint32)
|
| 83 |
+
K_this_token = Me - Ms # actual K for this token
|
| 84 |
+
else:
|
| 85 |
+
Ms = MAX_K * t_idx
|
| 86 |
+
K_this_token: tl.constexpr = MAX_K
|
| 87 |
+
|
| 88 |
+
# Outer loop over H tiles
|
| 89 |
+
for h_tile in tl.static_range(triton.cdiv(H, BLOCK_H)):
|
| 90 |
+
h_idx = (h_tile * BLOCK_H + tl.arange(0, BLOCK_H)).to(tl.uint32) # [BLOCK_H]
|
| 91 |
+
m_h = h_idx < H
|
| 92 |
+
|
| 93 |
+
# Initialize accumulator for this H tile
|
| 94 |
+
acc = tl.zeros([BLOCK_H], dtype=tl.float32) # [BLOCK_H]
|
| 95 |
+
|
| 96 |
+
# Inner loop over K tiles
|
| 97 |
+
for k_tile in tl.range(tl.cdiv(K_this_token, BLOCK_K)):
|
| 98 |
+
k_offset = k_tile * BLOCK_K
|
| 99 |
+
|
| 100 |
+
k_idx = (k_offset + tl.arange(0, BLOCK_K)).to(tl.uint32) # [BLOCK_K]
|
| 101 |
+
|
| 102 |
+
# Mask for valid K indices
|
| 103 |
+
m_k = k_idx < K_this_token # [BLOCK_K]
|
| 104 |
+
|
| 105 |
+
# Absolute positions into M_perm and w
|
| 106 |
+
m_abs = Ms + k_idx # [BLOCK_K]
|
| 107 |
+
|
| 108 |
+
# Gather permuted indices
|
| 109 |
+
perm_idx = tl.load(M_perm_ptr + m_abs, mask=m_k, other=0).to(tl.uint32) # [BLOCK_K]
|
| 110 |
+
|
| 111 |
+
# Load x values: [BLOCK_K, BLOCK_H]
|
| 112 |
+
x_ptrs = x_ptr + perm_idx[:, None] * stride_xM + h_idx[None, :] * stride_xH
|
| 113 |
+
x_mask = m_k[:, None] & m_h[None, :]
|
| 114 |
+
x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
|
| 115 |
+
|
| 116 |
+
# Reduce along K dimension and add to accumulator
|
| 117 |
+
if w_is_None:
|
| 118 |
+
acc += tl.sum(x_vals, axis=0) # [BLOCK_H]
|
| 119 |
+
else:
|
| 120 |
+
w_vals = tl.load(w_ptr + m_abs, mask=m_k, other=0.0).to(tl.float32) # [BLOCK_K]
|
| 121 |
+
acc += tl.sum(x_vals * w_vals[:, None], axis=0) # [BLOCK_H]
|
| 122 |
+
|
| 123 |
+
# Store final result for this H tile (only once!)
|
| 124 |
+
out_ptrs = out_ptr + t_idx * stride_outT + h_idx * stride_outH
|
| 125 |
+
tl.store(out_ptrs, acc, mask=m_h)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def token_gather_and_sum_varlen_K_triton(
|
| 129 |
+
x: torch.Tensor, # (Mtotal, H)
|
| 130 |
+
w: Optional[torch.Tensor], # (Mtotal,)
|
| 131 |
+
out: torch.Tensor, # (T, H)
|
| 132 |
+
M_perm: torch.Tensor, # (Mtotal,) int32
|
| 133 |
+
M_offset: torch.Tensor, # (T+1,) int32, variable K per token
|
| 134 |
+
T: int,
|
| 135 |
+
MAX_K: int, # maximum K across all tokens
|
| 136 |
+
H: int,
|
| 137 |
+
is_varlen_K: bool,
|
| 138 |
+
):
|
| 139 |
+
"""
|
| 140 |
+
1D parallelization over T, with iterative accumulation over K tiles and H tiles.
|
| 141 |
+
Supports variable K per token.
|
| 142 |
+
|
| 143 |
+
out[i, :] = sum_{j=0..K[i]-1} x[M_perm[M_offset[i] + j], :] * w[M_offset[i] + j]
|
| 144 |
+
|
| 145 |
+
where K[i] = M_offset[i+1] - M_offset[i] can vary per token.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
# 1D grid over T only
|
| 149 |
+
token_gather_sum_kernel[(T,)](
|
| 150 |
+
x,
|
| 151 |
+
w,
|
| 152 |
+
M_perm,
|
| 153 |
+
M_offset,
|
| 154 |
+
out,
|
| 155 |
+
T=T,
|
| 156 |
+
H=H,
|
| 157 |
+
MAX_K=MAX_K,
|
| 158 |
+
stride_xM=x.stride(0),
|
| 159 |
+
stride_xH=x.stride(1),
|
| 160 |
+
stride_outT=out.stride(0),
|
| 161 |
+
stride_outH=out.stride(1),
|
| 162 |
+
w_is_None=(w is None),
|
| 163 |
+
is_varlen_K=is_varlen_K,
|
| 164 |
+
)
|
build/torch-cuda/functional/tile_scheduler.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Boolean, Int32, const_expr
|
| 10 |
+
from ..quack.pipeline import PipelineStateWAdvance
|
| 11 |
+
from ..quack.tile_scheduler import TileScheduler, VarlenMTileScheduler
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SonicMoETileScheduler(TileScheduler):
|
| 15 |
+
@staticmethod
|
| 16 |
+
@cute.jit
|
| 17 |
+
def create(
|
| 18 |
+
params: TileScheduler.Params,
|
| 19 |
+
tile_count: cute.Tensor | None = None,
|
| 20 |
+
scheduler_pipeline: cutlass.pipeline.PipelineAsync | None = None,
|
| 21 |
+
is_scheduler_warp: bool | Boolean = False,
|
| 22 |
+
*,
|
| 23 |
+
loc=None,
|
| 24 |
+
ip=None,
|
| 25 |
+
) -> SonicMoETileScheduler:
|
| 26 |
+
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
| 27 |
+
stages = 0
|
| 28 |
+
if const_expr(not params.is_persistent):
|
| 29 |
+
cidx, cidy, _ = cute.arch.cluster_idx()
|
| 30 |
+
cdimx, _, _ = cute.arch.cluster_dim()
|
| 31 |
+
cluster_id = cidx + cidy * cdimx
|
| 32 |
+
current_work_linear_idx = Int32(cluster_id)
|
| 33 |
+
else:
|
| 34 |
+
_, _, bidz = cute.arch.block_idx()
|
| 35 |
+
current_work_linear_idx = Int32(bidz)
|
| 36 |
+
if const_expr(params.tile_count_semaphore is not None):
|
| 37 |
+
assert tile_count is not None
|
| 38 |
+
assert scheduler_pipeline is not None
|
| 39 |
+
stages = const_expr(cute.size(tile_count))
|
| 40 |
+
return SonicMoETileScheduler(
|
| 41 |
+
current_work_linear_idx,
|
| 42 |
+
Int32(0), # num_tiles_executed
|
| 43 |
+
tile_count,
|
| 44 |
+
scheduler_pipeline,
|
| 45 |
+
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
|
| 46 |
+
params,
|
| 47 |
+
loc=loc,
|
| 48 |
+
ip=ip,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def prefetch_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
|
| 52 |
+
old_current_work_linear_idx = self._current_work_linear_idx
|
| 53 |
+
if const_expr(self.params.is_persistent):
|
| 54 |
+
num_persistent_clusters = cute.arch.grid_dim()[2]
|
| 55 |
+
self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
|
| 56 |
+
future_tile_coord_mnkl = self.get_current_work()
|
| 57 |
+
self._current_work_linear_idx = old_current_work_linear_idx
|
| 58 |
+
return future_tile_coord_mnkl
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SonicMoEVarlenMTileScheduler(VarlenMTileScheduler, SonicMoETileScheduler):
|
| 62 |
+
@staticmethod
|
| 63 |
+
@cute.jit
|
| 64 |
+
def create(
|
| 65 |
+
params: VarlenMTileScheduler.Params,
|
| 66 |
+
tile_count: cute.Tensor | None = None,
|
| 67 |
+
scheduler_pipeline: cutlass.pipeline.PipelineAsync | None = None,
|
| 68 |
+
is_scheduler_warp: bool | Boolean = False,
|
| 69 |
+
*,
|
| 70 |
+
loc=None,
|
| 71 |
+
ip=None,
|
| 72 |
+
) -> SonicMoEVarlenMTileScheduler:
|
| 73 |
+
stages = 0
|
| 74 |
+
_, _, bidz = cute.arch.block_idx()
|
| 75 |
+
current_work_linear_idx = Int32(bidz)
|
| 76 |
+
if const_expr(params.tile_count_semaphore is not None):
|
| 77 |
+
assert tile_count is not None
|
| 78 |
+
assert scheduler_pipeline is not None
|
| 79 |
+
stages = const_expr(cute.size(tile_count))
|
| 80 |
+
return SonicMoEVarlenMTileScheduler(
|
| 81 |
+
current_work_linear_idx,
|
| 82 |
+
Int32(0), # num_tiles_executed
|
| 83 |
+
Int32(0), # current_batch_idx
|
| 84 |
+
Int32(0), # num_work_idx_before_cur_batch
|
| 85 |
+
tile_count,
|
| 86 |
+
scheduler_pipeline,
|
| 87 |
+
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
|
| 88 |
+
params,
|
| 89 |
+
loc=loc,
|
| 90 |
+
ip=ip,
|
| 91 |
+
)
|
build/torch-cuda/functional/topk_softmax.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
# this impl is adapted from QuACK's topk https://github.com/Dao-AILab/quack/blob/main/quack/topk.py
|
| 6 |
+
import math
|
| 7 |
+
from typing import Type
|
| 8 |
+
|
| 9 |
+
import cuda.bindings.driver as cuda
|
| 10 |
+
import cutlass
|
| 11 |
+
import cutlass.cute as cute
|
| 12 |
+
from ..quack import utils
|
| 13 |
+
from cutlass import const_expr
|
| 14 |
+
from ..quack.sort.bitonic_sort import bitonic_topk
|
| 15 |
+
from triton import next_power_of_2
|
| 16 |
+
|
| 17 |
+
from ..utils import domain_offset_i64
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TopK_Softmax:
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
input_dtype: Type[cutlass.Numeric],
|
| 24 |
+
output_dtype: Type[cutlass.Numeric],
|
| 25 |
+
N: int,
|
| 26 |
+
k: int,
|
| 27 |
+
require_softmax_fusion: bool = True,
|
| 28 |
+
):
|
| 29 |
+
self.input_dtype = input_dtype
|
| 30 |
+
self.output_dtype = output_dtype
|
| 31 |
+
self.N = N
|
| 32 |
+
self.input_vecsize = 128 // input_dtype.width
|
| 33 |
+
self.output_vecsize = 128 // output_dtype.width
|
| 34 |
+
self.k = k
|
| 35 |
+
self.next_power_of_2_N = next_power_of_2(N)
|
| 36 |
+
self.next_power_of_2_K = next_power_of_2(k)
|
| 37 |
+
assert k <= 128 and k <= N
|
| 38 |
+
assert N <= 4096 and N % 8 == 0
|
| 39 |
+
assert input_dtype.width <= output_dtype.width, "input bitwidth must <= output bitwidth"
|
| 40 |
+
|
| 41 |
+
self.require_softmax_fusion = require_softmax_fusion
|
| 42 |
+
|
| 43 |
+
def _calculate_threads_per_row(self):
|
| 44 |
+
# we want num_elems_per_thread >= self.k
|
| 45 |
+
# and each thread can handle at most 64 elements
|
| 46 |
+
N = self.next_power_of_2_N
|
| 47 |
+
num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
|
| 48 |
+
return num_threads_per_row
|
| 49 |
+
|
| 50 |
+
def _get_tv_layout(self, vecsize):
|
| 51 |
+
N = self.next_power_of_2_N
|
| 52 |
+
num_threads = 128 if N <= 16384 else 256
|
| 53 |
+
threads_per_row = self._calculate_threads_per_row()
|
| 54 |
+
cols_per_block = num_threads // threads_per_row
|
| 55 |
+
num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
|
| 56 |
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
| 57 |
+
tv_layout = cute.make_layout(
|
| 58 |
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
| 59 |
+
stride=(
|
| 60 |
+
(vecsize * cols_per_block, 1),
|
| 61 |
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
| 62 |
+
),
|
| 63 |
+
)
|
| 64 |
+
return tiler_mn, tv_layout
|
| 65 |
+
|
| 66 |
+
@cute.jit
|
| 67 |
+
def __call__(
|
| 68 |
+
self,
|
| 69 |
+
mX: cute.Tensor,
|
| 70 |
+
mValues: cute.Tensor,
|
| 71 |
+
mIndices: cute.Tensor,
|
| 72 |
+
stream: cuda.CUstream,
|
| 73 |
+
):
|
| 74 |
+
assert mX.element_type == self.input_dtype
|
| 75 |
+
assert mValues.element_type == self.output_dtype
|
| 76 |
+
assert mIndices.element_type == cutlass.Int32
|
| 77 |
+
input_tiler_mn, input_tv_layout = self._get_tv_layout(self.input_vecsize)
|
| 78 |
+
output_tiler_mn, output_tv_layout = self._get_tv_layout(self.output_vecsize)
|
| 79 |
+
|
| 80 |
+
num_threads = cute.size(input_tv_layout, mode=[0])
|
| 81 |
+
self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout, output_tiler_mn).launch(
|
| 82 |
+
grid=[cute.ceil_div(mX.shape[0], input_tiler_mn[0]), 1, 1],
|
| 83 |
+
block=[num_threads, 1, 1],
|
| 84 |
+
stream=stream,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
@cute.kernel
|
| 88 |
+
def kernel(
|
| 89 |
+
self,
|
| 90 |
+
mX: cute.Tensor,
|
| 91 |
+
mValues: cute.Tensor,
|
| 92 |
+
mIndices: cute.Tensor,
|
| 93 |
+
input_tv_layout: cute.Layout,
|
| 94 |
+
input_tiler_mn: cute.Shape,
|
| 95 |
+
output_tv_layout: cute.Layout,
|
| 96 |
+
output_tiler_mn: cute.Shape,
|
| 97 |
+
):
|
| 98 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 99 |
+
bidx, _, _ = cute.arch.block_idx()
|
| 100 |
+
|
| 101 |
+
shape = mX.shape
|
| 102 |
+
idX = cute.make_identity_tensor(shape)
|
| 103 |
+
# slice for CTAs
|
| 104 |
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
| 105 |
+
mX = domain_offset_i64((bidx * input_tiler_mn[0], 0), mX)
|
| 106 |
+
gX = cute.local_tile(mX, input_tiler_mn, (0, 0))
|
| 107 |
+
cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0))
|
| 108 |
+
|
| 109 |
+
# declare the atoms which will be used later for memory copy
|
| 110 |
+
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
|
| 111 |
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx)
|
| 112 |
+
tXgX = thr_copy_X.partition_S(gX)
|
| 113 |
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
| 114 |
+
|
| 115 |
+
# allocate fragments for gmem->rmem
|
| 116 |
+
tXrX = cute.make_rmem_tensor_like(tXgX)
|
| 117 |
+
|
| 118 |
+
is_even_N = const_expr(shape[1] == input_tiler_mn[1])
|
| 119 |
+
tXpX = (
|
| 120 |
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
| 121 |
+
if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N))
|
| 122 |
+
else None
|
| 123 |
+
)
|
| 124 |
+
if tXcX[0][0] < shape[0]:
|
| 125 |
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
| 126 |
+
tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, cutlass.Float32)
|
| 127 |
+
tXrX_f32.store(tXrX.load().to(cutlass.Float32))
|
| 128 |
+
|
| 129 |
+
# Encode the indices into the bottom bits of values.
|
| 130 |
+
log_N = int(math.log2(self.next_power_of_2_N))
|
| 131 |
+
idx_mask = const_expr((1 << log_N) - 1)
|
| 132 |
+
input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0])
|
| 133 |
+
tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32)
|
| 134 |
+
# Encode indices into the last log_N bits of tXrX_u32
|
| 135 |
+
for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True):
|
| 136 |
+
# tXcX only keeps track of the indices for every @vecsize elements
|
| 137 |
+
col_idx = cutlass.Uint32(tXcX[i // input_vecsize][1] + i % input_vecsize)
|
| 138 |
+
# If positive, invert the bits of the index, so that if there's a tie,
|
| 139 |
+
# indices coming from a earlier column will win.
|
| 140 |
+
encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx
|
| 141 |
+
# Mask to keep only the last log_N bits of the encoded index
|
| 142 |
+
encoded_idx = encoded_idx & idx_mask
|
| 143 |
+
# Clear the last log_N bits and set them to our encoded index
|
| 144 |
+
tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx
|
| 145 |
+
|
| 146 |
+
# Fill OOB values with -inf for top-k
|
| 147 |
+
if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)):
|
| 148 |
+
utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
|
| 149 |
+
|
| 150 |
+
threads_per_row = input_tv_layout.shape[0][0]
|
| 151 |
+
topk_vals = bitonic_topk(tXrX_f32, self.next_power_of_2_K, warp_width=threads_per_row)
|
| 152 |
+
|
| 153 |
+
# Extract indices and clean values
|
| 154 |
+
topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
|
| 155 |
+
topk_indices = cute.make_rmem_tensor(self.k, cutlass.Int32)
|
| 156 |
+
for i in cutlass.range_constexpr(self.k):
|
| 157 |
+
# Extract the encoded index from the last log_N bits
|
| 158 |
+
encoded_idx = topk_vals_u32[i] & idx_mask
|
| 159 |
+
# Check if original value was positive by looking at the cleaned value
|
| 160 |
+
topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits
|
| 161 |
+
# If positive, we need to invert the bits back to get original index
|
| 162 |
+
col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
|
| 163 |
+
topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
|
| 164 |
+
|
| 165 |
+
if const_expr(self.require_softmax_fusion):
|
| 166 |
+
topk_vals_max = -cutlass.Float32.inf
|
| 167 |
+
for i in cutlass.range_constexpr(self.k):
|
| 168 |
+
topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max)
|
| 169 |
+
|
| 170 |
+
topk_exp_sum = cutlass.Int32(0.0)
|
| 171 |
+
for i in cutlass.range_constexpr(self.k):
|
| 172 |
+
topk_vals[i] = cute.math.exp(topk_vals[i] - topk_vals_max)
|
| 173 |
+
topk_exp_sum = topk_exp_sum + topk_vals[i]
|
| 174 |
+
|
| 175 |
+
for i in cutlass.range_constexpr(self.k):
|
| 176 |
+
topk_vals[i] = topk_vals[i] / topk_exp_sum
|
| 177 |
+
|
| 178 |
+
# Convert cleaned values to output type
|
| 179 |
+
topk_vals_out = cute.make_rmem_tensor_like(topk_indices, mValues.element_type)
|
| 180 |
+
for i in cutlass.range_constexpr(self.k):
|
| 181 |
+
topk_vals_out[i] = topk_vals[i].to(mValues.element_type)
|
| 182 |
+
|
| 183 |
+
row = tXcX[0][0]
|
| 184 |
+
# Only the 1st thread in this row writes the top-k values and indices
|
| 185 |
+
output_vecsize = cutlass.const_expr(output_tv_layout.shape[1][0])
|
| 186 |
+
if row < shape[0] and tXcX[0][1] == 0:
|
| 187 |
+
# Vectorized write
|
| 188 |
+
elems_per_store = const_expr(math.gcd(output_vecsize, self.k))
|
| 189 |
+
mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
|
| 190 |
+
mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
|
| 191 |
+
topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
|
| 192 |
+
topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
|
| 193 |
+
for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])):
|
| 194 |
+
cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
|
| 195 |
+
cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
|
build/torch-cuda/functional/triton_kernels/__init__.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
from ..._ops_compat import add_op_namespace_prefix
|
| 8 |
+
from .bitmatrix import _bitmatrix_metadata_compute_stage1, _bitmatrix_metadata_compute_stage2, _keyed_add
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@triton.jit
|
| 12 |
+
def _compute_col_partial_sum_kernel(
|
| 13 |
+
topk_indices_ptr,
|
| 14 |
+
partial_sum_ptr,
|
| 15 |
+
T,
|
| 16 |
+
E: tl.constexpr,
|
| 17 |
+
n_tiles,
|
| 18 |
+
TOKENS_PER_TILE: tl.constexpr,
|
| 19 |
+
K_POW2: tl.constexpr, # next_power_of_2(K),
|
| 20 |
+
K: tl.constexpr, # actual number of experts per token
|
| 21 |
+
E_POW2: tl.constexpr, # next_power_of_2(E)
|
| 22 |
+
):
|
| 23 |
+
# One CTA per tile. Tile `t` covers tokens [t * TOKENS_PER_TILE, (t+1) * TOKENS_PER_TILE).
|
| 24 |
+
# Produces partial_sum[e, tile_id] = number of entries in this tile routed to expert e.
|
| 25 |
+
# Layout: partial_sum is [E, n_tiles] (row-major), so partial_sum[e, t] = partial_sum_ptr + e * n_tiles + t.
|
| 26 |
+
# Caller transposes to [n_tiles, E] before passing to stage1/stage2.
|
| 27 |
+
tile_id = tl.program_id(0)
|
| 28 |
+
|
| 29 |
+
# Zero this tile's column in partial_sum[*, tile_id].
|
| 30 |
+
# Chunked by E_POW2 to keep vector width a power of 2.
|
| 31 |
+
for e_start in tl.static_range(0, E, E_POW2):
|
| 32 |
+
e_offs = e_start + tl.arange(0, E_POW2)
|
| 33 |
+
tl.store(
|
| 34 |
+
partial_sum_ptr + e_offs * n_tiles + tile_id,
|
| 35 |
+
tl.zeros([E_POW2], tl.int32),
|
| 36 |
+
mask=e_offs < E,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Load expert ids for this tile: shape [TOKENS_PER_TILE, K_POW2].
|
| 40 |
+
# Tokens beyond T and k-slots beyond K are masked out (other=-1).
|
| 41 |
+
tok_offs = tile_id * TOKENS_PER_TILE + tl.arange(0, TOKENS_PER_TILE)
|
| 42 |
+
k_offs = tl.arange(0, K_POW2)
|
| 43 |
+
tok_mask = tok_offs < T
|
| 44 |
+
|
| 45 |
+
load_mask = tok_mask[:, None] & (k_offs[None, :] < K)
|
| 46 |
+
safe_k = tl.minimum(k_offs, K - 1) # avoid OOB when k_offs >= K
|
| 47 |
+
expert_ids = tl.load(
|
| 48 |
+
topk_indices_ptr + tok_offs[:, None] * K + safe_k[None, :],
|
| 49 |
+
mask=load_mask,
|
| 50 |
+
other=-1,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Flatten to [TOKENS_PER_TILE * K_POW2] and histogram into partial_sum.
|
| 54 |
+
# safe_experts remaps masked (-1) entries to expert 0 (harmless: flat_mask=False).
|
| 55 |
+
flat_experts = tl.reshape(expert_ids, [TOKENS_PER_TILE * K_POW2])
|
| 56 |
+
flat_mask = tl.reshape(load_mask, [TOKENS_PER_TILE * K_POW2])
|
| 57 |
+
safe_experts = tl.where(flat_mask, flat_experts, 0)
|
| 58 |
+
|
| 59 |
+
tl.atomic_add(
|
| 60 |
+
partial_sum_ptr + safe_experts * n_tiles + tile_id,
|
| 61 |
+
tl.full([TOKENS_PER_TILE * K_POW2], 1, dtype=tl.int32),
|
| 62 |
+
mask=flat_mask,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@torch.library.custom_op(
|
| 67 |
+
add_op_namespace_prefix("triton_kernels__TC_topk_router_metadata"),
|
| 68 |
+
mutates_args={
|
| 69 |
+
"expert_frequency",
|
| 70 |
+
"expert_frequency_offset",
|
| 71 |
+
"x_gather_idx",
|
| 72 |
+
"s_scatter_idx",
|
| 73 |
+
"s_reverse_scatter_idx",
|
| 74 |
+
},
|
| 75 |
+
)
|
| 76 |
+
def TC_topk_router_metadata_triton(
|
| 77 |
+
topk_router_indices: torch.Tensor,
|
| 78 |
+
E: int,
|
| 79 |
+
expert_frequency: torch.Tensor,
|
| 80 |
+
expert_frequency_offset: torch.Tensor,
|
| 81 |
+
x_gather_idx: torch.Tensor,
|
| 82 |
+
s_scatter_idx: torch.Tensor,
|
| 83 |
+
s_reverse_scatter_idx: torch.Tensor,
|
| 84 |
+
) -> None:
|
| 85 |
+
T, K = topk_router_indices.size()
|
| 86 |
+
TK = T * K
|
| 87 |
+
device = topk_router_indices.device
|
| 88 |
+
E_POW2 = triton.next_power_of_2(E)
|
| 89 |
+
K_POW2 = triton.next_power_of_2(K)
|
| 90 |
+
TOKENS_PER_BLOCK = 1024 // K_POW2
|
| 91 |
+
n_tiles = triton.cdiv(T, TOKENS_PER_BLOCK)
|
| 92 |
+
|
| 93 |
+
# ── Kernel 1: tiled histogram ─────────────────────────────────────────────
|
| 94 |
+
# col_partial_sum_trans[E, n_tiles]: raw per-expert-per-tile counts.
|
| 95 |
+
# Stored transposed so each CTA writes to its own column (tile_id), avoiding
|
| 96 |
+
# cross-CTA write conflicts. Transposed back to [n_tiles, E] for stage1/stage2.
|
| 97 |
+
col_partial_sum_trans = torch.empty(E, n_tiles, dtype=torch.int32, device=device)
|
| 98 |
+
_compute_col_partial_sum_kernel[(n_tiles,)](
|
| 99 |
+
topk_router_indices,
|
| 100 |
+
col_partial_sum_trans,
|
| 101 |
+
T,
|
| 102 |
+
E,
|
| 103 |
+
n_tiles,
|
| 104 |
+
TOKENS_PER_TILE=TOKENS_PER_BLOCK,
|
| 105 |
+
K_POW2=K_POW2,
|
| 106 |
+
K=K,
|
| 107 |
+
E_POW2=E_POW2,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
expert_frequency.copy_(col_partial_sum_trans.sum(dim=1, dtype=torch.int32))
|
| 111 |
+
col_partial_sum = col_partial_sum_trans.T # [n_tiles, E]
|
| 112 |
+
|
| 113 |
+
# ── Kernel 2: stage1 ─────────────────────────────────────────────────────
|
| 114 |
+
# - For each expert e (pid < E): convert col_partial_sum[*, e] from raw
|
| 115 |
+
# counts to exclusive prefix sums over tiles in-place.
|
| 116 |
+
# - For pid == E: write exclusive cumsum of expert_freq_offset into
|
| 117 |
+
# expert_freq_off[0:E] (= col_offs, a view into expert_freq_off).
|
| 118 |
+
|
| 119 |
+
_bitmatrix_metadata_compute_stage1[(E + 2,)](
|
| 120 |
+
expert_frequency,
|
| 121 |
+
expert_frequency_offset,
|
| 122 |
+
E,
|
| 123 |
+
col_partial_sum,
|
| 124 |
+
n_tiles,
|
| 125 |
+
TK,
|
| 126 |
+
BLOCK_M=128,
|
| 127 |
+
BLOCK_N=E_POW2,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# ── Kernel 3: stage2 ─────────────────────────────────────────���───────────
|
| 131 |
+
# For each tile: sort entries by expert, compute output positions, scatter.
|
| 132 |
+
_bitmatrix_metadata_compute_stage2[(n_tiles,)](
|
| 133 |
+
s_scatter_idx,
|
| 134 |
+
s_reverse_scatter_idx,
|
| 135 |
+
x_gather_idx,
|
| 136 |
+
topk_router_indices,
|
| 137 |
+
T,
|
| 138 |
+
col_partial_sum,
|
| 139 |
+
n_tiles,
|
| 140 |
+
expert_frequency_offset[:E],
|
| 141 |
+
K_POW2=K_POW2,
|
| 142 |
+
TOKENS_PER_BLOCK=TOKENS_PER_BLOCK,
|
| 143 |
+
K=K,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ── general_routing_router_metadata_triton --- Kernel 1: tiled histogram over flat selected_E ────────────────────────────
|
| 148 |
+
@triton.jit
|
| 149 |
+
def _general_compute_col_partial_sum_kernel(
|
| 150 |
+
selected_E_ptr,
|
| 151 |
+
partial_sum_ptr, # [E, n_tiles], column-major per tile
|
| 152 |
+
TK,
|
| 153 |
+
E: tl.constexpr,
|
| 154 |
+
n_tiles,
|
| 155 |
+
BLOCK_SIZE: tl.constexpr,
|
| 156 |
+
E_POW2: tl.constexpr,
|
| 157 |
+
):
|
| 158 |
+
tile_id = tl.program_id(0)
|
| 159 |
+
|
| 160 |
+
# Zero this tile's column in partial_sum[*, tile_id].
|
| 161 |
+
for e_start in tl.static_range(0, E, E_POW2):
|
| 162 |
+
e_offs = e_start + tl.arange(0, E_POW2)
|
| 163 |
+
tl.store(
|
| 164 |
+
partial_sum_ptr + e_offs * n_tiles + tile_id,
|
| 165 |
+
tl.zeros([E_POW2], tl.int32),
|
| 166 |
+
mask=e_offs < E,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Load expert ids for this tile (flat indexing into selected_E).
|
| 170 |
+
offs = tile_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 171 |
+
mask = offs < TK
|
| 172 |
+
expert_ids = tl.load(selected_E_ptr + offs, mask=mask, other=-1)
|
| 173 |
+
|
| 174 |
+
safe_experts = tl.where(mask, expert_ids, 0)
|
| 175 |
+
tl.atomic_add(
|
| 176 |
+
partial_sum_ptr + safe_experts * n_tiles + tile_id,
|
| 177 |
+
tl.full([BLOCK_SIZE], 1, dtype=tl.int32),
|
| 178 |
+
mask=mask,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ── general_routing_router_metadata_triton --- Kernel 3: sort entries by expert within each tile, scatter ────────────────
|
| 183 |
+
@triton.jit
|
| 184 |
+
def _general_metadata_compute_stage2(
|
| 185 |
+
s_scatter_idx_ptr,
|
| 186 |
+
s_reverse_scatter_idx_ptr,
|
| 187 |
+
x_gather_idx_ptr,
|
| 188 |
+
selected_E_ptr,
|
| 189 |
+
sorted_selected_T_ptr,
|
| 190 |
+
TK,
|
| 191 |
+
partial_sum_ptr, # [n_tiles, E] with strides (1, n_tiles)
|
| 192 |
+
n_tiles,
|
| 193 |
+
expert_offs_ptr,
|
| 194 |
+
BLOCK_SIZE: tl.constexpr,
|
| 195 |
+
):
|
| 196 |
+
tl.static_assert(BLOCK_SIZE <= 32768)
|
| 197 |
+
|
| 198 |
+
pid_m = tl.program_id(0)
|
| 199 |
+
offs_local = tl.arange(0, BLOCK_SIZE)
|
| 200 |
+
offs_global = pid_m * BLOCK_SIZE + offs_local
|
| 201 |
+
mask = offs_global < TK
|
| 202 |
+
|
| 203 |
+
# Load expert id for each entry in this tile.
|
| 204 |
+
expert = tl.load(selected_E_ptr + offs_global, mask=mask, other=-1).to(tl.uint32)
|
| 205 |
+
|
| 206 |
+
# Pack (expert, local_offset) into uint32 and sort by expert.
|
| 207 |
+
# Upper 16 bits = expert id, lower 16 bits = pre-sort local offset.
|
| 208 |
+
kv_pairs = tl.sort(((expert << 16) | offs_local).to(tl.uint32), 0)
|
| 209 |
+
expert = kv_pairs >> 16
|
| 210 |
+
mask = expert != 0xFFFF
|
| 211 |
+
|
| 212 |
+
# Segmented scan for within-expert rank.
|
| 213 |
+
scan_input = (kv_pairs & 0xFFFF0000) | 0x00000001
|
| 214 |
+
inclusive_run_lengths = tl.associative_scan(scan_input, 0, _keyed_add)
|
| 215 |
+
within_expert_rank = (inclusive_run_lengths - 1) & 0xFFFF
|
| 216 |
+
|
| 217 |
+
# Output position = expert_offs[e] + partial_sum[tile, e] + within_expert_rank.
|
| 218 |
+
s_reverse_scatter_val = tl.load(partial_sum_ptr + pid_m + expert * n_tiles, mask=mask)
|
| 219 |
+
s_reverse_scatter_val += tl.load(expert_offs_ptr + expert, mask=mask)
|
| 220 |
+
s_reverse_scatter_val += within_expert_rank
|
| 221 |
+
|
| 222 |
+
# Recover pre-sort entry index and look up the token index.
|
| 223 |
+
presort_offs = kv_pairs & 0xFFFF
|
| 224 |
+
entry_idx = pid_m * BLOCK_SIZE + presort_offs
|
| 225 |
+
token_idx = tl.load(sorted_selected_T_ptr + entry_idx, mask=mask)
|
| 226 |
+
|
| 227 |
+
tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_val, mask=mask)
|
| 228 |
+
tl.store(s_scatter_idx_ptr + s_reverse_scatter_val, entry_idx, mask=mask)
|
| 229 |
+
tl.store(x_gather_idx_ptr + s_reverse_scatter_val, token_idx, mask=mask)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ── general_routing_router_metadata_triton --- Kernel 4: parallel binary search for token offset ─────────────────────────
|
| 233 |
+
# Since sorted_selected_T is sorted ascending, num_activated_expert_per_token_offset[t]
|
| 234 |
+
# is exactly searchsorted_left(sorted_selected_T, t): the index of the first entry
|
| 235 |
+
# with token index >= t. We compute this via parallel binary search over T+1 queries,
|
| 236 |
+
# replacing the PyTorch bincount + cumsum path.
|
| 237 |
+
@triton.jit
|
| 238 |
+
def _token_offset_searchsorted_kernel(
|
| 239 |
+
sorted_T_ptr, # [TK] int32, sorted ascending
|
| 240 |
+
offset_ptr, # [T+1] int32, output
|
| 241 |
+
T, # number of tokens
|
| 242 |
+
TK, # length of sorted_T
|
| 243 |
+
BLOCK_SIZE: tl.constexpr,
|
| 244 |
+
N_ITERS: tl.constexpr, # ceil(log2(TK + 1)), controls binary search depth
|
| 245 |
+
):
|
| 246 |
+
pid = tl.program_id(0)
|
| 247 |
+
t_offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 248 |
+
mask = t_offs <= T # T+1 total values: offset[0], ..., offset[T]
|
| 249 |
+
|
| 250 |
+
t_vals = t_offs.to(tl.int32)
|
| 251 |
+
|
| 252 |
+
# Binary search: find smallest i such that sorted_T[i] >= t_vals
|
| 253 |
+
lo = tl.zeros([BLOCK_SIZE], dtype=tl.int32)
|
| 254 |
+
hi = tl.full([BLOCK_SIZE], TK, dtype=tl.int32)
|
| 255 |
+
|
| 256 |
+
for _ in tl.static_range(0, N_ITERS):
|
| 257 |
+
mid = (lo + hi) >> 1
|
| 258 |
+
# When mid >= TK, treat the value as +inf (>= any t), so hi = mid.
|
| 259 |
+
safe_mid = tl.where(mid < TK, mid, 0)
|
| 260 |
+
val = tl.load(sorted_T_ptr + safe_mid, mask=mask & (TK > 0), other=T)
|
| 261 |
+
go_right = (val < t_vals) & (mid < TK)
|
| 262 |
+
lo = tl.where(go_right, mid + 1, lo)
|
| 263 |
+
hi = tl.where(go_right, hi, mid)
|
| 264 |
+
|
| 265 |
+
tl.store(offset_ptr + t_offs, lo, mask=mask)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@torch.library.custom_op(
|
| 269 |
+
add_op_namespace_prefix("triton_kernels__general_routing_router_metadata"),
|
| 270 |
+
mutates_args={
|
| 271 |
+
"expert_frequency",
|
| 272 |
+
"expert_frequency_offset",
|
| 273 |
+
"x_gather_idx",
|
| 274 |
+
"s_scatter_idx",
|
| 275 |
+
"s_reverse_scatter_idx",
|
| 276 |
+
"num_activated_expert_per_token_offset",
|
| 277 |
+
},
|
| 278 |
+
)
|
| 279 |
+
def general_routing_router_metadata_triton(
|
| 280 |
+
sorted_selected_T: torch.Tensor,
|
| 281 |
+
selected_E: torch.Tensor,
|
| 282 |
+
T: int,
|
| 283 |
+
E: int,
|
| 284 |
+
expert_frequency: torch.Tensor,
|
| 285 |
+
expert_frequency_offset: torch.Tensor,
|
| 286 |
+
x_gather_idx: torch.Tensor,
|
| 287 |
+
s_scatter_idx: torch.Tensor,
|
| 288 |
+
s_reverse_scatter_idx: torch.Tensor,
|
| 289 |
+
num_activated_expert_per_token_offset: torch.Tensor,
|
| 290 |
+
) -> None:
|
| 291 |
+
TK = selected_E.size(0)
|
| 292 |
+
device = selected_E.device
|
| 293 |
+
E_POW2 = triton.next_power_of_2(E)
|
| 294 |
+
BLOCK_SIZE = 1024
|
| 295 |
+
n_tiles = triton.cdiv(TK, BLOCK_SIZE)
|
| 296 |
+
|
| 297 |
+
# ── Kernel 1: tiled histogram ─────────────────────────────────────────
|
| 298 |
+
col_partial_sum_trans = torch.empty(E, n_tiles, dtype=torch.int32, device=device)
|
| 299 |
+
_general_compute_col_partial_sum_kernel[(n_tiles,)](
|
| 300 |
+
selected_E,
|
| 301 |
+
col_partial_sum_trans,
|
| 302 |
+
TK,
|
| 303 |
+
E,
|
| 304 |
+
n_tiles,
|
| 305 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 306 |
+
E_POW2=E_POW2,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
expert_frequency.copy_(col_partial_sum_trans.sum(dim=1, dtype=torch.int32))
|
| 310 |
+
col_partial_sum = col_partial_sum_trans.T # [n_tiles, E], strides (1, n_tiles)
|
| 311 |
+
|
| 312 |
+
# ── Kernel 2: stage1 ─────────────────────────────────────────────────
|
| 313 |
+
_bitmatrix_metadata_compute_stage1[(E + 2,)](
|
| 314 |
+
expert_frequency,
|
| 315 |
+
expert_frequency_offset,
|
| 316 |
+
E,
|
| 317 |
+
col_partial_sum,
|
| 318 |
+
n_tiles,
|
| 319 |
+
TK,
|
| 320 |
+
BLOCK_M=128,
|
| 321 |
+
BLOCK_N=E_POW2,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# ── Kernel 3: stage2 ─────────────────────────────────────────────────
|
| 325 |
+
_general_metadata_compute_stage2[(n_tiles,)](
|
| 326 |
+
s_scatter_idx,
|
| 327 |
+
s_reverse_scatter_idx,
|
| 328 |
+
x_gather_idx,
|
| 329 |
+
selected_E,
|
| 330 |
+
sorted_selected_T,
|
| 331 |
+
TK,
|
| 332 |
+
col_partial_sum,
|
| 333 |
+
n_tiles,
|
| 334 |
+
expert_frequency_offset[:E],
|
| 335 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# ── Kernel 4: num_activated_expert_per_token_offset via searchsorted ──
|
| 339 |
+
# sorted_selected_T is sorted ascending, so offset[t] = searchsorted_left(sorted_T, t).
|
| 340 |
+
# Parallel binary search: each thread handles one token index, O(log TK) work.
|
| 341 |
+
N_ITERS = max(1, math.ceil(math.log2(TK + 1)))
|
| 342 |
+
TOKEN_BLOCK = 1024
|
| 343 |
+
n_token_blocks = triton.cdiv(T + 1, TOKEN_BLOCK)
|
| 344 |
+
_token_offset_searchsorted_kernel[(n_token_blocks,)](
|
| 345 |
+
sorted_selected_T,
|
| 346 |
+
num_activated_expert_per_token_offset,
|
| 347 |
+
T,
|
| 348 |
+
TK,
|
| 349 |
+
BLOCK_SIZE=TOKEN_BLOCK,
|
| 350 |
+
N_ITERS=N_ITERS,
|
| 351 |
+
)
|
build/torch-cuda/functional/triton_kernels/bitmatrix.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import triton
|
| 2 |
+
import triton.language as tl
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L33
|
| 6 |
+
@triton.jit
|
| 7 |
+
def _keyed_add(x, y):
|
| 8 |
+
# we keep the key in the upper 16 bits of a uint32:
|
| 9 |
+
key_mask: tl.constexpr = 0xFFFF0000
|
| 10 |
+
|
| 11 |
+
kx = x & key_mask
|
| 12 |
+
ky = y & key_mask
|
| 13 |
+
z = tl.where(kx == ky, x + y - kx, y)
|
| 14 |
+
return z
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Adapted from https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L44
|
| 18 |
+
@triton.jit
|
| 19 |
+
def _bitmatrix_metadata_compute_stage1(
|
| 20 |
+
expert_freq_ptr,
|
| 21 |
+
expert_freq_offs_ptr,
|
| 22 |
+
E: tl.constexpr,
|
| 23 |
+
partial_sum_ptr,
|
| 24 |
+
n_tiles,
|
| 25 |
+
TK,
|
| 26 |
+
BLOCK_M: tl.constexpr, # chunk size for iterating over tiles per expert
|
| 27 |
+
BLOCK_N: tl.constexpr, # chunk size for iterating over experts in cumsum
|
| 28 |
+
):
|
| 29 |
+
# Assume grid size == E + 1
|
| 30 |
+
|
| 31 |
+
pid = tl.program_id(0)
|
| 32 |
+
if pid < E:
|
| 33 |
+
# convert partial_sum[e, *] from raw counts to exclusive prefix
|
| 34 |
+
# sums over tiles. After this kernel, partial_sum[e, t] =
|
| 35 |
+
# number of entries for expert e in tiles 0..t-1.
|
| 36 |
+
|
| 37 |
+
# This is read by stage2 to locate each entry's position within expert e's contiguous output segment.
|
| 38 |
+
expert_partial_sum_ptr = partial_sum_ptr + pid * n_tiles
|
| 39 |
+
curr_sum = 0
|
| 40 |
+
for start in range(0, n_tiles, BLOCK_M):
|
| 41 |
+
offs = start + tl.arange(0, BLOCK_M)
|
| 42 |
+
tile_counts = tl.load(expert_partial_sum_ptr + offs, mask=offs < n_tiles, other=0)
|
| 43 |
+
excl_cumsum = tl.cumsum(tile_counts, 0) - tile_counts + curr_sum
|
| 44 |
+
curr_sum += tl.sum(tile_counts, 0)
|
| 45 |
+
tl.store(expert_partial_sum_ptr + offs, excl_cumsum, mask=offs < n_tiles)
|
| 46 |
+
elif pid == E:
|
| 47 |
+
# Exclusive prefix sum of per-expert total counts → expert_offs[e].
|
| 48 |
+
# expert_freq_offset[e] = total entries routed to expert e (from A.sum(dim=1)).
|
| 49 |
+
# expert_offs[e] = sum of expert_freq_offset[0..e-1] = global start of expert e.
|
| 50 |
+
curr_sum = 0
|
| 51 |
+
for start in tl.static_range(0, E, BLOCK_N):
|
| 52 |
+
offs = start + tl.arange(0, BLOCK_N)
|
| 53 |
+
expert_freq = tl.load(expert_freq_ptr + offs, mask=offs < E, other=0)
|
| 54 |
+
excl_cumsum = tl.cumsum(expert_freq, 0) - expert_freq + curr_sum
|
| 55 |
+
curr_sum += tl.sum(expert_freq, 0)
|
| 56 |
+
tl.store(expert_freq_offs_ptr + offs, excl_cumsum, mask=offs < E)
|
| 57 |
+
elif pid == E + 1:
|
| 58 |
+
# expert_freq_off[E] = TK (total number of entries)
|
| 59 |
+
tl.store(expert_freq_offs_ptr + E, TK)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Adapted from https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L44
|
| 63 |
+
@triton.jit
|
| 64 |
+
def _bitmatrix_metadata_compute_stage2(
|
| 65 |
+
s_scatter_idx_ptr,
|
| 66 |
+
s_reverse_scatter_idx_ptr,
|
| 67 |
+
x_gather_idx_ptr,
|
| 68 |
+
topk_indices_ptr,
|
| 69 |
+
T,
|
| 70 |
+
partial_sum_ptr,
|
| 71 |
+
n_tiles,
|
| 72 |
+
expert_offs_ptr,
|
| 73 |
+
K_POW2: tl.constexpr, # padded K, == BLOCK_SIZE / BLOCK
|
| 74 |
+
K: tl.constexpr, # actual experts per token
|
| 75 |
+
TOKENS_PER_BLOCK: tl.constexpr, # tokens per tile
|
| 76 |
+
):
|
| 77 |
+
# One CTA per tile, same tiling as _compute_col_partial_sum_kernel.
|
| 78 |
+
# For each entry (token t, k-slot k) in this tile:
|
| 79 |
+
# s_reverse_scatter_idx[entry_idx] = output position in expert-sorted order
|
| 80 |
+
# s_scatter_idx[output_pos] = entry_idx (inverse permutation)
|
| 81 |
+
# x_gather_idx[output_pos] = token index (= entry_idx // K)
|
| 82 |
+
#
|
| 83 |
+
# Output position = expert_offs[e] (global start of expert e)
|
| 84 |
+
# + partial_sum[tile, e] (entries for e in earlier tiles, after stage1)
|
| 85 |
+
# + within_expert_rank (position within this tile's group for e)
|
| 86 |
+
BLOCK_SIZE: tl.constexpr = TOKENS_PER_BLOCK * K_POW2
|
| 87 |
+
IS_POW2_K: tl.constexpr = K == K_POW2 # fast path: no padding waste
|
| 88 |
+
tl.static_assert(BLOCK_SIZE <= 32768)
|
| 89 |
+
|
| 90 |
+
pid_m = tl.program_id(0)
|
| 91 |
+
offs_local = tl.arange(0, BLOCK_SIZE) # position within this tile's flat [BLOCK*K_POW2] space
|
| 92 |
+
offs_global = pid_m * BLOCK_SIZE + offs_local
|
| 93 |
+
mask = offs_global < T * K_POW2
|
| 94 |
+
|
| 95 |
+
# Load expert id for each slot. IS_POW2_K fast path reads topk_indices as a
|
| 96 |
+
# flat 1D array (no padding gaps). Non-pow2 path reads 2D with k_slot masking.
|
| 97 |
+
if IS_POW2_K:
|
| 98 |
+
expert = tl.load(topk_indices_ptr + offs_global, mask=mask, other=-1).to(tl.uint32)
|
| 99 |
+
else:
|
| 100 |
+
token_i_local = offs_local // K_POW2
|
| 101 |
+
k_slot = offs_local % K_POW2
|
| 102 |
+
token_i_global = pid_m * TOKENS_PER_BLOCK + token_i_local
|
| 103 |
+
load_mask = mask & (k_slot < K)
|
| 104 |
+
safe_k = tl.minimum(k_slot, K - 1)
|
| 105 |
+
expert = tl.load(
|
| 106 |
+
topk_indices_ptr + token_i_global * K + safe_k,
|
| 107 |
+
mask=load_mask,
|
| 108 |
+
other=-1,
|
| 109 |
+
).to(tl.uint32)
|
| 110 |
+
|
| 111 |
+
# Pack (expert, presort_offs) into a uint32 kv pair and sort by expert.
|
| 112 |
+
# Upper 16 bits = expert id (sort key), lower 16 bits = pre-sort local offset.
|
| 113 |
+
# Invalid slots have expert=0xffff (from other=-1 cast to uint32 >> 16).
|
| 114 |
+
kv_pairs = tl.sort(((expert << 16) | offs_local).to(tl.uint32), 0)
|
| 115 |
+
expert = kv_pairs >> 16
|
| 116 |
+
mask = expert != 0xFFFF # exclude padding/OOB slots
|
| 117 |
+
|
| 118 |
+
# Segmented scan to compute within-expert rank (0-based exclusive count).
|
| 119 |
+
# scan_input packs expert id in upper 16 bits and count=1 in lower 16 bits.
|
| 120 |
+
# _keyed_add resets the count at each expert boundary.
|
| 121 |
+
scan_input = (kv_pairs & 0xFFFF0000) | 0x00000001
|
| 122 |
+
inclusive_run_lengths = tl.associative_scan(scan_input, 0, _keyed_add)
|
| 123 |
+
within_expert_rank = (inclusive_run_lengths - 1) & 0xFFFF # exclusive = inclusive - 1
|
| 124 |
+
|
| 125 |
+
# Output position for this entry in the expert-sorted output array.
|
| 126 |
+
# partial_sum layout after stage1: [n_tiles, E], stride (1, n_tiles).
|
| 127 |
+
# So partial_sum[pid_m, expert] = partial_sum_ptr + pid_m*1 + expert*n_tiles.
|
| 128 |
+
s_reverse_scatter_idx = tl.load(partial_sum_ptr + pid_m + expert * n_tiles, mask=mask)
|
| 129 |
+
s_reverse_scatter_idx += tl.load(expert_offs_ptr + expert, mask=mask)
|
| 130 |
+
s_reverse_scatter_idx += within_expert_rank
|
| 131 |
+
|
| 132 |
+
if IS_POW2_K:
|
| 133 |
+
# presort_offs == offs_local before sort; entry_idx is the flat index into
|
| 134 |
+
# topk_router_indices.view(-1), i.e. token * K + k_slot.
|
| 135 |
+
presort_offs = kv_pairs & 0xFFFF
|
| 136 |
+
entry_idx = pid_m * BLOCK_SIZE + presort_offs
|
| 137 |
+
tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_idx, mask=mask)
|
| 138 |
+
tl.store(s_scatter_idx_ptr + s_reverse_scatter_idx, entry_idx, mask=mask)
|
| 139 |
+
tl.store(x_gather_idx_ptr + s_reverse_scatter_idx, entry_idx // K_POW2, mask=mask)
|
| 140 |
+
else:
|
| 141 |
+
# presort_offs is in K_POW2-padded space; convert to unpadded entry_idx.
|
| 142 |
+
presort_offs = kv_pairs & 0xFFFF
|
| 143 |
+
token_i_global_s = pid_m * TOKENS_PER_BLOCK + presort_offs // K_POW2
|
| 144 |
+
entry_idx = token_i_global_s * K + presort_offs % K_POW2
|
| 145 |
+
tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_idx, mask=mask)
|
| 146 |
+
tl.store(s_scatter_idx_ptr + s_reverse_scatter_idx, entry_idx, mask=mask)
|
| 147 |
+
tl.store(x_gather_idx_ptr + s_reverse_scatter_idx, token_i_global_s, mask=mask)
|
build/torch-cuda/functional/utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_IS_USING_QUACK_GEMM = os.getenv("USE_QUACK_GEMM", "0") == "1"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@contextmanager
|
| 13 |
+
def enable_quack_gemm(enable: bool = True):
|
| 14 |
+
global _IS_USING_QUACK_GEMM
|
| 15 |
+
|
| 16 |
+
previous_value = _IS_USING_QUACK_GEMM
|
| 17 |
+
_IS_USING_QUACK_GEMM = enable
|
| 18 |
+
|
| 19 |
+
yield
|
| 20 |
+
|
| 21 |
+
_IS_USING_QUACK_GEMM = previous_value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def is_using_quack_gemm() -> bool:
|
| 25 |
+
return _IS_USING_QUACK_GEMM
|
build/torch-cuda/jit.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
import inspect
|
| 6 |
+
import os
|
| 7 |
+
from shutil import rmtree
|
| 8 |
+
from typing import Callable
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.cpp_extension import load as load_cpp_extension
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_CPP_MODULE_PREFIX = "sonicmoe"
|
| 16 |
+
_GLOBAL_RANK = int(os.getenv("RANK", 0))
|
| 17 |
+
_WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
| 18 |
+
|
| 19 |
+
_ALL_COMPILED_MODULES = {}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@torch.compiler.disable
|
| 23 |
+
def _get_cpp_function(function_name: str, module_name: str, source_files: list[str], build_directory: str) -> Callable:
|
| 24 |
+
module_name = f"{_CPP_MODULE_PREFIX}_{module_name}"
|
| 25 |
+
|
| 26 |
+
extra_cflags = ["-O3", "-Wall", "-shared", "-fPIC", "-fdiagnostics-color"]
|
| 27 |
+
extra_cuda_cflags = ["-O3", "-lineinfo"]
|
| 28 |
+
extra_include_paths = [
|
| 29 |
+
os.path.dirname(__file__), # sonicmoe/include
|
| 30 |
+
os.path.dirname(os.path.dirname(__file__)) + "/cutlass/include", # cutlass
|
| 31 |
+
os.path.dirname(os.path.dirname(__file__)) + "/cutlass/tools/util/include", # cutlass
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
module = _ALL_COMPILED_MODULES.get(module_name, None)
|
| 35 |
+
|
| 36 |
+
if module is None:
|
| 37 |
+
if torch.distributed.is_initialized():
|
| 38 |
+
os.makedirs(build_directory, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
if _GLOBAL_RANK == 0:
|
| 41 |
+
module = load_cpp_extension(
|
| 42 |
+
module_name,
|
| 43 |
+
sources=source_files,
|
| 44 |
+
with_cuda=True,
|
| 45 |
+
extra_cflags=extra_cflags,
|
| 46 |
+
extra_cuda_cflags=extra_cuda_cflags,
|
| 47 |
+
extra_include_paths=extra_include_paths,
|
| 48 |
+
build_directory=build_directory,
|
| 49 |
+
verbose=True,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
torch.distributed.barrier()
|
| 53 |
+
|
| 54 |
+
if _GLOBAL_RANK != 0:
|
| 55 |
+
module = load_cpp_extension(
|
| 56 |
+
module_name,
|
| 57 |
+
sources=source_files,
|
| 58 |
+
with_cuda=True,
|
| 59 |
+
extra_cflags=extra_cflags,
|
| 60 |
+
extra_cuda_cflags=extra_cuda_cflags,
|
| 61 |
+
extra_include_paths=extra_include_paths,
|
| 62 |
+
build_directory=build_directory,
|
| 63 |
+
verbose=False,
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
if _WORLD_SIZE > 1:
|
| 67 |
+
build_directory = os.path.join(build_directory, str(uuid4()))
|
| 68 |
+
|
| 69 |
+
os.makedirs(build_directory, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
module = load_cpp_extension(
|
| 72 |
+
module_name,
|
| 73 |
+
sources=source_files,
|
| 74 |
+
with_cuda=True,
|
| 75 |
+
extra_cflags=extra_cflags,
|
| 76 |
+
extra_cuda_cflags=extra_cuda_cflags,
|
| 77 |
+
extra_include_paths=extra_include_paths,
|
| 78 |
+
build_directory=build_directory,
|
| 79 |
+
verbose=True,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if _WORLD_SIZE > 1:
|
| 83 |
+
rmtree(build_directory, ignore_errors=True)
|
| 84 |
+
|
| 85 |
+
_ALL_COMPILED_MODULES[module_name] = module
|
| 86 |
+
|
| 87 |
+
return getattr(module, function_name)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def cpp_jit(
|
| 91 |
+
function_name: str | None = None,
|
| 92 |
+
extra_source_files: list[str] = [],
|
| 93 |
+
build_directory: str | None = None,
|
| 94 |
+
depth: int = 0,
|
| 95 |
+
) -> Callable:
|
| 96 |
+
"""wrapper to compile C++/CUDA source code at runtime.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
function_name (str | None, optional): name of the function to expose from the C++ file, the python function
|
| 100 |
+
name should match the funcion name in the C++ file if this is not specified. Defaults to None.
|
| 101 |
+
extra_source_files (list[str], optional): any extra files to use for compilation, by default it scans the
|
| 102 |
+
directory of the python stub file. Defaults to [].
|
| 103 |
+
build_directory (str | None, optional): directory in which to place the build artifacts. Defaults to None.
|
| 104 |
+
depth (int, optional): number of times dirname is called to get the build path. Defaults to 2.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Callable: returns the wrapped function that can be used to call the C++ functions from python
|
| 108 |
+
"""
|
| 109 |
+
cpp_function = None
|
| 110 |
+
args_spec = None
|
| 111 |
+
|
| 112 |
+
source_files = []
|
| 113 |
+
source_files.extend(extra_source_files)
|
| 114 |
+
|
| 115 |
+
calling_filename = inspect.stack()[1].filename
|
| 116 |
+
calling_directory = os.path.dirname(calling_filename)
|
| 117 |
+
|
| 118 |
+
for dirname, _, filenames in os.walk(calling_directory):
|
| 119 |
+
filenames = [os.path.join(dirname, f) for f in filenames]
|
| 120 |
+
filenames = filter(lambda f: os.path.splitext(f)[1] in [".cu", ".cpp"], filenames)
|
| 121 |
+
source_files.extend(filenames)
|
| 122 |
+
|
| 123 |
+
if build_directory is None:
|
| 124 |
+
module_name = calling_directory
|
| 125 |
+
for _ in range(depth):
|
| 126 |
+
module_name = os.path.dirname(module_name)
|
| 127 |
+
module_name = os.path.basename(module_name)
|
| 128 |
+
|
| 129 |
+
build_directory = os.path.join(os.path.dirname(os.path.dirname(__file__)), "build", module_name)
|
| 130 |
+
|
| 131 |
+
def _run(*args, **kwargs):
|
| 132 |
+
nonlocal cpp_function
|
| 133 |
+
|
| 134 |
+
if cpp_function is None:
|
| 135 |
+
cpp_function = _get_cpp_function(
|
| 136 |
+
function_name=_run.__name__,
|
| 137 |
+
module_name=module_name,
|
| 138 |
+
source_files=source_files,
|
| 139 |
+
build_directory=build_directory,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
full_args = []
|
| 143 |
+
full_args.extend(args)
|
| 144 |
+
for variable_name in args_spec.args[len(args) :]:
|
| 145 |
+
full_args.append(kwargs[variable_name])
|
| 146 |
+
|
| 147 |
+
return cpp_function(*full_args)
|
| 148 |
+
|
| 149 |
+
def _wrapper(function: Callable) -> Callable:
|
| 150 |
+
nonlocal args_spec
|
| 151 |
+
args_spec = inspect.getfullargspec(function)
|
| 152 |
+
|
| 153 |
+
_run.__doc__ = function.__doc__
|
| 154 |
+
_run.__name__ = function.__name__ if function_name is None else function_name
|
| 155 |
+
_run.__signature__ = inspect.signature(function)
|
| 156 |
+
|
| 157 |
+
return _run
|
| 158 |
+
|
| 159 |
+
return _wrapper
|
build/torch-cuda/metadata.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 1,
|
| 3 |
+
"license": "Apache-2.0",
|
| 4 |
+
"python-depends": [
|
| 5 |
+
"nvidia-cutlass-dsl"
|
| 6 |
+
],
|
| 7 |
+
"backend": {
|
| 8 |
+
"type": "cuda"
|
| 9 |
+
}
|
| 10 |
+
}
|
build/torch-cuda/moe.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ********************************************************************************
|
| 2 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
+
# ********************************************************************************
|
| 4 |
+
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from .enums import ActivationType, KernelBackendMoE, is_glu
|
| 12 |
+
from .functional import moe_TC_softmax_topk_layer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from xma.modules.moe import scattered_experts
|
| 17 |
+
|
| 18 |
+
_IS_XMA_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
_IS_XMA_AVAILABLE = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _swiglu(x: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
u = x[..., 1::2]
|
| 25 |
+
g = x[..., ::2]
|
| 26 |
+
return u * F.silu(g)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _geglu(x: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
u = x[..., 1::2]
|
| 31 |
+
g = x[..., ::2]
|
| 32 |
+
return (F.gelu(g.to(dtype=torch.float32)) * u).to(dtype=g.dtype)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _gelu(x: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
return F.gelu(x.to(dtype=torch.float32)).to(dtype=x.dtype)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _reglu(x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
u = x[..., 1::2]
|
| 41 |
+
g = x[..., ::2]
|
| 42 |
+
return (F.relu(g) * u).to(dtype=g.dtype)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _relu(x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
return F.relu(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _relu_sq(x: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
return F.relu(x) ** 2
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _silu(x: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
return F.silu(x)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Experts(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self, num_experts: int, in_features: int, out_features: int, add_bias: bool = True, std: float | None = None
|
| 60 |
+
) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features))
|
| 64 |
+
|
| 65 |
+
self.bias = None
|
| 66 |
+
if add_bias:
|
| 67 |
+
self.bias = nn.Parameter(torch.empty(num_experts, out_features))
|
| 68 |
+
|
| 69 |
+
self.std = std
|
| 70 |
+
|
| 71 |
+
self.num_experts = num_experts
|
| 72 |
+
self.in_features = in_features
|
| 73 |
+
self.out_features = out_features
|
| 74 |
+
|
| 75 |
+
self.reset_parameters()
|
| 76 |
+
|
| 77 |
+
def up_projection_scattermoe_forward(
|
| 78 |
+
self,
|
| 79 |
+
input: torch.Tensor,
|
| 80 |
+
num_experts_per_token: int | None = None,
|
| 81 |
+
sorted_expert_idxs: torch.Tensor | None = None,
|
| 82 |
+
sorted_scattered_idxs: torch.Tensor | None = None,
|
| 83 |
+
expert_offsets: torch.Tensor | None = None,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
assert self.bias is None
|
| 86 |
+
|
| 87 |
+
if not _IS_XMA_AVAILABLE:
|
| 88 |
+
raise ImportError(
|
| 89 |
+
"install accelerated-model-architectures from https://github.com/open-lm-engine/accelerated-model-architectures"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
input = scattered_experts(
|
| 93 |
+
inputs=input,
|
| 94 |
+
expert_weights=self.weight.permute(0, 2, 1),
|
| 95 |
+
k=num_experts_per_token,
|
| 96 |
+
sorted_expert_idxs=sorted_expert_idxs,
|
| 97 |
+
sorted_scattered_idxs=sorted_scattered_idxs,
|
| 98 |
+
expert_offsets=expert_offsets,
|
| 99 |
+
gates=None,
|
| 100 |
+
grouped_in=False,
|
| 101 |
+
grouped_out=True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return input
|
| 105 |
+
|
| 106 |
+
def down_projection_scattermoe_forward(
|
| 107 |
+
self,
|
| 108 |
+
input: torch.Tensor,
|
| 109 |
+
num_experts_per_token: int | None = None,
|
| 110 |
+
sorted_expert_idxs: torch.Tensor | None = None,
|
| 111 |
+
sorted_scattered_idxs: torch.Tensor | None = None,
|
| 112 |
+
expert_offsets: torch.Tensor | None = None,
|
| 113 |
+
gates: torch.Tensor | None = None,
|
| 114 |
+
) -> torch.Tensor:
|
| 115 |
+
assert self.bias is None
|
| 116 |
+
|
| 117 |
+
if not _IS_XMA_AVAILABLE:
|
| 118 |
+
raise ImportError(
|
| 119 |
+
"install accelerated-model-architectures from https://github.com/open-lm-engine/accelerated-model-architectures"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
input = scattered_experts(
|
| 123 |
+
inputs=input,
|
| 124 |
+
expert_weights=self.weight.permute(0, 2, 1),
|
| 125 |
+
k=num_experts_per_token,
|
| 126 |
+
sorted_expert_idxs=sorted_expert_idxs,
|
| 127 |
+
sorted_scattered_idxs=sorted_scattered_idxs,
|
| 128 |
+
expert_offsets=expert_offsets,
|
| 129 |
+
gates=gates,
|
| 130 |
+
grouped_in=True,
|
| 131 |
+
grouped_out=False,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return input
|
| 135 |
+
|
| 136 |
+
def torch_forward(
|
| 137 |
+
self, input: torch.Tensor, expert_frequency: torch.Tensor | None, return_list: bool = False
|
| 138 |
+
) -> list[torch.Tensor] | torch.Tensor:
|
| 139 |
+
if isinstance(input, torch.Tensor):
|
| 140 |
+
input = input.split(expert_frequency.tolist(), dim=0)
|
| 141 |
+
else:
|
| 142 |
+
assert expert_frequency is None
|
| 143 |
+
|
| 144 |
+
input = [
|
| 145 |
+
F.linear(input[i], self.weight[i], None if self.bias is None else self.bias[i])
|
| 146 |
+
for i in range(self.num_experts)
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
if not return_list:
|
| 150 |
+
input = torch.cat(input, dim=0)
|
| 151 |
+
|
| 152 |
+
return input
|
| 153 |
+
|
| 154 |
+
def extra_repr(self):
|
| 155 |
+
return "num_experts={}, in_features={}, out_features={}".format(
|
| 156 |
+
self.num_experts, self.in_features, self.out_features
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
@torch.no_grad()
|
| 160 |
+
def reset_parameters(self) -> None:
|
| 161 |
+
nn.init.normal_(self.weight, mean=0, std=self.std)
|
| 162 |
+
if hasattr(self, "bias") and self.bias is not None:
|
| 163 |
+
self.bias.zero_()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class MoE(nn.Module):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
num_experts: int,
|
| 170 |
+
num_experts_per_tok: int,
|
| 171 |
+
hidden_size: int,
|
| 172 |
+
intermediate_size: int,
|
| 173 |
+
activation_function: ActivationType,
|
| 174 |
+
add_bias: bool,
|
| 175 |
+
std: float,
|
| 176 |
+
) -> None:
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.num_experts = num_experts
|
| 180 |
+
self.top_k = num_experts_per_tok
|
| 181 |
+
|
| 182 |
+
self.hidden_size = hidden_size
|
| 183 |
+
self.intermediate_size = intermediate_size
|
| 184 |
+
|
| 185 |
+
self.router = nn.Linear(in_features=self.hidden_size, out_features=num_experts, bias=False)
|
| 186 |
+
|
| 187 |
+
self.activation_function = activation_function
|
| 188 |
+
|
| 189 |
+
self.c_fc = Experts(
|
| 190 |
+
num_experts=num_experts,
|
| 191 |
+
in_features=self.hidden_size,
|
| 192 |
+
out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size,
|
| 193 |
+
add_bias=add_bias,
|
| 194 |
+
std=std,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.c_proj = Experts(
|
| 198 |
+
num_experts=num_experts,
|
| 199 |
+
in_features=self.intermediate_size,
|
| 200 |
+
out_features=self.hidden_size,
|
| 201 |
+
add_bias=add_bias,
|
| 202 |
+
std=std,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
self.stream_id = torch.cuda.current_stream().cuda_stream
|
| 206 |
+
|
| 207 |
+
def forward(
|
| 208 |
+
self,
|
| 209 |
+
hidden_states: torch.Tensor,
|
| 210 |
+
kernel_backend_moe: KernelBackendMoE = KernelBackendMoE.sonicmoe,
|
| 211 |
+
is_inference_mode: bool = False,
|
| 212 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 213 |
+
original_shape = hidden_states.shape
|
| 214 |
+
|
| 215 |
+
# hidden_states -> (batch_size, query_length, hidden_size)
|
| 216 |
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
| 217 |
+
|
| 218 |
+
if kernel_backend_moe == KernelBackendMoE.sonicmoe and self.num_experts <= 32768:
|
| 219 |
+
hidden_states, router_logits, expert_frequency = moe_TC_softmax_topk_layer(
|
| 220 |
+
hidden_states,
|
| 221 |
+
self.router.weight,
|
| 222 |
+
self.c_fc.weight.permute(1, 2, 0),
|
| 223 |
+
self.c_fc.bias,
|
| 224 |
+
self.c_proj.weight.permute(1, 2, 0),
|
| 225 |
+
self.c_proj.bias,
|
| 226 |
+
self.top_k,
|
| 227 |
+
self.stream_id,
|
| 228 |
+
self.activation_function,
|
| 229 |
+
is_inference_mode or not self.training,
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
# hidden_states -> (total_q, hidden_size)
|
| 233 |
+
router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states)
|
| 234 |
+
|
| 235 |
+
# router_logits -> (total_q, num_experts)
|
| 236 |
+
# router_weights -> (total_q, top_k)
|
| 237 |
+
# selected_experts -> (total_q, top_k)
|
| 238 |
+
|
| 239 |
+
hidden_states, expert_frequency = self._compute_experts(
|
| 240 |
+
hidden_states,
|
| 241 |
+
router_weights,
|
| 242 |
+
selected_experts,
|
| 243 |
+
kernel_backend_moe=kernel_backend_moe,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
hidden_states = hidden_states.view(original_shape)
|
| 247 |
+
|
| 248 |
+
# hidden_states -> (batch_size, query_length, hidden_size)
|
| 249 |
+
|
| 250 |
+
if is_inference_mode:
|
| 251 |
+
aux_loss = None
|
| 252 |
+
else:
|
| 253 |
+
aux_loss = self._compute_switch_loss(
|
| 254 |
+
logits=router_logits,
|
| 255 |
+
probs=F.softmax(router_logits, dim=-1, dtype=torch.float32),
|
| 256 |
+
expert_frequency=expert_frequency,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return hidden_states, aux_loss
|
| 260 |
+
|
| 261 |
+
# copied from https://github.com/open-lm-engine/lm-engine/blob/1447883df709727839bbbb367ce727fa56962a6a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py#L432-L455
|
| 262 |
+
# NOTE we don't do all_reduce here for expert frequency for simplicity across data parallel workers
|
| 263 |
+
def _compute_switch_loss(
|
| 264 |
+
self, logits: torch.Tensor, probs: torch.Tensor, expert_frequency: torch.Tensor
|
| 265 |
+
) -> torch.Tensor:
|
| 266 |
+
logits = logits.view(-1, logits.size(-1))
|
| 267 |
+
probs = probs.view(-1, probs.size(-1))
|
| 268 |
+
|
| 269 |
+
num_experts = logits.size(1)
|
| 270 |
+
acc_probs = probs.sum(0)
|
| 271 |
+
|
| 272 |
+
expert_frequency = expert_frequency.float()
|
| 273 |
+
|
| 274 |
+
aux_loss = num_experts * (F.normalize(acc_probs, p=1, dim=0) * F.normalize(expert_frequency, p=1, dim=0)).sum()
|
| 275 |
+
|
| 276 |
+
return aux_loss
|
| 277 |
+
|
| 278 |
+
def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]:
|
| 279 |
+
# hidden_states -> (total_q, hidden_size)
|
| 280 |
+
router_logits = self.router(hidden_states)
|
| 281 |
+
# router_logits -> (total_q, num_experts)
|
| 282 |
+
|
| 283 |
+
router_weights, selected_experts = self._get_topk(router_logits)
|
| 284 |
+
|
| 285 |
+
# router_weights -> (total_q, top_k)
|
| 286 |
+
# selected_experts -> (total_q, top_k)
|
| 287 |
+
|
| 288 |
+
router_weights = F.softmax(router_weights.float(), dim=-1)
|
| 289 |
+
router_weights = router_weights.type_as(hidden_states)
|
| 290 |
+
|
| 291 |
+
return router_logits, router_weights, selected_experts
|
| 292 |
+
|
| 293 |
+
def _compute_experts(
|
| 294 |
+
self,
|
| 295 |
+
hidden_states: torch.Tensor,
|
| 296 |
+
router_weights: torch.Tensor,
|
| 297 |
+
selected_experts: torch.Tensor,
|
| 298 |
+
kernel_backend_moe: KernelBackendMoE,
|
| 299 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 300 |
+
selected_experts = selected_experts.flatten()
|
| 301 |
+
|
| 302 |
+
with torch.no_grad():
|
| 303 |
+
sorted_expert_idxs, sorted_scattered_idxs = selected_experts.sort()
|
| 304 |
+
|
| 305 |
+
expert_frequency = selected_experts.bincount(minlength=self.num_experts).to(torch.int32)
|
| 306 |
+
expert_offsets = expert_frequency.cumsum(-1).to(torch.int32)
|
| 307 |
+
|
| 308 |
+
act_func = {
|
| 309 |
+
ActivationType.SWIGLU: _swiglu,
|
| 310 |
+
ActivationType.GEGLU: _geglu,
|
| 311 |
+
ActivationType.REGLU: _reglu,
|
| 312 |
+
ActivationType.GELU: _gelu,
|
| 313 |
+
ActivationType.RELU: _relu,
|
| 314 |
+
ActivationType.SILU: _silu,
|
| 315 |
+
ActivationType.RELU_SQ: _relu_sq,
|
| 316 |
+
}[self.activation_function]
|
| 317 |
+
|
| 318 |
+
T = hidden_states.size(0)
|
| 319 |
+
|
| 320 |
+
if kernel_backend_moe == KernelBackendMoE.scattermoe:
|
| 321 |
+
hidden_states = self.c_fc.up_projection_scattermoe_forward(
|
| 322 |
+
input=hidden_states,
|
| 323 |
+
num_experts_per_token=self.top_k,
|
| 324 |
+
sorted_expert_idxs=sorted_expert_idxs,
|
| 325 |
+
sorted_scattered_idxs=sorted_scattered_idxs,
|
| 326 |
+
expert_offsets=expert_offsets,
|
| 327 |
+
)
|
| 328 |
+
hidden_states = act_func(hidden_states)
|
| 329 |
+
hidden_states = self.c_proj.down_projection_scattermoe_forward(
|
| 330 |
+
input=hidden_states,
|
| 331 |
+
num_experts_per_token=1,
|
| 332 |
+
sorted_expert_idxs=sorted_expert_idxs,
|
| 333 |
+
sorted_scattered_idxs=sorted_scattered_idxs,
|
| 334 |
+
expert_offsets=expert_offsets,
|
| 335 |
+
gates=router_weights,
|
| 336 |
+
)
|
| 337 |
+
elif kernel_backend_moe == KernelBackendMoE.torch:
|
| 338 |
+
# sort and group input tokens according to expert assignment
|
| 339 |
+
fan_in_index = sorted_scattered_idxs // self.top_k
|
| 340 |
+
|
| 341 |
+
# gather the gate values for grouped input tokens
|
| 342 |
+
router_weights = router_weights.flatten()
|
| 343 |
+
batch_gates = router_weights[sorted_scattered_idxs]
|
| 344 |
+
|
| 345 |
+
hidden_states = hidden_states[fan_in_index]
|
| 346 |
+
|
| 347 |
+
hidden_states = self.c_fc.torch_forward(
|
| 348 |
+
input=hidden_states, expert_frequency=expert_frequency, return_list=True
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
hidden_states = [act_func(i) for i in hidden_states]
|
| 352 |
+
hidden_states = self.c_proj.torch_forward(input=hidden_states, expert_frequency=None, return_list=False)
|
| 353 |
+
|
| 354 |
+
hidden_states = hidden_states * batch_gates.unsqueeze(-1)
|
| 355 |
+
zeros = torch.zeros((T, self.hidden_size), dtype=torch.float32, device=hidden_states.device)
|
| 356 |
+
hidden_states = zeros.index_add(0, fan_in_index, hidden_states)
|
| 357 |
+
else:
|
| 358 |
+
raise ValueError(f"unexpected kernel_backend_moe ({kernel_backend_moe})")
|
| 359 |
+
|
| 360 |
+
return hidden_states, expert_frequency
|
| 361 |
+
|
| 362 |
+
def _get_topk(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 363 |
+
if self.top_k == 1:
|
| 364 |
+
x, indices = x.max(dim=-1, keepdim=True)
|
| 365 |
+
else:
|
| 366 |
+
x, indices = x.topk(self.top_k, dim=-1)
|
| 367 |
+
|
| 368 |
+
return x, indices
|
build/torch-cuda/quack/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.2.5"
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
|
| 6 |
+
from . import cute_dsl_ptxas
|
| 7 |
+
|
| 8 |
+
cute_dsl_ptxas.patch()
|
build/torch-cuda/quack/_ops_compat.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .._ops_compat import add_op_namespace_prefix
|
| 2 |
+
|
| 3 |
+
def add_quack_op_namespace_prefix(name: str) -> str:
|
| 4 |
+
return add_op_namespace_prefix(f"quack__{name}")
|
build/torch-cuda/quack/activation.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
from cutlass import Float32, Boolean, const_expr
|
| 8 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 9 |
+
from cutlass._mlir.dialects import llvm
|
| 10 |
+
|
| 11 |
+
from . import utils as utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dsl_user_op
|
| 18 |
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
| 19 |
+
return Float32(
|
| 20 |
+
llvm.inline_asm(
|
| 21 |
+
T.f32(),
|
| 22 |
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
| 23 |
+
"tanh.approx.f32 $0, $1;",
|
| 24 |
+
"=f,f",
|
| 25 |
+
has_side_effects=False,
|
| 26 |
+
is_align_stack=False,
|
| 27 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 28 |
+
)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dsl_user_op
|
| 33 |
+
def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 34 |
+
if const_expr(not isinstance(x, tuple)):
|
| 35 |
+
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
| 36 |
+
return 0.5 + 0.5 * tanh(0.5 * x)
|
| 37 |
+
else:
|
| 38 |
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
|
| 39 |
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 40 |
+
return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dsl_user_op
|
| 44 |
+
def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
| 45 |
+
# return dout * out * (1.0 - out)
|
| 46 |
+
return dout * (out - out * out)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dsl_user_op
|
| 50 |
+
def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 51 |
+
if const_expr(not isinstance(x, tuple)):
|
| 52 |
+
return cute.arch.fmax(x, Float32(0.0))
|
| 53 |
+
else:
|
| 54 |
+
return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dsl_user_op
|
| 58 |
+
@cute.jit
|
| 59 |
+
def drelu(
|
| 60 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 61 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 62 |
+
if const_expr(not isinstance(x, tuple)):
|
| 63 |
+
x_pos = Boolean(x > 0)
|
| 64 |
+
return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
|
| 65 |
+
else:
|
| 66 |
+
x0_pos = Boolean(x[0] > 0)
|
| 67 |
+
x1_pos = Boolean(x[1] > 0)
|
| 68 |
+
dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
|
| 69 |
+
return dx, relu(x)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dsl_user_op
|
| 73 |
+
def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 74 |
+
if const_expr(not isinstance(x, tuple)):
|
| 75 |
+
return cute.arch.fmax(x, Float32(0.0)) * x
|
| 76 |
+
else:
|
| 77 |
+
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
| 78 |
+
return utils.mul_packed_f32x2(relu_x, x)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@dsl_user_op
|
| 82 |
+
@cute.jit
|
| 83 |
+
def drelu_sq(
|
| 84 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 85 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 86 |
+
"""
|
| 87 |
+
ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
|
| 88 |
+
Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
|
| 89 |
+
Returns: (dx, relu_sq_out) where:
|
| 90 |
+
- dx = dout * 2 * x if x > 0, else 0
|
| 91 |
+
- relu_sq_out = max(x, 0) * x
|
| 92 |
+
"""
|
| 93 |
+
if const_expr(not isinstance(x, tuple)):
|
| 94 |
+
relu_x = relu(x)
|
| 95 |
+
relu_sq_out = relu_x * x
|
| 96 |
+
# Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
|
| 97 |
+
dx = 2.0 * (dout * relu_x)
|
| 98 |
+
return dx, relu_sq_out
|
| 99 |
+
else:
|
| 100 |
+
relu_x = relu(x)
|
| 101 |
+
relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
|
| 102 |
+
dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
|
| 103 |
+
return dx, relu_sq_out
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dsl_user_op
|
| 107 |
+
def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 108 |
+
"""
|
| 109 |
+
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
| 110 |
+
= 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
|
| 111 |
+
"""
|
| 112 |
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
| 113 |
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
| 114 |
+
if const_expr(not isinstance(x, tuple)):
|
| 115 |
+
return 0.5 * (
|
| 116 |
+
x
|
| 117 |
+
# Currently cute.math.tanh(x, fastmath=True) generates very slow code
|
| 118 |
+
# * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
|
| 119 |
+
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
x_sq = utils.mul_packed_f32x2(x, x)
|
| 123 |
+
x_sq_scaled = utils.fma_packed_f32x2(
|
| 124 |
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 125 |
+
)
|
| 126 |
+
z = utils.mul_packed_f32x2(x, x_sq_scaled)
|
| 127 |
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 128 |
+
x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
|
| 129 |
+
return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@dsl_user_op
|
| 133 |
+
def dgelu_tanh_approx(
|
| 134 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 135 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 136 |
+
"""
|
| 137 |
+
GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
|
| 138 |
+
Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
|
| 139 |
+
Returns: (dx, gelu_out)
|
| 140 |
+
|
| 141 |
+
Derivative uses the chain rule:
|
| 142 |
+
d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 143 |
+
where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
|
| 144 |
+
and sech^2(z) = 1 - tanh^2(z)
|
| 145 |
+
"""
|
| 146 |
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
|
| 147 |
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
|
| 148 |
+
sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
|
| 149 |
+
|
| 150 |
+
if const_expr(not isinstance(x, tuple)):
|
| 151 |
+
# Compute z = x * (c1 + c2 * x^2)
|
| 152 |
+
x_sq = x * x
|
| 153 |
+
# tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
|
| 154 |
+
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
| 155 |
+
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
| 156 |
+
gelu_out = x * half_tanh_z_plus_one
|
| 157 |
+
|
| 158 |
+
# Compute gradient
|
| 159 |
+
# sech^2(z) = 1 - tanh^2(z)
|
| 160 |
+
sech2_z = 1 - tanh_z * tanh_z
|
| 161 |
+
# dz/dx = c1 + 3 * c2 * x^2
|
| 162 |
+
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
| 163 |
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 164 |
+
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
| 165 |
+
|
| 166 |
+
dx = dout * dgelu
|
| 167 |
+
return dx, gelu_out
|
| 168 |
+
else:
|
| 169 |
+
# Compute z = x * (c1 + c2 * x^2)
|
| 170 |
+
x_sq = utils.mul_packed_f32x2(x, x)
|
| 171 |
+
x_sq_scaled = utils.fma_packed_f32x2(
|
| 172 |
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 173 |
+
)
|
| 174 |
+
z = utils.mul_packed_f32x2(x, x_sq_scaled)
|
| 175 |
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 176 |
+
half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
| 177 |
+
gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
| 178 |
+
|
| 179 |
+
# Compute gradient
|
| 180 |
+
# sech^2(z) = 1 - tanh^2(z)
|
| 181 |
+
sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
| 182 |
+
# dz/dx = c1 + 3 * c2 * x^2
|
| 183 |
+
dz_dx = utils.fma_packed_f32x2(
|
| 184 |
+
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 185 |
+
)
|
| 186 |
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 187 |
+
sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
|
| 188 |
+
x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
|
| 189 |
+
dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
| 190 |
+
|
| 191 |
+
dx = utils.mul_packed_f32x2(dout, dgelu)
|
| 192 |
+
return dx, gelu_out
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@dsl_user_op
|
| 196 |
+
@cute.jit
|
| 197 |
+
def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 198 |
+
if const_expr(not isinstance(x, tuple)):
|
| 199 |
+
use_linear = Boolean(x > 20.0)
|
| 200 |
+
return (
|
| 201 |
+
cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
|
| 202 |
+
if not use_linear
|
| 203 |
+
else x
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
log2_e = math.log2(math.e)
|
| 207 |
+
x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
|
| 208 |
+
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
| 209 |
+
x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
|
| 210 |
+
log_x_exp_p1 = (
|
| 211 |
+
cute.math.log2(x_exp_p1[0], fastmath=True),
|
| 212 |
+
cute.math.log2(x_exp_p1[1], fastmath=True),
|
| 213 |
+
)
|
| 214 |
+
ln2 = math.log(2.0)
|
| 215 |
+
softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
| 216 |
+
use_linear_0 = Boolean(x[0] > 20.0)
|
| 217 |
+
use_linear_1 = Boolean(x[1] > 20.0)
|
| 218 |
+
return (
|
| 219 |
+
softplus_x[0] if not use_linear_0 else x[0],
|
| 220 |
+
softplus_x[1] if not use_linear_1 else x[1],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@dsl_user_op
|
| 225 |
+
@cute.jit
|
| 226 |
+
def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
| 227 |
+
use_linear = Boolean(out > 20.0)
|
| 228 |
+
# dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
|
| 229 |
+
dx = dout - dout * cute.math.exp(-out, fastmath=True)
|
| 230 |
+
return dx if not use_linear else dout
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@dsl_user_op
|
| 234 |
+
def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
|
| 235 |
+
"""
|
| 236 |
+
silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
|
| 237 |
+
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
| 238 |
+
"""
|
| 239 |
+
if const_expr(not isinstance(x, tuple)):
|
| 240 |
+
x_half = 0.5 * x if const_expr(not already_halved) else x
|
| 241 |
+
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
| 242 |
+
return x_half * tanh(x_half) + x_half
|
| 243 |
+
else:
|
| 244 |
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
| 245 |
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 246 |
+
return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@dsl_user_op
|
| 250 |
+
def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 251 |
+
if const_expr(not isinstance(x, tuple)):
|
| 252 |
+
return silu(x) * y
|
| 253 |
+
else:
|
| 254 |
+
return utils.mul_packed_f32x2(silu(x), y)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@dsl_user_op
|
| 258 |
+
def dswiglu(
|
| 259 |
+
x: F32_or_F32x2,
|
| 260 |
+
y: F32_or_F32x2,
|
| 261 |
+
dout: F32_or_F32x2,
|
| 262 |
+
*,
|
| 263 |
+
already_halved: bool = False,
|
| 264 |
+
loc=None,
|
| 265 |
+
ip=None,
|
| 266 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 267 |
+
"""
|
| 268 |
+
SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 269 |
+
Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
|
| 270 |
+
Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
|
| 271 |
+
|
| 272 |
+
d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
| 273 |
+
|
| 274 |
+
This has been optimized to use fewer instructions (i.e. we expand things out
|
| 275 |
+
to use FFMA instead of FADD and FMUL).
|
| 276 |
+
"""
|
| 277 |
+
if const_expr(not isinstance(x, tuple)):
|
| 278 |
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
| 279 |
+
# FMUL, MUFU.TANH, then FFMA
|
| 280 |
+
if const_expr(not already_halved):
|
| 281 |
+
sigmoid_x = sigmoid(x)
|
| 282 |
+
silu_x = x * sigmoid_x # FMUL
|
| 283 |
+
else:
|
| 284 |
+
tanh_x = tanh(x) # MUFU.TANH
|
| 285 |
+
sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
|
| 286 |
+
silu_x = x * tanh_x + x # FFMA
|
| 287 |
+
silu_x_dout = silu_x * dout # FMUL
|
| 288 |
+
# d_silu(x) * dout
|
| 289 |
+
# = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
|
| 290 |
+
# = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
|
| 291 |
+
# = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
|
| 292 |
+
# = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
|
| 293 |
+
# = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
| 294 |
+
d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
|
| 295 |
+
dx = d_silu_x_dout * y # FMUL
|
| 296 |
+
dy = silu_x_dout
|
| 297 |
+
swiglu_out = silu_x * y # FMUL
|
| 298 |
+
# Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
|
| 299 |
+
return dx, dy, swiglu_out
|
| 300 |
+
else:
|
| 301 |
+
# Compute sigmoid(x) and silu(x)
|
| 302 |
+
if const_expr(not already_halved):
|
| 303 |
+
sigmoid_x = sigmoid(x)
|
| 304 |
+
silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
|
| 305 |
+
else:
|
| 306 |
+
tanh_x = (tanh(x[0]), tanh(x[1]))
|
| 307 |
+
sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
| 308 |
+
silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
|
| 309 |
+
silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
|
| 310 |
+
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
| 311 |
+
sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
|
| 312 |
+
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
| 313 |
+
)
|
| 314 |
+
d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
|
| 315 |
+
dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
|
| 316 |
+
dy = silu_x_dout
|
| 317 |
+
swiglu_out = utils.mul_packed_f32x2(silu_x, y)
|
| 318 |
+
return dx, dy, swiglu_out
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
@dsl_user_op
|
| 322 |
+
def swiglu_oai(
|
| 323 |
+
x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
| 324 |
+
) -> F32_or_F32x2:
|
| 325 |
+
"""The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
|
| 326 |
+
https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
|
| 327 |
+
x * sigmoid(alpha * x) * (y + 1)
|
| 328 |
+
Compile down to FMUL, FMUL, TANH, FFMA, FFMA
|
| 329 |
+
"""
|
| 330 |
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
| 331 |
+
if const_expr(not isinstance(x, tuple)):
|
| 332 |
+
x_half = 0.5 * x
|
| 333 |
+
# silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
|
| 334 |
+
silu_x = x_half * tanh(alpha * x_half) + x_half
|
| 335 |
+
return silu_x * y + silu_x
|
| 336 |
+
else:
|
| 337 |
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
|
| 338 |
+
alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
|
| 339 |
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 340 |
+
silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
| 341 |
+
return utils.fma_packed_f32x2(silu_x, y, silu_x)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@dsl_user_op
|
| 345 |
+
def dswiglu_oai(
|
| 346 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
| 347 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 348 |
+
"""
|
| 349 |
+
Swiglu OAI backward pass: computes gradients w.r.t. x and y
|
| 350 |
+
Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
|
| 351 |
+
Returns: (dx, dy, swiglu_oai_out)
|
| 352 |
+
|
| 353 |
+
Derivative of x * sigmoid(alpha * x) w.r.t. x:
|
| 354 |
+
d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
|
| 355 |
+
"""
|
| 356 |
+
if const_expr(not isinstance(x, tuple)):
|
| 357 |
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
| 358 |
+
alpha_x_half = (0.5 * alpha) * x # FMUL
|
| 359 |
+
# MUFU.TANH, then FFMA
|
| 360 |
+
# sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
|
| 361 |
+
sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
|
| 362 |
+
silu_x = x * sigmoid_alpha_x # FMUL
|
| 363 |
+
silu_x_dout = silu_x * dout # FMUL
|
| 364 |
+
# FFMA, FFMA, FMUL
|
| 365 |
+
d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
| 366 |
+
dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
|
| 367 |
+
dy = silu_x_dout
|
| 368 |
+
swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
|
| 369 |
+
# Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
|
| 370 |
+
return dx, dy, swiglu_out
|
| 371 |
+
else:
|
| 372 |
+
# Compute sigmoid(alpha * x)
|
| 373 |
+
alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
| 374 |
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 375 |
+
sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
| 376 |
+
silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
|
| 377 |
+
silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
|
| 378 |
+
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
| 379 |
+
silu_x_minus_product = utils.fma_packed_f32x2(
|
| 380 |
+
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
| 381 |
+
)
|
| 382 |
+
sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
|
| 383 |
+
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
| 384 |
+
)
|
| 385 |
+
d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
| 386 |
+
dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
| 387 |
+
dy = silu_x_dout
|
| 388 |
+
swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
|
| 389 |
+
return dx, dy, swiglu_out
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@dsl_user_op
|
| 393 |
+
def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 394 |
+
"""GLU: Gated Linear Unit
|
| 395 |
+
glu(x, y) = sigmoid(x) * y
|
| 396 |
+
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
| 397 |
+
"""
|
| 398 |
+
if const_expr(not isinstance(x, tuple)):
|
| 399 |
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
| 400 |
+
return sigmoid_x * y # FMUL
|
| 401 |
+
else:
|
| 402 |
+
sigmoid_x = sigmoid(x)
|
| 403 |
+
return utils.mul_packed_f32x2(sigmoid_x, y)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
@dsl_user_op
|
| 407 |
+
def dglu(
|
| 408 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 409 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 410 |
+
"""
|
| 411 |
+
GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 412 |
+
Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
|
| 413 |
+
Returns: (dx, dy, glu_out) where:
|
| 414 |
+
- dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
|
| 415 |
+
- dy = dout * sigmoid(x)
|
| 416 |
+
- glu_out = sigmoid(x) * y
|
| 417 |
+
"""
|
| 418 |
+
if const_expr(not isinstance(x, tuple)):
|
| 419 |
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
| 420 |
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
| 421 |
+
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
| 422 |
+
glu_out = sigmoid_x * y # FMUL
|
| 423 |
+
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
| 424 |
+
# = y * (1 - sigmoid(x)) * sigmoid_x_dout
|
| 425 |
+
# = (y - y * sigmoid(x)) * sigmoid_x_dout
|
| 426 |
+
# = (y - glu_out) * sigmoid_x_dout
|
| 427 |
+
dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
|
| 428 |
+
dy = sigmoid_x_dout
|
| 429 |
+
# Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
|
| 430 |
+
return dx, dy, glu_out
|
| 431 |
+
else:
|
| 432 |
+
sigmoid_x = sigmoid(x)
|
| 433 |
+
sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
|
| 434 |
+
glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
|
| 435 |
+
# dx = (y - glu_out) * sigmoid_x_dout
|
| 436 |
+
y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
|
| 437 |
+
dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
| 438 |
+
dy = sigmoid_x_dout
|
| 439 |
+
return dx, dy, glu_out
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
@dsl_user_op
|
| 443 |
+
def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 444 |
+
"""ReGLU: ReLU Gated Linear Unit
|
| 445 |
+
reglu(x, y) = relu(x) * y = max(x, 0) * y
|
| 446 |
+
"""
|
| 447 |
+
if const_expr(not isinstance(x, tuple)):
|
| 448 |
+
return cute.arch.fmax(x, Float32(0.0)) * y
|
| 449 |
+
else:
|
| 450 |
+
relu_x = relu(x)
|
| 451 |
+
return utils.mul_packed_f32x2(relu_x, y)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
@dsl_user_op
|
| 455 |
+
@cute.jit
|
| 456 |
+
def dreglu(
|
| 457 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 458 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 459 |
+
"""
|
| 460 |
+
ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 461 |
+
Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
|
| 462 |
+
Returns: (dx, dy, reglu_out) where:
|
| 463 |
+
- dx = dout * y if x > 0, else 0
|
| 464 |
+
- dy = dout * relu(x)
|
| 465 |
+
- reglu_out = relu(x) * y
|
| 466 |
+
"""
|
| 467 |
+
if const_expr(not isinstance(x, tuple)):
|
| 468 |
+
x_pos = Boolean(x > 0)
|
| 469 |
+
relu_x = cute.arch.fmax(x, Float32(0.0))
|
| 470 |
+
dx = (dout * y) if x_pos else Float32(0.0)
|
| 471 |
+
dy = dout * relu_x
|
| 472 |
+
reglu_out = relu_x * y
|
| 473 |
+
return dx, dy, reglu_out
|
| 474 |
+
else:
|
| 475 |
+
x0_pos = Boolean(x[0] > 0)
|
| 476 |
+
x1_pos = Boolean(x[1] > 0)
|
| 477 |
+
relu_x = relu(x)
|
| 478 |
+
dout_y = utils.mul_packed_f32x2(dout, y)
|
| 479 |
+
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
| 480 |
+
dy = utils.mul_packed_f32x2(dout, relu_x)
|
| 481 |
+
reglu_out = utils.mul_packed_f32x2(relu_x, y)
|
| 482 |
+
return dx, dy, reglu_out
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
@dsl_user_op
|
| 486 |
+
def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 487 |
+
"""GeGLU: GELU Gated Linear Unit
|
| 488 |
+
geglu(x, y) = gelu(x) * y
|
| 489 |
+
Uses the tanh approximation of GELU
|
| 490 |
+
"""
|
| 491 |
+
if const_expr(not isinstance(x, tuple)):
|
| 492 |
+
return gelu_tanh_approx(x) * y
|
| 493 |
+
else:
|
| 494 |
+
return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
@dsl_user_op
|
| 498 |
+
def dgeglu(
|
| 499 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 500 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 501 |
+
"""
|
| 502 |
+
GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 503 |
+
Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
|
| 504 |
+
Returns: (dx, dy, geglu_out) where:
|
| 505 |
+
- dx = dout * y * d_gelu(x)
|
| 506 |
+
- dy = dout * gelu(x)
|
| 507 |
+
- geglu_out = gelu(x) * y
|
| 508 |
+
"""
|
| 509 |
+
if const_expr(not isinstance(x, tuple)):
|
| 510 |
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
| 511 |
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
| 512 |
+
# Compute gradients for geglu
|
| 513 |
+
dx = dgelu_x_dout * y
|
| 514 |
+
dy = gelu_x * dout
|
| 515 |
+
geglu_out = gelu_x * y
|
| 516 |
+
return dx, dy, geglu_out
|
| 517 |
+
else:
|
| 518 |
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
| 519 |
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
| 520 |
+
# Compute gradients for geglu
|
| 521 |
+
dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
|
| 522 |
+
dy = utils.mul_packed_f32x2(gelu_x, dout)
|
| 523 |
+
geglu_out = utils.mul_packed_f32x2(gelu_x, y)
|
| 524 |
+
return dx, dy, geglu_out
|
build/torch-cuda/quack/autotuner.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py
|
| 2 |
+
# Copyright (C) 2025, Tri Dao.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import builtins
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import inspect
|
| 9 |
+
import base64
|
| 10 |
+
import hashlib
|
| 11 |
+
import json
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from functools import cached_property, partial
|
| 14 |
+
from typing import Dict, Tuple, List, Optional, Any
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
import triton
|
| 20 |
+
|
| 21 |
+
from . import __version__
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
PACKAGE_NAME = "quack"
|
| 25 |
+
VERSION = __version__
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_home_dir():
|
| 29 |
+
return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def default_cache_dir():
|
| 33 |
+
return os.path.join(get_home_dir(), f".{PACKAGE_NAME}", "cache")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class FileCacheManager(triton.runtime.cache.FileCacheManager):
|
| 37 |
+
def __init__(self, key):
|
| 38 |
+
super().__init__(key)
|
| 39 |
+
self.cache_dir = (
|
| 40 |
+
os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_DIR", "").strip() or default_cache_dir()
|
| 41 |
+
)
|
| 42 |
+
if self.cache_dir:
|
| 43 |
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
| 44 |
+
self.lock_path = os.path.join(self.cache_dir, "lock")
|
| 45 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 46 |
+
else:
|
| 47 |
+
raise RuntimeError("Could not create or locate cache dir")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _base32(key):
|
| 51 |
+
# Assume key is a hex string.
|
| 52 |
+
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Autotuner:
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
fn,
|
| 59 |
+
key,
|
| 60 |
+
configs,
|
| 61 |
+
restore_value=None,
|
| 62 |
+
prune_configs_by: Optional[Dict] = None,
|
| 63 |
+
do_bench=None,
|
| 64 |
+
cache_results=False,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
| 68 |
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
| 69 |
+
'top_k': number of configs to bench
|
| 70 |
+
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
| 71 |
+
"""
|
| 72 |
+
if not configs:
|
| 73 |
+
self.configs = [AutotuneConfig()]
|
| 74 |
+
else:
|
| 75 |
+
self.configs = configs
|
| 76 |
+
signature = inspect.signature(fn)
|
| 77 |
+
self.keys = key
|
| 78 |
+
self.cache: Dict[Tuple, AutotuneConfig] = {}
|
| 79 |
+
self.arg_names = list(signature.parameters.keys())
|
| 80 |
+
self.cache_results = (
|
| 81 |
+
cache_results or os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_AUTOTUNING", None) == "1"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.restore_value = []
|
| 85 |
+
if restore_value is not None:
|
| 86 |
+
self.restore_value = list(restore_value)
|
| 87 |
+
|
| 88 |
+
if len(self.restore_value) > 0:
|
| 89 |
+
|
| 90 |
+
def _pre_hook(kwargs):
|
| 91 |
+
self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value}
|
| 92 |
+
|
| 93 |
+
self.pre_hook = _pre_hook
|
| 94 |
+
else:
|
| 95 |
+
self.pre_hook = None
|
| 96 |
+
|
| 97 |
+
if len(self.restore_value) > 0:
|
| 98 |
+
|
| 99 |
+
def _post_hook(kwargs, exception):
|
| 100 |
+
for name in self.restore_value:
|
| 101 |
+
kwargs[name].copy_(self.restore_copies[name])
|
| 102 |
+
self.restore_copies = {}
|
| 103 |
+
|
| 104 |
+
self.post_hook = _post_hook
|
| 105 |
+
else:
|
| 106 |
+
self.post_hook = None
|
| 107 |
+
|
| 108 |
+
self.perf_model = None
|
| 109 |
+
self.configs_top_k = 1.0
|
| 110 |
+
self.early_config_prune = None
|
| 111 |
+
if prune_configs_by:
|
| 112 |
+
self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
|
| 113 |
+
self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
|
| 114 |
+
self.early_config_prune = prune_configs_by.get(
|
| 115 |
+
"early_config_prune", self.early_config_prune
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.fn = fn
|
| 119 |
+
self._do_bench = do_bench
|
| 120 |
+
|
| 121 |
+
@cached_property
|
| 122 |
+
def do_bench(self):
|
| 123 |
+
if self._do_bench is None:
|
| 124 |
+
return partial(triton.testing.do_bench, warmup=5, rep=25)
|
| 125 |
+
return self._do_bench
|
| 126 |
+
|
| 127 |
+
def _bench(self, *args, config, **meta):
|
| 128 |
+
verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
| 129 |
+
if verbose:
|
| 130 |
+
print(f"Autotuning kernel {self.fn.__name__} with config {config}")
|
| 131 |
+
|
| 132 |
+
# check for conflicts, i.e. meta-parameters both provided
|
| 133 |
+
# as kwargs and by the autotuner
|
| 134 |
+
conflicts = meta.keys() & config.kwargs.keys()
|
| 135 |
+
if conflicts:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
| 138 |
+
" Make sure that you don't re-define auto-tuned symbols."
|
| 139 |
+
)
|
| 140 |
+
# augment meta-parameters with tunable ones
|
| 141 |
+
current = dict(meta, **config.all_kwargs())
|
| 142 |
+
full_nargs = {**self.nargs, **current}
|
| 143 |
+
|
| 144 |
+
def kernel_call():
|
| 145 |
+
if self.pre_hook is not None:
|
| 146 |
+
self.pre_hook(full_nargs)
|
| 147 |
+
try:
|
| 148 |
+
self.fn.__call__(
|
| 149 |
+
*args,
|
| 150 |
+
**current,
|
| 151 |
+
)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
try:
|
| 154 |
+
if self.post_hook is not None:
|
| 155 |
+
self.post_hook(full_nargs, exception=e)
|
| 156 |
+
finally:
|
| 157 |
+
# Throw exception raised by `self.fn.run`
|
| 158 |
+
raise
|
| 159 |
+
|
| 160 |
+
if self.post_hook is not None:
|
| 161 |
+
self.post_hook(full_nargs, exception=None)
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
|
| 165 |
+
except Exception as e:
|
| 166 |
+
if verbose:
|
| 167 |
+
print(f"Autotuning failed with {e}")
|
| 168 |
+
return [float("inf"), float("inf"), float("inf")]
|
| 169 |
+
|
| 170 |
+
@torch.compiler.disable
|
| 171 |
+
def check_disk_cache(self, tuning_key, configs, bench_fn):
|
| 172 |
+
if not tuning_key:
|
| 173 |
+
bench_fn()
|
| 174 |
+
return
|
| 175 |
+
|
| 176 |
+
fn = self.fn
|
| 177 |
+
config_str_list = [str(c) for c in configs]
|
| 178 |
+
assert len(config_str_list) == len(set(config_str_list)), "Config strings must be unique"
|
| 179 |
+
cache_key = [VERSION, str(tuning_key)] + config_str_list
|
| 180 |
+
cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
|
| 181 |
+
cache = FileCacheManager(_base32(cache_key))
|
| 182 |
+
file_name = f"{fn.__name__[:150]}.autotune.json"
|
| 183 |
+
path = cache.get_file(file_name)
|
| 184 |
+
# There's an environment variable to force cache update
|
| 185 |
+
if path and not os.environ.get(f"{PACKAGE_NAME.upper()}_FORCE_CACHE_UPDATE", False):
|
| 186 |
+
str2config = {s: c for s, c in zip(config_str_list, configs)}
|
| 187 |
+
with open(path, "r") as cached_configs:
|
| 188 |
+
timings = json.load(cached_configs)["configs_timings"]
|
| 189 |
+
timings = {str2config[config]: timing for config, timing in timings}
|
| 190 |
+
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
|
| 191 |
+
self.configs_timings = timings
|
| 192 |
+
self.bench_time = 0
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
bench_fn()
|
| 196 |
+
cache.put(
|
| 197 |
+
json.dumps(
|
| 198 |
+
{
|
| 199 |
+
"key": tuning_key,
|
| 200 |
+
"configs_timings": [
|
| 201 |
+
(str(config), timings) for config, timings in self.configs_timings.items()
|
| 202 |
+
],
|
| 203 |
+
}
|
| 204 |
+
),
|
| 205 |
+
file_name,
|
| 206 |
+
binary=False,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def __call__(self, *args, **kwargs):
|
| 210 |
+
self.nargs = dict(zip(self.arg_names, args))
|
| 211 |
+
used_cached_result = True
|
| 212 |
+
if len(self.configs) > 1:
|
| 213 |
+
all_args = {**self.nargs, **kwargs}
|
| 214 |
+
_args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
|
| 215 |
+
# Need "str" to make it json-serializable
|
| 216 |
+
key = [str(_args[key]) for key in self.keys if key in _args]
|
| 217 |
+
for _, arg in _args.items():
|
| 218 |
+
if isinstance(arg, Tensor):
|
| 219 |
+
key.append(str(arg.shape))
|
| 220 |
+
# If stride != 0, 1, we just cache it as 2
|
| 221 |
+
key.append(str([s if s in {0, 1} else 2 for s in arg.stride()]))
|
| 222 |
+
key.append(str(arg.dtype))
|
| 223 |
+
key = tuple(key)
|
| 224 |
+
if key not in self.cache:
|
| 225 |
+
used_cached_result = False
|
| 226 |
+
pruned_configs = self.prune_configs(kwargs)
|
| 227 |
+
|
| 228 |
+
@torch.compiler.disable # Don't want any tracing here
|
| 229 |
+
def benchmark():
|
| 230 |
+
bench_start = time.time()
|
| 231 |
+
timings = {
|
| 232 |
+
config: self._bench(*args, config=config, **kwargs)
|
| 233 |
+
for config in pruned_configs
|
| 234 |
+
}
|
| 235 |
+
bench_end = time.time()
|
| 236 |
+
if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
|
| 237 |
+
for config, time_ in timings.items():
|
| 238 |
+
print(f"[{config}] -> {time_[0]:.3f}ms")
|
| 239 |
+
self.bench_time = bench_end - bench_start
|
| 240 |
+
self.cache[key] = builtins.min(timings, key=timings.get)
|
| 241 |
+
self.configs_timings = timings
|
| 242 |
+
|
| 243 |
+
if self.cache_results:
|
| 244 |
+
self.check_disk_cache(key, pruned_configs, benchmark)
|
| 245 |
+
else:
|
| 246 |
+
benchmark()
|
| 247 |
+
|
| 248 |
+
config = self.cache[key]
|
| 249 |
+
else:
|
| 250 |
+
config = self.configs[0]
|
| 251 |
+
self.best_config = config
|
| 252 |
+
if (
|
| 253 |
+
os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
| 254 |
+
and not used_cached_result
|
| 255 |
+
):
|
| 256 |
+
print(
|
| 257 |
+
f"{PACKAGE_NAME} autotuning for function {self.fn.__name__} finished after "
|
| 258 |
+
f"{self.bench_time:.2f}s; best config selected: {self.best_config};"
|
| 259 |
+
)
|
| 260 |
+
ret = self.fn.__call__(
|
| 261 |
+
*args,
|
| 262 |
+
**kwargs,
|
| 263 |
+
**config.all_kwargs(),
|
| 264 |
+
)
|
| 265 |
+
self.nargs = None
|
| 266 |
+
return ret
|
| 267 |
+
|
| 268 |
+
def prune_configs(self, kwargs: Dict) -> List[Any]:
|
| 269 |
+
pruned_configs = self.configs
|
| 270 |
+
if self.early_config_prune:
|
| 271 |
+
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
|
| 272 |
+
if self.perf_model:
|
| 273 |
+
top_k = self.configs_top_k
|
| 274 |
+
if isinstance(top_k, float) and top_k <= 1.0:
|
| 275 |
+
top_k = int(len(self.configs) * top_k)
|
| 276 |
+
elif not isinstance(top_k, int):
|
| 277 |
+
# Slice index must be an integer
|
| 278 |
+
raise TypeError(
|
| 279 |
+
"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if len(pruned_configs) > top_k:
|
| 283 |
+
est_timing = {
|
| 284 |
+
config: self.perf_model(
|
| 285 |
+
**self.nargs,
|
| 286 |
+
**kwargs,
|
| 287 |
+
**config.all_kwargs(),
|
| 288 |
+
)
|
| 289 |
+
for config in pruned_configs
|
| 290 |
+
}
|
| 291 |
+
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
| 292 |
+
return pruned_configs
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class AutotuneConfig:
|
| 296 |
+
"""
|
| 297 |
+
An object that represents a possible kernel configuration for the auto-tuner to try.
|
| 298 |
+
|
| 299 |
+
:ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
| 300 |
+
:type kwargs: dict[Str, Any]
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def __init__(self, **kwargs):
|
| 304 |
+
self.kwargs = kwargs
|
| 305 |
+
|
| 306 |
+
def __setstate__(self, state):
|
| 307 |
+
self.kwargs = state.get("kwargs", {})
|
| 308 |
+
|
| 309 |
+
def all_kwargs(self):
|
| 310 |
+
return self.kwargs
|
| 311 |
+
|
| 312 |
+
def __str__(self):
|
| 313 |
+
res = []
|
| 314 |
+
for k, v in self.kwargs.items():
|
| 315 |
+
res.append(f"{k}: {v}")
|
| 316 |
+
return ", ".join(res)
|
| 317 |
+
|
| 318 |
+
def __hash__(self):
|
| 319 |
+
return hash(tuple(*self.all_kwargs().items()))
|
| 320 |
+
|
| 321 |
+
def __eq__(self, other):
|
| 322 |
+
self_tuple = tuple(*self.all_kwargs().items())
|
| 323 |
+
other_tuple = tuple(*other.all_kwargs().items())
|
| 324 |
+
return self_tuple == other_tuple
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def autotune(
|
| 328 |
+
configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
|
| 329 |
+
):
|
| 330 |
+
f"""
|
| 331 |
+
Decorator for auto-tuning a function function.
|
| 332 |
+
|
| 333 |
+
.. highlight:: python
|
| 334 |
+
|
| 335 |
+
If the environment variable :code:`{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING` is set to
|
| 336 |
+
:code:`"1"`, we will print a message to stdout after autotuning each
|
| 337 |
+
kernel, including the time spent autotuning and the best configuration.
|
| 338 |
+
|
| 339 |
+
:param configs: a list of :code:`AutotuneConfig` objects
|
| 340 |
+
:type configs: list[AutotuneConfig]
|
| 341 |
+
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
| 342 |
+
:type key: list[str]
|
| 343 |
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
| 344 |
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
| 345 |
+
'top_k': number of configs to bench
|
| 346 |
+
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
| 347 |
+
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
| 348 |
+
:type restore_value: list[str]
|
| 349 |
+
:param do_bench: a benchmark function to measure the time of each run.
|
| 350 |
+
:type do_bench: lambda fn, quantiles
|
| 351 |
+
:param cache_results: whether to cache autotune timings to disk. Defaults to False.
|
| 352 |
+
"type cache_results: bool
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
if key is None:
|
| 356 |
+
key = []
|
| 357 |
+
|
| 358 |
+
def decorator(fn):
|
| 359 |
+
return Autotuner(
|
| 360 |
+
fn,
|
| 361 |
+
key,
|
| 362 |
+
configs,
|
| 363 |
+
restore_value=restore_value,
|
| 364 |
+
prune_configs_by=prune_configs_by,
|
| 365 |
+
do_bench=do_bench,
|
| 366 |
+
cache_results=cache_results,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return decorator
|
build/torch-cuda/quack/broadcast_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
from typing import Callable
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
from cutlass import Float32, const_expr
|
| 7 |
+
|
| 8 |
+
from .layout_utils import make_acc_tensor_mn_view
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@cute.jit
|
| 12 |
+
def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
|
| 13 |
+
if const_expr(tCrC.element_type != Float32): # Convert to f32
|
| 14 |
+
tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
|
| 15 |
+
tCrC_f32.store(tCrC.load().to(Float32))
|
| 16 |
+
else:
|
| 17 |
+
tCrC_f32 = tCrC
|
| 18 |
+
# this happens to work for frgA layout too, not just acc layout
|
| 19 |
+
tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
|
| 20 |
+
if const_expr(is_colvec):
|
| 21 |
+
assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
|
| 22 |
+
for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
|
| 23 |
+
tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
|
| 24 |
+
else:
|
| 25 |
+
assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
|
| 26 |
+
for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
|
| 27 |
+
tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
|
| 28 |
+
if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
|
| 29 |
+
tCrC.store(tCrC_f32.load().to(tCrC.element_type))
|
build/torch-cuda/quack/compile_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
|
| 9 |
+
if leading_dim < 0:
|
| 10 |
+
leading_dim = len(shape) + leading_dim
|
| 11 |
+
if dtype is None:
|
| 12 |
+
return None
|
| 13 |
+
stride = tuple(
|
| 14 |
+
cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
|
| 15 |
+
for i in range(len(shape))
|
| 16 |
+
)
|
| 17 |
+
return cute.runtime.make_fake_tensor(
|
| 18 |
+
dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
|
| 19 |
+
)
|
build/torch-cuda/quack/copy_utils.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Optional, Type, Tuple, Callable
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
|
| 9 |
+
from cutlass import Int32, Boolean, const_expr
|
| 10 |
+
from cutlass.cute.nvgpu import cpasync, warpgroup
|
| 11 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 12 |
+
import cutlass.pipeline
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dsl_user_op
|
| 16 |
+
def cvt_copy(
|
| 17 |
+
tiled_copy: cute.TiledCopy,
|
| 18 |
+
src: cute.Tensor,
|
| 19 |
+
dst: cute.Tensor,
|
| 20 |
+
*,
|
| 21 |
+
pred: Optional[cute.Tensor] = None,
|
| 22 |
+
retile: bool = False,
|
| 23 |
+
loc=None,
|
| 24 |
+
ip=None,
|
| 25 |
+
**kwargs,
|
| 26 |
+
) -> None:
|
| 27 |
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 28 |
+
if const_expr(src.element_type != dst.element_type):
|
| 29 |
+
src_cvt = cute.make_fragment_like(src, dst.element_type)
|
| 30 |
+
src_cvt.store(src.load().to(dst.element_type))
|
| 31 |
+
src = src_cvt
|
| 32 |
+
if const_expr(retile):
|
| 33 |
+
src = tiled_copy.retile(src)
|
| 34 |
+
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dsl_user_op
|
| 38 |
+
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 39 |
+
dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
|
| 40 |
+
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
| 41 |
+
return dst
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dsl_user_op
|
| 45 |
+
def load_s2r_retile(
|
| 46 |
+
tiled_copy: cute.TiledCopy,
|
| 47 |
+
src: cute.Tensor,
|
| 48 |
+
dst_shape: cute.Tensor | cute.Shape,
|
| 49 |
+
*,
|
| 50 |
+
loc=None,
|
| 51 |
+
ip=None,
|
| 52 |
+
) -> cute.Tensor:
|
| 53 |
+
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
| 54 |
+
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
| 55 |
+
dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
|
| 56 |
+
else:
|
| 57 |
+
dst = dst_shape
|
| 58 |
+
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
| 59 |
+
return dst
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dsl_user_op
|
| 63 |
+
def get_copy_atom(
|
| 64 |
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
| 65 |
+
) -> cute.CopyAtom:
|
| 66 |
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
| 67 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 68 |
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dsl_user_op
|
| 72 |
+
def copy(
|
| 73 |
+
src: cute.Tensor,
|
| 74 |
+
dst: cute.Tensor,
|
| 75 |
+
*,
|
| 76 |
+
pred: Optional[cute.Tensor] = None,
|
| 77 |
+
is_async: bool = False,
|
| 78 |
+
loc=None,
|
| 79 |
+
ip=None,
|
| 80 |
+
**kwargs,
|
| 81 |
+
) -> None:
|
| 82 |
+
num_copy_elems = src.shape[0][0]
|
| 83 |
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
| 84 |
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def tiled_copy_1d(
|
| 88 |
+
dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
|
| 89 |
+
) -> cute.TiledCopy:
|
| 90 |
+
num_copy_bits = num_copy_elems * dtype.width
|
| 91 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 92 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 93 |
+
thr_layout = cute.make_layout(num_threads)
|
| 94 |
+
val_layout = cute.make_layout(num_copy_elems)
|
| 95 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def tiled_copy_2d(
|
| 99 |
+
dtype: Type[cutlass.Numeric],
|
| 100 |
+
threads_per_row: int,
|
| 101 |
+
num_threads: int,
|
| 102 |
+
num_copy_elems: int = 1,
|
| 103 |
+
is_async: bool = False,
|
| 104 |
+
) -> cute.TiledCopy:
|
| 105 |
+
num_copy_bits = num_copy_elems * dtype.width
|
| 106 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 107 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 108 |
+
assert num_threads % threads_per_row == 0
|
| 109 |
+
thr_layout = cute.make_ordered_layout(
|
| 110 |
+
(num_threads // threads_per_row, threads_per_row),
|
| 111 |
+
order=(1, 0),
|
| 112 |
+
)
|
| 113 |
+
val_layout = cute.make_layout((1, num_copy_elems))
|
| 114 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@cute.jit
|
| 118 |
+
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
| 119 |
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
| 120 |
+
tApA = cute.make_fragment(
|
| 121 |
+
cute.make_layout(
|
| 122 |
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
| 123 |
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
| 124 |
+
),
|
| 125 |
+
Boolean,
|
| 126 |
+
)
|
| 127 |
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
| 128 |
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
| 129 |
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
| 130 |
+
return tApA
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# def tiled_copy_2d(
|
| 134 |
+
# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
|
| 135 |
+
# ) -> cute.TiledCopy:
|
| 136 |
+
# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
| 137 |
+
# copy_elems = num_copy_bits // dtype.width
|
| 138 |
+
# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 139 |
+
# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 140 |
+
# gmem_threads_per_row = major_mode_size // copy_elems
|
| 141 |
+
# assert num_threads % gmem_threads_per_row == 0
|
| 142 |
+
# thr_layout = cute.make_ordered_layout(
|
| 143 |
+
# (num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 144 |
+
# order=(1, 0),
|
| 145 |
+
# )
|
| 146 |
+
# val_layout = cute.make_layout((1, copy_elems))
|
| 147 |
+
# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
|
| 151 |
+
"""Extract swizzle parameters from a pointer's swizzle_type.
|
| 152 |
+
|
| 153 |
+
The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
|
| 154 |
+
b, m, s are the swizzle parameters (bits, base, shift).
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
A cute.Swizzle object constructed from the extracted parameters
|
| 158 |
+
|
| 159 |
+
Raises:
|
| 160 |
+
ValueError: If the swizzle_type string cannot be parsed
|
| 161 |
+
"""
|
| 162 |
+
# Ideally there should be a better API to get swizzle parameters, but we'll just parse
|
| 163 |
+
# the string here.
|
| 164 |
+
swizzle_str = str(ptr.type.swizzle_type)
|
| 165 |
+
# Extract the inner part "S<b,m,s>"
|
| 166 |
+
match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
|
| 167 |
+
if match:
|
| 168 |
+
b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
| 169 |
+
return b, m, s
|
| 170 |
+
else:
|
| 171 |
+
raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
|
| 175 |
+
bit_msk = (1 << b) - 1
|
| 176 |
+
yyy_msk = bit_msk << (m + s)
|
| 177 |
+
return ptr_int ^ ((ptr_int & yyy_msk) >> s)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def swizzle_ptr(ptr: cute.Pointer):
|
| 181 |
+
b, m, s = parse_swizzle_from_pointer(ptr)
|
| 182 |
+
ptr_int = swizzle_int(ptr.toint(), b, m, s)
|
| 183 |
+
return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
|
| 187 |
+
outer = tensor.layout
|
| 188 |
+
width = tensor.element_type.width
|
| 189 |
+
inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
|
| 190 |
+
# Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
|
| 191 |
+
# for 16 bits and <3, 2, 3> for 32 bits)
|
| 192 |
+
new_layout = cute.recast_layout(
|
| 193 |
+
width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
|
| 194 |
+
)
|
| 195 |
+
# recast_ptr to remove the pointer swizzle
|
| 196 |
+
return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def partition_D_position_independent(
|
| 200 |
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
| 201 |
+
) -> cute.Tensor:
|
| 202 |
+
return cute.make_tensor(
|
| 203 |
+
swizzle_ptr(thr_copy.partition_D(tensor).iterator),
|
| 204 |
+
thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def partition_S_position_independent(
|
| 209 |
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
| 210 |
+
) -> cute.Tensor:
|
| 211 |
+
return cute.make_tensor(
|
| 212 |
+
swizzle_ptr(thr_copy.partition_S(tensor).iterator),
|
| 213 |
+
thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@dsl_user_op
|
| 218 |
+
def sm90_get_smem_load_op(
|
| 219 |
+
layout_c: cutlass.utils.LayoutEnum,
|
| 220 |
+
elem_ty_c: Type[cutlass.Numeric],
|
| 221 |
+
*,
|
| 222 |
+
loc=None,
|
| 223 |
+
ip=None,
|
| 224 |
+
) -> cute.CopyAtom:
|
| 225 |
+
"""
|
| 226 |
+
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
| 227 |
+
|
| 228 |
+
Parameters:
|
| 229 |
+
-----------
|
| 230 |
+
layout_c : LayoutEnum
|
| 231 |
+
The layout enum of the output tensor D.
|
| 232 |
+
|
| 233 |
+
elem_ty_c : Type[Numeric]
|
| 234 |
+
The element type for output tensor D.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
--------
|
| 238 |
+
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
| 242 |
+
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
| 243 |
+
is_m_major = layout_c.is_m_major_c()
|
| 244 |
+
if elem_ty_c.width == 16:
|
| 245 |
+
return cute.make_copy_atom(
|
| 246 |
+
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_smem_store_atom(
|
| 253 |
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
| 254 |
+
) -> cute.CopyAtom:
|
| 255 |
+
if const_expr(arch < 90 or element_type.width != 16):
|
| 256 |
+
return cute.make_copy_atom(
|
| 257 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 258 |
+
element_type,
|
| 259 |
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
return cute.make_copy_atom(
|
| 263 |
+
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
| 264 |
+
element_type,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def get_smem_load_atom(
|
| 269 |
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
| 270 |
+
) -> cute.CopyAtom:
|
| 271 |
+
if const_expr(arch < 90 or element_type.width != 16):
|
| 272 |
+
return cute.make_copy_atom(
|
| 273 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 274 |
+
element_type,
|
| 275 |
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 276 |
+
)
|
| 277 |
+
else:
|
| 278 |
+
return cute.make_copy_atom(
|
| 279 |
+
cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
| 280 |
+
element_type,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def get_smem_store_C(
|
| 285 |
+
tiled_mma: cute.TiledMma,
|
| 286 |
+
sC: cute.Tensor,
|
| 287 |
+
tidx: Int32,
|
| 288 |
+
arch: int,
|
| 289 |
+
transpose: bool = False,
|
| 290 |
+
position_independent=False,
|
| 291 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 292 |
+
dtype = sC.element_type
|
| 293 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 294 |
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
| 295 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 296 |
+
if const_expr(not position_independent):
|
| 297 |
+
tRS_sC = thr_copy.partition_D(sC)
|
| 298 |
+
else:
|
| 299 |
+
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
| 300 |
+
|
| 301 |
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
| 302 |
+
cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs)
|
| 303 |
+
|
| 304 |
+
return copy_fn, thr_copy, tRS_sC
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def get_smem_load_C(
|
| 308 |
+
tiled_mma: cute.TiledMma,
|
| 309 |
+
sC: cute.Tensor,
|
| 310 |
+
tidx: Int32,
|
| 311 |
+
arch: int,
|
| 312 |
+
transpose: bool = False,
|
| 313 |
+
position_independent=False,
|
| 314 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 315 |
+
dtype = sC.element_type
|
| 316 |
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
| 317 |
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
| 318 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 319 |
+
if const_expr(not position_independent):
|
| 320 |
+
tSR_sC = thr_copy.partition_S(sC)
|
| 321 |
+
else:
|
| 322 |
+
tSR_sC = partition_S_position_independent(thr_copy, sC)
|
| 323 |
+
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
| 324 |
+
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
| 325 |
+
tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
|
| 326 |
+
|
| 327 |
+
def copy_fn(src_idx: Int32, **new_kwargs):
|
| 328 |
+
return load_s2r_retile(
|
| 329 |
+
tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
return copy_fn, thr_copy, tSR_sC
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def get_smem_store_A(
|
| 336 |
+
tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
|
| 337 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 338 |
+
dtype = sA.element_type
|
| 339 |
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
| 340 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 341 |
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 342 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 343 |
+
if const_expr(not position_independent):
|
| 344 |
+
tRS_sA = thr_copy.partition_D(sA)
|
| 345 |
+
else:
|
| 346 |
+
tRS_sA = partition_D_position_independent(thr_copy, sA)
|
| 347 |
+
|
| 348 |
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
| 349 |
+
cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
|
| 350 |
+
|
| 351 |
+
return copy_fn, thr_copy, tRS_sA
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def get_smem_load_A(
|
| 355 |
+
tiled_mma: cute.TiledMma,
|
| 356 |
+
sA: cute.Tensor,
|
| 357 |
+
tidx: Int32,
|
| 358 |
+
arch: int,
|
| 359 |
+
with_dst_tensor: bool = False,
|
| 360 |
+
position_independent=False,
|
| 361 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 362 |
+
dtype = sA.element_type
|
| 363 |
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
| 364 |
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
| 365 |
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 366 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 367 |
+
if const_expr(not position_independent):
|
| 368 |
+
tSR_sA = thr_copy.partition_S(sA)
|
| 369 |
+
else:
|
| 370 |
+
tSR_sA = partition_S_position_independent(thr_copy, sA)
|
| 371 |
+
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
| 372 |
+
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
| 373 |
+
tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
|
| 374 |
+
|
| 375 |
+
def copy_fn(src_idx: Int32, **new_kwargs):
|
| 376 |
+
return load_s2r_retile(
|
| 377 |
+
tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
|
| 381 |
+
return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
|
| 382 |
+
|
| 383 |
+
return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def tma_get_copy_fn(
|
| 387 |
+
atom: cute.CopyAtom,
|
| 388 |
+
cta_coord: cute.Coord,
|
| 389 |
+
cta_layout: cute.Layout,
|
| 390 |
+
src_tensor: cute.Tensor,
|
| 391 |
+
dst_tensor: cute.Tensor,
|
| 392 |
+
filter_zeros: bool = False,
|
| 393 |
+
single_stage: bool = False,
|
| 394 |
+
**kwargs,
|
| 395 |
+
) -> Callable:
|
| 396 |
+
src_is_smem = const_expr(
|
| 397 |
+
isinstance(src_tensor.iterator, cute.Pointer)
|
| 398 |
+
and src_tensor.memspace == cute.AddressSpace.smem
|
| 399 |
+
)
|
| 400 |
+
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
| 401 |
+
group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
|
| 402 |
+
group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
|
| 403 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 404 |
+
s, g = cpasync.tma_partition(
|
| 405 |
+
atom,
|
| 406 |
+
cta_coord,
|
| 407 |
+
cta_layout,
|
| 408 |
+
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
| 409 |
+
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
| 410 |
+
)
|
| 411 |
+
if const_expr(filter_zeros):
|
| 412 |
+
s = cute.filter_zeros(s)
|
| 413 |
+
g = cute.filter_zeros(g)
|
| 414 |
+
src, dst = (s, g) if src_is_smem else (g, s)
|
| 415 |
+
|
| 416 |
+
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
| 417 |
+
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
| 418 |
+
|
| 419 |
+
def copy_tma_single_stage(**new_kwargs):
|
| 420 |
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs)
|
| 421 |
+
|
| 422 |
+
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
| 426 |
+
def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
|
| 427 |
+
copy(
|
| 428 |
+
src_idx=src_idx,
|
| 429 |
+
dst_idx=producer_state.index,
|
| 430 |
+
tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
|
| 431 |
+
**new_kwargs,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
return copy_fn
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@cute.jit
|
| 438 |
+
def gather_m_get_copy_fn(
|
| 439 |
+
thr_copy_A: cute.ThrCopy,
|
| 440 |
+
mA: cute.Tensor, # (whatever, K)
|
| 441 |
+
sA: cute.Tensor, # (tile_M, tile_N, STAGE)
|
| 442 |
+
gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
|
| 443 |
+
limit_m: Int32,
|
| 444 |
+
limit_k: Int32,
|
| 445 |
+
) -> Callable:
|
| 446 |
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
| 447 |
+
tAsA = thr_copy_A.partition_D(sA)
|
| 448 |
+
# k-major
|
| 449 |
+
assert tAsA.shape[2] == 1
|
| 450 |
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
| 451 |
+
|
| 452 |
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
| 453 |
+
if const_expr(not is_even_m_smem):
|
| 454 |
+
limit_m = min(limit_m, tile_shape_mk[0])
|
| 455 |
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 456 |
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
| 457 |
+
tAcA = thr_copy_A.partition_S(cA)
|
| 458 |
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 459 |
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
| 460 |
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
| 461 |
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
| 462 |
+
limit_m = limit_m - tAcA[0][0]
|
| 463 |
+
limit_k = limit_k - tAcA[0][1]
|
| 464 |
+
# Read and cache indices for A
|
| 465 |
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 466 |
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 467 |
+
tApA_m = cute.make_fragment(rows_per_thread, Boolean)
|
| 468 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 469 |
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 470 |
+
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
| 471 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 472 |
+
row_idx = tAcA[0, m, 0][0]
|
| 473 |
+
if tApA_m[m]:
|
| 474 |
+
m_idx[m] = gsAIdx[row_idx]
|
| 475 |
+
else:
|
| 476 |
+
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
| 477 |
+
|
| 478 |
+
mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
|
| 479 |
+
|
| 480 |
+
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
| 481 |
+
tApA_k = None
|
| 482 |
+
if const_expr(pred):
|
| 483 |
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
| 484 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 485 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 486 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 487 |
+
mA_cur = mA_k[None, (None, src_idx)]
|
| 488 |
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
| 489 |
+
# cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
|
| 490 |
+
# ((elems_per_load), thread_per_row)
|
| 491 |
+
# But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
|
| 492 |
+
# So we append 1s to the last dimension and then do tiled_divide, then slice.
|
| 493 |
+
mA_row = cute.tiled_divide(
|
| 494 |
+
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
| 495 |
+
)[None, None, 0]
|
| 496 |
+
if const_expr(is_even_m_smem) or tApA_m[m]:
|
| 497 |
+
# There's only 1 load per row
|
| 498 |
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
| 499 |
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
| 500 |
+
cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
|
| 501 |
+
|
| 502 |
+
return copy_fn
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
@cute.jit
|
| 506 |
+
def gather_k_get_copy_fn(
|
| 507 |
+
thr_copy_A: cute.ThrCopy,
|
| 508 |
+
mA: cute.Tensor, # (tile_M, whatever)
|
| 509 |
+
sA: cute.Tensor, # (tile_M, tile_N, STAGE)
|
| 510 |
+
gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
|
| 511 |
+
limit_m: Int32,
|
| 512 |
+
limit_k: Int32,
|
| 513 |
+
) -> Callable:
|
| 514 |
+
gAIdx, sAIdx = None, None
|
| 515 |
+
if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
|
| 516 |
+
gAIdx = gsAIdx
|
| 517 |
+
else:
|
| 518 |
+
assert gsAIdx.memspace == cute.AddressSpace.smem
|
| 519 |
+
sAIdx = gsAIdx
|
| 520 |
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
| 521 |
+
# (atom_v, CPY_M, 1, STAGE)
|
| 522 |
+
tAsA = thr_copy_A.partition_D(sA)
|
| 523 |
+
# m-major
|
| 524 |
+
tAsA = cute.group_modes(tAsA, 0, 3)
|
| 525 |
+
|
| 526 |
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
| 527 |
+
if const_expr(not is_even_m_smem):
|
| 528 |
+
limit_m = min(limit_m, tile_shape_mk[0])
|
| 529 |
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 530 |
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
| 531 |
+
tAcA = thr_copy_A.partition_S(cA)
|
| 532 |
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 533 |
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
| 534 |
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
| 535 |
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
| 536 |
+
limit_m = limit_m - tAcA[0][0]
|
| 537 |
+
limit_k = limit_k - tAcA[0][1]
|
| 538 |
+
# Read and cache indices for A
|
| 539 |
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 540 |
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 541 |
+
tApA_m = cute.make_fragment(rows_per_thread, Boolean)
|
| 542 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 543 |
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 544 |
+
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
| 545 |
+
# This is very convoluted but idk a better way
|
| 546 |
+
# for tile_M=128, flat_divide gives (8, 16, K),
|
| 547 |
+
# then logical_divide gives ((8, 1), (8, 2), K).
|
| 548 |
+
tidx = thr_copy_A.thr_idx
|
| 549 |
+
tAmA = cute.logical_divide(
|
| 550 |
+
cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
|
| 551 |
+
)[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
|
| 552 |
+
|
| 553 |
+
def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 554 |
+
# Prefetch mAIdx early, even before smem is free
|
| 555 |
+
tApA_k = None
|
| 556 |
+
if const_expr(pred):
|
| 557 |
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
| 558 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 559 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 560 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 561 |
+
gAIdx_cur = gAIdx[None, src_idx]
|
| 562 |
+
k_idx = cute.make_fragment(cols_per_thread, Int32)
|
| 563 |
+
for k in cutlass.range(cols_per_thread):
|
| 564 |
+
col_idx = tAcA[0, 0, k][1]
|
| 565 |
+
if const_expr(not pred):
|
| 566 |
+
k_idx[k] = gAIdx_cur[col_idx]
|
| 567 |
+
else:
|
| 568 |
+
if tApA_k[k]:
|
| 569 |
+
k_idx[k] = gAIdx_cur[col_idx]
|
| 570 |
+
else:
|
| 571 |
+
k_idx[k] = -1
|
| 572 |
+
return k_idx, tApA_k
|
| 573 |
+
|
| 574 |
+
def prefetch_from_smem_fn(
|
| 575 |
+
a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
|
| 576 |
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 577 |
+
tApA_k = None
|
| 578 |
+
if const_expr(pred):
|
| 579 |
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
| 580 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 581 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 582 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 583 |
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
| 584 |
+
sAIdx_cur = sAIdx[None, dst_idx]
|
| 585 |
+
k_idx = cute.make_fragment(cols_per_thread, Int32)
|
| 586 |
+
for k in cutlass.range(cols_per_thread):
|
| 587 |
+
col_idx = tAcA[0, 0, k][1]
|
| 588 |
+
k_idx[k] = sAIdx_cur[col_idx]
|
| 589 |
+
cute.arch.sync_warp()
|
| 590 |
+
with cute.arch.elect_one():
|
| 591 |
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
| 592 |
+
return k_idx, tApA_k
|
| 593 |
+
|
| 594 |
+
def copy_fn(
|
| 595 |
+
src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
|
| 596 |
+
):
|
| 597 |
+
k_idx, tApA_k = k_idx_tApA_k
|
| 598 |
+
tApA_k_pred = None
|
| 599 |
+
if const_expr(pred):
|
| 600 |
+
tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
|
| 601 |
+
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
| 602 |
+
# copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
|
| 603 |
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
| 604 |
+
if tApA_m[m]:
|
| 605 |
+
cute.copy(
|
| 606 |
+
thr_copy_A,
|
| 607 |
+
tAmA[None, m, k_idx[k]],
|
| 608 |
+
tAsA[(None, m, k), dst_idx],
|
| 609 |
+
pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
return copy_fn, prefetch_from_gmem_fn if const_expr(
|
| 613 |
+
gAIdx is not None
|
| 614 |
+
) else prefetch_from_smem_fn
|
build/torch-cuda/quack/cute_dsl_ptxas.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
System ptxas replacement for CUTLASS DSL.
|
| 3 |
+
Environment variables:
|
| 4 |
+
CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
|
| 5 |
+
CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import re
|
| 11 |
+
import ctypes
|
| 12 |
+
import subprocess
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import cutlass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
|
| 19 |
+
VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
|
| 20 |
+
|
| 21 |
+
_original_load_cuda_library = None
|
| 22 |
+
_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _log(msg):
|
| 26 |
+
if VERBOSE:
|
| 27 |
+
print(f"[ptxas] {msg}", file=sys.stderr)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _get_ptx(compiled_func) -> tuple[str, Path] | None:
|
| 31 |
+
"""Find and read PTX file, stripping null bytes."""
|
| 32 |
+
func_name = getattr(compiled_func, "function_name", None)
|
| 33 |
+
if not func_name:
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
|
| 37 |
+
for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
|
| 38 |
+
content = ptx_path.read_text().rstrip("\x00")
|
| 39 |
+
if ".entry " in content and content.rstrip().endswith("}"):
|
| 40 |
+
_log(f"Found PTX: {ptx_path}")
|
| 41 |
+
return content, ptx_path
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:
|
| 46 |
+
"""Compile PTX to cubin using system ptxas."""
|
| 47 |
+
# Extract arch from PTX
|
| 48 |
+
match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content)
|
| 49 |
+
arch = match.group(1) if match else "sm_90a"
|
| 50 |
+
|
| 51 |
+
# Write stripped content back if needed
|
| 52 |
+
if ptx_path.read_text() != ptx_content:
|
| 53 |
+
ptx_path.write_text(ptx_content)
|
| 54 |
+
|
| 55 |
+
# Compile
|
| 56 |
+
cubin_tmp = ptx_path.with_suffix(".cubin.tmp")
|
| 57 |
+
try:
|
| 58 |
+
assert CUTE_DSL_PTXAS_PATH is not None
|
| 59 |
+
result = subprocess.run(
|
| 60 |
+
[CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)],
|
| 61 |
+
capture_output=True,
|
| 62 |
+
text=True,
|
| 63 |
+
)
|
| 64 |
+
if result.returncode != 0:
|
| 65 |
+
raise RuntimeError(f"ptxas failed: {result.stderr}")
|
| 66 |
+
|
| 67 |
+
cubin_data = cubin_tmp.read_bytes()
|
| 68 |
+
_log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})")
|
| 69 |
+
|
| 70 |
+
# Save cubin if CUTE_DSL_KEEP_CUBIN is set
|
| 71 |
+
if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1":
|
| 72 |
+
cubin_out = ptx_path.with_suffix(".cubin")
|
| 73 |
+
cubin_out.write_bytes(cubin_data)
|
| 74 |
+
_log(f"Saved: {cubin_out}")
|
| 75 |
+
|
| 76 |
+
return cubin_data
|
| 77 |
+
finally:
|
| 78 |
+
cubin_tmp.unlink(missing_ok=True)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _patched_load_cuda_library(self):
|
| 82 |
+
"""Replacement for _load_cuda_library that uses system ptxas."""
|
| 83 |
+
|
| 84 |
+
result = _get_ptx(self)
|
| 85 |
+
if not result:
|
| 86 |
+
_log("PTX not found, falling back to embedded ptxas")
|
| 87 |
+
return _original_load_cuda_library(self)
|
| 88 |
+
|
| 89 |
+
ptx_content, ptx_path = result
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
cubin = _compile_ptx(ptx_path, ptx_content)
|
| 93 |
+
except Exception as e:
|
| 94 |
+
_log(f"Compilation failed ({e}), falling back to embedded ptxas")
|
| 95 |
+
return _original_load_cuda_library(self)
|
| 96 |
+
|
| 97 |
+
# Load cubin
|
| 98 |
+
import cuda.bindings.runtime as cuda_runtime
|
| 99 |
+
|
| 100 |
+
err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0)
|
| 101 |
+
if err != cuda_runtime.cudaError_t.cudaSuccess:
|
| 102 |
+
_log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
|
| 103 |
+
return _original_load_cuda_library(self)
|
| 104 |
+
|
| 105 |
+
# Register kernels on all devices
|
| 106 |
+
_, cuda_load_to_device = self._get_cuda_init_and_load()
|
| 107 |
+
lib_ptr = ctypes.c_void_p(int(library))
|
| 108 |
+
dev_id = ctypes.c_int32(0)
|
| 109 |
+
err_val = ctypes.c_int32(0)
|
| 110 |
+
args = (ctypes.c_void_p * 3)(
|
| 111 |
+
ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
|
| 112 |
+
ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
|
| 113 |
+
ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
for dev in range(self.num_devices):
|
| 117 |
+
dev_id.value = dev
|
| 118 |
+
cuda_load_to_device(args)
|
| 119 |
+
if err_val.value != 0:
|
| 120 |
+
_log("cuda_load_to_device failed, falling back to embedded ptxas")
|
| 121 |
+
return _original_load_cuda_library(self)
|
| 122 |
+
|
| 123 |
+
_log(f"Loaded kernel from {ptx_path.name}")
|
| 124 |
+
|
| 125 |
+
# Delete PTX if user didn't originally want it kept
|
| 126 |
+
if not _user_wanted_ptx:
|
| 127 |
+
ptx_path.unlink(missing_ok=True)
|
| 128 |
+
|
| 129 |
+
return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def patch():
|
| 133 |
+
"""Install system ptxas hook. Call before importing cutlass."""
|
| 134 |
+
global _original_load_cuda_library, _user_wanted_ptx
|
| 135 |
+
|
| 136 |
+
assert CUTE_DSL_PTXAS_PATH is not None
|
| 137 |
+
if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
|
| 138 |
+
raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
|
| 139 |
+
|
| 140 |
+
# Track if user originally wanted PTX kept
|
| 141 |
+
_user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
|
| 142 |
+
# os.environ['CUTE_DSL_KEEP_PTX'] = '1'
|
| 143 |
+
assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
|
| 144 |
+
"Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
|
| 148 |
+
_original_load_cuda_library = cls._load_cuda_library
|
| 149 |
+
cls._load_cuda_library = _patched_load_cuda_library
|
| 150 |
+
_log("Patch applied")
|
| 151 |
+
return
|
build/torch-cuda/quack/cute_dsl_utils.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from dataclasses import dataclass, fields
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from triton.tools.disasm import extract
|
| 11 |
+
except ImportError:
|
| 12 |
+
extract = None
|
| 13 |
+
|
| 14 |
+
import cutlass
|
| 15 |
+
import cutlass.cute as cute
|
| 16 |
+
from cutlass import Int32, Int64, Float16, BFloat16, Float32
|
| 17 |
+
from cutlass.base_dsl.typing import JitArgument
|
| 18 |
+
from cutlass.cutlass_dsl import NumericMeta
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
| 25 |
+
cute_compile_og = cute.compile
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
torch2cute_dtype_map = {
|
| 29 |
+
torch.float16: Float16,
|
| 30 |
+
torch.bfloat16: BFloat16,
|
| 31 |
+
torch.float32: Float32,
|
| 32 |
+
torch.int32: Int32,
|
| 33 |
+
torch.int64: Int64,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@lru_cache
|
| 38 |
+
def get_max_active_clusters(cluster_size):
|
| 39 |
+
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@lru_cache
|
| 43 |
+
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
| 44 |
+
return torch.cuda.get_device_capability(device)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class ParamsBase:
|
| 49 |
+
def __extract_mlir_values__(self):
|
| 50 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 51 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 52 |
+
values, self._values_pos = [], []
|
| 53 |
+
for obj in non_constexpr_fields:
|
| 54 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 55 |
+
values += obj_values
|
| 56 |
+
self._values_pos.append(len(obj_values))
|
| 57 |
+
return values
|
| 58 |
+
|
| 59 |
+
def __new_from_mlir_values__(self, values):
|
| 60 |
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 61 |
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 62 |
+
non_constexpr_fields = {
|
| 63 |
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 64 |
+
}
|
| 65 |
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 66 |
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 67 |
+
values = values[n_items:]
|
| 68 |
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class ArgumentsBase(JitArgument):
|
| 73 |
+
def __c_pointers__(self):
|
| 74 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 75 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 76 |
+
c_ptrs = []
|
| 77 |
+
for obj in non_constexpr_fields:
|
| 78 |
+
if hasattr(obj, "__c_pointers__"):
|
| 79 |
+
c_ptrs.extend(obj.__c_pointers__())
|
| 80 |
+
return c_ptrs
|
| 81 |
+
|
| 82 |
+
def __get_mlir_types__(self):
|
| 83 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 84 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 85 |
+
types, self._values_pos = [], []
|
| 86 |
+
for obj in non_constexpr_fields:
|
| 87 |
+
if hasattr(obj, "__get_mlir_types__"):
|
| 88 |
+
obj_types = obj.__get_mlir_types__()
|
| 89 |
+
types.extend(obj_types)
|
| 90 |
+
self._values_pos.append(len(obj_types))
|
| 91 |
+
else:
|
| 92 |
+
self._values_pos.append(0)
|
| 93 |
+
return types
|
| 94 |
+
|
| 95 |
+
def __new_from_mlir_values__(self, values):
|
| 96 |
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 97 |
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 98 |
+
non_constexpr_fields = {
|
| 99 |
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 100 |
+
}
|
| 101 |
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 102 |
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 103 |
+
values = values[n_items:]
|
| 104 |
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
build/torch-cuda/quack/fast_math.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Int32, Uint32
|
| 9 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 10 |
+
from cutlass._mlir.dialects import llvm
|
| 11 |
+
|
| 12 |
+
from .cute_dsl_utils import ParamsBase
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@cute.jit
|
| 16 |
+
def clz(x: Int32) -> Int32:
|
| 17 |
+
# for i in cutlass.range_constexpr(32):
|
| 18 |
+
# if (1 << (31 - i)) & x:
|
| 19 |
+
# return Int32(i)
|
| 20 |
+
# return Int32(32)
|
| 21 |
+
# Early exit is not supported yet
|
| 22 |
+
res = Int32(32)
|
| 23 |
+
done = False
|
| 24 |
+
for i in cutlass.range(32):
|
| 25 |
+
if ((1 << (31 - i)) & x) and not done:
|
| 26 |
+
res = Int32(i)
|
| 27 |
+
done = True
|
| 28 |
+
return res
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def find_log2(x: Int32) -> Int32:
|
| 32 |
+
a: Int32 = Int32(31 - clz(x))
|
| 33 |
+
return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dsl_user_op
|
| 37 |
+
def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
|
| 38 |
+
return Uint32(
|
| 39 |
+
llvm.inline_asm(
|
| 40 |
+
T.i32(),
|
| 41 |
+
[Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
|
| 42 |
+
"mul.hi.u32 $0, $1, $2;",
|
| 43 |
+
"=r,r,r",
|
| 44 |
+
has_side_effects=False,
|
| 45 |
+
is_align_stack=False,
|
| 46 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class FastDivmod(ParamsBase):
|
| 53 |
+
divisor: Int32
|
| 54 |
+
multiplier: Uint32
|
| 55 |
+
shift_right: Uint32
|
| 56 |
+
|
| 57 |
+
# called by host
|
| 58 |
+
@staticmethod
|
| 59 |
+
def create(divisor: Int32) -> "FastDivmod":
|
| 60 |
+
"""Construct the FastDivmod object, in host code.
|
| 61 |
+
This precomputes some values based on the divisor and is computationally expensive.
|
| 62 |
+
"""
|
| 63 |
+
p = Uint32(31 + find_log2(divisor))
|
| 64 |
+
divisor_u32 = Uint32(divisor)
|
| 65 |
+
multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
|
| 66 |
+
shift_right = Uint32(p - 32)
|
| 67 |
+
return FastDivmod(divisor, multiplier, shift_right)
|
| 68 |
+
|
| 69 |
+
@cute.jit
|
| 70 |
+
def div(self, dividend: Int32) -> Int32:
|
| 71 |
+
return (
|
| 72 |
+
Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
|
| 73 |
+
if self.divisor != 1
|
| 74 |
+
else dividend
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
|
| 78 |
+
quotient = self.div(dividend)
|
| 79 |
+
remainder = dividend - quotient * self.divisor
|
| 80 |
+
return quotient, remainder
|
build/torch-cuda/quack/gemm.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
import cutlass.torch as cutlass_torch
|
| 8 |
+
from cutlass import Float32
|
| 9 |
+
from cutlass.cute.runtime import from_dlpack, make_ptr
|
| 10 |
+
|
| 11 |
+
from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
| 12 |
+
from .gemm_wrapper_utils import GemmWrapperBase
|
| 13 |
+
from .gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def gemm(
|
| 17 |
+
# (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
|
| 18 |
+
A: Tensor,
|
| 19 |
+
B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
|
| 20 |
+
D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
| 21 |
+
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 22 |
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 23 |
+
tile_M: int,
|
| 24 |
+
tile_N: int,
|
| 25 |
+
cluster_M: int,
|
| 26 |
+
cluster_N: int,
|
| 27 |
+
pingpong: bool = False,
|
| 28 |
+
persistent: bool = True,
|
| 29 |
+
max_swizzle_size: int = 8,
|
| 30 |
+
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
| 31 |
+
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
| 32 |
+
alpha: float | Tensor = 1.0,
|
| 33 |
+
beta: float | Tensor = 1.0,
|
| 34 |
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
| 35 |
+
cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
|
| 36 |
+
A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
|
| 37 |
+
batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
|
| 38 |
+
add_to_output: bool = False,
|
| 39 |
+
) -> None:
|
| 40 |
+
varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
|
| 41 |
+
assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
|
| 42 |
+
"Only one of cu_seqlens_m and cu_seqlens_k can be specified"
|
| 43 |
+
)
|
| 44 |
+
gather_A = A_idx is not None
|
| 45 |
+
if gather_A:
|
| 46 |
+
assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
|
| 47 |
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
| 48 |
+
if varlen:
|
| 49 |
+
assert persistent, "varlen requires persistent=True"
|
| 50 |
+
if add_to_output:
|
| 51 |
+
assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
|
| 52 |
+
if cu_seqlens_m is not None:
|
| 53 |
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 54 |
+
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
| 55 |
+
if cu_seqlens_k is not None:
|
| 56 |
+
assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
|
| 57 |
+
assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
|
| 58 |
+
|
| 59 |
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
| 60 |
+
A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
|
| 61 |
+
)
|
| 62 |
+
GemmWrapperBase.permute_tensors(
|
| 63 |
+
tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
|
| 64 |
+
)
|
| 65 |
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
| 66 |
+
major_configs = {
|
| 67 |
+
"A": ("m", "k", "l"),
|
| 68 |
+
"B": ("n", "k", "l"),
|
| 69 |
+
"D": ("m", "n", "l"),
|
| 70 |
+
"C": ("m", "n", "l"),
|
| 71 |
+
}
|
| 72 |
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
| 73 |
+
|
| 74 |
+
device_capacity = get_device_capacity(A.device)
|
| 75 |
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
| 76 |
+
GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90
|
| 77 |
+
|
| 78 |
+
acc_dtype = Float32
|
| 79 |
+
tile_shape_mn = (tile_M, tile_N)
|
| 80 |
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
| 81 |
+
if not GemmCls.is_valid_dtypes(
|
| 82 |
+
tensor_infos["A"].dtype,
|
| 83 |
+
tensor_infos["B"].dtype,
|
| 84 |
+
acc_dtype,
|
| 85 |
+
tensor_infos["D"].dtype,
|
| 86 |
+
tensor_infos["A"].major,
|
| 87 |
+
tensor_infos["B"].major,
|
| 88 |
+
):
|
| 89 |
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
| 90 |
+
|
| 91 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 92 |
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
| 93 |
+
|
| 94 |
+
def scalar_arg(scalar: float | Tensor):
|
| 95 |
+
if isinstance(scalar, float):
|
| 96 |
+
return Float32(scalar) if scalar != 1.0 else None
|
| 97 |
+
else:
|
| 98 |
+
assert isinstance(scalar, Tensor)
|
| 99 |
+
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
| 100 |
+
|
| 101 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 102 |
+
scalar_arg(alpha),
|
| 103 |
+
scalar_arg(beta),
|
| 104 |
+
mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
| 105 |
+
leading_dim=1
|
| 106 |
+
)
|
| 107 |
+
if rowvec_bias is not None
|
| 108 |
+
else None,
|
| 109 |
+
mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
| 110 |
+
leading_dim=1 if cu_seqlens_m is None else 0
|
| 111 |
+
)
|
| 112 |
+
if colvec_bias is not None
|
| 113 |
+
else None,
|
| 114 |
+
add_to_output=add_to_output,
|
| 115 |
+
)
|
| 116 |
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
| 117 |
+
max_active_clusters,
|
| 118 |
+
tile_count_semaphore,
|
| 119 |
+
batch_idx_permute,
|
| 120 |
+
max_swizzle_size,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Create varlen arguments if needed (assumes persistent=True when varlen)
|
| 124 |
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
| 125 |
+
cu_seqlens_m,
|
| 126 |
+
cu_seqlens_k,
|
| 127 |
+
A_idx,
|
| 128 |
+
max_active_clusters,
|
| 129 |
+
cluster_shape_mnk,
|
| 130 |
+
tensor_infos,
|
| 131 |
+
GemmCls.num_epi_tensormaps,
|
| 132 |
+
pingpong,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
current_stream = cutlass_torch.current_stream()
|
| 136 |
+
compile_key = GemmWrapperBase.get_compile_key(
|
| 137 |
+
tensor_infos,
|
| 138 |
+
None, # activation
|
| 139 |
+
tile_shape_mn,
|
| 140 |
+
cluster_shape_mnk,
|
| 141 |
+
pingpong,
|
| 142 |
+
persistent,
|
| 143 |
+
tile_count_semaphore is not None,
|
| 144 |
+
device_capacity,
|
| 145 |
+
# Technically we don't need to recompile for different max_swizzle_size, but currently
|
| 146 |
+
# not recompiling will skew the autotuning results due to power throttling.
|
| 147 |
+
# Effectively we're recompiling as a way to pause between benchmarks during autotuning.
|
| 148 |
+
max_swizzle_size,
|
| 149 |
+
rowvec_bias.dtype if rowvec_bias is not None else None,
|
| 150 |
+
colvec_bias.dtype if colvec_bias is not None else None,
|
| 151 |
+
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
| 152 |
+
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
| 153 |
+
add_to_output,
|
| 154 |
+
cu_seqlens_m is not None,
|
| 155 |
+
cu_seqlens_k is not None,
|
| 156 |
+
gather_A,
|
| 157 |
+
batch_idx_permute is not None,
|
| 158 |
+
key_tensor_names=("A", "B", "D", "C"),
|
| 159 |
+
)
|
| 160 |
+
cache = gemm.compile_cache
|
| 161 |
+
if compile_key not in cache:
|
| 162 |
+
if device_capacity[0] == 9:
|
| 163 |
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
| 164 |
+
gemm_obj = GemmCls(
|
| 165 |
+
acc_dtype,
|
| 166 |
+
tensor_infos["A"].dtype,
|
| 167 |
+
tile_shape_mn,
|
| 168 |
+
cluster_shape_mnk,
|
| 169 |
+
gather_A=gather_A,
|
| 170 |
+
)
|
| 171 |
+
cache[compile_key] = cute.compile(
|
| 172 |
+
gemm_obj,
|
| 173 |
+
tensor_infos["A"].cute_tensor,
|
| 174 |
+
tensor_infos["B"].cute_tensor,
|
| 175 |
+
tensor_infos["D"].cute_tensor,
|
| 176 |
+
tensor_infos["C"].cute_tensor,
|
| 177 |
+
epi_args,
|
| 178 |
+
scheduler_args,
|
| 179 |
+
varlen_args,
|
| 180 |
+
current_stream,
|
| 181 |
+
)
|
| 182 |
+
cache[compile_key](
|
| 183 |
+
tensor_infos["A"].cute_tensor,
|
| 184 |
+
tensor_infos["B"].cute_tensor,
|
| 185 |
+
tensor_infos["D"].cute_tensor,
|
| 186 |
+
tensor_infos["C"].cute_tensor,
|
| 187 |
+
epi_args,
|
| 188 |
+
scheduler_args,
|
| 189 |
+
varlen_args,
|
| 190 |
+
current_stream,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
gemm.compile_cache = {}
|
build/torch-cuda/quack/gemm_act.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
| 2 |
+
from typing import Tuple, Optional, Callable
|
| 3 |
+
from functools import partial
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
import cutlass
|
| 9 |
+
import cutlass.cute as cute
|
| 10 |
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
| 11 |
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 12 |
+
from cutlass import Int32, Float32, Boolean, const_expr
|
| 13 |
+
from cutlass.cutlass_dsl import if_generate
|
| 14 |
+
import cutlass.torch as cutlass_torch
|
| 15 |
+
from cutlass.cute.runtime import from_dlpack
|
| 16 |
+
|
| 17 |
+
from .cute_dsl_utils import ArgumentsBase, ParamsBase
|
| 18 |
+
from .varlen_utils import VarlenManager
|
| 19 |
+
from .gemm_sm90 import GemmSm90
|
| 20 |
+
from .gemm_sm100 import GemmSm100
|
| 21 |
+
from .gemm_default_epi import GemmDefaultEpiMixin
|
| 22 |
+
from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
| 23 |
+
from .gemm_wrapper_utils import GemmWrapperBase
|
| 24 |
+
from . import sm90_utils as sm90_utils
|
| 25 |
+
from . import copy_utils as copy_utils
|
| 26 |
+
from . import activation
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class GemmActMixin(GemmDefaultEpiMixin):
|
| 30 |
+
num_epi_tensormaps: int = 1
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class EpilogueArguments(ArgumentsBase):
|
| 34 |
+
mPostAct: cute.Tensor
|
| 35 |
+
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
| 36 |
+
alpha: Optional[Float32 | cute.Tensor] = None
|
| 37 |
+
beta: Optional[Float32 | cute.Tensor] = None
|
| 38 |
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 39 |
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class EpilogueParams(ParamsBase):
|
| 43 |
+
tma_atom_postact: cute.CopyAtom
|
| 44 |
+
mPostAct_mnl: cute.Tensor
|
| 45 |
+
epi_postact_smem_layout_staged: cute.ComposedLayout
|
| 46 |
+
epi_tile_postact: cute.Tile
|
| 47 |
+
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
| 48 |
+
alpha: Optional[Float32 | cute.Tensor] = None
|
| 49 |
+
beta: Optional[Float32 | cute.Tensor] = None
|
| 50 |
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 51 |
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 52 |
+
|
| 53 |
+
def epi_to_underlying_arguments(
|
| 54 |
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
| 55 |
+
) -> EpilogueParams:
|
| 56 |
+
self.postact_dtype = args.mPostAct.element_type
|
| 57 |
+
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
| 58 |
+
|
| 59 |
+
self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
|
| 60 |
+
epi_tile_postact = self.epi_tile
|
| 61 |
+
utils_cls = sm100_utils if self.arch == 100 else sm90_utils
|
| 62 |
+
epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
|
| 63 |
+
self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
|
| 64 |
+
)
|
| 65 |
+
tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
|
| 66 |
+
args.mPostAct,
|
| 67 |
+
epi_postact_smem_layout_staged,
|
| 68 |
+
epi_tile_postact,
|
| 69 |
+
op_type="store",
|
| 70 |
+
)
|
| 71 |
+
# Assume all strides are divisible by 32 bits except the last stride
|
| 72 |
+
new_stride = lambda t: tuple(
|
| 73 |
+
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
| 74 |
+
for s in t.stride
|
| 75 |
+
)
|
| 76 |
+
mRowVecBroadcast, mColVecBroadcast = [
|
| 77 |
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
| 78 |
+
if t is not None
|
| 79 |
+
else None
|
| 80 |
+
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
| 81 |
+
]
|
| 82 |
+
return self.EpilogueParams(
|
| 83 |
+
tma_atom_postact,
|
| 84 |
+
tma_tensor_postact,
|
| 85 |
+
epi_postact_smem_layout_staged,
|
| 86 |
+
epi_tile_postact,
|
| 87 |
+
args.act_fn,
|
| 88 |
+
alpha=args.alpha,
|
| 89 |
+
beta=args.beta,
|
| 90 |
+
mRowVecBroadcast=mRowVecBroadcast,
|
| 91 |
+
mColVecBroadcast=mColVecBroadcast,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def epi_get_tma_atoms(
|
| 95 |
+
self, params: EpilogueParams, *, loc=None, ip=None
|
| 96 |
+
) -> list[cute.CopyAtom]:
|
| 97 |
+
return [params.tma_atom_postact]
|
| 98 |
+
|
| 99 |
+
def epi_get_tensormap_update_shapes_orders(
|
| 100 |
+
self,
|
| 101 |
+
params: EpilogueParams,
|
| 102 |
+
cu_seqlens_m: Optional[cute.Tensor],
|
| 103 |
+
batch_idx: Int32,
|
| 104 |
+
*,
|
| 105 |
+
loc=None,
|
| 106 |
+
ip=None,
|
| 107 |
+
) -> tuple[list[Int32], list[int]]:
|
| 108 |
+
shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None]
|
| 109 |
+
orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
|
| 110 |
+
return shapes, orders
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def epi_smem_bytes_per_stage(
|
| 114 |
+
args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile
|
| 115 |
+
) -> int:
|
| 116 |
+
postact_dtype = args.mPostAct.element_type
|
| 117 |
+
postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8)
|
| 118 |
+
rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
|
| 119 |
+
args, cta_tile_shape_mnk, epi_tile
|
| 120 |
+
)
|
| 121 |
+
return postact_bytes_per_stage + rowvec_colvec_bytes
|
| 122 |
+
|
| 123 |
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
| 124 |
+
row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
|
| 125 |
+
col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
|
| 126 |
+
row_vec_dtype = (
|
| 127 |
+
params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
|
| 128 |
+
)
|
| 129 |
+
col_vec_dtype = (
|
| 130 |
+
params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
@cute.struct
|
| 134 |
+
class EpiSharedStorage:
|
| 135 |
+
sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
|
| 136 |
+
sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
|
| 137 |
+
sPostAct: cute.struct.Align[
|
| 138 |
+
cute.struct.MemRange[
|
| 139 |
+
self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
|
| 140 |
+
],
|
| 141 |
+
self.buffer_align_bytes,
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
return EpiSharedStorage
|
| 145 |
+
|
| 146 |
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
| 147 |
+
sRowVec, sColVec = super().epi_get_smem_tensors(params, storage)
|
| 148 |
+
sPostAct = storage.epi.sPostAct.get_tensor(
|
| 149 |
+
params.epi_postact_smem_layout_staged.outer,
|
| 150 |
+
swizzle=params.epi_postact_smem_layout_staged.inner,
|
| 151 |
+
)
|
| 152 |
+
return (sRowVec, sColVec, sPostAct)
|
| 153 |
+
|
| 154 |
+
@cute.jit
|
| 155 |
+
def epilogue(
|
| 156 |
+
self,
|
| 157 |
+
params: EpilogueParams,
|
| 158 |
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 159 |
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
| 160 |
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
| 161 |
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
| 162 |
+
epi_read_state: cutlass.pipeline.PipelineState,
|
| 163 |
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
| 164 |
+
epi_tile: cute.Tile,
|
| 165 |
+
load_acc_subtile: Callable,
|
| 166 |
+
tRS_rD: cute.Tensor,
|
| 167 |
+
tRS_rC: Optional[cute.Tensor],
|
| 168 |
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
| 169 |
+
tiled_copy_r2s: cute.TiledCopy,
|
| 170 |
+
tRS_sD: cute.Tensor,
|
| 171 |
+
tiled_copy_s2r: Optional[cute.TiledCopy],
|
| 172 |
+
tSR_rC: Optional[cute.Tensor],
|
| 173 |
+
tSR_sC: Optional[cute.Tensor],
|
| 174 |
+
copy_D: Optional[Callable],
|
| 175 |
+
copy_C: Optional[Callable],
|
| 176 |
+
tile_coord_mnkl: cute.Coord,
|
| 177 |
+
varlen_manager: VarlenManager,
|
| 178 |
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
| 179 |
+
tile_scheduler,
|
| 180 |
+
tidx: Int32,
|
| 181 |
+
is_tma_warp: Boolean,
|
| 182 |
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
| 183 |
+
has_C = const_expr(tRS_rC is not None)
|
| 184 |
+
has_D = const_expr(copy_D is not None)
|
| 185 |
+
|
| 186 |
+
tma_atom_postact = params.tma_atom_postact
|
| 187 |
+
mPostAct_mnl = params.mPostAct_mnl
|
| 188 |
+
sRowVec, sColVec, sPostAct = epi_smem_tensors
|
| 189 |
+
get_smem_store_op = (
|
| 190 |
+
partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
|
| 191 |
+
if self.arch == 100
|
| 192 |
+
else sm90_utils_og.sm90_get_smem_store_op
|
| 193 |
+
)
|
| 194 |
+
copy_atom_postact_r2s = get_smem_store_op(
|
| 195 |
+
self.postact_layout, self.postact_dtype, self.acc_dtype
|
| 196 |
+
)
|
| 197 |
+
# tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
| 198 |
+
# tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
| 199 |
+
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
| 200 |
+
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
| 201 |
+
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
| 202 |
+
batch_idx = tile_coord_mnkl[3]
|
| 203 |
+
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
| 204 |
+
tma_atom_postact,
|
| 205 |
+
varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
|
| 206 |
+
self.cta_tile_shape_postact_mn,
|
| 207 |
+
params.epi_tile_postact,
|
| 208 |
+
sPostAct,
|
| 209 |
+
tile_coord_mnkl,
|
| 210 |
+
tma_desc_ptr=tma_desc_postact_ptr,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
| 214 |
+
epi_tile_shape = cute.zipped_divide(
|
| 215 |
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
| 216 |
+
).shape[1]
|
| 217 |
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
| 218 |
+
epi_tile_num = cute.size(epi_tile_shape)
|
| 219 |
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
| 220 |
+
|
| 221 |
+
epi_tensors = self.epi_begin(
|
| 222 |
+
params,
|
| 223 |
+
epi_smem_tensors,
|
| 224 |
+
epi_tile,
|
| 225 |
+
tiled_copy_t2r,
|
| 226 |
+
tiled_copy_r2s,
|
| 227 |
+
tile_coord_mnkl,
|
| 228 |
+
varlen_manager,
|
| 229 |
+
epilogue_barrier,
|
| 230 |
+
tidx,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
if const_expr(copy_C is not None):
|
| 234 |
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
| 235 |
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
| 236 |
+
if is_tma_warp:
|
| 237 |
+
epi_pipeline.producer_acquire(epi_producer_state)
|
| 238 |
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 239 |
+
epi_pipeline.producer_commit(epi_producer_state)
|
| 240 |
+
epi_producer_state.advance()
|
| 241 |
+
|
| 242 |
+
def tma_store_fn(src_idx, dst_idx):
|
| 243 |
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
| 244 |
+
cute.arch.fence_proxy(
|
| 245 |
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 246 |
+
)
|
| 247 |
+
epilogue_barrier.arrive_and_wait()
|
| 248 |
+
# Copy from shared memory to global memory
|
| 249 |
+
if is_tma_warp:
|
| 250 |
+
if const_expr(has_D):
|
| 251 |
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
| 252 |
+
copy_postact(src_idx=src_idx, dst_idx=dst_idx)
|
| 253 |
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
| 254 |
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
| 255 |
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
| 256 |
+
epilogue_barrier.arrive_and_wait()
|
| 257 |
+
|
| 258 |
+
delay_tma_store = True
|
| 259 |
+
|
| 260 |
+
src_idx_prev, dst_idx_prev = None, None
|
| 261 |
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
| 262 |
+
# The global memory coordinate for the current epi tile
|
| 263 |
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
| 264 |
+
# Copy from acc to D registers
|
| 265 |
+
load_acc_subtile(tRS_rD, epi_idx)
|
| 266 |
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
| 267 |
+
if const_expr(has_C):
|
| 268 |
+
epi_pipeline.consumer_wait(epi_read_state)
|
| 269 |
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
| 270 |
+
# Fence to make sure shared memory read is visible to TMA load
|
| 271 |
+
cute.arch.fence_proxy(
|
| 272 |
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 273 |
+
)
|
| 274 |
+
cute.arch.sync_warp()
|
| 275 |
+
with cute.arch.elect_one():
|
| 276 |
+
epi_pipeline.consumer_release(epi_read_state)
|
| 277 |
+
epi_read_state.advance()
|
| 278 |
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
| 279 |
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
| 280 |
+
if is_tma_warp:
|
| 281 |
+
epi_pipeline.producer_acquire(epi_producer_state)
|
| 282 |
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 283 |
+
epi_pipeline.producer_commit(epi_producer_state)
|
| 284 |
+
epi_producer_state.advance()
|
| 285 |
+
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 286 |
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
| 287 |
+
if const_expr(delay_tma_store):
|
| 288 |
+
if const_expr(epi_idx > 0):
|
| 289 |
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
| 290 |
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
| 291 |
+
# Copy from D registers to shared memory
|
| 292 |
+
if const_expr(has_D):
|
| 293 |
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
| 294 |
+
cute.copy(
|
| 295 |
+
tiled_copy_postact_r2s,
|
| 296 |
+
tiled_copy_postact_r2s.retile(tRS_rPostAct),
|
| 297 |
+
tRS_sPostAct[None, None, None, epi_buffer],
|
| 298 |
+
)
|
| 299 |
+
if const_expr(not delay_tma_store):
|
| 300 |
+
tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
|
| 301 |
+
|
| 302 |
+
if const_expr(delay_tma_store):
|
| 303 |
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
| 304 |
+
|
| 305 |
+
self.epi_end(
|
| 306 |
+
params,
|
| 307 |
+
epi_tensors,
|
| 308 |
+
epi_tile,
|
| 309 |
+
tiled_copy_t2r,
|
| 310 |
+
tiled_copy_r2s,
|
| 311 |
+
tile_coord_mnkl,
|
| 312 |
+
varlen_manager,
|
| 313 |
+
tidx,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
return epi_read_state, epi_producer_state
|
| 317 |
+
|
| 318 |
+
@cute.jit
|
| 319 |
+
def epi_visit_subtile(
|
| 320 |
+
self,
|
| 321 |
+
params: EpilogueParams,
|
| 322 |
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 323 |
+
tRS_rD: cute.Tensor,
|
| 324 |
+
tRS_rC: Optional[cute.Tensor] = None,
|
| 325 |
+
) -> Optional[cute.Tensor]:
|
| 326 |
+
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 327 |
+
# Apply activation function if provided
|
| 328 |
+
# If we don't have .shape here, the compiler generates local stores and loads
|
| 329 |
+
if const_expr(params.act_fn is not None):
|
| 330 |
+
tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
|
| 331 |
+
if const_expr(self.arch < 100):
|
| 332 |
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 333 |
+
tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
|
| 334 |
+
else:
|
| 335 |
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
| 336 |
+
tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
|
| 337 |
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1])
|
| 338 |
+
)
|
| 339 |
+
else:
|
| 340 |
+
tRS_rPostAct = tRS_rD
|
| 341 |
+
# Type conversion
|
| 342 |
+
tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
|
| 343 |
+
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
| 344 |
+
return tRS_rPostAct_out
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class GemmActSm90(GemmActMixin, GemmSm90):
|
| 348 |
+
pass
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class GemmActSm100(GemmActMixin, GemmSm100):
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
act_fn_map = {
|
| 356 |
+
None: None,
|
| 357 |
+
"relu": activation.relu,
|
| 358 |
+
"relu_sq": activation.relu_sq,
|
| 359 |
+
"gelu_tanh_approx": activation.gelu_tanh_approx,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def gemm_act(
|
| 364 |
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
| 365 |
+
B: Tensor, # (l, n, k)
|
| 366 |
+
D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 367 |
+
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 368 |
+
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
| 369 |
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 370 |
+
activation: Optional[str],
|
| 371 |
+
tile_M: int,
|
| 372 |
+
tile_N: int,
|
| 373 |
+
cluster_M: int,
|
| 374 |
+
cluster_N: int,
|
| 375 |
+
pingpong: bool = False,
|
| 376 |
+
persistent: bool = True,
|
| 377 |
+
max_swizzle_size: int = 8,
|
| 378 |
+
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
| 379 |
+
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
| 380 |
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
| 381 |
+
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
| 382 |
+
) -> None:
|
| 383 |
+
if cu_seqlens_m is not None:
|
| 384 |
+
assert persistent, "varlen_m requires persistent=True"
|
| 385 |
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 386 |
+
if D is not None:
|
| 387 |
+
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
| 388 |
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
| 389 |
+
gather_A = A_idx is not None
|
| 390 |
+
if gather_A:
|
| 391 |
+
assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
|
| 392 |
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
| 393 |
+
assert activation in act_fn_map, f"Unsupported activation {activation}"
|
| 394 |
+
|
| 395 |
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
| 396 |
+
A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
|
| 397 |
+
)
|
| 398 |
+
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
| 399 |
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
| 400 |
+
major_configs = {
|
| 401 |
+
"A": ("m", "k", "l"),
|
| 402 |
+
"B": ("n", "k", "l"),
|
| 403 |
+
"D": ("m", "n", "l"),
|
| 404 |
+
"C": ("m", "n", "l"),
|
| 405 |
+
"PostAct": ("m", "n", "l"),
|
| 406 |
+
}
|
| 407 |
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
| 408 |
+
|
| 409 |
+
device_capacity = get_device_capacity(A.device)
|
| 410 |
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
| 411 |
+
GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90
|
| 412 |
+
|
| 413 |
+
acc_dtype = Float32
|
| 414 |
+
tile_shape_mn = (tile_M, tile_N)
|
| 415 |
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
| 416 |
+
if not GemmCls.is_valid_dtypes(
|
| 417 |
+
tensor_infos["A"].dtype,
|
| 418 |
+
tensor_infos["B"].dtype,
|
| 419 |
+
acc_dtype,
|
| 420 |
+
tensor_infos["D"].dtype,
|
| 421 |
+
tensor_infos["A"].major,
|
| 422 |
+
tensor_infos["B"].major,
|
| 423 |
+
):
|
| 424 |
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
| 425 |
+
|
| 426 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 427 |
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
| 428 |
+
act_fn = act_fn_map[activation]
|
| 429 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 430 |
+
tensor_infos["PostAct"].cute_tensor,
|
| 431 |
+
act_fn,
|
| 432 |
+
mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
| 433 |
+
leading_dim=1
|
| 434 |
+
)
|
| 435 |
+
if rowvec_bias is not None
|
| 436 |
+
else None,
|
| 437 |
+
mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
| 438 |
+
leading_dim=1 if cu_seqlens_m is None else 0
|
| 439 |
+
)
|
| 440 |
+
if colvec_bias is not None
|
| 441 |
+
else None,
|
| 442 |
+
)
|
| 443 |
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
| 444 |
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
| 448 |
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
| 449 |
+
cu_seqlens_m,
|
| 450 |
+
None, # cu_seqlens_k
|
| 451 |
+
A_idx,
|
| 452 |
+
max_active_clusters,
|
| 453 |
+
cluster_shape_mnk,
|
| 454 |
+
tensor_infos,
|
| 455 |
+
GemmCls.num_epi_tensormaps,
|
| 456 |
+
pingpong,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
current_stream = cutlass_torch.current_stream()
|
| 460 |
+
compile_key = GemmWrapperBase.get_compile_key(
|
| 461 |
+
tensor_infos,
|
| 462 |
+
activation,
|
| 463 |
+
tile_shape_mn,
|
| 464 |
+
cluster_shape_mnk,
|
| 465 |
+
pingpong,
|
| 466 |
+
persistent,
|
| 467 |
+
tile_count_semaphore is not None,
|
| 468 |
+
device_capacity,
|
| 469 |
+
max_swizzle_size,
|
| 470 |
+
rowvec_bias.dtype if rowvec_bias is not None else None,
|
| 471 |
+
colvec_bias.dtype if colvec_bias is not None else None,
|
| 472 |
+
cu_seqlens_m is not None,
|
| 473 |
+
A_idx is not None,
|
| 474 |
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
| 475 |
+
)
|
| 476 |
+
cache = gemm_act.compile_cache
|
| 477 |
+
if compile_key not in cache:
|
| 478 |
+
if device_capacity[0] == 9:
|
| 479 |
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
| 480 |
+
gemm_obj = GemmCls(
|
| 481 |
+
acc_dtype,
|
| 482 |
+
tensor_infos["A"].dtype,
|
| 483 |
+
tile_shape_mn,
|
| 484 |
+
cluster_shape_mnk,
|
| 485 |
+
gather_A=gather_A,
|
| 486 |
+
)
|
| 487 |
+
cache[compile_key] = cute.compile(
|
| 488 |
+
gemm_obj,
|
| 489 |
+
tensor_infos["A"].cute_tensor,
|
| 490 |
+
tensor_infos["B"].cute_tensor,
|
| 491 |
+
tensor_infos["D"].cute_tensor,
|
| 492 |
+
tensor_infos["C"].cute_tensor,
|
| 493 |
+
epi_args,
|
| 494 |
+
scheduler_args,
|
| 495 |
+
varlen_args,
|
| 496 |
+
current_stream,
|
| 497 |
+
)
|
| 498 |
+
cache[compile_key](
|
| 499 |
+
tensor_infos["A"].cute_tensor,
|
| 500 |
+
tensor_infos["B"].cute_tensor,
|
| 501 |
+
tensor_infos["D"].cute_tensor,
|
| 502 |
+
tensor_infos["C"].cute_tensor,
|
| 503 |
+
epi_args,
|
| 504 |
+
scheduler_args,
|
| 505 |
+
varlen_args,
|
| 506 |
+
current_stream,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
gemm_act.compile_cache = {}
|
build/torch-cuda/quack/gemm_config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025, Fri Dao.
|
| 2 |
+
import itertools
|
| 3 |
+
from typing import Optional, List, Literal
|
| 4 |
+
from functools import partial
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class GemmConfig:
|
| 10 |
+
tile_m: int = 128
|
| 11 |
+
tile_n: int = 192
|
| 12 |
+
pingpong: bool = True
|
| 13 |
+
cluster_m: int = 2
|
| 14 |
+
cluster_n: int = 1
|
| 15 |
+
swap_ab: bool = False
|
| 16 |
+
# raster_order: int = 1
|
| 17 |
+
max_swizzle_size: int = 8
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_all_configs(
|
| 21 |
+
device_capacity: Literal[9, 10] = 9,
|
| 22 |
+
epilogue: Optional[str] = None,
|
| 23 |
+
tune_coop: bool = True,
|
| 24 |
+
# tune_raster_order=True,
|
| 25 |
+
) -> List[GemmConfig]:
|
| 26 |
+
assert device_capacity in [9, 10]
|
| 27 |
+
if device_capacity == 9:
|
| 28 |
+
tile_n_vals = [128, 144, 160, 176, 192, 208]
|
| 29 |
+
tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
|
| 30 |
+
(128, 224),
|
| 31 |
+
(128, 256),
|
| 32 |
+
# (192, 256), # Getting IOT instruction (core dumped) in the bwd
|
| 33 |
+
]
|
| 34 |
+
tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
|
| 35 |
+
if epilogue in ["gated"]:
|
| 36 |
+
tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
|
| 37 |
+
tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
|
| 38 |
+
elif epilogue in ["lse"]:
|
| 39 |
+
tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
|
| 40 |
+
tile_mn_vals = []
|
| 41 |
+
if tune_coop:
|
| 42 |
+
tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
|
| 43 |
+
tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
|
| 44 |
+
cluster = [(1, 2), (2, 1)]
|
| 45 |
+
# cluster = [(1, 1), (1, 2), (2, 1)]
|
| 46 |
+
if epilogue in ["lse"]:
|
| 47 |
+
cluster = [(1, 2), (2, 1)]
|
| 48 |
+
swap_ab_vals = [False, True]
|
| 49 |
+
if epilogue in ["lse", "gated"]:
|
| 50 |
+
swap_ab_vals = [False]
|
| 51 |
+
# raster_swizzle = (
|
| 52 |
+
# [(0, 1)]
|
| 53 |
+
# if not tune_raster_order
|
| 54 |
+
# else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
|
| 55 |
+
# )
|
| 56 |
+
return [
|
| 57 |
+
GemmConfig(
|
| 58 |
+
tile_m=tile_m,
|
| 59 |
+
tile_n=tile_n,
|
| 60 |
+
pingpong=pingpong,
|
| 61 |
+
cluster_m=cluster_m,
|
| 62 |
+
cluster_n=cluster_n,
|
| 63 |
+
swap_ab=swap_ab,
|
| 64 |
+
# raster_order=raster_order,
|
| 65 |
+
# max_swizzle_size=max_swizzle_size,
|
| 66 |
+
)
|
| 67 |
+
for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
|
| 68 |
+
tile_mn_vals,
|
| 69 |
+
cluster,
|
| 70 |
+
swap_ab_vals,
|
| 71 |
+
# raster_swizzle,
|
| 72 |
+
)
|
| 73 |
+
]
|
| 74 |
+
elif device_capacity == 10:
|
| 75 |
+
tile_n_vals = [128, 160, 192, 224, 256]
|
| 76 |
+
tile_n_64_vals = [128, 192, 256]
|
| 77 |
+
tile_mn_cluster_vals = (
|
| 78 |
+
[(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
|
| 79 |
+
# + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals]
|
| 80 |
+
+ [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
|
| 81 |
+
+ [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
|
| 82 |
+
)
|
| 83 |
+
swap_ab_vals = [False, True]
|
| 84 |
+
if epilogue in ["lse", "gated"]:
|
| 85 |
+
swap_ab_vals = [False]
|
| 86 |
+
max_swizzle_size_vals = [4, 8, 16]
|
| 87 |
+
GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100
|
| 88 |
+
return [
|
| 89 |
+
GemmConfigCls(
|
| 90 |
+
tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms
|
| 91 |
+
)
|
| 92 |
+
for (m, n, (cm, cn)), sab, ms in itertools.product(
|
| 93 |
+
tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals
|
| 94 |
+
)
|
| 95 |
+
]
|
build/torch-cuda/quack/gemm_dact.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Float32, const_expr
|
| 10 |
+
import cutlass.torch as cutlass_torch
|
| 11 |
+
|
| 12 |
+
from .gemm_sm90 import GemmSm90
|
| 13 |
+
from .gemm_sm100 import GemmSm100
|
| 14 |
+
from .gemm_default_epi import GemmDefaultEpiMixin
|
| 15 |
+
from .gemm_act import GemmActMixin
|
| 16 |
+
from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
| 17 |
+
from .gemm_wrapper_utils import GemmWrapperBase
|
| 18 |
+
from . import activation
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class GemmDActMixin(GemmActMixin):
|
| 22 |
+
# Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
|
| 23 |
+
# and return 2 arguments (dx, out)
|
| 24 |
+
EpilogueArguments = GemmActMixin.EpilogueArguments
|
| 25 |
+
EpilogueParams = GemmActMixin.EpilogueParams
|
| 26 |
+
|
| 27 |
+
@cute.jit
|
| 28 |
+
def epi_visit_subtile(
|
| 29 |
+
self,
|
| 30 |
+
params: EpilogueParams,
|
| 31 |
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 32 |
+
tRS_rD: cute.Tensor,
|
| 33 |
+
tRS_rC: Optional[cute.Tensor] = None,
|
| 34 |
+
) -> Optional[cute.Tensor]:
|
| 35 |
+
assert tRS_rC is not None
|
| 36 |
+
# We don't add C to the accumulator
|
| 37 |
+
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
|
| 38 |
+
tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
|
| 39 |
+
tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
|
| 40 |
+
# If we don't have .shape here, the compiler generates local stores and loads
|
| 41 |
+
if const_expr(params.act_fn is not None):
|
| 42 |
+
tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
|
| 43 |
+
if const_expr(self.arch < 100):
|
| 44 |
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 45 |
+
tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
|
| 46 |
+
else:
|
| 47 |
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
| 48 |
+
(
|
| 49 |
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1]),
|
| 50 |
+
(tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]),
|
| 51 |
+
) = params.act_fn(
|
| 52 |
+
(tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]),
|
| 53 |
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1]),
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
tRS_rPostAct = tRS_rC_acc
|
| 57 |
+
# Type conversion
|
| 58 |
+
tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
|
| 59 |
+
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
| 60 |
+
return tRS_rPostAct_out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GemmDActSm90(GemmDActMixin, GemmSm90):
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class GemmDActSm100(GemmDActMixin, GemmSm100):
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
dact_fn_map = {
|
| 72 |
+
None: None,
|
| 73 |
+
"relu": activation.drelu,
|
| 74 |
+
"relu_sq": activation.drelu_sq,
|
| 75 |
+
"gelu_tanh_approx": activation.dgelu_tanh_approx,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def gemm_dact(
|
| 80 |
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
| 81 |
+
B: Tensor, # (l, n, k)
|
| 82 |
+
Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
| 83 |
+
PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
| 84 |
+
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
| 85 |
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 86 |
+
activation: Optional[str],
|
| 87 |
+
tile_M: int,
|
| 88 |
+
tile_N: int,
|
| 89 |
+
cluster_M: int,
|
| 90 |
+
cluster_N: int,
|
| 91 |
+
pingpong: bool = True,
|
| 92 |
+
persistent: bool = True,
|
| 93 |
+
max_swizzle_size: int = 8,
|
| 94 |
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
| 95 |
+
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
| 96 |
+
) -> None:
|
| 97 |
+
if cu_seqlens_m is not None:
|
| 98 |
+
assert persistent, "varlen_m requires persistent=True"
|
| 99 |
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 100 |
+
assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
|
| 101 |
+
assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
|
| 102 |
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
| 103 |
+
gather_A = A_idx is not None
|
| 104 |
+
if gather_A:
|
| 105 |
+
assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
|
| 106 |
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
| 107 |
+
assert activation in dact_fn_map, f"Unsupported activation {activation}"
|
| 108 |
+
|
| 109 |
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
| 110 |
+
A,
|
| 111 |
+
B,
|
| 112 |
+
Out,
|
| 113 |
+
PreAct,
|
| 114 |
+
additional_tensors={"PostAct": PostAct},
|
| 115 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 116 |
+
A_idx=A_idx,
|
| 117 |
+
)
|
| 118 |
+
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
| 119 |
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
| 120 |
+
major_configs = {
|
| 121 |
+
"A": ("m", "k", "l"),
|
| 122 |
+
"B": ("n", "k", "l"),
|
| 123 |
+
"D": ("m", "n", "l"),
|
| 124 |
+
"C": ("m", "n", "l"),
|
| 125 |
+
"PostAct": ("m", "n", "l"),
|
| 126 |
+
}
|
| 127 |
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
| 128 |
+
|
| 129 |
+
device_capacity = get_device_capacity(A.device)
|
| 130 |
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
| 131 |
+
GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90
|
| 132 |
+
|
| 133 |
+
acc_dtype = Float32
|
| 134 |
+
tile_shape_mn = (tile_M, tile_N)
|
| 135 |
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
| 136 |
+
if not GemmCls.is_valid_dtypes(
|
| 137 |
+
tensor_infos["A"].dtype,
|
| 138 |
+
tensor_infos["B"].dtype,
|
| 139 |
+
acc_dtype,
|
| 140 |
+
tensor_infos["D"].dtype,
|
| 141 |
+
tensor_infos["A"].major,
|
| 142 |
+
tensor_infos["B"].major,
|
| 143 |
+
):
|
| 144 |
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
| 145 |
+
|
| 146 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 147 |
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
| 148 |
+
act_fn = dact_fn_map[activation]
|
| 149 |
+
epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
|
| 150 |
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
| 151 |
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
| 155 |
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
| 156 |
+
cu_seqlens_m,
|
| 157 |
+
None, # cu_seqlens_k
|
| 158 |
+
A_idx,
|
| 159 |
+
max_active_clusters,
|
| 160 |
+
cluster_shape_mnk,
|
| 161 |
+
tensor_infos,
|
| 162 |
+
GemmCls.num_epi_tensormaps,
|
| 163 |
+
pingpong,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
current_stream = cutlass_torch.current_stream()
|
| 167 |
+
compile_key = GemmWrapperBase.get_compile_key(
|
| 168 |
+
tensor_infos,
|
| 169 |
+
activation,
|
| 170 |
+
tile_shape_mn,
|
| 171 |
+
cluster_shape_mnk,
|
| 172 |
+
pingpong,
|
| 173 |
+
persistent,
|
| 174 |
+
tile_count_semaphore is not None,
|
| 175 |
+
device_capacity,
|
| 176 |
+
max_swizzle_size,
|
| 177 |
+
cu_seqlens_m is not None,
|
| 178 |
+
A_idx is not None,
|
| 179 |
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
| 180 |
+
)
|
| 181 |
+
cache = gemm_dact.compile_cache
|
| 182 |
+
if compile_key not in cache:
|
| 183 |
+
if device_capacity[0] == 9:
|
| 184 |
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
| 185 |
+
gemm = GemmCls(
|
| 186 |
+
acc_dtype,
|
| 187 |
+
tensor_infos["A"].dtype,
|
| 188 |
+
tile_shape_mn,
|
| 189 |
+
cluster_shape_mnk,
|
| 190 |
+
gather_A=gather_A,
|
| 191 |
+
)
|
| 192 |
+
cache[compile_key] = cute.compile(
|
| 193 |
+
gemm,
|
| 194 |
+
tensor_infos["A"].cute_tensor,
|
| 195 |
+
tensor_infos["B"].cute_tensor,
|
| 196 |
+
tensor_infos["D"].cute_tensor,
|
| 197 |
+
tensor_infos["C"].cute_tensor,
|
| 198 |
+
epi_args,
|
| 199 |
+
scheduler_args,
|
| 200 |
+
varlen_args,
|
| 201 |
+
current_stream,
|
| 202 |
+
)
|
| 203 |
+
cache[compile_key](
|
| 204 |
+
tensor_infos["A"].cute_tensor,
|
| 205 |
+
tensor_infos["B"].cute_tensor,
|
| 206 |
+
tensor_infos["D"].cute_tensor,
|
| 207 |
+
tensor_infos["C"].cute_tensor,
|
| 208 |
+
epi_args,
|
| 209 |
+
scheduler_args,
|
| 210 |
+
varlen_args,
|
| 211 |
+
current_stream,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
gemm_dact.compile_cache = {}
|
build/torch-cuda/quack/gemm_default_epi.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
from functools import partial
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Int32, Float32, Boolean, const_expr
|
| 10 |
+
|
| 11 |
+
from .cute_dsl_utils import ArgumentsBase, ParamsBase
|
| 12 |
+
from .gemm_sm90 import GemmSm90
|
| 13 |
+
from .gemm_sm100 import GemmSm100
|
| 14 |
+
from .sm90_utils import partition_for_epilogue
|
| 15 |
+
from . import utils as utils
|
| 16 |
+
from . import copy_utils as copy_utils
|
| 17 |
+
from .varlen_utils import VarlenManager
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GemmDefaultEpiMixin:
|
| 21 |
+
num_epi_tensormaps: int = 0
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class EpilogueArguments(ArgumentsBase):
|
| 25 |
+
alpha: Optional[Float32 | cute.Tensor] = None
|
| 26 |
+
beta: Optional[Float32 | cute.Tensor] = None
|
| 27 |
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 28 |
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 29 |
+
add_to_output: bool = False
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class EpilogueParams(ParamsBase):
|
| 33 |
+
alpha: Optional[Float32 | cute.Tensor] = None
|
| 34 |
+
beta: Optional[Float32 | cute.Tensor] = None
|
| 35 |
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 36 |
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 37 |
+
|
| 38 |
+
def epi_to_underlying_arguments(
|
| 39 |
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
| 40 |
+
) -> EpilogueParams:
|
| 41 |
+
# Assume all strides are divisible by 32 bits except the last stride
|
| 42 |
+
new_stride = lambda t: tuple(
|
| 43 |
+
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
| 44 |
+
for s in t.stride
|
| 45 |
+
)
|
| 46 |
+
mRowVecBroadcast, mColVecBroadcast = [
|
| 47 |
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
| 48 |
+
if t is not None
|
| 49 |
+
else None
|
| 50 |
+
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
| 51 |
+
]
|
| 52 |
+
return self.EpilogueParams(
|
| 53 |
+
alpha=args.alpha,
|
| 54 |
+
beta=args.beta,
|
| 55 |
+
mRowVecBroadcast=mRowVecBroadcast,
|
| 56 |
+
mColVecBroadcast=mColVecBroadcast,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@cute.jit
|
| 60 |
+
def epi_begin(
|
| 61 |
+
self,
|
| 62 |
+
params: EpilogueParams,
|
| 63 |
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 64 |
+
epi_tile: cute.Tile,
|
| 65 |
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
| 66 |
+
tiled_copy_r2s: cute.TiledCopy,
|
| 67 |
+
tile_coord_mnkl: cute.Coord,
|
| 68 |
+
varlen_manager: VarlenManager,
|
| 69 |
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
| 70 |
+
tidx: Int32,
|
| 71 |
+
):
|
| 72 |
+
alpha, beta = None, None
|
| 73 |
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
| 74 |
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
| 75 |
+
if const_expr(hasattr(params, "beta") and params.beta is not None):
|
| 76 |
+
beta = utils.load_scalar_or_pointer(params.beta)
|
| 77 |
+
sRowVec, sColVec, *rest = epi_smem_tensors
|
| 78 |
+
tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
|
| 79 |
+
batch_idx = tile_coord_mnkl[3]
|
| 80 |
+
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
|
| 81 |
+
# Don't need sync as we assume the previous epilogue has finished
|
| 82 |
+
|
| 83 |
+
partition_for_epilogue_fn = partial(
|
| 84 |
+
partition_for_epilogue,
|
| 85 |
+
epi_tile=epi_tile,
|
| 86 |
+
tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
|
| 87 |
+
tidx=tidx,
|
| 88 |
+
reference_src=tiled_copy_t2r is None,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
tDsRowVec = None
|
| 92 |
+
if const_expr(params.mRowVecBroadcast is not None):
|
| 93 |
+
rowvec_dtype = params.mRowVecBroadcast.element_type
|
| 94 |
+
num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width
|
| 95 |
+
thr_copy_RV = copy_utils.tiled_copy_1d(
|
| 96 |
+
params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
|
| 97 |
+
).get_slice(tidx)
|
| 98 |
+
mRowVec = params.mRowVecBroadcast[batch_idx, None]
|
| 99 |
+
gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],))
|
| 100 |
+
tRVgRV = thr_copy_RV.partition_S(gRowVec)
|
| 101 |
+
tRVsRV = thr_copy_RV.partition_D(sRowVec)
|
| 102 |
+
tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
|
| 103 |
+
limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
|
| 104 |
+
tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
|
| 105 |
+
for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
|
| 106 |
+
tRVpRV[0, m] = tRVcRV[0, m] < limit_n
|
| 107 |
+
cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
|
| 108 |
+
# (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
|
| 109 |
+
tDsRowVec = partition_for_epilogue_fn(
|
| 110 |
+
cute.make_tensor(
|
| 111 |
+
sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1))
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
if const_expr(tiled_copy_t2r is not None):
|
| 115 |
+
tDsRowVec = tiled_copy_r2s.retile(tDsRowVec)
|
| 116 |
+
|
| 117 |
+
tDsColVec = None
|
| 118 |
+
if const_expr(params.mColVecBroadcast is not None):
|
| 119 |
+
colvec_dtype = params.mColVecBroadcast.element_type
|
| 120 |
+
num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width
|
| 121 |
+
thr_copy_CV = copy_utils.tiled_copy_1d(
|
| 122 |
+
params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
|
| 123 |
+
).get_slice(tidx)
|
| 124 |
+
if const_expr(not varlen_manager.varlen_m):
|
| 125 |
+
mColVec = params.mColVecBroadcast[batch_idx, None]
|
| 126 |
+
else:
|
| 127 |
+
mColVec = cute.domain_offset(
|
| 128 |
+
(varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast
|
| 129 |
+
)
|
| 130 |
+
gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
|
| 131 |
+
tCVgCV = thr_copy_CV.partition_S(gColVec)
|
| 132 |
+
tCVsCV = thr_copy_CV.partition_D(sColVec)
|
| 133 |
+
tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
|
| 134 |
+
limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
|
| 135 |
+
tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
|
| 136 |
+
for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
|
| 137 |
+
tCVpCV[0, m] = tCVcCV[0, m] < limit_m
|
| 138 |
+
cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
|
| 139 |
+
tDsColVec = partition_for_epilogue_fn(
|
| 140 |
+
cute.make_tensor(
|
| 141 |
+
sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0))
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
if const_expr(tiled_copy_t2r is not None):
|
| 145 |
+
tDsColVec = tiled_copy_r2s.retile(tDsColVec)
|
| 146 |
+
|
| 147 |
+
if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None):
|
| 148 |
+
cute.arch.cp_async_commit_group()
|
| 149 |
+
cute.arch.cp_async_wait_group(0)
|
| 150 |
+
epilogue_barrier.arrive_and_wait()
|
| 151 |
+
return alpha, beta, tDsRowVec, tDsColVec
|
| 152 |
+
|
| 153 |
+
def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
|
| 154 |
+
alpha, beta, tDsRowVec, tDsColVec = epi_tensors
|
| 155 |
+
tDrRowVec_cvt = None
|
| 156 |
+
if const_expr(tDsRowVec is not None):
|
| 157 |
+
tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[
|
| 158 |
+
None, None, None, epi_coord
|
| 159 |
+
]
|
| 160 |
+
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
| 161 |
+
tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
|
| 162 |
+
cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
|
| 163 |
+
tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
|
| 164 |
+
tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
|
| 165 |
+
tDrColVec_cvt = None
|
| 166 |
+
if const_expr(tDsColVec is not None):
|
| 167 |
+
tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[
|
| 168 |
+
None, None, None, epi_coord
|
| 169 |
+
]
|
| 170 |
+
# This somehow doesn't work, some dim with stride 0 turns to non-zero stride
|
| 171 |
+
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
| 172 |
+
tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
|
| 173 |
+
cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
|
| 174 |
+
tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
|
| 175 |
+
tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
|
| 176 |
+
return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt
|
| 177 |
+
|
| 178 |
+
@cute.jit
|
| 179 |
+
def epi_visit_subtile(
|
| 180 |
+
self,
|
| 181 |
+
params: EpilogueParams,
|
| 182 |
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 183 |
+
tRS_rD: cute.Tensor,
|
| 184 |
+
tRS_rC: Optional[cute.Tensor] = None,
|
| 185 |
+
) -> Optional[cute.Tensor]:
|
| 186 |
+
alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors
|
| 187 |
+
rD = tRS_rD.load()
|
| 188 |
+
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
| 189 |
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
| 190 |
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
| 191 |
+
rD *= alpha
|
| 192 |
+
# Apply C with beta scaling
|
| 193 |
+
if const_expr(tRS_rC is not None):
|
| 194 |
+
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
| 195 |
+
# beta is None, default behavior: add C (beta=1.0)
|
| 196 |
+
rD += tRS_rC.load().to(tRS_rD.element_type)
|
| 197 |
+
else:
|
| 198 |
+
beta = utils.load_scalar_or_pointer(params.beta)
|
| 199 |
+
rD += beta * tRS_rC.load().to(tRS_rD.element_type)
|
| 200 |
+
tRS_rD.store(rD)
|
| 201 |
+
if const_expr(tDrRowVec is not None):
|
| 202 |
+
for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True):
|
| 203 |
+
tRS_rD[i] += tDrRowVec[i]
|
| 204 |
+
if const_expr(tDrColVec is not None):
|
| 205 |
+
for i in cutlass.range(cute.size(tDrColVec), unroll_full=True):
|
| 206 |
+
tRS_rD[i] += tDrColVec[i]
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
def epi_smem_bytes_per_stage(
|
| 211 |
+
args: Optional[EpilogueArguments],
|
| 212 |
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
| 213 |
+
epi_tile: cute.Tile,
|
| 214 |
+
) -> int:
|
| 215 |
+
row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1]
|
| 216 |
+
col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0]
|
| 217 |
+
row_vec_dtype = (
|
| 218 |
+
args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32
|
| 219 |
+
)
|
| 220 |
+
col_vec_dtype = (
|
| 221 |
+
args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32
|
| 222 |
+
)
|
| 223 |
+
return (
|
| 224 |
+
row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width
|
| 225 |
+
) // 8
|
| 226 |
+
|
| 227 |
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
| 228 |
+
row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
|
| 229 |
+
col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
|
| 230 |
+
row_vec_dtype = (
|
| 231 |
+
params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
|
| 232 |
+
)
|
| 233 |
+
col_vec_dtype = (
|
| 234 |
+
params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
@cute.struct
|
| 238 |
+
class EpiSharedStorage:
|
| 239 |
+
sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
|
| 240 |
+
sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
|
| 241 |
+
|
| 242 |
+
return EpiSharedStorage
|
| 243 |
+
|
| 244 |
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
| 245 |
+
sRowVec = None
|
| 246 |
+
if const_expr(params.mRowVecBroadcast is not None):
|
| 247 |
+
sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1]))
|
| 248 |
+
sColVec = None
|
| 249 |
+
if const_expr(params.mColVecBroadcast is not None):
|
| 250 |
+
sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0]))
|
| 251 |
+
return (sRowVec, sColVec)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
|
| 259 |
+
pass
|
build/torch-cuda/quack/gemm_interface.py
ADDED
|
@@ -0,0 +1,1058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao
|
| 2 |
+
from typing import Optional, Tuple, Literal
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from ._ops_compat import add_quack_op_namespace_prefix
|
| 9 |
+
|
| 10 |
+
from .gemm_config import GemmConfig, get_all_configs
|
| 11 |
+
|
| 12 |
+
from .autotuner import autotune, AutotuneConfig
|
| 13 |
+
from .cute_dsl_utils import get_device_capacity
|
| 14 |
+
from .gemm import gemm as gemm_sm90_sm100
|
| 15 |
+
from .gemm_act import gemm_act as gemm_act_sm90_sm100
|
| 16 |
+
from .gemm_dact import gemm_dact as gemm_dact_sm90_sm100
|
| 17 |
+
from .gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Dictionary mapping activation names to PyTorch functions
|
| 21 |
+
act_to_pytorch_fn_map = {
|
| 22 |
+
None: lambda x: x,
|
| 23 |
+
"relu": F.relu,
|
| 24 |
+
"relu_sq": lambda x: F.relu(x).square(),
|
| 25 |
+
"gelu_tanh_approx": partial(F.gelu, approximate="tanh"),
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Dictionary mapping gated activation names to their forward functions
|
| 30 |
+
# Each function takes (gate, up) and returns postact
|
| 31 |
+
gated_to_pytorch_fn_map = {
|
| 32 |
+
"swiglu": lambda gate, up: F.silu(gate) * up,
|
| 33 |
+
"swiglu_oai": lambda gate, up: gate * torch.sigmoid(1.702 * gate) * (up + 1),
|
| 34 |
+
"reglu": lambda gate, up: F.relu(gate) * up,
|
| 35 |
+
"geglu": lambda gate, up: F.gelu(gate, approximate="tanh") * up,
|
| 36 |
+
"glu": lambda gate, up: torch.sigmoid(gate) * up,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _get_default_device_capacity():
|
| 41 |
+
if not torch.cuda.is_available():
|
| 42 |
+
return (9, 0)
|
| 43 |
+
cap = get_device_capacity(torch.device("cuda"))
|
| 44 |
+
if cap[0] not in (9, 10):
|
| 45 |
+
return (9, 0)
|
| 46 |
+
return cap
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class _LazyDeviceCapacity:
|
| 50 |
+
"""Defer torch.cuda.get_device_capability until first access so the
|
| 51 |
+
module can be imported in environments without a GPU (e.g. nix build)."""
|
| 52 |
+
_value = None
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
if self._value is None:
|
| 55 |
+
self._value = _get_default_device_capacity()
|
| 56 |
+
return self._value[idx]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
default_device_capacity = _LazyDeviceCapacity()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def default_config(device):
|
| 63 |
+
if get_device_capacity(device)[0] != 10:
|
| 64 |
+
return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
| 65 |
+
else:
|
| 66 |
+
return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
|
| 70 |
+
kwargs = named_args | kwargs
|
| 71 |
+
gather_A = kwargs.get("A_idx", None) is not None
|
| 72 |
+
varlen_m = kwargs.get("cu_seqlens_m", None) is not None
|
| 73 |
+
if varlen_m or gather_A: # Doesn't support swap_ab
|
| 74 |
+
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
| 75 |
+
if gather_A:
|
| 76 |
+
if get_device_capacity(kwargs["A"].device)[0] == 9:
|
| 77 |
+
# tile_n == 208 causes register spills, as gather_A requires more registers for the producer
|
| 78 |
+
configs = [
|
| 79 |
+
conf
|
| 80 |
+
for conf in configs
|
| 81 |
+
if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
|
| 82 |
+
]
|
| 83 |
+
return configs
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@autotune(
|
| 87 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
| 88 |
+
key=["dynamic_scheduler"],
|
| 89 |
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 90 |
+
)
|
| 91 |
+
def gemm_tuned(
|
| 92 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 93 |
+
A: Tensor,
|
| 94 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 95 |
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 96 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 97 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 98 |
+
alpha: float | Tensor = 1.0, # (1,)
|
| 99 |
+
beta: float | Tensor = 1.0, # (1,)
|
| 100 |
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 101 |
+
cu_seqlens_k: Optional[Tensor] = None, # (L+1), int32
|
| 102 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 103 |
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 104 |
+
add_to_output: bool = False,
|
| 105 |
+
dynamic_scheduler: bool = False,
|
| 106 |
+
config: Optional[GemmConfig] = None,
|
| 107 |
+
) -> None:
|
| 108 |
+
if config is None:
|
| 109 |
+
config = default_config(A.device)
|
| 110 |
+
varlen_m = cu_seqlens_m is not None
|
| 111 |
+
varlen_k = cu_seqlens_k is not None
|
| 112 |
+
varlen = varlen_m or varlen_k
|
| 113 |
+
gather_A = A_idx is not None
|
| 114 |
+
if gather_A:
|
| 115 |
+
assert varlen, "gather_A requires either varlen_m or varlen_k"
|
| 116 |
+
assert config.cluster_n == 1, "gather_A requires cluster_n=1"
|
| 117 |
+
if varlen_m:
|
| 118 |
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
| 119 |
+
if A.ndim == 2 and not varlen:
|
| 120 |
+
A = A.unsqueeze(0) # (1, M, K)
|
| 121 |
+
B = B.mT # (N, K) or (L, N, K) or (N, total_K)
|
| 122 |
+
if B.ndim == 2 and not varlen_k:
|
| 123 |
+
B = B.unsqueeze(0) # (1, N, K)
|
| 124 |
+
if C is not None and C.ndim == 2 and not varlen_m:
|
| 125 |
+
C = C.unsqueeze(0) # (1, M, N)
|
| 126 |
+
if out.ndim == 2 and not varlen_m:
|
| 127 |
+
out = out.unsqueeze(0)
|
| 128 |
+
if bias is not None and bias.ndim == 1:
|
| 129 |
+
bias = bias.unsqueeze(0) # (L, N)
|
| 130 |
+
batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1
|
| 131 |
+
if varlen_m:
|
| 132 |
+
# If gather_A (A_idx provided), use its length; otherwise use A.shape[0]
|
| 133 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 134 |
+
out_shape = (total_m, B.shape[-2])
|
| 135 |
+
else:
|
| 136 |
+
out_shape = (batch_size, A.shape[-2], B.shape[-2])
|
| 137 |
+
assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
|
| 138 |
+
tile_count_semaphore = (
|
| 139 |
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
| 140 |
+
)
|
| 141 |
+
gemm_sm90_sm100(
|
| 142 |
+
A if not config.swap_ab else B,
|
| 143 |
+
B if not config.swap_ab else A,
|
| 144 |
+
out if not config.swap_ab else out.mT,
|
| 145 |
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
| 146 |
+
tile_count_semaphore,
|
| 147 |
+
config.tile_m,
|
| 148 |
+
config.tile_n,
|
| 149 |
+
config.cluster_m,
|
| 150 |
+
config.cluster_n,
|
| 151 |
+
config.pingpong,
|
| 152 |
+
persistent=True,
|
| 153 |
+
max_swizzle_size=config.max_swizzle_size,
|
| 154 |
+
rowvec_bias=bias if not config.swap_ab else None,
|
| 155 |
+
colvec_bias=bias if config.swap_ab else None,
|
| 156 |
+
alpha=alpha,
|
| 157 |
+
beta=beta,
|
| 158 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 159 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 160 |
+
A_idx=A_idx,
|
| 161 |
+
batch_idx_permute=batch_idx_permute,
|
| 162 |
+
add_to_output=add_to_output,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@autotune(
|
| 167 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
| 168 |
+
key=["activation", "dynamic_scheduler"],
|
| 169 |
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 170 |
+
)
|
| 171 |
+
def gemm_act_tuned(
|
| 172 |
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 173 |
+
A: Tensor,
|
| 174 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 175 |
+
# (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
|
| 176 |
+
preact_out: Optional[Tensor],
|
| 177 |
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 178 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 179 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 180 |
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
| 181 |
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 182 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 183 |
+
dynamic_scheduler: bool = False,
|
| 184 |
+
config: Optional[GemmConfig] = None,
|
| 185 |
+
) -> None:
|
| 186 |
+
if config is None:
|
| 187 |
+
config = default_config(A.device)
|
| 188 |
+
varlen_m = cu_seqlens_m is not None
|
| 189 |
+
if varlen_m:
|
| 190 |
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
| 191 |
+
if A.ndim == 2 and not varlen_m:
|
| 192 |
+
A = A.unsqueeze(0) # (1, M, K)
|
| 193 |
+
B = B.mT # (N, K) or (L, N, K)
|
| 194 |
+
if B.ndim == 2:
|
| 195 |
+
B = B.unsqueeze(0) # (1, N, K)
|
| 196 |
+
if C is not None and C.ndim == 2 and not varlen_m:
|
| 197 |
+
C = C.unsqueeze(0) # (1, M, N)
|
| 198 |
+
if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
|
| 199 |
+
D = preact_out.unsqueeze(0)
|
| 200 |
+
else:
|
| 201 |
+
D = preact_out
|
| 202 |
+
if postact_out.ndim == 2 and not varlen_m:
|
| 203 |
+
PostAct = postact_out.unsqueeze(0)
|
| 204 |
+
else:
|
| 205 |
+
PostAct = postact_out
|
| 206 |
+
if bias is not None and bias.ndim == 1:
|
| 207 |
+
bias = bias.unsqueeze(0) # (L, N)
|
| 208 |
+
tile_count_semaphore = (
|
| 209 |
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
| 210 |
+
)
|
| 211 |
+
gemm_act_sm90_sm100(
|
| 212 |
+
A if not config.swap_ab else B,
|
| 213 |
+
B if not config.swap_ab else A,
|
| 214 |
+
(D if not config.swap_ab else D.mT) if D is not None else None,
|
| 215 |
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
| 216 |
+
PostAct if not config.swap_ab else PostAct.mT,
|
| 217 |
+
tile_count_semaphore,
|
| 218 |
+
activation,
|
| 219 |
+
config.tile_m,
|
| 220 |
+
config.tile_n,
|
| 221 |
+
config.cluster_m,
|
| 222 |
+
config.cluster_n,
|
| 223 |
+
config.pingpong,
|
| 224 |
+
persistent=True,
|
| 225 |
+
max_swizzle_size=config.max_swizzle_size,
|
| 226 |
+
rowvec_bias=bias if not config.swap_ab else None,
|
| 227 |
+
colvec_bias=bias if config.swap_ab else None,
|
| 228 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 229 |
+
A_idx=A_idx,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@autotune(
|
| 234 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
| 235 |
+
key=["activation", "dynamic_scheduler"],
|
| 236 |
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 237 |
+
)
|
| 238 |
+
def gemm_dact_tuned(
|
| 239 |
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 240 |
+
A: Tensor,
|
| 241 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 242 |
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 243 |
+
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 244 |
+
postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
|
| 245 |
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
| 246 |
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 247 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 248 |
+
dynamic_scheduler: bool = True,
|
| 249 |
+
config: Optional[GemmConfig] = None,
|
| 250 |
+
) -> None:
|
| 251 |
+
if config is None:
|
| 252 |
+
config = default_config(A.device)
|
| 253 |
+
varlen_m = cu_seqlens_m is not None
|
| 254 |
+
if varlen_m:
|
| 255 |
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
| 256 |
+
if A.ndim == 2 and not varlen_m:
|
| 257 |
+
A = A.unsqueeze(0) # (1, M, K)
|
| 258 |
+
B = B.mT # (N, K) or (L, N, K)
|
| 259 |
+
if B.ndim == 2:
|
| 260 |
+
B = B.unsqueeze(0) # (1, N, K)
|
| 261 |
+
if PreAct.ndim == 2 and not varlen_m:
|
| 262 |
+
PreAct = PreAct.unsqueeze(0) # (1, M, N)
|
| 263 |
+
if dx_out.ndim == 2 and not varlen_m:
|
| 264 |
+
D = dx_out.unsqueeze(0)
|
| 265 |
+
else:
|
| 266 |
+
D = dx_out
|
| 267 |
+
if postact_out.ndim == 2 and not varlen_m:
|
| 268 |
+
PostAct = postact_out.unsqueeze(0)
|
| 269 |
+
else:
|
| 270 |
+
PostAct = postact_out
|
| 271 |
+
tile_count_semaphore = (
|
| 272 |
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
| 273 |
+
)
|
| 274 |
+
gemm_dact_sm90_sm100(
|
| 275 |
+
A if not config.swap_ab else B,
|
| 276 |
+
B if not config.swap_ab else A,
|
| 277 |
+
D if not config.swap_ab else D.mT,
|
| 278 |
+
PreAct if not config.swap_ab else PreAct.mT,
|
| 279 |
+
PostAct if not config.swap_ab else PostAct.mT,
|
| 280 |
+
tile_count_semaphore,
|
| 281 |
+
activation,
|
| 282 |
+
config.tile_m,
|
| 283 |
+
config.tile_n,
|
| 284 |
+
config.cluster_m,
|
| 285 |
+
config.cluster_n,
|
| 286 |
+
config.pingpong,
|
| 287 |
+
persistent=True,
|
| 288 |
+
max_swizzle_size=config.max_swizzle_size,
|
| 289 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 290 |
+
A_idx=A_idx,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def gemm(
|
| 295 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 296 |
+
A: Tensor,
|
| 297 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 298 |
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 299 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 300 |
+
alpha: float | Tensor = 1.0,
|
| 301 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 302 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 303 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 304 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 305 |
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 306 |
+
dynamic_scheduler: bool = False,
|
| 307 |
+
tuned: bool = True,
|
| 308 |
+
) -> Tensor:
|
| 309 |
+
"""GEMM with optional output tensor and tuning control."""
|
| 310 |
+
if out is None:
|
| 311 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 312 |
+
varlen_m = cu_seqlens_m is not None
|
| 313 |
+
varlen_k = cu_seqlens_k is not None
|
| 314 |
+
if varlen_m:
|
| 315 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 316 |
+
out_shape = (total_m, B.shape[-1])
|
| 317 |
+
elif varlen_k:
|
| 318 |
+
L = cu_seqlens_k.shape[0] - 1
|
| 319 |
+
# For varlen_k, the first dimension is always A.shape[0] (M dimension)
|
| 320 |
+
out_shape = (L, A.shape[0], B.shape[-1])
|
| 321 |
+
else:
|
| 322 |
+
out_shape = (
|
| 323 |
+
(A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
|
| 324 |
+
)
|
| 325 |
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 326 |
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
| 327 |
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
| 328 |
+
gemm_out(
|
| 329 |
+
A,
|
| 330 |
+
B,
|
| 331 |
+
out,
|
| 332 |
+
bias=bias,
|
| 333 |
+
alpha=alpha,
|
| 334 |
+
alpha_tensor=alpha_tensor,
|
| 335 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 336 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 337 |
+
A_idx=A_idx,
|
| 338 |
+
batch_idx_permute=batch_idx_permute,
|
| 339 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 340 |
+
tuned=tuned,
|
| 341 |
+
)
|
| 342 |
+
return out
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
@torch.library.custom_op(
|
| 346 |
+
add_quack_op_namespace_prefix("gemm_out"),
|
| 347 |
+
mutates_args=("out",),
|
| 348 |
+
device_types="cuda",
|
| 349 |
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
| 350 |
+
# each argument to have a fixed type
|
| 351 |
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? bias, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
| 352 |
+
)
|
| 353 |
+
def gemm_out(
|
| 354 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 355 |
+
A: Tensor,
|
| 356 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 357 |
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 358 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 359 |
+
alpha: float = 1.0,
|
| 360 |
+
alpha_tensor: Optional[Tensor] = None,
|
| 361 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 362 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 363 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 364 |
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 365 |
+
dynamic_scheduler: bool = False,
|
| 366 |
+
tuned: bool = True,
|
| 367 |
+
) -> None:
|
| 368 |
+
"""GEMM with pre-allocated output tensor."""
|
| 369 |
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
| 370 |
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
| 371 |
+
fn(
|
| 372 |
+
A,
|
| 373 |
+
B,
|
| 374 |
+
out,
|
| 375 |
+
C=None,
|
| 376 |
+
bias=bias,
|
| 377 |
+
alpha=alpha,
|
| 378 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 379 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 380 |
+
A_idx=A_idx,
|
| 381 |
+
batch_idx_permute=batch_idx_permute,
|
| 382 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def gemm_ref(
|
| 387 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 388 |
+
A: Tensor,
|
| 389 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 390 |
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 391 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 392 |
+
alpha: float | Tensor = 1.0,
|
| 393 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 394 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 395 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 396 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 397 |
+
) -> Tensor:
|
| 398 |
+
"""Reference implementation for GEMM with pre-allocated output."""
|
| 399 |
+
# The out_dtype argument requires torch >= 2.8
|
| 400 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 401 |
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 402 |
+
fn = torch.bmm if A.ndim == 3 else torch.mm
|
| 403 |
+
out = fn(A, B, out_dtype=out_dtype, out=out)
|
| 404 |
+
if not isinstance(alpha, float) or alpha != 1.0:
|
| 405 |
+
out *= alpha
|
| 406 |
+
if bias is not None:
|
| 407 |
+
bias = bias if A.ndim == 2 else bias.unsqueeze(1)
|
| 408 |
+
out += bias
|
| 409 |
+
elif cu_seqlens_m is not None:
|
| 410 |
+
# Handle varlen_m case
|
| 411 |
+
if out is None:
|
| 412 |
+
# When gather_A (A_idx provided), output size is determined by A_idx length
|
| 413 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 414 |
+
out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
|
| 415 |
+
for i in range(cu_seqlens_m.shape[0] - 1):
|
| 416 |
+
A_slice = (
|
| 417 |
+
A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
|
| 418 |
+
if A_idx is not None
|
| 419 |
+
else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
| 420 |
+
)
|
| 421 |
+
torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]])
|
| 422 |
+
if not isinstance(alpha, float) or alpha != 1.0:
|
| 423 |
+
out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] *= alpha
|
| 424 |
+
if bias is not None:
|
| 425 |
+
out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] += bias[i]
|
| 426 |
+
else: # cu_seqlens_k is not None
|
| 427 |
+
L = cu_seqlens_k.shape[0] - 1
|
| 428 |
+
if out is None:
|
| 429 |
+
out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
| 430 |
+
for i in range(L):
|
| 431 |
+
A_slice = (
|
| 432 |
+
A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
|
| 433 |
+
if A_idx is not None
|
| 434 |
+
else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
|
| 435 |
+
)
|
| 436 |
+
torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i])
|
| 437 |
+
if not isinstance(alpha, float) or alpha != 1.0:
|
| 438 |
+
out *= alpha
|
| 439 |
+
if bias is not None:
|
| 440 |
+
out += bias
|
| 441 |
+
return out
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def gemm_add(
|
| 445 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 446 |
+
A: Tensor,
|
| 447 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 448 |
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
| 449 |
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 450 |
+
alpha: float | Tensor = 1.0,
|
| 451 |
+
beta: float | Tensor = 1.0,
|
| 452 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 453 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 454 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 455 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 456 |
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 457 |
+
dynamic_scheduler: bool = False,
|
| 458 |
+
tuned: bool = True,
|
| 459 |
+
) -> Tensor:
|
| 460 |
+
"""GEMM with addition and optional output tensor."""
|
| 461 |
+
if out is None:
|
| 462 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 463 |
+
varlen_m = cu_seqlens_m is not None
|
| 464 |
+
varlen_k = cu_seqlens_k is not None
|
| 465 |
+
if varlen_m:
|
| 466 |
+
# If A_idx is provided (gather_A), use its length; otherwise use A.shape[0]
|
| 467 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 468 |
+
out_shape = (total_m, B.shape[-1])
|
| 469 |
+
elif varlen_k:
|
| 470 |
+
L = cu_seqlens_k.shape[0] - 1
|
| 471 |
+
# For varlen_k, the first dimension is always A.shape[0] (M dimension)
|
| 472 |
+
out_shape = (L, A.shape[0], B.shape[-1])
|
| 473 |
+
else:
|
| 474 |
+
out_shape = (
|
| 475 |
+
(A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
|
| 476 |
+
)
|
| 477 |
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 478 |
+
add_to_output = C is out and isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
|
| 479 |
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
| 480 |
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
| 481 |
+
beta_tensor = beta if not isinstance(beta, float) else None
|
| 482 |
+
beta = beta if isinstance(beta, float) else 1.0
|
| 483 |
+
gemm_add_out(
|
| 484 |
+
A,
|
| 485 |
+
B,
|
| 486 |
+
C if not add_to_output else None,
|
| 487 |
+
out,
|
| 488 |
+
alpha,
|
| 489 |
+
beta,
|
| 490 |
+
alpha_tensor,
|
| 491 |
+
beta_tensor,
|
| 492 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 493 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 494 |
+
A_idx=A_idx,
|
| 495 |
+
batch_idx_permute=batch_idx_permute,
|
| 496 |
+
add_to_output=add_to_output,
|
| 497 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 498 |
+
tuned=tuned,
|
| 499 |
+
)
|
| 500 |
+
return out
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
@torch.library.custom_op(
|
| 504 |
+
add_quack_op_namespace_prefix("gemm_add_out"),
|
| 505 |
+
mutates_args=("out",),
|
| 506 |
+
device_types="cuda",
|
| 507 |
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
| 508 |
+
# each argument to have a fixed type
|
| 509 |
+
# schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
| 510 |
+
)
|
| 511 |
+
def gemm_add_out(
|
| 512 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 513 |
+
A: Tensor,
|
| 514 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 515 |
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
| 516 |
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 517 |
+
alpha: float = 1.0,
|
| 518 |
+
beta: float = 1.0,
|
| 519 |
+
alpha_tensor: Optional[Tensor] = None,
|
| 520 |
+
beta_tensor: Optional[Tensor] = None,
|
| 521 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 522 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 523 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 524 |
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 525 |
+
add_to_output: bool = False,
|
| 526 |
+
dynamic_scheduler: bool = False,
|
| 527 |
+
tuned: bool = True,
|
| 528 |
+
) -> None:
|
| 529 |
+
"""GEMM with addition and pre-allocated output tensor."""
|
| 530 |
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
| 531 |
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
| 532 |
+
beta = beta_tensor if beta_tensor is not None else beta
|
| 533 |
+
fn(
|
| 534 |
+
A,
|
| 535 |
+
B,
|
| 536 |
+
out,
|
| 537 |
+
C,
|
| 538 |
+
alpha=alpha,
|
| 539 |
+
beta=beta,
|
| 540 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 541 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 542 |
+
A_idx=A_idx,
|
| 543 |
+
batch_idx_permute=batch_idx_permute,
|
| 544 |
+
add_to_output=add_to_output,
|
| 545 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def gemm_add_ref(
|
| 550 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 551 |
+
A: Tensor,
|
| 552 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 553 |
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 554 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 555 |
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 556 |
+
alpha: float | Tensor = 1.0,
|
| 557 |
+
beta: float | Tensor = 1.0,
|
| 558 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 559 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 560 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 561 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 562 |
+
) -> Tensor:
|
| 563 |
+
"""Reference implementation for GEMM with addition and pre-allocated output."""
|
| 564 |
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 565 |
+
if isinstance(alpha, float) and isinstance(beta, float):
|
| 566 |
+
out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
|
| 567 |
+
else:
|
| 568 |
+
out_dtype = (
|
| 569 |
+
out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
|
| 570 |
+
)
|
| 571 |
+
result = (alpha * (A @ B) + beta * C).to(out_dtype)
|
| 572 |
+
if out is not None:
|
| 573 |
+
out.copy_(result)
|
| 574 |
+
if bias is not None:
|
| 575 |
+
bias = bias if A.ndim == 2 else bias.unsqueeze(1)
|
| 576 |
+
out += bias
|
| 577 |
+
elif cu_seqlens_m is not None:
|
| 578 |
+
# Handle varlen_m case
|
| 579 |
+
if out is None:
|
| 580 |
+
# When gather_A (A_idx provided), output size is determined by A_idx length
|
| 581 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 582 |
+
out_dtype = out_dtype if out_dtype is not None else A.dtype
|
| 583 |
+
out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
|
| 584 |
+
for i in range(cu_seqlens_m.shape[0] - 1):
|
| 585 |
+
A_slice = (
|
| 586 |
+
A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
|
| 587 |
+
if A_idx is not None
|
| 588 |
+
else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
| 589 |
+
)
|
| 590 |
+
C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
| 591 |
+
out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
| 592 |
+
result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice
|
| 593 |
+
if bias is not None:
|
| 594 |
+
result += bias[i]
|
| 595 |
+
out_slice.copy_(result)
|
| 596 |
+
else: # cu_seqlens_k is not None
|
| 597 |
+
# Handle varlen_k case
|
| 598 |
+
L = cu_seqlens_k.shape[0] - 1
|
| 599 |
+
out_dtype = out_dtype if out_dtype is not None else A.dtype
|
| 600 |
+
if out is None:
|
| 601 |
+
out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
| 602 |
+
for i in range(L):
|
| 603 |
+
A_slice = (
|
| 604 |
+
A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
|
| 605 |
+
if A_idx is not None
|
| 606 |
+
else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
|
| 607 |
+
)
|
| 608 |
+
B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :]
|
| 609 |
+
result = alpha * torch.mm(A_slice, B_slice) + beta * C[i]
|
| 610 |
+
out[i].copy_(result)
|
| 611 |
+
if bias is not None:
|
| 612 |
+
out += bias
|
| 613 |
+
return out
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def gemm_add_inplace(
|
| 617 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 618 |
+
A: Tensor,
|
| 619 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 620 |
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
| 621 |
+
alpha: float | Tensor = 1.0,
|
| 622 |
+
beta: float | Tensor = 1.0,
|
| 623 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 624 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 625 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 626 |
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 627 |
+
dynamic_scheduler: bool = False,
|
| 628 |
+
tuned: bool = True,
|
| 629 |
+
) -> None:
|
| 630 |
+
"""In-place GEMM with addition: out = alpha * A @ B + beta * out.
|
| 631 |
+
Args:
|
| 632 |
+
A: (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k - input tensor
|
| 633 |
+
B: (K, N) or (L, K, N) or (total_K, N) if varlen_k - input tensor
|
| 634 |
+
out: (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k - tensor to accumulate into (modified in-place)
|
| 635 |
+
alpha: Scalar multiplier for A @ B
|
| 636 |
+
beta: Scalar multiplier for out
|
| 637 |
+
cu_seqlens_m: Optional cumulative sequence lengths for variable M
|
| 638 |
+
cu_seqlens_k: Optional cumulative sequence lengths for variable K
|
| 639 |
+
dynamic_scheduler: Whether to use dynamic scheduler
|
| 640 |
+
tuned: Whether to use autotuned configuration
|
| 641 |
+
"""
|
| 642 |
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
| 643 |
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
| 644 |
+
beta_tensor = beta if not isinstance(beta, float) else None
|
| 645 |
+
beta = beta if isinstance(beta, float) else 1.0
|
| 646 |
+
gemm_add_inplace_op(
|
| 647 |
+
A,
|
| 648 |
+
B,
|
| 649 |
+
out,
|
| 650 |
+
alpha,
|
| 651 |
+
beta,
|
| 652 |
+
alpha_tensor,
|
| 653 |
+
beta_tensor,
|
| 654 |
+
cu_seqlens_m,
|
| 655 |
+
cu_seqlens_k,
|
| 656 |
+
A_idx=A_idx,
|
| 657 |
+
batch_idx_permute=batch_idx_permute,
|
| 658 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 659 |
+
tuned=tuned,
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
@torch.library.custom_op(
|
| 664 |
+
add_quack_op_namespace_prefix("gemm_add_inplace"),
|
| 665 |
+
mutates_args=("out",),
|
| 666 |
+
device_types="cuda",
|
| 667 |
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
| 668 |
+
# each argument to have a fixed type
|
| 669 |
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
| 670 |
+
)
|
| 671 |
+
def gemm_add_inplace_op(
|
| 672 |
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
| 673 |
+
A: Tensor,
|
| 674 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 675 |
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
| 676 |
+
alpha: float = 1.0,
|
| 677 |
+
beta: float = 1.0,
|
| 678 |
+
alpha_tensor: Optional[Tensor] = None,
|
| 679 |
+
beta_tensor: Optional[Tensor] = None,
|
| 680 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 681 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 682 |
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 683 |
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 684 |
+
dynamic_scheduler: bool = False,
|
| 685 |
+
tuned: bool = True,
|
| 686 |
+
) -> None:
|
| 687 |
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
| 688 |
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
| 689 |
+
beta = beta_tensor if beta_tensor is not None else beta
|
| 690 |
+
add_to_output = isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
|
| 691 |
+
# Use out as both input bias and output
|
| 692 |
+
fn(
|
| 693 |
+
A,
|
| 694 |
+
B,
|
| 695 |
+
out,
|
| 696 |
+
out if not add_to_output else None,
|
| 697 |
+
alpha=alpha,
|
| 698 |
+
beta=beta,
|
| 699 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 700 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 701 |
+
A_idx=A_idx,
|
| 702 |
+
batch_idx_permute=batch_idx_permute,
|
| 703 |
+
add_to_output=add_to_output,
|
| 704 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def gemm_act(
|
| 709 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 710 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 711 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 712 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 713 |
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
| 714 |
+
preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 715 |
+
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 716 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 717 |
+
postact_dtype: Optional[torch.dtype] = None,
|
| 718 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 719 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 720 |
+
store_preact: bool = True,
|
| 721 |
+
dynamic_scheduler: bool = False,
|
| 722 |
+
tuned: bool = True,
|
| 723 |
+
) -> Tuple[Optional[Tensor], Tensor]:
|
| 724 |
+
"""GEMM with activation and optional output tensors."""
|
| 725 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 726 |
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 727 |
+
varlen_m = cu_seqlens_m is not None
|
| 728 |
+
# Determine output shape based on gather_A
|
| 729 |
+
if varlen_m:
|
| 730 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 731 |
+
out_shape = (total_m, B.shape[-1])
|
| 732 |
+
elif A.ndim == 2:
|
| 733 |
+
out_shape = (A.shape[0], B.shape[-1])
|
| 734 |
+
else:
|
| 735 |
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
| 736 |
+
if preact_out is None and store_preact:
|
| 737 |
+
preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 738 |
+
if postact_out is None:
|
| 739 |
+
postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
|
| 740 |
+
gemm_act_out(
|
| 741 |
+
A,
|
| 742 |
+
B,
|
| 743 |
+
preact_out,
|
| 744 |
+
postact_out,
|
| 745 |
+
C,
|
| 746 |
+
bias,
|
| 747 |
+
activation,
|
| 748 |
+
cu_seqlens_m,
|
| 749 |
+
A_idx,
|
| 750 |
+
dynamic_scheduler,
|
| 751 |
+
tuned,
|
| 752 |
+
)
|
| 753 |
+
return preact_out, postact_out
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
@torch.library.custom_op(
|
| 757 |
+
add_quack_op_namespace_prefix("gemm_act_out"),
|
| 758 |
+
mutates_args=("preact_out", "postact_out"),
|
| 759 |
+
device_types="cuda",
|
| 760 |
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
| 761 |
+
)
|
| 762 |
+
def gemm_act_out(
|
| 763 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 764 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 765 |
+
preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 766 |
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 767 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 768 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 769 |
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
| 770 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 771 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 772 |
+
dynamic_scheduler: bool = False,
|
| 773 |
+
tuned: bool = True,
|
| 774 |
+
) -> None:
|
| 775 |
+
"""GEMM with activation and pre-allocated output tensors."""
|
| 776 |
+
fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
|
| 777 |
+
fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def gemm_act_ref(
|
| 781 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
| 782 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 783 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 784 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 785 |
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
| 786 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 787 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 788 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 789 |
+
postact_dtype: Optional[torch.dtype] = None,
|
| 790 |
+
store_preact: bool = True,
|
| 791 |
+
) -> Tuple[Optional[Tensor], Tensor]:
|
| 792 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 793 |
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 794 |
+
if C is None:
|
| 795 |
+
out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
| 796 |
+
else:
|
| 797 |
+
out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
| 798 |
+
postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
|
| 799 |
+
return out.to(out_dtype) if store_preact else None, postact
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
def gemm_dact(
|
| 803 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 804 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 805 |
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 806 |
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
| 807 |
+
dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 808 |
+
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 809 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 810 |
+
postact_dtype: Optional[torch.dtype] = None,
|
| 811 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 812 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 813 |
+
dynamic_scheduler: bool = True,
|
| 814 |
+
tuned: bool = True,
|
| 815 |
+
) -> Tuple[Tensor, Tensor]:
|
| 816 |
+
"""GEMM with activation gradient and optional output tensors."""
|
| 817 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 818 |
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
| 819 |
+
varlen_m = cu_seqlens_m is not None
|
| 820 |
+
# Determine output shape based on gather_A
|
| 821 |
+
if varlen_m:
|
| 822 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 823 |
+
out_shape = (total_m, B.shape[-1])
|
| 824 |
+
elif A.ndim == 2:
|
| 825 |
+
out_shape = (A.shape[0], B.shape[-1])
|
| 826 |
+
else:
|
| 827 |
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
| 828 |
+
if dx_out is None:
|
| 829 |
+
dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 830 |
+
if postact_out is None:
|
| 831 |
+
postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
|
| 832 |
+
gemm_dact_out(
|
| 833 |
+
A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
|
| 834 |
+
)
|
| 835 |
+
return dx_out, postact_out
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
@torch.library.custom_op(
|
| 839 |
+
add_quack_op_namespace_prefix("gemm_dact_out"),
|
| 840 |
+
mutates_args=("dx_out", "postact_out"),
|
| 841 |
+
device_types="cuda",
|
| 842 |
+
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
|
| 843 |
+
)
|
| 844 |
+
def gemm_dact_out(
|
| 845 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 846 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 847 |
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 848 |
+
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 849 |
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 850 |
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
| 851 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 852 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 853 |
+
dynamic_scheduler: bool = True,
|
| 854 |
+
tuned: bool = True,
|
| 855 |
+
) -> None:
|
| 856 |
+
"""GEMM with activation gradient and pre-allocated output tensors."""
|
| 857 |
+
fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
|
| 858 |
+
fn(A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def gemm_dact_ref(
|
| 862 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
| 863 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 864 |
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 865 |
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
| 866 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 867 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 868 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 869 |
+
postact_dtype: Optional[torch.dtype] = None,
|
| 870 |
+
) -> Tuple[Tensor, Tensor]:
|
| 871 |
+
"""Reference implementation for GEMM with activation gradient."""
|
| 872 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 873 |
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
| 874 |
+
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
| 875 |
+
postact = act_to_pytorch_fn_map[activation](PreAct)
|
| 876 |
+
# Compute gradient using autograd
|
| 877 |
+
if activation is None:
|
| 878 |
+
dx = dout
|
| 879 |
+
else:
|
| 880 |
+
PreAct_requires_grad = PreAct.requires_grad
|
| 881 |
+
PreAct.requires_grad_(True)
|
| 882 |
+
postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
|
| 883 |
+
dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
|
| 884 |
+
PreAct.requires_grad_(PreAct_requires_grad)
|
| 885 |
+
return dx.to(out_dtype), postact.to(postact_dtype)
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def gemm_gated_ref(
|
| 889 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
| 890 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 891 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 892 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 893 |
+
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
|
| 894 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 895 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 896 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 897 |
+
postact_dtype: Optional[torch.dtype] = None,
|
| 898 |
+
store_preact: bool = True,
|
| 899 |
+
) -> Tuple[Optional[Tensor], Tensor]:
|
| 900 |
+
"""Reference implementation for GEMM with gated activation forward.
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
A: (M, K) - input tensor
|
| 904 |
+
B: (K, N) - weight tensor with gate and up projections
|
| 905 |
+
C: (M, N) - optional bias tensor
|
| 906 |
+
activation: Type of gated activation
|
| 907 |
+
out_dtype: Output dtype for preact
|
| 908 |
+
postact_dtype: Output dtype for postact
|
| 909 |
+
store_preact: Whether to return the pre-activation
|
| 910 |
+
|
| 911 |
+
Returns:
|
| 912 |
+
(preact, postact) where:
|
| 913 |
+
- preact: (M, N) pre-activation (if store_preact=True, else None)
|
| 914 |
+
- postact: (M, N // 2) post-activation output
|
| 915 |
+
"""
|
| 916 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 917 |
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 918 |
+
if C is None:
|
| 919 |
+
preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
| 920 |
+
else:
|
| 921 |
+
preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
| 922 |
+
# Split preact into gate and up projections
|
| 923 |
+
gate = preact[..., ::2] # (M, N//2)
|
| 924 |
+
up = preact[..., 1::2] # (M, N//2)
|
| 925 |
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
| 926 |
+
return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
def gemm_dgated_ref(
|
| 930 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
| 931 |
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 932 |
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
| 933 |
+
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
|
| 934 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 935 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 936 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 937 |
+
postact_dtype: Optional[torch.dtype] = None,
|
| 938 |
+
) -> Tuple[Tensor, Tensor]:
|
| 939 |
+
"""Reference implementation for GEMM with gated activation gradient.
|
| 940 |
+
|
| 941 |
+
Args:
|
| 942 |
+
A: (M, K) - dout input tensor
|
| 943 |
+
B: (K, N) - weight tensor
|
| 944 |
+
PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved
|
| 945 |
+
activation: Type of gated activation
|
| 946 |
+
out_dtype: Output dtype for dx
|
| 947 |
+
postact_dtype: Output dtype for postact
|
| 948 |
+
|
| 949 |
+
Returns:
|
| 950 |
+
(dx, postact) where:
|
| 951 |
+
- dx: (M, 2*N) gradient w.r.t. PreAct
|
| 952 |
+
- postact: (M, N) post-activation output
|
| 953 |
+
"""
|
| 954 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 955 |
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
| 956 |
+
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
| 957 |
+
# Split PreAct into gate and up projections
|
| 958 |
+
gate = PreAct[..., ::2] # (M, N)
|
| 959 |
+
up = PreAct[..., 1::2] # (M, N)
|
| 960 |
+
# Use autograd to compute gradients w.r.t. gate and up
|
| 961 |
+
gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
|
| 962 |
+
gate.requires_grad_(True)
|
| 963 |
+
up.requires_grad_(True)
|
| 964 |
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
| 965 |
+
dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
|
| 966 |
+
gate.requires_grad_(gate_requires_grad)
|
| 967 |
+
up.requires_grad_(up_requires_grad)
|
| 968 |
+
# Interleave gradients back
|
| 969 |
+
dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
|
| 970 |
+
return dx.to(out_dtype), postact.to(postact_dtype)
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
@torch.library.custom_op(
|
| 974 |
+
add_quack_op_namespace_prefix("gemm_symmetric_out"),
|
| 975 |
+
mutates_args=("out",),
|
| 976 |
+
device_types="cuda",
|
| 977 |
+
schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? C=None, bool dynamic_scheduler=False, float alpha=1.0, float beta=1.0) -> ()",
|
| 978 |
+
)
|
| 979 |
+
def gemm_symmetric_out(
|
| 980 |
+
A: Tensor, # (M, K) or (L, M, K)
|
| 981 |
+
B: Tensor, # (K, M) or (L, K, M)
|
| 982 |
+
out: Tensor, # (M, M) or (L, M, M)
|
| 983 |
+
C: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
| 984 |
+
dynamic_scheduler: bool = False,
|
| 985 |
+
alpha: float = 1.0,
|
| 986 |
+
beta: float = 1.0,
|
| 987 |
+
) -> None:
|
| 988 |
+
"""GEMM with guaranteed symmetric output."""
|
| 989 |
+
if A.ndim == 2:
|
| 990 |
+
A = A.unsqueeze(0) # (1, M, K)
|
| 991 |
+
B = B.mT # (M, K) or (L, M, K)
|
| 992 |
+
if B.ndim == 2:
|
| 993 |
+
B = B.unsqueeze(0) # (1, M, K)
|
| 994 |
+
if C is not None and C.ndim == 2:
|
| 995 |
+
C = C.unsqueeze(0) # (1, M, M)
|
| 996 |
+
if out.ndim == 2:
|
| 997 |
+
out = out.unsqueeze(0)
|
| 998 |
+
else:
|
| 999 |
+
out = out
|
| 1000 |
+
tile_count_semaphore = (
|
| 1001 |
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
| 1002 |
+
)
|
| 1003 |
+
gemm_symmetric_sm90_sm100(
|
| 1004 |
+
A,
|
| 1005 |
+
B,
|
| 1006 |
+
out if out is not None else None,
|
| 1007 |
+
C if C is not None else None,
|
| 1008 |
+
tile_count_semaphore,
|
| 1009 |
+
tile_M=128,
|
| 1010 |
+
tile_N=256,
|
| 1011 |
+
cluster_M=2,
|
| 1012 |
+
cluster_N=1,
|
| 1013 |
+
pingpong=False,
|
| 1014 |
+
persistent=True,
|
| 1015 |
+
max_swizzle_size=8,
|
| 1016 |
+
alpha=alpha,
|
| 1017 |
+
beta=beta,
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
def gemm_symmetric(
|
| 1022 |
+
A: Tensor, # (M, K) or (L, M, K)
|
| 1023 |
+
B: Tensor, # (K, M) or (L, K, M)
|
| 1024 |
+
C: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
| 1025 |
+
out: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
| 1026 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 1027 |
+
dynamic_scheduler: bool = False,
|
| 1028 |
+
alpha: float | Tensor = 1.0,
|
| 1029 |
+
beta: float | Tensor = 1.0,
|
| 1030 |
+
) -> Tuple[Optional[Tensor], Tensor]:
|
| 1031 |
+
"""GEMM with symmetric output."""
|
| 1032 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 1033 |
+
# Determine output shape based on gather_A
|
| 1034 |
+
if A.ndim == 2:
|
| 1035 |
+
out_shape = (A.shape[0], B.shape[-1])
|
| 1036 |
+
else:
|
| 1037 |
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
| 1038 |
+
if out is None:
|
| 1039 |
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 1040 |
+
|
| 1041 |
+
alpha_val = alpha if isinstance(alpha, float) else 1.0
|
| 1042 |
+
beta_val = beta if isinstance(beta, float) else 1.0
|
| 1043 |
+
|
| 1044 |
+
gemm_symmetric_out(
|
| 1045 |
+
A, B, out, C, dynamic_scheduler=dynamic_scheduler, alpha=alpha_val, beta=beta_val
|
| 1046 |
+
)
|
| 1047 |
+
return out
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
|
| 1051 |
+
# try:
|
| 1052 |
+
# from torch._inductor.fx_passes.reinplace import InplaceableOp
|
| 1053 |
+
# torch._inductor.fx_passes.reinplace.inplaceable_ops.update({
|
| 1054 |
+
# torch.ops.quack.gemm_add_out.default:
|
| 1055 |
+
# InplaceableOp(torch.ops.quack.gemm_add_inplace.default, mutated_arg=2)
|
| 1056 |
+
# })
|
| 1057 |
+
# except ImportError:
|
| 1058 |
+
# pass
|
build/torch-cuda/quack/gemm_sm100.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch-cuda/quack/gemm_sm90.py
ADDED
|
@@ -0,0 +1,2070 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Based on the cute-dsl example:
|
| 2 |
+
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
|
| 3 |
+
|
| 4 |
+
import enum
|
| 5 |
+
from typing import Tuple, Type, Callable, Optional, Union, Literal
|
| 6 |
+
from functools import partial
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import cuda.bindings.driver as cuda
|
| 11 |
+
|
| 12 |
+
import cutlass
|
| 13 |
+
import cutlass.cute as cute
|
| 14 |
+
import cutlass.pipeline as pipeline
|
| 15 |
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
| 16 |
+
import cutlass.utils.hopper_helpers as sm90_utils
|
| 17 |
+
from cutlass import Int32, Float32, Float16, Boolean, const_expr
|
| 18 |
+
from cutlass.cutlass_dsl import if_generate
|
| 19 |
+
from cutlass.utils import LayoutEnum
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from .cute_dsl_utils import ParamsBase, ArgumentsBase
|
| 23 |
+
from .tile_scheduler import (
|
| 24 |
+
TileSchedulerOptions,
|
| 25 |
+
TileSchedulerArguments,
|
| 26 |
+
TileScheduler,
|
| 27 |
+
VarlenMTileSchedulerArguments,
|
| 28 |
+
VarlenMTileScheduler,
|
| 29 |
+
)
|
| 30 |
+
from .varlen_utils import VarlenArguments, VarlenManager
|
| 31 |
+
|
| 32 |
+
# return PipelineStateWAdvance instead of PipelineState
|
| 33 |
+
from .pipeline import make_pipeline_state, PipelineTmaCpAsync
|
| 34 |
+
from . import copy_utils as copy_utils
|
| 35 |
+
from . import sm90_utils as quack_sm90_utils
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
| 39 |
+
using CUTE DSL.
|
| 40 |
+
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
| 41 |
+
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
| 42 |
+
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
| 43 |
+
|
| 44 |
+
This GEMM kernel supports the following features:
|
| 45 |
+
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
|
| 46 |
+
- Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations
|
| 47 |
+
- Implements TMA multicast with cluster to reduce L2 memory traffic
|
| 48 |
+
- Supports multi-stage pipeline to overlap computation and memory access
|
| 49 |
+
|
| 50 |
+
This GEMM works as follows:
|
| 51 |
+
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
| 52 |
+
2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction.
|
| 53 |
+
3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations.
|
| 54 |
+
|
| 55 |
+
Hopper WGMMA instructions operate as follows:
|
| 56 |
+
- Read matrix A from SMEM
|
| 57 |
+
- Read matrix B from SMEM
|
| 58 |
+
- Perform MMA operation and store the result in Accumulator(register)
|
| 59 |
+
|
| 60 |
+
Constraints:
|
| 61 |
+
* Supported input data types: fp16, fp8 (e4m3fn, e5m2)
|
| 62 |
+
* For fp16 types, A and B must have the same data type
|
| 63 |
+
* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
|
| 64 |
+
* Fp8 types only support k-major layout
|
| 65 |
+
* Only fp32 accumulation is supported in this example
|
| 66 |
+
* CTA tile shape M must be 64/128
|
| 67 |
+
* CTA tile shape N must be 64/128/256
|
| 68 |
+
* CTA tile shape K must be 64
|
| 69 |
+
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
| 70 |
+
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
| 71 |
+
i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class NamedBarrierGemm(enum.IntEnum):
|
| 76 |
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
| 77 |
+
# For mainloop load warps to signal that the epilogue load warp can start.
|
| 78 |
+
# This is to avoid loading C too early, interfering with loading A and B.
|
| 79 |
+
EpilogueLoad = enum.auto()
|
| 80 |
+
MmaWG0 = enum.auto()
|
| 81 |
+
MmaWG1 = enum.auto()
|
| 82 |
+
EpiWG0 = enum.auto()
|
| 83 |
+
EpiWG1 = enum.auto()
|
| 84 |
+
TmemPtr = enum.auto()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class GemmSm90:
|
| 88 |
+
"""
|
| 89 |
+
This class implements batched matrix multiplication (C = A x B) with support for various data types
|
| 90 |
+
and architectural features specific to Hopper GPUs with persistent tile scheduling and warp specialization.
|
| 91 |
+
|
| 92 |
+
:param acc_dtype: Data type for accumulation during computation
|
| 93 |
+
:type acc_dtype: type[cutlass.Numeric]
|
| 94 |
+
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
| 95 |
+
:type tile_shape_mn: Tuple[int, int, int]
|
| 96 |
+
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
| 97 |
+
:type cluster_shape_mnk: Tuple[int, int, int]
|
| 98 |
+
|
| 99 |
+
:note: Data type requirements:
|
| 100 |
+
- For 16-bit types: A and B must have the same data type
|
| 101 |
+
- For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit
|
| 102 |
+
- Float8 types only support k-major layout
|
| 103 |
+
|
| 104 |
+
:note: Supported data types:
|
| 105 |
+
- Float16
|
| 106 |
+
- BFloat16
|
| 107 |
+
- Float8E4M3FN/Float8E5M2
|
| 108 |
+
|
| 109 |
+
:note: Supported accumulation types:
|
| 110 |
+
- Float32 (for all floating point inputs)
|
| 111 |
+
|
| 112 |
+
:note: Constraints:
|
| 113 |
+
- Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
| 114 |
+
|
| 115 |
+
Example:
|
| 116 |
+
>>> gemm = GemmSm90(
|
| 117 |
+
... acc_dtype=Float32,
|
| 118 |
+
... tile_shape_mn=(128, 256),
|
| 119 |
+
... cluster_shape_mnk=(1, 1, 1)
|
| 120 |
+
... )
|
| 121 |
+
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
arch = 90
|
| 125 |
+
num_epi_tensormaps: int = 0
|
| 126 |
+
|
| 127 |
+
EpilogueArguments = ArgumentsBase
|
| 128 |
+
EpilogueParams = ParamsBase
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
acc_dtype: Type[cutlass.Numeric],
|
| 133 |
+
a_dtype: Type[cutlass.Numeric],
|
| 134 |
+
tile_shape_mn: Tuple[int, int],
|
| 135 |
+
cluster_shape_mnk: Tuple[int, int, int],
|
| 136 |
+
pingpong: bool = False,
|
| 137 |
+
is_persistent: bool = True,
|
| 138 |
+
fp8_fast_accum: bool = False,
|
| 139 |
+
gather_A: bool = False,
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
Initializes the configuration for a Hopper dense GEMM kernel.
|
| 143 |
+
|
| 144 |
+
This configuration includes data types for operands, tile shape, cluster configuration,
|
| 145 |
+
and thread layout.
|
| 146 |
+
|
| 147 |
+
:param acc_dtype: Data type for accumulation during computation
|
| 148 |
+
:type acc_dtype: type[cutlass.Numeric]
|
| 149 |
+
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
| 150 |
+
:type tile_shape_mn: Tuple[int, int]
|
| 151 |
+
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
| 152 |
+
:type cluster_shape_mnk: Tuple[int, int, int]
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
self.acc_dtype = acc_dtype
|
| 156 |
+
self.pingpong = pingpong
|
| 157 |
+
self.is_persistent = is_persistent
|
| 158 |
+
if self.pingpong:
|
| 159 |
+
assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
|
| 160 |
+
self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
|
| 161 |
+
self.gather_A = gather_A
|
| 162 |
+
if gather_A:
|
| 163 |
+
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
| 164 |
+
|
| 165 |
+
self.cluster_shape_mnk = cluster_shape_mnk
|
| 166 |
+
# K dimension is deferred in _setup_attributes
|
| 167 |
+
self.cta_tile_shape_mnk = (*tile_shape_mn, 1)
|
| 168 |
+
tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
|
| 169 |
+
# check the cta tile shape
|
| 170 |
+
if not self.pingpong:
|
| 171 |
+
if tile_M not in [64, 128, 192, 256, 320]:
|
| 172 |
+
raise ValueError("CTA tile shape M must be 64/128/192/256/320")
|
| 173 |
+
if tile_M in [192, 320]: # special case
|
| 174 |
+
tile_N_max = 256 if tile_M == 192 else 160
|
| 175 |
+
if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
if not (
|
| 181 |
+
(tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
|
| 182 |
+
):
|
| 183 |
+
raise ValueError(
|
| 184 |
+
"CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
if tile_M not in [64, 128, 192]:
|
| 188 |
+
raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
|
| 189 |
+
tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
|
| 190 |
+
if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
|
| 191 |
+
raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
|
| 192 |
+
|
| 193 |
+
if not self.pingpong:
|
| 194 |
+
if tile_M == 320: # tile_M / 64 is not even so we have to split along N
|
| 195 |
+
atom_layout_m, atom_layout_n = 1, 2
|
| 196 |
+
elif tile_M == 192:
|
| 197 |
+
if tile_N <= 128:
|
| 198 |
+
atom_layout_m, atom_layout_n = 3, 1
|
| 199 |
+
else:
|
| 200 |
+
atom_layout_m, atom_layout_n = 1, 2
|
| 201 |
+
else:
|
| 202 |
+
atom_layout_m = (
|
| 203 |
+
self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2
|
| 204 |
+
)
|
| 205 |
+
atom_layout_n = 1
|
| 206 |
+
assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
|
| 207 |
+
else:
|
| 208 |
+
atom_layout_m, atom_layout_n = 1, 1
|
| 209 |
+
self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
|
| 210 |
+
|
| 211 |
+
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
| 212 |
+
if self.gather_A:
|
| 213 |
+
assert self.num_mcast_ctas_a == 1
|
| 214 |
+
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
| 215 |
+
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
| 216 |
+
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
| 217 |
+
|
| 218 |
+
self.occupancy = 1
|
| 219 |
+
self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
|
| 220 |
+
if self.pingpong:
|
| 221 |
+
assert self.mma_warp_groups == 2
|
| 222 |
+
assert self.mma_warp_groups in [1, 2, 3]
|
| 223 |
+
self.num_threads_per_warp_group = 128
|
| 224 |
+
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
| 225 |
+
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
| 226 |
+
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
|
| 227 |
+
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
| 228 |
+
self.ab_load_warp_id = self.mma_warp_groups * 4
|
| 229 |
+
# self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
| 230 |
+
# self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
| 231 |
+
|
| 232 |
+
regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
|
| 233 |
+
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
| 234 |
+
)
|
| 235 |
+
if self.fp8_slow_accum:
|
| 236 |
+
regs_per_thread *= 2
|
| 237 |
+
if not self.gather_A:
|
| 238 |
+
if self.mma_warp_groups == 3:
|
| 239 |
+
self.num_regs_load, self.num_regs_mma = 32, 160
|
| 240 |
+
else:
|
| 241 |
+
heavy_register_pressure = regs_per_thread >= 208
|
| 242 |
+
self.num_regs_load, self.num_regs_mma = (
|
| 243 |
+
(40, 232) if not heavy_register_pressure else (24, 240)
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
if self.mma_warp_groups == 3:
|
| 247 |
+
self.num_regs_load, self.num_regs_mma = 56, 152
|
| 248 |
+
else:
|
| 249 |
+
self.num_regs_load, self.num_regs_mma = (56, 224)
|
| 250 |
+
|
| 251 |
+
self.ab_stage = None
|
| 252 |
+
self.epi_stage = None
|
| 253 |
+
|
| 254 |
+
self.a_smem_layout_staged = None
|
| 255 |
+
self.b_smem_layout_staged = None
|
| 256 |
+
self.epi_smem_layout_staged = None
|
| 257 |
+
self.epi_tile = None
|
| 258 |
+
|
| 259 |
+
self.shared_storage = None
|
| 260 |
+
self.buffer_align_bytes = 1024
|
| 261 |
+
|
| 262 |
+
def _setup_attributes(self, epilogue_args: EpilogueArguments):
|
| 263 |
+
"""Set up configurations that are dependent on GEMM inputs
|
| 264 |
+
|
| 265 |
+
This method configures various attributes based on the input tensor properties
|
| 266 |
+
(data types, leading dimensions) and kernel settings:
|
| 267 |
+
- Configuring tiled MMA
|
| 268 |
+
- Computing MMA/cluster/tile shapes
|
| 269 |
+
- Computing cluster layout
|
| 270 |
+
- Computing multicast CTAs for A/B
|
| 271 |
+
- Computing epilogue subtile
|
| 272 |
+
- Setting up A/B/C stage counts in shared memory
|
| 273 |
+
- Computing A/B/C shared memory layout
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
| 277 |
+
self.a_dtype,
|
| 278 |
+
self.b_dtype,
|
| 279 |
+
self.a_layout.sm90_mma_major_mode(),
|
| 280 |
+
self.b_layout.sm90_mma_major_mode(),
|
| 281 |
+
self.acc_dtype,
|
| 282 |
+
self.atom_layout_mnk,
|
| 283 |
+
tiler_mn=(64, self.cta_tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
| 284 |
+
)
|
| 285 |
+
if const_expr(self.atom_layout_mnk[1] > 1):
|
| 286 |
+
# If N dimension is split among 2 WGs, we need to permute the N dimension so
|
| 287 |
+
# that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
|
| 288 |
+
# containing accumulators that are next to each other in the N dimension.
|
| 289 |
+
# Without permutation WG0 would write to epi smem of size (64, 16) and
|
| 290 |
+
# WG1 would write to a separate epi smem of size (64, 16) that's far away.
|
| 291 |
+
atom_n = self.atom_layout_mnk[1]
|
| 292 |
+
permutation_n = cute.make_ordered_layout(
|
| 293 |
+
(8, self.cta_tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
|
| 294 |
+
)
|
| 295 |
+
self.tiled_mma = cute.make_tiled_mma(
|
| 296 |
+
cute.make_mma_atom(self.tiled_mma.op),
|
| 297 |
+
self.atom_layout_mnk,
|
| 298 |
+
permutation_mnk=(None, permutation_n, None),
|
| 299 |
+
)
|
| 300 |
+
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
|
| 301 |
+
mma_inst_tile_k = 4
|
| 302 |
+
self.cta_tile_shape_mnk = (
|
| 303 |
+
self.cta_tile_shape_mnk[0],
|
| 304 |
+
self.cta_tile_shape_mnk[1],
|
| 305 |
+
mma_inst_shape_k * mma_inst_tile_k,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
| 309 |
+
|
| 310 |
+
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
| 311 |
+
self.cta_tile_shape_mnk,
|
| 312 |
+
self.atom_layout_mnk,
|
| 313 |
+
self.d_dtype,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Compute stage before compute smem layout
|
| 317 |
+
self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
|
| 318 |
+
self.cta_tile_shape_mnk,
|
| 319 |
+
self.epi_tile,
|
| 320 |
+
self.a_dtype,
|
| 321 |
+
self.b_dtype,
|
| 322 |
+
self.d_dtype,
|
| 323 |
+
self.c_dtype,
|
| 324 |
+
epilogue_args,
|
| 325 |
+
cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
|
| 326 |
+
self.occupancy,
|
| 327 |
+
# epi_smem will reuse smem ab if not persistent.
|
| 328 |
+
overlap_sD_sA=not self.is_persistent,
|
| 329 |
+
)
|
| 330 |
+
self.sched_stage = 2 if self.pingpong else 1
|
| 331 |
+
|
| 332 |
+
(
|
| 333 |
+
self.a_smem_layout_staged,
|
| 334 |
+
self.b_smem_layout_staged,
|
| 335 |
+
self.epi_smem_layout_staged,
|
| 336 |
+
self.epi_c_smem_layout_staged,
|
| 337 |
+
) = self._make_smem_layouts(
|
| 338 |
+
self.cta_tile_shape_mnk,
|
| 339 |
+
self.epi_tile,
|
| 340 |
+
self.a_dtype,
|
| 341 |
+
self.a_layout,
|
| 342 |
+
self.b_dtype,
|
| 343 |
+
self.b_layout,
|
| 344 |
+
self.ab_stage,
|
| 345 |
+
self.d_dtype,
|
| 346 |
+
self.d_layout,
|
| 347 |
+
self.epi_stage,
|
| 348 |
+
self.c_dtype,
|
| 349 |
+
self.c_layout,
|
| 350 |
+
self.epi_c_stage,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
@cute.jit
|
| 354 |
+
def __call__(
|
| 355 |
+
self,
|
| 356 |
+
mA: cute.Tensor,
|
| 357 |
+
mB: cute.Tensor,
|
| 358 |
+
mD: Optional[cute.Tensor],
|
| 359 |
+
mC: Optional[cute.Tensor],
|
| 360 |
+
epilogue_args: ArgumentsBase,
|
| 361 |
+
scheduler_args: TileSchedulerOptions,
|
| 362 |
+
varlen_args: Optional[VarlenArguments],
|
| 363 |
+
stream: cuda.CUstream,
|
| 364 |
+
):
|
| 365 |
+
"""Execute the GEMM operation in steps:
|
| 366 |
+
- Setup static attributes
|
| 367 |
+
- Setup TMA load/store atoms and tensors
|
| 368 |
+
- Compute grid size
|
| 369 |
+
- Define shared storage for kernel
|
| 370 |
+
- Launch the kernel synchronously
|
| 371 |
+
|
| 372 |
+
:param mA: Input tensor A
|
| 373 |
+
:type mA: cute.Tensor
|
| 374 |
+
:param mB: Input tensor B
|
| 375 |
+
:type mB: cute.Tensor
|
| 376 |
+
:param mD: Output tensor D
|
| 377 |
+
:type mD: cute.Tensor
|
| 378 |
+
:param stream: CUDA stream for asynchronous execution
|
| 379 |
+
:type stream: cuda.CUstream
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
# setup static attributes before smem/grid/tma computation
|
| 383 |
+
self.a_dtype = mA.element_type
|
| 384 |
+
self.b_dtype = mB.element_type
|
| 385 |
+
self.d_dtype = mD.element_type if mD is not None else None
|
| 386 |
+
self.c_dtype = mC.element_type if mC is not None else None
|
| 387 |
+
self.a_layout = LayoutEnum.from_tensor(mA)
|
| 388 |
+
self.b_layout = LayoutEnum.from_tensor(mB)
|
| 389 |
+
self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
|
| 390 |
+
self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
|
| 391 |
+
|
| 392 |
+
if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
|
| 393 |
+
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
|
| 394 |
+
if const_expr(self.a_dtype.width != self.b_dtype.width):
|
| 395 |
+
raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
|
| 396 |
+
if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
|
| 397 |
+
raise TypeError("a_dtype should be float16 or float8")
|
| 398 |
+
|
| 399 |
+
if const_expr(varlen_args is None):
|
| 400 |
+
varlen_args = VarlenArguments()
|
| 401 |
+
assert (varlen_args.mAIdx is not None) == self.gather_A
|
| 402 |
+
|
| 403 |
+
# Assume all strides are divisible by 128 bits except the last stride
|
| 404 |
+
new_stride = lambda t: tuple(
|
| 405 |
+
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
| 406 |
+
for s in t.stride
|
| 407 |
+
)
|
| 408 |
+
mA, mD = [
|
| 409 |
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
| 410 |
+
if t is not None
|
| 411 |
+
else None
|
| 412 |
+
for t in (mA, mD)
|
| 413 |
+
]
|
| 414 |
+
|
| 415 |
+
self._setup_attributes(epilogue_args)
|
| 416 |
+
|
| 417 |
+
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0))
|
| 418 |
+
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0))
|
| 419 |
+
tma_atom_a, tma_tensor_a = None, None
|
| 420 |
+
if const_expr(not self.gather_A):
|
| 421 |
+
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
| 422 |
+
mA,
|
| 423 |
+
a_smem_layout,
|
| 424 |
+
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
|
| 425 |
+
self.cluster_shape_mnk[1],
|
| 426 |
+
)
|
| 427 |
+
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
| 428 |
+
mB,
|
| 429 |
+
b_smem_layout,
|
| 430 |
+
(self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
|
| 431 |
+
self.cluster_shape_mnk[0],
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
| 435 |
+
if const_expr(not self.gather_A):
|
| 436 |
+
self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
| 437 |
+
|
| 438 |
+
tma_atom_d, tma_tensor_d = None, None
|
| 439 |
+
if const_expr(mD is not None):
|
| 440 |
+
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
| 441 |
+
mD,
|
| 442 |
+
self.epi_smem_layout_staged,
|
| 443 |
+
self.epi_tile,
|
| 444 |
+
op_type="store"
|
| 445 |
+
if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
|
| 446 |
+
else "add",
|
| 447 |
+
)
|
| 448 |
+
tma_atom_c, tma_tensor_c = None, None
|
| 449 |
+
if const_expr(mC is not None):
|
| 450 |
+
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
| 451 |
+
mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
| 455 |
+
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
|
| 456 |
+
|
| 457 |
+
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
|
| 458 |
+
tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
|
| 459 |
+
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
| 460 |
+
grid = TileSchedulerCls.get_grid_shape(
|
| 461 |
+
tile_sched_params, scheduler_args.max_active_clusters
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
epi_smem_size = (
|
| 465 |
+
cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
|
| 466 |
+
)
|
| 467 |
+
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
| 468 |
+
|
| 469 |
+
@cute.struct
|
| 470 |
+
class SharedStorage:
|
| 471 |
+
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
| 472 |
+
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
| 473 |
+
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
| 474 |
+
tile_count: cute.struct.MemRange[Int32, self.sched_stage]
|
| 475 |
+
sD: cute.struct.Align[
|
| 476 |
+
cute.struct.MemRange[
|
| 477 |
+
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
| 478 |
+
],
|
| 479 |
+
self.buffer_align_bytes,
|
| 480 |
+
]
|
| 481 |
+
sC: cute.struct.Align[
|
| 482 |
+
cute.struct.MemRange[
|
| 483 |
+
self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
|
| 484 |
+
],
|
| 485 |
+
self.buffer_align_bytes,
|
| 486 |
+
]
|
| 487 |
+
epi: self.epi_get_smem_struct(epilogue_params)
|
| 488 |
+
sA: cute.struct.Align[
|
| 489 |
+
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
|
| 490 |
+
self.buffer_align_bytes,
|
| 491 |
+
]
|
| 492 |
+
sB: cute.struct.Align[
|
| 493 |
+
cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
|
| 494 |
+
self.buffer_align_bytes,
|
| 495 |
+
]
|
| 496 |
+
|
| 497 |
+
self.shared_storage = SharedStorage
|
| 498 |
+
|
| 499 |
+
# Launch the kernel synchronously
|
| 500 |
+
self.kernel(
|
| 501 |
+
self.tiled_mma,
|
| 502 |
+
tma_atom_a,
|
| 503 |
+
tma_tensor_a if const_expr(not self.gather_A) else mA,
|
| 504 |
+
tma_atom_b,
|
| 505 |
+
tma_tensor_b,
|
| 506 |
+
tma_atom_d,
|
| 507 |
+
tma_tensor_d,
|
| 508 |
+
tma_atom_c,
|
| 509 |
+
tma_tensor_c,
|
| 510 |
+
epilogue_params,
|
| 511 |
+
varlen_params,
|
| 512 |
+
self.cluster_layout_mnk,
|
| 513 |
+
self.a_smem_layout_staged,
|
| 514 |
+
self.b_smem_layout_staged,
|
| 515 |
+
self.epi_smem_layout_staged,
|
| 516 |
+
self.epi_c_smem_layout_staged,
|
| 517 |
+
tile_sched_params,
|
| 518 |
+
TileSchedulerCls,
|
| 519 |
+
).launch(
|
| 520 |
+
grid=grid,
|
| 521 |
+
block=[self.threads_per_cta, 1, 1],
|
| 522 |
+
cluster=self.cluster_shape_mnk,
|
| 523 |
+
stream=stream,
|
| 524 |
+
min_blocks_per_mp=1,
|
| 525 |
+
)
|
| 526 |
+
return
|
| 527 |
+
|
| 528 |
+
# GPU device kernel
|
| 529 |
+
@cute.kernel
|
| 530 |
+
def kernel(
|
| 531 |
+
self,
|
| 532 |
+
tiled_mma: cute.TiledMma,
|
| 533 |
+
tma_atom_a: Optional[cute.CopyAtom],
|
| 534 |
+
mA_mkl: cute.Tensor,
|
| 535 |
+
tma_atom_b: cute.CopyAtom,
|
| 536 |
+
mB_nkl: cute.Tensor,
|
| 537 |
+
tma_atom_d: Optional[cute.CopyAtom],
|
| 538 |
+
mD_mnl: Optional[cute.Tensor],
|
| 539 |
+
tma_atom_c: Optional[cute.CopyAtom],
|
| 540 |
+
mC_mnl: Optional[cute.Tensor],
|
| 541 |
+
epilogue_params: ParamsBase,
|
| 542 |
+
varlen_params: VarlenManager.Params,
|
| 543 |
+
cluster_layout_mnk: cute.Layout,
|
| 544 |
+
a_smem_layout: cute.ComposedLayout,
|
| 545 |
+
b_smem_layout: cute.ComposedLayout,
|
| 546 |
+
epi_smem_layout: cute.ComposedLayout,
|
| 547 |
+
epi_c_smem_layout: cute.ComposedLayout,
|
| 548 |
+
tile_sched_params: ParamsBase,
|
| 549 |
+
TileSchedulerCls: cutlass.Constexpr[Callable],
|
| 550 |
+
):
|
| 551 |
+
"""
|
| 552 |
+
GPU device kernel performing the batched GEMM computation.
|
| 553 |
+
|
| 554 |
+
:param tma_atom_a: TMA copy atom for A tensor
|
| 555 |
+
:type tma_atom_a: cute.CopyAtom
|
| 556 |
+
:param mA_mkl: Input tensor A
|
| 557 |
+
:type mA_mkl: cute.Tensor
|
| 558 |
+
:param tma_atom_b: TMA copy atom for B tensor
|
| 559 |
+
:type tma_atom_b: cute.CopyAtom
|
| 560 |
+
:param mB_nkl: Input tensor B
|
| 561 |
+
:type mB_nkl: cute.Tensor
|
| 562 |
+
:param tma_atom_d: TMA copy atom for D tensor
|
| 563 |
+
:type tma_atom_d: cute.CopyAtom
|
| 564 |
+
:param mD_mnl: Output tensor D
|
| 565 |
+
:type mD_mnl: cute.Tensor
|
| 566 |
+
:param tiled_mma: Tiled MMA object
|
| 567 |
+
:type tiled_mma: cute.TiledMma
|
| 568 |
+
:param cluster_layout_mnk: CTA layout
|
| 569 |
+
:type cluster_layout_mnk: cute.Layout
|
| 570 |
+
:param a_smem_layout: Shared memory layout for A
|
| 571 |
+
:type a_smem_layout: cute.ComposedLayout
|
| 572 |
+
:param b_smem_layout: Shared memory layout for B
|
| 573 |
+
:type b_smem_layout: cute.ComposedLayout
|
| 574 |
+
:param epi_smem_layout: Shared memory layout for epilogue
|
| 575 |
+
:type epi_smem_layout: cute.ComposedLayout
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
|
| 579 |
+
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
|
| 580 |
+
assert not (varlen_m and varlen_k)
|
| 581 |
+
if const_expr(self.gather_A):
|
| 582 |
+
assert varlen_m or varlen_k
|
| 583 |
+
has_D = const_expr(mD_mnl is not None)
|
| 584 |
+
has_C = const_expr(mC_mnl is not None)
|
| 585 |
+
|
| 586 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 587 |
+
|
| 588 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 589 |
+
# Prefetch Tma desc
|
| 590 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 591 |
+
if warp_idx == self.ab_load_warp_id:
|
| 592 |
+
for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
|
| 593 |
+
if const_expr(tma_atom is not None):
|
| 594 |
+
cpasync.prefetch_descriptor(tma_atom)
|
| 595 |
+
|
| 596 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 597 |
+
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
| 598 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 599 |
+
smem = cutlass.utils.SmemAllocator()
|
| 600 |
+
storage = smem.allocate(self.shared_storage)
|
| 601 |
+
|
| 602 |
+
ab_pipeline = self.make_ab_pipeline(
|
| 603 |
+
tiled_mma=tiled_mma,
|
| 604 |
+
cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
|
| 605 |
+
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
| 606 |
+
)
|
| 607 |
+
epi_pipeline = None
|
| 608 |
+
if const_expr(has_C):
|
| 609 |
+
epi_pipeline = self.make_epi_pipeline(
|
| 610 |
+
c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
|
| 611 |
+
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
| 612 |
+
)
|
| 613 |
+
sched_pipeline = None
|
| 614 |
+
tile_count = None
|
| 615 |
+
if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
| 616 |
+
# Dynamic persistent scheduler
|
| 617 |
+
sched_pipeline = self.make_sched_pipeline(
|
| 618 |
+
cluster_layout_mnk,
|
| 619 |
+
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
| 620 |
+
varlen_k=varlen_k,
|
| 621 |
+
)
|
| 622 |
+
tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
| 623 |
+
|
| 624 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 625 |
+
# Generate smem tensor A/B
|
| 626 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 627 |
+
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
| 628 |
+
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
| 629 |
+
sD = None
|
| 630 |
+
if const_expr(has_D):
|
| 631 |
+
if const_expr(not self.is_persistent):
|
| 632 |
+
sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
|
| 633 |
+
sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
|
| 634 |
+
else:
|
| 635 |
+
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
| 636 |
+
sC = None
|
| 637 |
+
if const_expr(has_C):
|
| 638 |
+
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
| 639 |
+
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
| 640 |
+
|
| 641 |
+
varlen_manager = VarlenManager.create(
|
| 642 |
+
varlen_params,
|
| 643 |
+
has_D,
|
| 644 |
+
self.num_epi_tensormaps,
|
| 645 |
+
# Only used if not varlen_m
|
| 646 |
+
len_m_static=Int32(
|
| 647 |
+
mA_mkl.shape[0]
|
| 648 |
+
if varlen_k or varlen_params.mAIdx is None
|
| 649 |
+
else varlen_params.mAIdx.shape[0]
|
| 650 |
+
),
|
| 651 |
+
len_k_static=Int32(mA_mkl.shape[1]),
|
| 652 |
+
pingpong=self.pingpong,
|
| 653 |
+
warp_idx=warp_idx,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
TileSchedulerCls = partial(
|
| 657 |
+
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
if warp_idx >= self.ab_load_warp_id:
|
| 661 |
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
|
| 662 |
+
if (
|
| 663 |
+
warp_idx >= self.ab_load_warp_id
|
| 664 |
+
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
| 665 |
+
):
|
| 666 |
+
is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
| 667 |
+
# initialize tensormap for A & B
|
| 668 |
+
varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
|
| 669 |
+
tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
|
| 670 |
+
tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
|
| 671 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 672 |
+
# Get mcast mask
|
| 673 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 674 |
+
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
| 675 |
+
block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
| 676 |
+
a_mcast_mask = cute.make_layout_image_mask(
|
| 677 |
+
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
|
| 678 |
+
)
|
| 679 |
+
b_mcast_mask = cute.make_layout_image_mask(
|
| 680 |
+
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0
|
| 681 |
+
)
|
| 682 |
+
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
|
| 683 |
+
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
|
| 684 |
+
|
| 685 |
+
# Persistent tile scheduling loop
|
| 686 |
+
is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
| 687 |
+
if const_expr(cute.size(cluster_layout_mnk) > 1):
|
| 688 |
+
is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
|
| 689 |
+
tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
|
| 690 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 691 |
+
ab_producer_state = make_pipeline_state(
|
| 692 |
+
pipeline.PipelineUserType.Producer, self.ab_stage
|
| 693 |
+
)
|
| 694 |
+
if const_expr(varlen_k):
|
| 695 |
+
# wait tensormap initialization complete before update
|
| 696 |
+
varlen_manager.fence_tensormap_init()
|
| 697 |
+
while work_tile.is_valid_tile:
|
| 698 |
+
tile_coord_mnkl = work_tile.tile_idx
|
| 699 |
+
batch_idx = tile_coord_mnkl[3]
|
| 700 |
+
varlen_manager.update_tensormap_AB(
|
| 701 |
+
batch_idx,
|
| 702 |
+
self.a_layout,
|
| 703 |
+
self.b_layout,
|
| 704 |
+
is_tma_warp,
|
| 705 |
+
)
|
| 706 |
+
# ///////////////////////////////////////////////////////////////////////////
|
| 707 |
+
# Local_tile partition global tensors
|
| 708 |
+
# ///////////////////////////////////////////////////////////////////////////
|
| 709 |
+
if const_expr(not self.gather_A):
|
| 710 |
+
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
| 711 |
+
# (bM, bK, RestK)
|
| 712 |
+
gA_mk = cute.local_tile(
|
| 713 |
+
mA_mk,
|
| 714 |
+
cute.select(self.cta_tile_shape_mnk, [0, 2]),
|
| 715 |
+
(tile_coord_mnkl[0], None),
|
| 716 |
+
)
|
| 717 |
+
else:
|
| 718 |
+
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
| 719 |
+
if const_expr(varlen_m):
|
| 720 |
+
gAIdx = cute.local_tile(
|
| 721 |
+
mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
| 722 |
+
)
|
| 723 |
+
# (M, K)
|
| 724 |
+
mA_mk = mA_mkl
|
| 725 |
+
else:
|
| 726 |
+
assert varlen_k
|
| 727 |
+
# (tile_K, RestK)
|
| 728 |
+
gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
|
| 729 |
+
# (tile_M, K)
|
| 730 |
+
mA_mk = cute.local_tile(
|
| 731 |
+
mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
|
| 732 |
+
)
|
| 733 |
+
# (bN, bK, RestK)
|
| 734 |
+
gB_nk = cute.local_tile(
|
| 735 |
+
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
|
| 736 |
+
cute.select(self.cta_tile_shape_mnk, [1, 2]),
|
| 737 |
+
(tile_coord_mnkl[1], None),
|
| 738 |
+
)
|
| 739 |
+
# //////////////////////////////////////////////////////////////////////////
|
| 740 |
+
# Partition shared tensor for TMA load A/B
|
| 741 |
+
# //////////////////////////////////////////////////////////////////////////
|
| 742 |
+
varlen_manager.fence_tensormap_update_AB(is_tma_warp)
|
| 743 |
+
len_m = varlen_manager.len_m(batch_idx)
|
| 744 |
+
len_k = varlen_manager.len_k(batch_idx)
|
| 745 |
+
# TMA load A partition_S/D
|
| 746 |
+
copy_A = None
|
| 747 |
+
if const_expr(not self.gather_A):
|
| 748 |
+
copy_A, _, _ = copy_utils.tma_get_copy_fn(
|
| 749 |
+
tma_atom_a,
|
| 750 |
+
cta_coord=block_in_cluster_coord_mnk[1],
|
| 751 |
+
cta_layout=cute.make_layout(
|
| 752 |
+
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
| 753 |
+
),
|
| 754 |
+
src_tensor=gA_mk,
|
| 755 |
+
dst_tensor=sA,
|
| 756 |
+
mcast_mask=a_mcast_mask,
|
| 757 |
+
tma_desc_ptr=tma_desc_a_ptr,
|
| 758 |
+
)
|
| 759 |
+
else:
|
| 760 |
+
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
| 761 |
+
mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
|
| 762 |
+
)
|
| 763 |
+
tidx = (
|
| 764 |
+
cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
|
| 765 |
+
)
|
| 766 |
+
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
| 767 |
+
copy_A, prefetch_A = None, None
|
| 768 |
+
if const_expr(varlen_m):
|
| 769 |
+
copy_A = copy_utils.gather_m_get_copy_fn(
|
| 770 |
+
thr_copy_A,
|
| 771 |
+
mA_mk,
|
| 772 |
+
sA,
|
| 773 |
+
gAIdx,
|
| 774 |
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
| 775 |
+
limit_k=len_k,
|
| 776 |
+
)
|
| 777 |
+
else:
|
| 778 |
+
copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
|
| 779 |
+
thr_copy_A,
|
| 780 |
+
mA_mk,
|
| 781 |
+
sA,
|
| 782 |
+
gAIdx,
|
| 783 |
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
| 784 |
+
limit_k=len_k,
|
| 785 |
+
)
|
| 786 |
+
# TMA load B partition_S/D
|
| 787 |
+
copy_B, _, _ = copy_utils.tma_get_copy_fn(
|
| 788 |
+
tma_atom_b,
|
| 789 |
+
cta_coord=block_in_cluster_coord_mnk[0],
|
| 790 |
+
cta_layout=cute.make_layout(
|
| 791 |
+
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
| 792 |
+
),
|
| 793 |
+
src_tensor=gB_nk,
|
| 794 |
+
dst_tensor=sB,
|
| 795 |
+
mcast_mask=b_mcast_mask,
|
| 796 |
+
tma_desc_ptr=tma_desc_b_ptr,
|
| 797 |
+
)
|
| 798 |
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 799 |
+
if const_expr(not self.gather_A):
|
| 800 |
+
ab_producer_state = self.load_AB(
|
| 801 |
+
ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
|
| 802 |
+
)
|
| 803 |
+
else:
|
| 804 |
+
ab_producer_state = self.load_AB_gather_A(
|
| 805 |
+
ab_pipeline,
|
| 806 |
+
ab_producer_state,
|
| 807 |
+
copy_A,
|
| 808 |
+
prefetch_A,
|
| 809 |
+
copy_B,
|
| 810 |
+
k_tile_cnt,
|
| 811 |
+
varlen_m=varlen_m,
|
| 812 |
+
)
|
| 813 |
+
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
| 814 |
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
| 815 |
+
work_tile = tile_scheduler.get_current_work()
|
| 816 |
+
# End of persistent scheduler loop
|
| 817 |
+
if const_expr(self.pingpong and not varlen_k):
|
| 818 |
+
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
| 819 |
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
| 820 |
+
ab_pipeline.producer_tail(ab_producer_state)
|
| 821 |
+
if is_scheduler_warp:
|
| 822 |
+
tile_scheduler.producer_tail()
|
| 823 |
+
|
| 824 |
+
if warp_idx < self.ab_load_warp_id:
|
| 825 |
+
cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
|
| 826 |
+
is_tma_warp = Boolean(
|
| 827 |
+
(not self.pingpong and warp_idx == 0)
|
| 828 |
+
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
| 829 |
+
)
|
| 830 |
+
varlen_manager.init_tensormap_epi(
|
| 831 |
+
tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
|
| 832 |
+
)
|
| 833 |
+
tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
|
| 834 |
+
tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
|
| 835 |
+
# //////////////////////////////////////////////////////////////////////////////
|
| 836 |
+
# Partition global tensor for TiledMMA_A/B/C
|
| 837 |
+
# //////////////////////////////////////////////////////////////////////////////
|
| 838 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 839 |
+
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 840 |
+
if const_expr(self.pingpong):
|
| 841 |
+
tidx = tidx % self.num_threads_per_warp_group
|
| 842 |
+
warp_group_thread_layout = cute.make_layout(
|
| 843 |
+
self.mma_warp_groups if not self.pingpong else 1,
|
| 844 |
+
stride=self.num_threads_per_warp_group,
|
| 845 |
+
)
|
| 846 |
+
thr_mma = tiled_mma.get_slice(
|
| 847 |
+
warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
# //////////////////////////////////////////////////////////////////////////////
|
| 851 |
+
# Make fragments
|
| 852 |
+
# //////////////////////////////////////////////////////////////////////////////
|
| 853 |
+
tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
|
| 854 |
+
tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
|
| 855 |
+
|
| 856 |
+
acc_shape = tiled_mma.partition_shape_C(
|
| 857 |
+
cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
|
| 858 |
+
)
|
| 859 |
+
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
| 860 |
+
acc_slow = None
|
| 861 |
+
if const_expr(self.fp8_slow_accum):
|
| 862 |
+
acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
|
| 863 |
+
|
| 864 |
+
if const_expr(self.pingpong):
|
| 865 |
+
if warp_group_idx == 0:
|
| 866 |
+
# WG0 needs a start signal at the very beginning
|
| 867 |
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
| 868 |
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
| 869 |
+
|
| 870 |
+
k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2])
|
| 871 |
+
c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
|
| 872 |
+
|
| 873 |
+
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
| 874 |
+
epi_store_pipeline = self.make_epi_store_pipeline()
|
| 875 |
+
epi_read_state = make_pipeline_state(
|
| 876 |
+
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
| 877 |
+
)
|
| 878 |
+
epi_producer_state = make_pipeline_state(
|
| 879 |
+
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
| 880 |
+
)
|
| 881 |
+
tile_scheduler = TileSchedulerCls()
|
| 882 |
+
work_tile = None
|
| 883 |
+
if const_expr(self.pingpong):
|
| 884 |
+
if const_expr(varlen_k):
|
| 885 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 886 |
+
if warp_idx >= 4:
|
| 887 |
+
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
| 888 |
+
epi_read_state.advance_iters(c_tile_cnt)
|
| 889 |
+
epi_producer_state.advance_iters(c_tile_cnt)
|
| 890 |
+
if const_expr(not varlen_k):
|
| 891 |
+
ab_read_state.advance_iters(k_tile_cnt_static)
|
| 892 |
+
else:
|
| 893 |
+
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
| 894 |
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 895 |
+
ab_read_state.advance_iters(k_tile_cnt)
|
| 896 |
+
tile_scheduler.advance_to_next_work()
|
| 897 |
+
if const_expr(varlen_k):
|
| 898 |
+
work_tile = tile_scheduler.get_current_work()
|
| 899 |
+
if const_expr(not varlen_k):
|
| 900 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 901 |
+
else:
|
| 902 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 903 |
+
if const_expr(varlen_m):
|
| 904 |
+
# wait tensormap initialization complete before update
|
| 905 |
+
varlen_manager.fence_tensormap_init()
|
| 906 |
+
while work_tile.is_valid_tile:
|
| 907 |
+
tile_coord_mnkl = work_tile.tile_idx
|
| 908 |
+
batch_idx = tile_coord_mnkl[3]
|
| 909 |
+
epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
|
| 910 |
+
epilogue_params, varlen_params.cu_seqlens_m, batch_idx
|
| 911 |
+
)
|
| 912 |
+
varlen_manager.update_tensormap_epi(
|
| 913 |
+
batch_idx,
|
| 914 |
+
self.d_layout,
|
| 915 |
+
epi_shapes,
|
| 916 |
+
epi_orders,
|
| 917 |
+
is_tma_warp,
|
| 918 |
+
)
|
| 919 |
+
len_k = varlen_manager.len_k(batch_idx)
|
| 920 |
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 921 |
+
ab_read_state, tiled_mma = self.mma(
|
| 922 |
+
ab_pipeline,
|
| 923 |
+
ab_read_state,
|
| 924 |
+
tiled_mma,
|
| 925 |
+
tCrA,
|
| 926 |
+
tCrB,
|
| 927 |
+
acc,
|
| 928 |
+
acc_slow,
|
| 929 |
+
k_tile_cnt,
|
| 930 |
+
warp_group_idx,
|
| 931 |
+
)
|
| 932 |
+
if const_expr(varlen_k):
|
| 933 |
+
if k_tile_cnt == 0:
|
| 934 |
+
acc.fill(0.0)
|
| 935 |
+
|
| 936 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 937 |
+
# EPILOGUE
|
| 938 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 939 |
+
if const_expr(self.pingpong):
|
| 940 |
+
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
| 941 |
+
|
| 942 |
+
epilogue_barrier = pipeline.NamedBarrier(
|
| 943 |
+
barrier_id=int(NamedBarrierGemm.Epilogue),
|
| 944 |
+
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
varlen_manager.fence_tensormap_update_epi(is_tma_warp)
|
| 948 |
+
|
| 949 |
+
copy_D = None
|
| 950 |
+
if const_expr(has_D):
|
| 951 |
+
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
|
| 952 |
+
tma_atom_d,
|
| 953 |
+
varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
|
| 954 |
+
self.cta_tile_shape_mnk[:2],
|
| 955 |
+
self.epi_tile,
|
| 956 |
+
sD,
|
| 957 |
+
tile_coord_mnkl,
|
| 958 |
+
tma_desc_ptr=tma_desc_d_ptr,
|
| 959 |
+
)
|
| 960 |
+
copy_C = None
|
| 961 |
+
if const_expr(has_C):
|
| 962 |
+
copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
|
| 963 |
+
tma_atom_c,
|
| 964 |
+
varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
|
| 965 |
+
self.cta_tile_shape_mnk[:2],
|
| 966 |
+
self.epi_tile,
|
| 967 |
+
sC,
|
| 968 |
+
tile_coord_mnkl,
|
| 969 |
+
)
|
| 970 |
+
copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
|
| 971 |
+
|
| 972 |
+
d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
|
| 973 |
+
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
| 974 |
+
tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
|
| 975 |
+
)
|
| 976 |
+
# (R2S, R2S_M, R2S_N)
|
| 977 |
+
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
| 978 |
+
load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
|
| 979 |
+
if const_expr(has_C):
|
| 980 |
+
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
| 981 |
+
tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
|
| 982 |
+
)
|
| 983 |
+
else:
|
| 984 |
+
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
| 985 |
+
|
| 986 |
+
# Wait for all warp groups in the thread block to finish, because smem for tensor
|
| 987 |
+
# A in the mainloop is reused in the epilogue if not persistent.
|
| 988 |
+
if const_expr(not self.is_persistent):
|
| 989 |
+
epilogue_barrier.arrive_and_wait()
|
| 990 |
+
|
| 991 |
+
self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
|
| 992 |
+
|
| 993 |
+
epi_read_state, epi_producer_state = self.epilogue(
|
| 994 |
+
epilogue_params,
|
| 995 |
+
epi_smem_tensors,
|
| 996 |
+
tma_desc_epi_ptrs,
|
| 997 |
+
epi_pipeline,
|
| 998 |
+
epi_store_pipeline,
|
| 999 |
+
epi_read_state,
|
| 1000 |
+
epi_producer_state,
|
| 1001 |
+
self.epi_tile,
|
| 1002 |
+
load_acc_subtile,
|
| 1003 |
+
tRS_rD,
|
| 1004 |
+
tRS_rC,
|
| 1005 |
+
None, # tiled_copy_t2r, for Sm100 only
|
| 1006 |
+
tiled_copy_r2s,
|
| 1007 |
+
tRS_sD,
|
| 1008 |
+
tiled_copy_s2r,
|
| 1009 |
+
tSR_rC,
|
| 1010 |
+
tSR_sC,
|
| 1011 |
+
copy_D,
|
| 1012 |
+
copy_C,
|
| 1013 |
+
tile_coord_mnkl,
|
| 1014 |
+
varlen_manager,
|
| 1015 |
+
epilogue_barrier,
|
| 1016 |
+
tile_scheduler,
|
| 1017 |
+
tidx,
|
| 1018 |
+
is_tma_warp,
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
if const_expr(self.pingpong):
|
| 1022 |
+
# With pingpong, 2 WGs write two different output tiles to the same smem,
|
| 1023 |
+
# so we have to make sure the smem content is done reading before signaling
|
| 1024 |
+
# the next WG's epilogue.
|
| 1025 |
+
if is_tma_warp:
|
| 1026 |
+
epi_store_pipeline.producer_tail()
|
| 1027 |
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
| 1028 |
+
|
| 1029 |
+
if const_expr(not self.pingpong):
|
| 1030 |
+
tile_scheduler.advance_to_next_work()
|
| 1031 |
+
work_tile = tile_scheduler.get_current_work()
|
| 1032 |
+
else: # Skip a tile for pingpong
|
| 1033 |
+
# Update starting load/store pipeline states for the next tile
|
| 1034 |
+
epi_read_state.advance_iters(c_tile_cnt)
|
| 1035 |
+
epi_producer_state.advance_iters(c_tile_cnt)
|
| 1036 |
+
# Update starting mainloop pipeline state for the next tile
|
| 1037 |
+
if const_expr(not varlen_k):
|
| 1038 |
+
ab_read_state.advance_iters(k_tile_cnt_static)
|
| 1039 |
+
tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
|
| 1040 |
+
work_tile = tile_scheduler.get_current_work()
|
| 1041 |
+
else:
|
| 1042 |
+
tile_scheduler.advance_to_next_work()
|
| 1043 |
+
work_tile = tile_scheduler.get_current_work()
|
| 1044 |
+
if work_tile.is_valid_tile:
|
| 1045 |
+
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
| 1046 |
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 1047 |
+
ab_read_state.advance_iters(k_tile_cnt)
|
| 1048 |
+
tile_scheduler.advance_to_next_work()
|
| 1049 |
+
work_tile = tile_scheduler.get_current_work()
|
| 1050 |
+
# End of persistent scheduler loop
|
| 1051 |
+
|
| 1052 |
+
# Wait for D store complete
|
| 1053 |
+
if const_expr(not self.pingpong):
|
| 1054 |
+
if is_tma_warp:
|
| 1055 |
+
epi_store_pipeline.producer_tail()
|
| 1056 |
+
|
| 1057 |
+
@cute.jit
|
| 1058 |
+
def load_AB(
|
| 1059 |
+
self,
|
| 1060 |
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1061 |
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
| 1062 |
+
copy_A: Optional[Callable],
|
| 1063 |
+
copy_B: Callable,
|
| 1064 |
+
k_tile_cnt: Int32,
|
| 1065 |
+
# These are for Sm100 blockscaled gemm
|
| 1066 |
+
copy_SFA: Optional[Callable] = None,
|
| 1067 |
+
copy_SFB: Optional[Callable] = None,
|
| 1068 |
+
) -> cutlass.pipeline.PipelineState:
|
| 1069 |
+
blockscaled = const_expr(copy_SFA is not None)
|
| 1070 |
+
if const_expr(blockscaled):
|
| 1071 |
+
assert copy_SFB is not None
|
| 1072 |
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
| 1073 |
+
peek_ab_empty_status = Boolean(True)
|
| 1074 |
+
if 0 < k_tile_cnt:
|
| 1075 |
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
| 1076 |
+
# /////////////////////////////////////////////////////////////////////////
|
| 1077 |
+
# TMA load
|
| 1078 |
+
# /////////////////////////////////////////////////////////////////////////
|
| 1079 |
+
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
| 1080 |
+
# Wait for A/B buffers to be empty before loading into them
|
| 1081 |
+
# Also sets the transaction barrier for the A/B buffers
|
| 1082 |
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
| 1083 |
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
| 1084 |
+
smem_idx = ab_producer_state.index
|
| 1085 |
+
if const_expr(copy_A is not None):
|
| 1086 |
+
copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
| 1087 |
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
| 1088 |
+
if const_expr(blockscaled):
|
| 1089 |
+
copy_SFA(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
| 1090 |
+
copy_SFB(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
| 1091 |
+
# Mainloop pipeline's producer commit is a NOP
|
| 1092 |
+
ab_pipeline.producer_commit(ab_producer_state)
|
| 1093 |
+
ab_producer_state.advance()
|
| 1094 |
+
peek_ab_empty_status = Boolean(True)
|
| 1095 |
+
if k_tile + 1 < k_tile_cnt:
|
| 1096 |
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
| 1097 |
+
return ab_producer_state
|
| 1098 |
+
|
| 1099 |
+
@cute.jit
|
| 1100 |
+
def load_AB_gather_A(
|
| 1101 |
+
self,
|
| 1102 |
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1103 |
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
| 1104 |
+
copy_A: Callable,
|
| 1105 |
+
prefetch_A: Optional[Callable],
|
| 1106 |
+
copy_B: Callable,
|
| 1107 |
+
k_tile_cnt: Int32,
|
| 1108 |
+
varlen_m: bool = True,
|
| 1109 |
+
) -> cutlass.pipeline.PipelineState:
|
| 1110 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 1111 |
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
| 1112 |
+
peek_ab_empty_status = Boolean(True)
|
| 1113 |
+
if 0 < k_tile_cnt:
|
| 1114 |
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
| 1115 |
+
# /////////////////////////////////////////////////////////////////////////
|
| 1116 |
+
# TMA load on B and cp.async on A
|
| 1117 |
+
# /////////////////////////////////////////////////////////////////////////
|
| 1118 |
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
| 1119 |
+
prefetch_out = ()
|
| 1120 |
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
| 1121 |
+
prefetch_out = (prefetch_A(k_tile),)
|
| 1122 |
+
# Wait for A/B buffers to be empty before loading into them
|
| 1123 |
+
# Also sets the transaction barrier for the A/B buffers
|
| 1124 |
+
# A tiny bit faster to rotate the warp that does TMA
|
| 1125 |
+
# However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
|
| 1126 |
+
# since that's the warp that does the tensormap update.
|
| 1127 |
+
is_tma_warp = warp_idx == self.ab_load_warp_id + (
|
| 1128 |
+
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
| 1129 |
+
)
|
| 1130 |
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
| 1131 |
+
smem_idx = ab_producer_state.index
|
| 1132 |
+
# A bit faster to load B first while we calculate the indices for A
|
| 1133 |
+
if is_tma_warp:
|
| 1134 |
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
| 1135 |
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
| 1136 |
+
copy_A(k_tile, smem_idx, *prefetch_out)
|
| 1137 |
+
# This tells mbarrier to track the completion of cp.async
|
| 1138 |
+
ab_pipeline.producer_cpasync_commit(ab_producer_state)
|
| 1139 |
+
ab_producer_state.advance()
|
| 1140 |
+
peek_ab_empty_status = Boolean(True)
|
| 1141 |
+
if k_tile + 1 < k_tile_cnt:
|
| 1142 |
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
| 1143 |
+
# bound checking in the K dimension on the last k_tile
|
| 1144 |
+
if 0 < k_tile_cnt:
|
| 1145 |
+
k_tile = k_tile_cnt - 1
|
| 1146 |
+
prefetch_out = ()
|
| 1147 |
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
| 1148 |
+
prefetch_out = (prefetch_A(k_tile, pred=True),)
|
| 1149 |
+
is_tma_warp = warp_idx == self.ab_load_warp_id + (
|
| 1150 |
+
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
| 1151 |
+
)
|
| 1152 |
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
| 1153 |
+
smem_idx = ab_producer_state.index
|
| 1154 |
+
if is_tma_warp:
|
| 1155 |
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
| 1156 |
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
| 1157 |
+
copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
|
| 1158 |
+
ab_pipeline.producer_cpasync_commit(ab_producer_state)
|
| 1159 |
+
ab_producer_state.advance()
|
| 1160 |
+
return ab_producer_state
|
| 1161 |
+
|
| 1162 |
+
@cute.jit
|
| 1163 |
+
def mma(
|
| 1164 |
+
self,
|
| 1165 |
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1166 |
+
ab_read_state: cutlass.pipeline.PipelineState,
|
| 1167 |
+
tiled_mma: cute.TiledMma,
|
| 1168 |
+
tCrA: cute.Tensor,
|
| 1169 |
+
tCrB: cute.Tensor,
|
| 1170 |
+
acc: cute.Tensor,
|
| 1171 |
+
acc_slow: Optional[cute.Tensor],
|
| 1172 |
+
k_tile_cnt: Int32,
|
| 1173 |
+
warp_group_idx: Int32,
|
| 1174 |
+
) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
|
| 1175 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 1176 |
+
# Prologue MMAs
|
| 1177 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 1178 |
+
k_pipe_mmas = 1
|
| 1179 |
+
ab_release_state = ab_read_state.clone()
|
| 1180 |
+
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
|
| 1181 |
+
if const_expr(self.pingpong):
|
| 1182 |
+
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
|
| 1183 |
+
peek_ab_full_status = Boolean(True)
|
| 1184 |
+
if 0 < k_tile_cnt:
|
| 1185 |
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
| 1186 |
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
| 1187 |
+
num_k_blocks = cute.size(tCrA, mode=[2])
|
| 1188 |
+
for k_tile in cutlass.range(num_prologue_mma):
|
| 1189 |
+
# Wait for A/B buffer to be ready
|
| 1190 |
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
| 1191 |
+
warpgroup.fence()
|
| 1192 |
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
| 1193 |
+
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
| 1194 |
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
| 1195 |
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
| 1196 |
+
warpgroup.commit_group()
|
| 1197 |
+
ab_read_state.advance()
|
| 1198 |
+
peek_ab_full_status = Boolean(True)
|
| 1199 |
+
if k_tile + 1 < k_tile_cnt:
|
| 1200 |
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
| 1201 |
+
# If k_tile_cnt == 0, this is not correct. But we will set acc to 0 in the mainloop
|
| 1202 |
+
# in that case.
|
| 1203 |
+
if const_expr(self.fp8_slow_accum):
|
| 1204 |
+
warpgroup.wait_group(0)
|
| 1205 |
+
acc_slow.store(acc.load())
|
| 1206 |
+
|
| 1207 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 1208 |
+
# MAINLOOP
|
| 1209 |
+
# /////////////////////////////////////////////////////////////////////////////
|
| 1210 |
+
for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
|
| 1211 |
+
# Wait for TMA copies to complete
|
| 1212 |
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
| 1213 |
+
# WGMMA
|
| 1214 |
+
warpgroup.fence()
|
| 1215 |
+
if const_expr(self.fp8_slow_accum):
|
| 1216 |
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
| 1217 |
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
| 1218 |
+
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
| 1219 |
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
| 1220 |
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
| 1221 |
+
warpgroup.commit_group()
|
| 1222 |
+
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
|
| 1223 |
+
if const_expr(not self.fp8_slow_accum):
|
| 1224 |
+
warpgroup.wait_group(k_pipe_mmas)
|
| 1225 |
+
else:
|
| 1226 |
+
warpgroup.wait_group(0)
|
| 1227 |
+
acc_slow.store(acc_slow.load() + acc.load())
|
| 1228 |
+
ab_pipeline.consumer_release(ab_release_state)
|
| 1229 |
+
ab_read_state.advance()
|
| 1230 |
+
ab_release_state.advance()
|
| 1231 |
+
peek_ab_full_status = Boolean(True)
|
| 1232 |
+
if k_tile + 1 < k_tile_cnt:
|
| 1233 |
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
| 1234 |
+
if const_expr(self.pingpong):
|
| 1235 |
+
# Cue for next WG's MMA to start
|
| 1236 |
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
|
| 1237 |
+
if const_expr(not self.fp8_slow_accum):
|
| 1238 |
+
# fp8_slow_accum would already called wait_group(0) inside the loop
|
| 1239 |
+
warpgroup.wait_group(0)
|
| 1240 |
+
for k_tile in cutlass.range(num_prologue_mma, unroll=1):
|
| 1241 |
+
ab_pipeline.consumer_release(ab_release_state)
|
| 1242 |
+
ab_release_state.advance()
|
| 1243 |
+
if const_expr(self.fp8_slow_accum):
|
| 1244 |
+
acc.store(acc_slow.load())
|
| 1245 |
+
# If we don't return the tiled_mma, we get compiler error
|
| 1246 |
+
# "operand #0 does not dominate this use"
|
| 1247 |
+
return ab_read_state, tiled_mma
|
| 1248 |
+
|
| 1249 |
+
@cute.jit
|
| 1250 |
+
def epilogue(
|
| 1251 |
+
self,
|
| 1252 |
+
params: EpilogueParams,
|
| 1253 |
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 1254 |
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
| 1255 |
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1256 |
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1257 |
+
epi_read_state: cutlass.pipeline.PipelineState,
|
| 1258 |
+
epi_producer_state: Optional[cutlass.pipeline.PipelineState],
|
| 1259 |
+
epi_tile: cute.Tile,
|
| 1260 |
+
load_acc_subtile: Callable,
|
| 1261 |
+
tRS_rD: cute.Tensor,
|
| 1262 |
+
tRS_rC: Optional[cute.Tensor],
|
| 1263 |
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
| 1264 |
+
tiled_copy_r2s: cute.TiledCopy,
|
| 1265 |
+
tRS_sD: cute.Tensor,
|
| 1266 |
+
tiled_copy_s2r: Optional[cute.ThrCopy],
|
| 1267 |
+
tSR_rC: Optional[cute.Tensor],
|
| 1268 |
+
tSR_sC: Optional[cute.Tensor],
|
| 1269 |
+
copy_D: Optional[Callable],
|
| 1270 |
+
copy_C: Optional[Callable],
|
| 1271 |
+
tile_coord_mnkl: cute.Coord,
|
| 1272 |
+
varlen_manager: VarlenManager,
|
| 1273 |
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
| 1274 |
+
tile_scheduler,
|
| 1275 |
+
tidx: Int32,
|
| 1276 |
+
is_tma_warp: Boolean,
|
| 1277 |
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
| 1278 |
+
has_C = const_expr(tRS_rC is not None)
|
| 1279 |
+
has_D = const_expr(copy_D is not None)
|
| 1280 |
+
epi_tile_shape = cute.zipped_divide(
|
| 1281 |
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
| 1282 |
+
).shape[1]
|
| 1283 |
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
| 1284 |
+
epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0))
|
| 1285 |
+
epi_tile_num = cute.size(epi_tile_shape)
|
| 1286 |
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
| 1287 |
+
|
| 1288 |
+
epi_tensors = self.epi_begin(
|
| 1289 |
+
params,
|
| 1290 |
+
epi_smem_tensors,
|
| 1291 |
+
epi_tile,
|
| 1292 |
+
tiled_copy_t2r,
|
| 1293 |
+
tiled_copy_r2s,
|
| 1294 |
+
tile_coord_mnkl,
|
| 1295 |
+
varlen_manager,
|
| 1296 |
+
epilogue_barrier,
|
| 1297 |
+
tidx,
|
| 1298 |
+
)
|
| 1299 |
+
|
| 1300 |
+
if const_expr(copy_C is not None):
|
| 1301 |
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
| 1302 |
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
| 1303 |
+
if is_tma_warp:
|
| 1304 |
+
epi_pipeline.producer_acquire(epi_producer_state)
|
| 1305 |
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 1306 |
+
epi_pipeline.producer_commit(epi_producer_state)
|
| 1307 |
+
epi_producer_state.advance()
|
| 1308 |
+
|
| 1309 |
+
def tma_store_fn(src_idx, dst_idx):
|
| 1310 |
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
| 1311 |
+
cute.arch.fence_proxy(
|
| 1312 |
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 1313 |
+
)
|
| 1314 |
+
epilogue_barrier.arrive_and_wait()
|
| 1315 |
+
# Copy from shared memory to global memory
|
| 1316 |
+
if is_tma_warp:
|
| 1317 |
+
if const_expr(has_D):
|
| 1318 |
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
| 1319 |
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
| 1320 |
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
| 1321 |
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
| 1322 |
+
epilogue_barrier.arrive_and_wait()
|
| 1323 |
+
|
| 1324 |
+
# We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops
|
| 1325 |
+
# with the TMA store. However, currently this doesn't seem to improve perf.
|
| 1326 |
+
delay_tma_store = False
|
| 1327 |
+
|
| 1328 |
+
src_idx_prev, dst_idx_prev = None, None
|
| 1329 |
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
| 1330 |
+
# The global memory coordinate for the current epi tile
|
| 1331 |
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
| 1332 |
+
# Copy from acc to D registers
|
| 1333 |
+
load_acc_subtile(tRS_rD, epi_idx)
|
| 1334 |
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
| 1335 |
+
if const_expr(has_C):
|
| 1336 |
+
epi_pipeline.consumer_wait(epi_read_state)
|
| 1337 |
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
| 1338 |
+
# Fence to make sure shared memory read is visible to TMA load
|
| 1339 |
+
cute.arch.fence_proxy(
|
| 1340 |
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 1341 |
+
)
|
| 1342 |
+
cute.arch.sync_warp()
|
| 1343 |
+
with cute.arch.elect_one():
|
| 1344 |
+
epi_pipeline.consumer_release(epi_read_state)
|
| 1345 |
+
epi_read_state.advance()
|
| 1346 |
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
| 1347 |
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
| 1348 |
+
if is_tma_warp:
|
| 1349 |
+
epi_pipeline.producer_acquire(epi_producer_state)
|
| 1350 |
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 1351 |
+
epi_pipeline.producer_commit(epi_producer_state)
|
| 1352 |
+
epi_producer_state.advance()
|
| 1353 |
+
tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 1354 |
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
| 1355 |
+
if const_expr(delay_tma_store):
|
| 1356 |
+
if const_expr(epi_idx > 0):
|
| 1357 |
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
| 1358 |
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
| 1359 |
+
# Copy from D registers to shared memory
|
| 1360 |
+
if const_expr(has_D):
|
| 1361 |
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
| 1362 |
+
if const_expr(not delay_tma_store):
|
| 1363 |
+
tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
|
| 1364 |
+
|
| 1365 |
+
if const_expr(delay_tma_store):
|
| 1366 |
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
| 1367 |
+
|
| 1368 |
+
self.epi_end(
|
| 1369 |
+
params,
|
| 1370 |
+
epi_tensors,
|
| 1371 |
+
epi_tile,
|
| 1372 |
+
tiled_copy_t2r,
|
| 1373 |
+
tiled_copy_r2s,
|
| 1374 |
+
tile_coord_mnkl,
|
| 1375 |
+
varlen_manager,
|
| 1376 |
+
tidx,
|
| 1377 |
+
)
|
| 1378 |
+
|
| 1379 |
+
return epi_read_state, epi_producer_state
|
| 1380 |
+
|
| 1381 |
+
def get_scheduler_class(self, varlen_m: bool = False):
|
| 1382 |
+
"""Return the scheduler class to use. Override in subclasses for custom schedulers."""
|
| 1383 |
+
return TileScheduler if not varlen_m else VarlenMTileScheduler
|
| 1384 |
+
|
| 1385 |
+
def get_scheduler_arguments(
|
| 1386 |
+
self,
|
| 1387 |
+
mA: cute.Tensor,
|
| 1388 |
+
mB: cute.Tensor,
|
| 1389 |
+
mD: Optional[cute.Tensor],
|
| 1390 |
+
scheduler_args,
|
| 1391 |
+
varlen_args,
|
| 1392 |
+
):
|
| 1393 |
+
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
| 1394 |
+
if const_expr(varlen_args.mCuSeqlensM is None):
|
| 1395 |
+
num_problems = (
|
| 1396 |
+
mD.shape[2]
|
| 1397 |
+
if mD is not None
|
| 1398 |
+
else (
|
| 1399 |
+
mB.shape[2]
|
| 1400 |
+
if varlen_args.mCuSeqlensK is None
|
| 1401 |
+
else varlen_args.mCuSeqlensK.shape[0] - 1
|
| 1402 |
+
)
|
| 1403 |
+
)
|
| 1404 |
+
problem_shape_ntile_mnl = (
|
| 1405 |
+
cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]),
|
| 1406 |
+
cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
|
| 1407 |
+
num_problems,
|
| 1408 |
+
)
|
| 1409 |
+
tile_sched_args = TileSchedulerArguments(
|
| 1410 |
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
| 1411 |
+
raster_order=scheduler_args.raster_order,
|
| 1412 |
+
group_size=scheduler_args.max_swizzle_size,
|
| 1413 |
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
| 1414 |
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
| 1415 |
+
batch_idx_permute=scheduler_args.batch_idx_permute,
|
| 1416 |
+
is_persistent=self.is_persistent,
|
| 1417 |
+
)
|
| 1418 |
+
else:
|
| 1419 |
+
assert mD is not None or not self.gather_A
|
| 1420 |
+
problem_shape_ntile_mnl = (
|
| 1421 |
+
None,
|
| 1422 |
+
cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
|
| 1423 |
+
varlen_args.mCuSeqlensM.shape[0] - 1,
|
| 1424 |
+
)
|
| 1425 |
+
tile_sched_args = VarlenMTileSchedulerArguments(
|
| 1426 |
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
| 1427 |
+
total_m=mD.shape[0] if mD is not None else varlen_args.mAIdx.shape[0],
|
| 1428 |
+
cu_seqlens_m=varlen_args.mCuSeqlensM,
|
| 1429 |
+
raster_order=scheduler_args.raster_order,
|
| 1430 |
+
group_size=scheduler_args.max_swizzle_size,
|
| 1431 |
+
tile_shape_mn=self.cta_tile_shape_mnk[:2],
|
| 1432 |
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
| 1433 |
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
| 1434 |
+
is_persistent=self.is_persistent,
|
| 1435 |
+
)
|
| 1436 |
+
return tile_sched_args
|
| 1437 |
+
|
| 1438 |
+
@cute.jit
|
| 1439 |
+
def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
|
| 1440 |
+
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
| 1441 |
+
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
| 1442 |
+
|
| 1443 |
+
@cute.jit
|
| 1444 |
+
def epi_begin(
|
| 1445 |
+
self,
|
| 1446 |
+
params: EpilogueParams,
|
| 1447 |
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 1448 |
+
epi_tile: cute.Tile,
|
| 1449 |
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
| 1450 |
+
tiled_copy_r2s: cute.TiledCopy,
|
| 1451 |
+
tile_coord_mnkl: cute.Coord,
|
| 1452 |
+
varlen_manager: VarlenManager,
|
| 1453 |
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
| 1454 |
+
tidx: Int32,
|
| 1455 |
+
) -> Tuple[cute.Tensor, ...]:
|
| 1456 |
+
return ()
|
| 1457 |
+
|
| 1458 |
+
def epi_begin_loop(
|
| 1459 |
+
self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord
|
| 1460 |
+
) -> Tuple[cute.Tensor, ...]:
|
| 1461 |
+
return ()
|
| 1462 |
+
|
| 1463 |
+
def epi_visit_subtile(
|
| 1464 |
+
self,
|
| 1465 |
+
params: EpilogueParams,
|
| 1466 |
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 1467 |
+
tRS_rD: cute.Tensor,
|
| 1468 |
+
tRS_rC: Optional[cute.Tensor] = None,
|
| 1469 |
+
) -> Optional[cute.Tensor]:
|
| 1470 |
+
return None
|
| 1471 |
+
|
| 1472 |
+
def epi_visit_acc(
|
| 1473 |
+
self,
|
| 1474 |
+
params: EpilogueParams,
|
| 1475 |
+
acc: cute.Tensor,
|
| 1476 |
+
tiled_mma: cute.TiledMma,
|
| 1477 |
+
tile_coord_mnkl: cute.Coord,
|
| 1478 |
+
tidx: Int32,
|
| 1479 |
+
) -> None:
|
| 1480 |
+
pass
|
| 1481 |
+
|
| 1482 |
+
@cute.jit
|
| 1483 |
+
def epi_end(
|
| 1484 |
+
self,
|
| 1485 |
+
params: EpilogueParams,
|
| 1486 |
+
epi_tensors: Tuple[cute.Tensor, ...],
|
| 1487 |
+
epi_tile: cute.Tile,
|
| 1488 |
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
| 1489 |
+
tiled_copy_r2s: cute.TiledCopy,
|
| 1490 |
+
tile_coord_mnkl: cute.Coord,
|
| 1491 |
+
varlen_manager,
|
| 1492 |
+
tidx,
|
| 1493 |
+
) -> None:
|
| 1494 |
+
pass
|
| 1495 |
+
|
| 1496 |
+
def epi_to_underlying_arguments(
|
| 1497 |
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
| 1498 |
+
) -> EpilogueParams:
|
| 1499 |
+
return self.EpilogueParams()
|
| 1500 |
+
|
| 1501 |
+
def epi_get_tma_atoms(
|
| 1502 |
+
self, params: EpilogueParams, *, loc=None, ip=None
|
| 1503 |
+
) -> list[cute.CopyAtom]:
|
| 1504 |
+
"""Subclasses can override this"""
|
| 1505 |
+
return []
|
| 1506 |
+
|
| 1507 |
+
def epi_get_tensormap_update_shapes_orders(
|
| 1508 |
+
self,
|
| 1509 |
+
params: EpilogueParams,
|
| 1510 |
+
cu_seqlens_m: cute.Tensor,
|
| 1511 |
+
batch_idx: Int32,
|
| 1512 |
+
*,
|
| 1513 |
+
loc=None,
|
| 1514 |
+
ip=None,
|
| 1515 |
+
) -> tuple[list[Int32], list[int]]:
|
| 1516 |
+
"""Subclasses can override this"""
|
| 1517 |
+
return [], []
|
| 1518 |
+
|
| 1519 |
+
@staticmethod
|
| 1520 |
+
def epi_smem_bytes_per_stage(
|
| 1521 |
+
args: Optional[EpilogueArguments],
|
| 1522 |
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
| 1523 |
+
epi_tile: cute.Tile,
|
| 1524 |
+
) -> int:
|
| 1525 |
+
return 0
|
| 1526 |
+
|
| 1527 |
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
| 1528 |
+
return cute.struct.MemRange[Int32, 0] # Dummy struct
|
| 1529 |
+
|
| 1530 |
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
| 1531 |
+
return tuple()
|
| 1532 |
+
|
| 1533 |
+
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
|
| 1534 |
+
assert stage in ["mma", "epi"]
|
| 1535 |
+
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
| 1536 |
+
cute.arch.barrier(
|
| 1537 |
+
barrier_id=int(barrier) + warp_group_idx,
|
| 1538 |
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
| 1539 |
+
)
|
| 1540 |
+
|
| 1541 |
+
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
|
| 1542 |
+
assert stage in ["mma", "epi"]
|
| 1543 |
+
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
| 1544 |
+
cute.arch.barrier_arrive(
|
| 1545 |
+
barrier_id=int(barrier) + warp_group_idx,
|
| 1546 |
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
| 1547 |
+
)
|
| 1548 |
+
|
| 1549 |
+
def epilog_smem_copy_atom(self, tiled_mma: cute.TiledMma) -> cute.TiledCopy:
|
| 1550 |
+
copy_atom_C = cute.make_copy_atom(
|
| 1551 |
+
warp.StMatrix8x8x16bOp(
|
| 1552 |
+
self.d_layout.is_m_major_c() if self.d_layout is not None else False,
|
| 1553 |
+
num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
|
| 1554 |
+
),
|
| 1555 |
+
Float16, # this is just to get the right source layout
|
| 1556 |
+
)
|
| 1557 |
+
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
| 1558 |
+
return tiled_copy_C_atom
|
| 1559 |
+
|
| 1560 |
+
def epilog_smem_store_and_partition(
|
| 1561 |
+
self,
|
| 1562 |
+
tiled_mma: cute.TiledMma,
|
| 1563 |
+
d_layout: Optional[LayoutEnum],
|
| 1564 |
+
dtype: Type[cutlass.Numeric],
|
| 1565 |
+
sD: Optional[cute.Tensor],
|
| 1566 |
+
tidx: Int32,
|
| 1567 |
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
| 1568 |
+
if d_layout is None:
|
| 1569 |
+
d_layout = LayoutEnum.ROW_MAJOR
|
| 1570 |
+
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
| 1571 |
+
# Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
|
| 1572 |
+
# get st.matrix with num_matrices=4
|
| 1573 |
+
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
|
| 1574 |
+
d_layout, elem_ty_d=dtype, elem_ty_acc=self.acc_dtype
|
| 1575 |
+
)
|
| 1576 |
+
tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
|
| 1577 |
+
# (R2S, R2S_M, R2S_N, PIPE_D)
|
| 1578 |
+
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
| 1579 |
+
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
| 1580 |
+
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
|
| 1581 |
+
tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
|
| 1582 |
+
tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
|
| 1583 |
+
return tiled_copy_r2s, tRS_rD, tRS_sD
|
| 1584 |
+
|
| 1585 |
+
def epilog_smem_load_and_partition(
|
| 1586 |
+
self,
|
| 1587 |
+
tiled_mma: cute.TiledMma,
|
| 1588 |
+
c_layout: LayoutEnum,
|
| 1589 |
+
dtype: Type[cutlass.Numeric],
|
| 1590 |
+
sC: cute.Tensor,
|
| 1591 |
+
tRS_rD_layout: cutlass.Layout,
|
| 1592 |
+
tidx: Int32,
|
| 1593 |
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
| 1594 |
+
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
| 1595 |
+
copy_atom_s2r = copy_utils.sm90_get_smem_load_op(c_layout, dtype)
|
| 1596 |
+
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
| 1597 |
+
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
| 1598 |
+
tSR_sC = thr_copy_s2r.partition_S(sC)
|
| 1599 |
+
tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
|
| 1600 |
+
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
| 1601 |
+
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
| 1602 |
+
|
| 1603 |
+
def epilog_gmem_copy_and_partition(
|
| 1604 |
+
self,
|
| 1605 |
+
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
| 1606 |
+
mD_mn: cute.Tensor,
|
| 1607 |
+
tile_shape_mn: cute.Tile,
|
| 1608 |
+
epi_tile: cute.Tile,
|
| 1609 |
+
sD: cute.Tensor,
|
| 1610 |
+
tile_coord_mnkl: cute.Coord,
|
| 1611 |
+
tma_desc_ptr: Optional[cute.Pointer] = None,
|
| 1612 |
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 1613 |
+
# (bM, bN)
|
| 1614 |
+
gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
|
| 1615 |
+
tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
|
| 1616 |
+
is_s2g = isinstance(
|
| 1617 |
+
atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp)
|
| 1618 |
+
)
|
| 1619 |
+
src_tensor, dst_tensor = (
|
| 1620 |
+
(sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD)
|
| 1621 |
+
)
|
| 1622 |
+
return copy_utils.tma_get_copy_fn(
|
| 1623 |
+
atom,
|
| 1624 |
+
cta_coord=0,
|
| 1625 |
+
cta_layout=cute.make_layout(1),
|
| 1626 |
+
src_tensor=src_tensor,
|
| 1627 |
+
dst_tensor=dst_tensor,
|
| 1628 |
+
tma_desc_ptr=tma_desc_ptr,
|
| 1629 |
+
)
|
| 1630 |
+
|
| 1631 |
+
def make_ab_pipeline(
|
| 1632 |
+
self,
|
| 1633 |
+
tiled_mma: cute.TiledMma,
|
| 1634 |
+
cluster_layout_vmnk: cute.Layout,
|
| 1635 |
+
ab_pipeline_mbar_ptr: cute.Pointer,
|
| 1636 |
+
):
|
| 1637 |
+
# Threads/warps participating in this pipeline
|
| 1638 |
+
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32
|
| 1639 |
+
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
| 1640 |
+
# Each warp will contribute to the arrive count with the number of mcast size
|
| 1641 |
+
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
| 1642 |
+
consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
|
| 1643 |
+
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
| 1644 |
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
| 1645 |
+
)
|
| 1646 |
+
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
| 1647 |
+
return pipeline_cls.create(
|
| 1648 |
+
barrier_storage=ab_pipeline_mbar_ptr,
|
| 1649 |
+
num_stages=self.ab_stage,
|
| 1650 |
+
producer_group=ab_pipeline_producer_group,
|
| 1651 |
+
consumer_group=ab_pipeline_consumer_group,
|
| 1652 |
+
tx_count=self.num_tma_load_bytes,
|
| 1653 |
+
cta_layout_vmnk=cluster_layout_vmnk,
|
| 1654 |
+
)
|
| 1655 |
+
|
| 1656 |
+
def make_epi_pipeline(
|
| 1657 |
+
self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer
|
| 1658 |
+
):
|
| 1659 |
+
# Threads/warps participating in this pipeline
|
| 1660 |
+
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
| 1661 |
+
# Each warp will contribute 1 to the arrive count
|
| 1662 |
+
consumer_arrive_cnt = self.num_epi_warps
|
| 1663 |
+
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
| 1664 |
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
| 1665 |
+
)
|
| 1666 |
+
tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
|
| 1667 |
+
return pipeline.PipelineTmaAsync.create(
|
| 1668 |
+
barrier_storage=epi_pipeline_mbar_ptr,
|
| 1669 |
+
num_stages=self.epi_c_stage,
|
| 1670 |
+
producer_group=epi_pipeline_producer_group,
|
| 1671 |
+
consumer_group=epi_pipeline_consumer_group,
|
| 1672 |
+
tx_count=tma_copy_c_bytes,
|
| 1673 |
+
)
|
| 1674 |
+
|
| 1675 |
+
def make_epi_store_pipeline(self):
|
| 1676 |
+
# Threads/warps participating in tma store pipeline
|
| 1677 |
+
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
|
| 1678 |
+
epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads)
|
| 1679 |
+
return pipeline.PipelineTmaStore.create(
|
| 1680 |
+
num_stages=self.epi_stage, producer_group=epi_store_producer_group
|
| 1681 |
+
)
|
| 1682 |
+
|
| 1683 |
+
def make_sched_pipeline(
|
| 1684 |
+
self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
|
| 1685 |
+
):
|
| 1686 |
+
# Threads/warps participating in this pipeline
|
| 1687 |
+
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
| 1688 |
+
cluster_size = cute.size(cluster_layout_mnk)
|
| 1689 |
+
# Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
| 1690 |
+
# If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
|
| 1691 |
+
# at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
|
| 1692 |
+
consumer_arrive_cnt = (
|
| 1693 |
+
(self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
|
| 1694 |
+
+ self.num_ab_load_warps
|
| 1695 |
+
) * cluster_size - 1
|
| 1696 |
+
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
| 1697 |
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
| 1698 |
+
)
|
| 1699 |
+
return pipeline.PipelineAsync.create(
|
| 1700 |
+
barrier_storage=sched_pipeline_mbar_ptr,
|
| 1701 |
+
num_stages=self.sched_stage,
|
| 1702 |
+
producer_group=sched_pipeline_producer_group,
|
| 1703 |
+
consumer_group=sched_pipeline_consumer_group,
|
| 1704 |
+
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
| 1705 |
+
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
| 1706 |
+
)
|
| 1707 |
+
|
| 1708 |
+
@classmethod
|
| 1709 |
+
def _compute_stages(
|
| 1710 |
+
cls,
|
| 1711 |
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
| 1712 |
+
epi_tile: Tuple[int, int],
|
| 1713 |
+
a_dtype: Type[cutlass.Numeric],
|
| 1714 |
+
b_dtype: Type[cutlass.Numeric],
|
| 1715 |
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
| 1716 |
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
| 1717 |
+
epilogue_args: EpilogueArguments,
|
| 1718 |
+
smem_capacity: int,
|
| 1719 |
+
occupancy: int,
|
| 1720 |
+
overlap_sD_sA: bool = False,
|
| 1721 |
+
) -> Tuple[int, int]:
|
| 1722 |
+
"""Computes the number of stages for A/B/C operands based on heuristics.
|
| 1723 |
+
|
| 1724 |
+
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
| 1725 |
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
| 1726 |
+
:param a_dtype: Data type of operand A.
|
| 1727 |
+
:type a_dtype: type[cutlass.Numeric]
|
| 1728 |
+
:param b_dtype: Data type of operand B.
|
| 1729 |
+
:type b_dtype: type[cutlass.Numeric]
|
| 1730 |
+
:param smem_capacity: Total available shared memory capacity in bytes.
|
| 1731 |
+
:type smem_capacity: int
|
| 1732 |
+
:param occupancy: Target number of CTAs per SM (occupancy).
|
| 1733 |
+
:type occupancy: int
|
| 1734 |
+
|
| 1735 |
+
:return: A tuple containing the computed number of stages for:
|
| 1736 |
+
(A/B operand stages, epilogue stages)
|
| 1737 |
+
:rtype: Tuple[int, int]
|
| 1738 |
+
"""
|
| 1739 |
+
|
| 1740 |
+
epi_stage = 4 if epi_tile[1] <= 16 else 2
|
| 1741 |
+
if overlap_sD_sA:
|
| 1742 |
+
epi_bytes = 0
|
| 1743 |
+
else:
|
| 1744 |
+
d_bytes_per_stage = (
|
| 1745 |
+
cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
|
| 1746 |
+
)
|
| 1747 |
+
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
| 1748 |
+
epilogue_args, cta_tile_shape_mnk, epi_tile
|
| 1749 |
+
)
|
| 1750 |
+
epi_bytes = epi_bytes_per_stage * epi_stage
|
| 1751 |
+
epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
|
| 1752 |
+
if c_dtype is not None:
|
| 1753 |
+
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
| 1754 |
+
|
| 1755 |
+
a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
|
| 1756 |
+
b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
|
| 1757 |
+
ab_bytes_per_stage = (
|
| 1758 |
+
cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
|
| 1759 |
+
)
|
| 1760 |
+
mbar_helpers_bytes = 1024
|
| 1761 |
+
|
| 1762 |
+
remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
|
| 1763 |
+
ab_stage = remaining_bytes // ab_bytes_per_stage
|
| 1764 |
+
|
| 1765 |
+
# Refine epilogue stages:
|
| 1766 |
+
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
| 1767 |
+
# Add remaining unused smem to epilogue
|
| 1768 |
+
if not overlap_sD_sA and epi_bytes_per_stage > 0:
|
| 1769 |
+
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
|
| 1770 |
+
return ab_stage, epi_stage, epi_c_stage
|
| 1771 |
+
|
| 1772 |
+
@staticmethod
|
| 1773 |
+
def _sm90_compute_tile_shape_or_override(
|
| 1774 |
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
| 1775 |
+
atom_layout_mnk: Tuple[int, int, int],
|
| 1776 |
+
element_type: Optional[Type[cutlass.Numeric]] = None,
|
| 1777 |
+
epi_tile_override: Tuple[int, int] | None = None,
|
| 1778 |
+
) -> Tuple[int, int]:
|
| 1779 |
+
"""Compute the epilogue tile shape or use override if provided.
|
| 1780 |
+
|
| 1781 |
+
:param cta_tile_shape_mnk: CTA tile shape (M,N,K)
|
| 1782 |
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
| 1783 |
+
:param element_type: Data type of elements
|
| 1784 |
+
:type element_type: type[cutlass.Numeric]
|
| 1785 |
+
:param is_cooperative: Whether to use cooperative approach
|
| 1786 |
+
:type is_cooperative: bool
|
| 1787 |
+
:param epi_tile_override: Optional override for epilogue tile shape
|
| 1788 |
+
:type epi_tile_override: Tuple[int, int] or None
|
| 1789 |
+
|
| 1790 |
+
:return: Computed epilogue tile shape
|
| 1791 |
+
:rtype: Tuple[int, int]
|
| 1792 |
+
"""
|
| 1793 |
+
if epi_tile_override is not None:
|
| 1794 |
+
return epi_tile_override
|
| 1795 |
+
if cta_tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
|
| 1796 |
+
tile_m = math.gcd(128, cute.size(cta_tile_shape_mnk, mode=[0]))
|
| 1797 |
+
tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
|
| 1798 |
+
elif cta_tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
|
| 1799 |
+
tile_m = math.gcd(192, cute.size(cta_tile_shape_mnk, mode=[0]))
|
| 1800 |
+
tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
|
| 1801 |
+
else:
|
| 1802 |
+
# In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
|
| 1803 |
+
# epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
|
| 1804 |
+
# M dimension first, then move to the N dimension. But the accumulator in registers
|
| 1805 |
+
# iterate along the N dimension first, then move to the M dimension.
|
| 1806 |
+
# We could change the epilogue to accommodate this,
|
| 1807 |
+
# but it's easier to just set epi_tile_m = 64.
|
| 1808 |
+
n_perf = 64 if element_type is not None and element_type.width == 8 else 32
|
| 1809 |
+
tile_m = math.gcd(64, cute.size(cta_tile_shape_mnk, mode=[0]))
|
| 1810 |
+
tile_n = math.gcd(n_perf, cute.size(cta_tile_shape_mnk, mode=[1]))
|
| 1811 |
+
return (tile_m, tile_n)
|
| 1812 |
+
|
| 1813 |
+
@staticmethod
|
| 1814 |
+
def _make_smem_layouts(
|
| 1815 |
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
| 1816 |
+
epi_tile: Tuple[int, int],
|
| 1817 |
+
a_dtype: Type[cutlass.Numeric],
|
| 1818 |
+
a_layout: LayoutEnum,
|
| 1819 |
+
b_dtype: Type[cutlass.Numeric],
|
| 1820 |
+
b_layout: LayoutEnum,
|
| 1821 |
+
ab_stage: int,
|
| 1822 |
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
| 1823 |
+
d_layout: LayoutEnum,
|
| 1824 |
+
epi_stage: int,
|
| 1825 |
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
| 1826 |
+
c_layout: Optional[LayoutEnum],
|
| 1827 |
+
epi_c_stage: int,
|
| 1828 |
+
) -> Tuple[
|
| 1829 |
+
cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
|
| 1830 |
+
]:
|
| 1831 |
+
"""Create shared memory layouts for A, B, and C tensors.
|
| 1832 |
+
|
| 1833 |
+
:param cta_tile_shape_mnk: CTA tile shape (M,N,K)
|
| 1834 |
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
| 1835 |
+
:param epi_tile: Epilogue tile shape
|
| 1836 |
+
:type epi_tile: Tuple[int, int]
|
| 1837 |
+
:param a_dtype: Data type for matrix A
|
| 1838 |
+
:type a_dtype: type[cutlass.Numeric]
|
| 1839 |
+
:param a_layout: Layout enum for matrix A
|
| 1840 |
+
:type a_layout: LayoutEnum
|
| 1841 |
+
:param b_dtype: Data type for matrix B
|
| 1842 |
+
:type b_dtype: type[cutlass.Numeric]
|
| 1843 |
+
:param b_layout: Layout enum for matrix B
|
| 1844 |
+
:type b_layout: LayoutEnum
|
| 1845 |
+
:param ab_stage: Number of stages for A/B tensors
|
| 1846 |
+
:type ab_stage: int
|
| 1847 |
+
:param d_dtype: Data type for output matrix D
|
| 1848 |
+
:type d_dtype: type[cutlass.Numeric]
|
| 1849 |
+
:param d_layout: Layout enum for the output matrix C
|
| 1850 |
+
:type d_layout: LayoutEnum
|
| 1851 |
+
:param epi_stage: Number of epilogue stages
|
| 1852 |
+
:type epi_stage: int
|
| 1853 |
+
|
| 1854 |
+
:return: Tuple of shared memory layouts for A, B, and C
|
| 1855 |
+
:rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
|
| 1856 |
+
"""
|
| 1857 |
+
a_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
|
| 1858 |
+
|
| 1859 |
+
a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
| 1860 |
+
b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
| 1861 |
+
a_major_mode_size = cta_tile_shape_mnk[2 if a_is_k_major else 0]
|
| 1862 |
+
a_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
| 1863 |
+
sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
|
| 1864 |
+
a_dtype,
|
| 1865 |
+
)
|
| 1866 |
+
a_smem_layout_staged = cute.tile_to_shape(
|
| 1867 |
+
a_smem_layout_atom,
|
| 1868 |
+
cute.append(a_smem_shape, ab_stage),
|
| 1869 |
+
order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
|
| 1870 |
+
)
|
| 1871 |
+
|
| 1872 |
+
b_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
|
| 1873 |
+
|
| 1874 |
+
b_major_mode_size = cta_tile_shape_mnk[2 if b_is_k_major else 1]
|
| 1875 |
+
b_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
| 1876 |
+
sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
|
| 1877 |
+
b_dtype,
|
| 1878 |
+
)
|
| 1879 |
+
b_smem_layout_staged = cute.tile_to_shape(
|
| 1880 |
+
b_smem_layout_atom,
|
| 1881 |
+
cute.append(b_smem_shape, ab_stage),
|
| 1882 |
+
order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
|
| 1883 |
+
)
|
| 1884 |
+
|
| 1885 |
+
epi_smem_layout_staged = None
|
| 1886 |
+
if d_dtype is not None:
|
| 1887 |
+
epi_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
|
| 1888 |
+
d_dtype, d_layout, epi_tile, epi_stage
|
| 1889 |
+
)
|
| 1890 |
+
|
| 1891 |
+
epi_c_smem_layout_staged = None
|
| 1892 |
+
if c_dtype is not None:
|
| 1893 |
+
assert c_layout is not None
|
| 1894 |
+
epi_c_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
|
| 1895 |
+
c_dtype, c_layout, epi_tile, epi_c_stage
|
| 1896 |
+
)
|
| 1897 |
+
|
| 1898 |
+
return (
|
| 1899 |
+
a_smem_layout_staged,
|
| 1900 |
+
b_smem_layout_staged,
|
| 1901 |
+
epi_smem_layout_staged,
|
| 1902 |
+
epi_c_smem_layout_staged,
|
| 1903 |
+
)
|
| 1904 |
+
|
| 1905 |
+
@staticmethod
|
| 1906 |
+
def _make_tma_epi_atoms_and_tensors(
|
| 1907 |
+
tensor_d: cute.Tensor,
|
| 1908 |
+
epi_smem_layout_staged: cute.ComposedLayout,
|
| 1909 |
+
epi_tile: Tuple[int, int],
|
| 1910 |
+
op_type: Literal["store", "load", "add"],
|
| 1911 |
+
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
| 1912 |
+
"""Create TMA atoms and tensors for storing D or loading C.
|
| 1913 |
+
|
| 1914 |
+
:param tensor_d: Output tensor D
|
| 1915 |
+
:type tensor_d: cute.Tensor
|
| 1916 |
+
:param epi_smem_layout_staged: Shared memory layout for epilogue
|
| 1917 |
+
:type epi_smem_layout_staged: cute.ComposedLayout
|
| 1918 |
+
:param epi_tile: Epilogue tile shape
|
| 1919 |
+
:type epi_tile: Tuple[int, int]
|
| 1920 |
+
|
| 1921 |
+
:return: TMA atom and tensor for C
|
| 1922 |
+
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
| 1923 |
+
"""
|
| 1924 |
+
assert op_type in ["load", "store", "add"]
|
| 1925 |
+
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
| 1926 |
+
d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
|
| 1927 |
+
op = (
|
| 1928 |
+
cpasync.CopyBulkTensorTileG2SOp()
|
| 1929 |
+
if op_type == "load"
|
| 1930 |
+
else cpasync.CopyBulkTensorTileS2GOp()
|
| 1931 |
+
if op_type == "store"
|
| 1932 |
+
else cpasync.CopyReduceBulkTensorTileS2GOp(cute.ReductionOp.ADD)
|
| 1933 |
+
)
|
| 1934 |
+
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
| 1935 |
+
op, tensor_d, epi_smem_layout, d_cta_v_layout
|
| 1936 |
+
)
|
| 1937 |
+
return tma_atom_d, tma_tensor_d
|
| 1938 |
+
|
| 1939 |
+
@staticmethod
|
| 1940 |
+
def _make_tma_atoms_and_tensors(
|
| 1941 |
+
tensor: cute.Tensor,
|
| 1942 |
+
smem_layout: cute.ComposedLayout,
|
| 1943 |
+
smem_tile: Tuple[int, int],
|
| 1944 |
+
mcast_dim: int,
|
| 1945 |
+
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
| 1946 |
+
"""Create TMA atoms and tensors for input tensors.
|
| 1947 |
+
|
| 1948 |
+
:param tensor: Input tensor (A or B)
|
| 1949 |
+
:type tensor: cute.Tensor
|
| 1950 |
+
:param smem_layout: Shared memory layout for the tensor
|
| 1951 |
+
:type smem_layout: cute.ComposedLayout
|
| 1952 |
+
:param smem_tile: Shared memory tile shape
|
| 1953 |
+
:type smem_tile: Tuple[int, int]
|
| 1954 |
+
:param mcast_dim: Multicast dimension
|
| 1955 |
+
:type mcast_dim: int
|
| 1956 |
+
|
| 1957 |
+
:return: TMA atom and tensor
|
| 1958 |
+
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
| 1959 |
+
"""
|
| 1960 |
+
op = (
|
| 1961 |
+
cpasync.CopyBulkTensorTileG2SOp()
|
| 1962 |
+
if mcast_dim == 1
|
| 1963 |
+
else cpasync.CopyBulkTensorTileG2SMulticastOp()
|
| 1964 |
+
)
|
| 1965 |
+
tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
|
| 1966 |
+
op,
|
| 1967 |
+
tensor,
|
| 1968 |
+
smem_layout,
|
| 1969 |
+
smem_tile,
|
| 1970 |
+
num_multicast=mcast_dim,
|
| 1971 |
+
)
|
| 1972 |
+
return tma_atom, tma_tensor
|
| 1973 |
+
|
| 1974 |
+
def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128):
|
| 1975 |
+
atom_async_copy = cute.make_copy_atom(
|
| 1976 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
| 1977 |
+
dtype,
|
| 1978 |
+
num_bits_per_copy=copy_bits,
|
| 1979 |
+
)
|
| 1980 |
+
copy_elems = copy_bits // dtype.width
|
| 1981 |
+
loads_per_cache_line = 128 * 8 // copy_bits # 128 bytes per cache line
|
| 1982 |
+
shape_dim_1 = cute.size(self.cta_tile_shape_mnk[2]) // copy_elems
|
| 1983 |
+
if shape_dim_1 > loads_per_cache_line:
|
| 1984 |
+
shape_dim_1 = math.gcd(shape_dim_1, loads_per_cache_line)
|
| 1985 |
+
# thread layout for copy
|
| 1986 |
+
thread_layout = cute.make_layout(
|
| 1987 |
+
(num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
| 1988 |
+
)
|
| 1989 |
+
if major_mode != LayoutEnum.ROW_MAJOR:
|
| 1990 |
+
shape_dim_0 = cute.size(self.cta_tile_shape_mnk[0]) // copy_elems
|
| 1991 |
+
if shape_dim_0 > loads_per_cache_line:
|
| 1992 |
+
shape_dim_0 = math.gcd(shape_dim_0, loads_per_cache_line)
|
| 1993 |
+
thread_layout = cute.make_layout(
|
| 1994 |
+
(shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
| 1995 |
+
)
|
| 1996 |
+
# Value layout for copy
|
| 1997 |
+
value_layout = (
|
| 1998 |
+
cute.make_layout((1, copy_elems))
|
| 1999 |
+
if major_mode == LayoutEnum.ROW_MAJOR
|
| 2000 |
+
else cute.make_layout((copy_elems, 1))
|
| 2001 |
+
)
|
| 2002 |
+
return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
|
| 2003 |
+
|
| 2004 |
+
@staticmethod
|
| 2005 |
+
def is_valid_dtypes(
|
| 2006 |
+
a_dtype: Type[cutlass.Numeric],
|
| 2007 |
+
b_dtype: Type[cutlass.Numeric],
|
| 2008 |
+
acc_dtype: Type[cutlass.Numeric],
|
| 2009 |
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
| 2010 |
+
a_major: str,
|
| 2011 |
+
b_major: str,
|
| 2012 |
+
) -> bool:
|
| 2013 |
+
"""
|
| 2014 |
+
Check if the dtypes are valid
|
| 2015 |
+
|
| 2016 |
+
:param a_dtype: The data type of tensor A
|
| 2017 |
+
:type a_dtype: Type[cutlass.Numeric]
|
| 2018 |
+
:param b_dtype: The data type of tensor B
|
| 2019 |
+
:type b_dtype: Type[cutlass.Numeric]
|
| 2020 |
+
:param acc_dtype: The data type of the accumulator
|
| 2021 |
+
:type acc_dtype: Type[cutlass.Numeric]
|
| 2022 |
+
:param d_dtype: The data type of the output tensor
|
| 2023 |
+
:type d_dtype: Type[cutlass.Numeric]
|
| 2024 |
+
:param a_major: major mode of tensor A
|
| 2025 |
+
:type a_major: str
|
| 2026 |
+
:param b_major: major mode of tensor B
|
| 2027 |
+
:type b_major: str
|
| 2028 |
+
|
| 2029 |
+
:return: True if the dtypes are valid, False otherwise
|
| 2030 |
+
:rtype: bool
|
| 2031 |
+
"""
|
| 2032 |
+
is_valid = True
|
| 2033 |
+
if a_dtype not in {
|
| 2034 |
+
Float16,
|
| 2035 |
+
cutlass.BFloat16,
|
| 2036 |
+
cutlass.Float8E4M3FN,
|
| 2037 |
+
cutlass.Float8E5M2,
|
| 2038 |
+
}:
|
| 2039 |
+
is_valid = False
|
| 2040 |
+
# tested b_dtype
|
| 2041 |
+
if b_dtype not in {
|
| 2042 |
+
Float16,
|
| 2043 |
+
cutlass.BFloat16,
|
| 2044 |
+
cutlass.Float8E4M3FN,
|
| 2045 |
+
cutlass.Float8E5M2,
|
| 2046 |
+
}:
|
| 2047 |
+
is_valid = False
|
| 2048 |
+
if acc_dtype not in {Float32, Float16}:
|
| 2049 |
+
is_valid = False
|
| 2050 |
+
# tested d_dtype
|
| 2051 |
+
if d_dtype not in {
|
| 2052 |
+
None,
|
| 2053 |
+
Float32,
|
| 2054 |
+
Float16,
|
| 2055 |
+
cutlass.BFloat16,
|
| 2056 |
+
cutlass.Float8E4M3FN,
|
| 2057 |
+
cutlass.Float8E5M2,
|
| 2058 |
+
}:
|
| 2059 |
+
is_valid = False
|
| 2060 |
+
# make sure a_dtype == b_dtype for Float16
|
| 2061 |
+
if a_dtype.width == 16 and a_dtype != b_dtype:
|
| 2062 |
+
is_valid = False
|
| 2063 |
+
# make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2)
|
| 2064 |
+
if a_dtype.width != b_dtype.width:
|
| 2065 |
+
is_valid = False
|
| 2066 |
+
|
| 2067 |
+
# for Float8 types, this implementation only supports k-major layout
|
| 2068 |
+
if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
|
| 2069 |
+
is_valid = False
|
| 2070 |
+
return is_valid
|
build/torch-cuda/quack/gemm_symmetric.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Optional, Callable
|
| 2 |
+
from functools import partial
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
from .gemm_act import GemmActMixin, act_fn_map, gemm_act
|
| 5 |
+
from .gemm_sm90 import GemmSm90
|
| 6 |
+
from .gemm_sm100 import GemmSm100
|
| 7 |
+
from .tile_scheduler import TriangularTileScheduler
|
| 8 |
+
from .gemm_wrapper_utils import GemmWrapperBase
|
| 9 |
+
from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
| 10 |
+
from .varlen_utils import VarlenManager
|
| 11 |
+
from . import copy_utils as copy_utils
|
| 12 |
+
import cutlass
|
| 13 |
+
import cutlass.cute as cute
|
| 14 |
+
import cutlass.torch as cutlass_torch
|
| 15 |
+
from cutlass.cute.runtime import make_ptr
|
| 16 |
+
from cutlass import Int32, Float32, Boolean, const_expr
|
| 17 |
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
| 18 |
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 19 |
+
from cutlass.cutlass_dsl import if_generate
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
| 23 |
+
def get_scheduler_class(self, varlen_m: bool = False):
|
| 24 |
+
return TriangularTileScheduler
|
| 25 |
+
|
| 26 |
+
@cute.jit
|
| 27 |
+
def epilogue(
|
| 28 |
+
self,
|
| 29 |
+
params: GemmActMixin.EpilogueParams,
|
| 30 |
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 31 |
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
| 32 |
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
| 33 |
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
| 34 |
+
epi_read_state: cutlass.pipeline.PipelineState,
|
| 35 |
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
| 36 |
+
epi_tile: cute.Tile,
|
| 37 |
+
load_acc_subtile: Callable,
|
| 38 |
+
tRS_rD: cute.Tensor,
|
| 39 |
+
tRS_rC: Optional[cute.Tensor],
|
| 40 |
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
| 41 |
+
tiled_copy_r2s: cute.TiledCopy,
|
| 42 |
+
tRS_sD: cute.Tensor,
|
| 43 |
+
tiled_copy_s2r: Optional[cute.TiledCopy],
|
| 44 |
+
tSR_rC: Optional[cute.Tensor],
|
| 45 |
+
tSR_sC: Optional[cute.Tensor],
|
| 46 |
+
copy_D: Optional[Callable],
|
| 47 |
+
copy_C: Optional[Callable],
|
| 48 |
+
tile_coord_mnkl: cute.Coord,
|
| 49 |
+
varlen_manager: VarlenManager,
|
| 50 |
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
| 51 |
+
tile_scheduler,
|
| 52 |
+
tidx: Int32,
|
| 53 |
+
is_tma_warp: Boolean,
|
| 54 |
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
| 55 |
+
has_C = const_expr(tRS_rC is not None)
|
| 56 |
+
has_D = const_expr(copy_D is not None)
|
| 57 |
+
|
| 58 |
+
tma_atom_postact = params.tma_atom_postact
|
| 59 |
+
mPostAct_mnl = params.mPostAct_mnl
|
| 60 |
+
sRowVec, sColVec, sPostAct = epi_smem_tensors
|
| 61 |
+
get_smem_store_op = (
|
| 62 |
+
partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
|
| 63 |
+
if self.arch == 100
|
| 64 |
+
else sm90_utils_og.sm90_get_smem_store_op
|
| 65 |
+
)
|
| 66 |
+
copy_atom_postact_r2s = get_smem_store_op(
|
| 67 |
+
self.postact_layout, self.postact_dtype, self.acc_dtype
|
| 68 |
+
)
|
| 69 |
+
# tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
| 70 |
+
# tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
| 71 |
+
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
| 72 |
+
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
| 73 |
+
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
| 74 |
+
batch_idx = tile_coord_mnkl[3]
|
| 75 |
+
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
| 76 |
+
tma_atom_postact,
|
| 77 |
+
varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
|
| 78 |
+
self.cta_tile_shape_postact_mn,
|
| 79 |
+
params.epi_tile_postact,
|
| 80 |
+
sPostAct,
|
| 81 |
+
tile_coord_mnkl,
|
| 82 |
+
tma_desc_ptr=tma_desc_postact_ptr,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
| 86 |
+
epi_tile_shape = cute.zipped_divide(
|
| 87 |
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
| 88 |
+
).shape[1]
|
| 89 |
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
| 90 |
+
epi_tile_num = cute.size(epi_tile_shape)
|
| 91 |
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
| 92 |
+
|
| 93 |
+
epi_tensors = self.epi_begin(
|
| 94 |
+
params,
|
| 95 |
+
epi_smem_tensors,
|
| 96 |
+
epi_tile,
|
| 97 |
+
tiled_copy_t2r,
|
| 98 |
+
tiled_copy_r2s,
|
| 99 |
+
tile_coord_mnkl,
|
| 100 |
+
varlen_manager,
|
| 101 |
+
epilogue_barrier,
|
| 102 |
+
tidx,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if const_expr(copy_C is not None):
|
| 106 |
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
| 107 |
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
| 108 |
+
if is_tma_warp:
|
| 109 |
+
epi_pipeline.producer_acquire(epi_producer_state)
|
| 110 |
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 111 |
+
epi_pipeline.producer_commit(epi_producer_state)
|
| 112 |
+
epi_producer_state.advance()
|
| 113 |
+
|
| 114 |
+
def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl):
|
| 115 |
+
pid_m = tile_coord_mnkl[0]
|
| 116 |
+
pid_n = tile_coord_mnkl[1]
|
| 117 |
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
| 118 |
+
cute.arch.fence_proxy(
|
| 119 |
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 120 |
+
)
|
| 121 |
+
epilogue_barrier.arrive_and_wait()
|
| 122 |
+
# Copy from shared memory to global memory
|
| 123 |
+
if is_tma_warp:
|
| 124 |
+
square_tile_m = pid_m // self.cluster_shape_mnk[0]
|
| 125 |
+
square_tile_n = pid_n // self.cluster_shape_mnk[1]
|
| 126 |
+
if const_expr(has_D):
|
| 127 |
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
| 128 |
+
if square_tile_m != square_tile_n: # don't write twice to the same tile
|
| 129 |
+
copy_postact(src_idx=src_idx, dst_idx=dst_idx)
|
| 130 |
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
| 131 |
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
| 132 |
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
| 133 |
+
epilogue_barrier.arrive_and_wait()
|
| 134 |
+
|
| 135 |
+
delay_tma_store = True
|
| 136 |
+
|
| 137 |
+
src_idx_prev, dst_idx_prev = None, None
|
| 138 |
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
| 139 |
+
# The global memory coordinate for the current epi tile
|
| 140 |
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
| 141 |
+
# Copy from acc to D registers
|
| 142 |
+
load_acc_subtile(tRS_rD, epi_idx)
|
| 143 |
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
| 144 |
+
if const_expr(has_C):
|
| 145 |
+
epi_pipeline.consumer_wait(epi_read_state)
|
| 146 |
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
| 147 |
+
# Fence to make sure shared memory read is visible to TMA load
|
| 148 |
+
cute.arch.fence_proxy(
|
| 149 |
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 150 |
+
)
|
| 151 |
+
cute.arch.sync_warp()
|
| 152 |
+
with cute.arch.elect_one():
|
| 153 |
+
epi_pipeline.consumer_release(epi_read_state)
|
| 154 |
+
epi_read_state.advance()
|
| 155 |
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
| 156 |
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
| 157 |
+
if is_tma_warp:
|
| 158 |
+
epi_pipeline.producer_acquire(epi_producer_state)
|
| 159 |
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 160 |
+
epi_pipeline.producer_commit(epi_producer_state)
|
| 161 |
+
epi_producer_state.advance()
|
| 162 |
+
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 163 |
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
| 164 |
+
if const_expr(delay_tma_store):
|
| 165 |
+
if const_expr(epi_idx > 0):
|
| 166 |
+
tma_store_fn(
|
| 167 |
+
src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
|
| 168 |
+
)
|
| 169 |
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
| 170 |
+
# Copy from D registers to shared memory
|
| 171 |
+
if const_expr(has_D):
|
| 172 |
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
| 173 |
+
cute.copy(
|
| 174 |
+
tiled_copy_postact_r2s,
|
| 175 |
+
tiled_copy_postact_r2s.retile(tRS_rPostAct),
|
| 176 |
+
tRS_sPostAct[None, None, None, epi_buffer],
|
| 177 |
+
)
|
| 178 |
+
if const_expr(not delay_tma_store):
|
| 179 |
+
tma_store_fn(
|
| 180 |
+
src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if const_expr(delay_tma_store):
|
| 184 |
+
tma_store_fn(
|
| 185 |
+
src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.epi_end(
|
| 189 |
+
params,
|
| 190 |
+
epi_tensors,
|
| 191 |
+
epi_tile,
|
| 192 |
+
tiled_copy_t2r,
|
| 193 |
+
tiled_copy_r2s,
|
| 194 |
+
tile_coord_mnkl,
|
| 195 |
+
varlen_manager,
|
| 196 |
+
tidx,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
return epi_read_state, epi_producer_state
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90):
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
|
| 207 |
+
pass
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def gemm_symmetric(
|
| 211 |
+
A: Tensor, # (l, m, k)
|
| 212 |
+
B: Tensor, # (l, m, k)
|
| 213 |
+
D: Optional[Tensor], # (l, m, m)
|
| 214 |
+
C: Optional[Tensor], # (l, m, m)
|
| 215 |
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 216 |
+
tile_M: int,
|
| 217 |
+
tile_N: int,
|
| 218 |
+
cluster_M: int,
|
| 219 |
+
cluster_N: int,
|
| 220 |
+
pingpong: bool = False,
|
| 221 |
+
persistent: bool = True,
|
| 222 |
+
max_swizzle_size: int = 8,
|
| 223 |
+
alpha: float | Tensor = 1.0,
|
| 224 |
+
beta: float | Tensor = 1.0,
|
| 225 |
+
) -> None:
|
| 226 |
+
# Tranpose D so the "activation" is a write to the mirrored tile
|
| 227 |
+
PostAct = D.mT
|
| 228 |
+
|
| 229 |
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
| 230 |
+
A, B, D, C, additional_tensors={"PostAct": PostAct}
|
| 231 |
+
)
|
| 232 |
+
assert M == N, "M and N must be the same; symmetric gemm only supports square matrices"
|
| 233 |
+
GemmWrapperBase.permute_tensors(tensor_infos)
|
| 234 |
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
| 235 |
+
major_configs = {
|
| 236 |
+
"A": ("m", "k", "l"),
|
| 237 |
+
"B": ("n", "k", "l"),
|
| 238 |
+
"D": ("m", "n", "l"),
|
| 239 |
+
"C": ("m", "n", "l"),
|
| 240 |
+
"PostAct": ("m", "n", "l"),
|
| 241 |
+
}
|
| 242 |
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
| 243 |
+
|
| 244 |
+
device_capacity = get_device_capacity(A.device)
|
| 245 |
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
| 246 |
+
GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100
|
| 247 |
+
|
| 248 |
+
acc_dtype = Float32
|
| 249 |
+
tile_shape_mn = (tile_M, tile_N)
|
| 250 |
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
| 251 |
+
if not GemmCls.is_valid_dtypes(
|
| 252 |
+
tensor_infos["A"].dtype,
|
| 253 |
+
tensor_infos["B"].dtype,
|
| 254 |
+
acc_dtype,
|
| 255 |
+
tensor_infos["D"].dtype,
|
| 256 |
+
tensor_infos["A"].major,
|
| 257 |
+
tensor_infos["B"].major,
|
| 258 |
+
):
|
| 259 |
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
| 260 |
+
|
| 261 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 262 |
+
GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs)
|
| 263 |
+
|
| 264 |
+
def scalar_arg(scalar: float | Tensor):
|
| 265 |
+
if isinstance(scalar, float):
|
| 266 |
+
return Float32(scalar) if scalar != 1.0 else None
|
| 267 |
+
else:
|
| 268 |
+
assert isinstance(scalar, Tensor)
|
| 269 |
+
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
| 270 |
+
|
| 271 |
+
activation = None # Equivalent to identity
|
| 272 |
+
act_fn = act_fn_map[activation]
|
| 273 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 274 |
+
tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta)
|
| 275 |
+
)
|
| 276 |
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
| 277 |
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
| 278 |
+
)
|
| 279 |
+
varlen_args = None
|
| 280 |
+
|
| 281 |
+
current_stream = cutlass_torch.current_stream()
|
| 282 |
+
compile_key = GemmWrapperBase.get_compile_key(
|
| 283 |
+
tensor_infos,
|
| 284 |
+
activation,
|
| 285 |
+
tile_shape_mn,
|
| 286 |
+
cluster_shape_mnk,
|
| 287 |
+
pingpong,
|
| 288 |
+
persistent,
|
| 289 |
+
tile_count_semaphore is not None,
|
| 290 |
+
device_capacity,
|
| 291 |
+
max_swizzle_size,
|
| 292 |
+
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
| 293 |
+
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
| 294 |
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
| 295 |
+
)
|
| 296 |
+
cache = gemm_act.compile_cache
|
| 297 |
+
if compile_key not in cache:
|
| 298 |
+
if device_capacity[0] == 9:
|
| 299 |
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
| 300 |
+
gemm_obj = GemmCls(
|
| 301 |
+
acc_dtype,
|
| 302 |
+
tensor_infos["A"].dtype,
|
| 303 |
+
tile_shape_mn,
|
| 304 |
+
cluster_shape_mnk,
|
| 305 |
+
gather_A=False,
|
| 306 |
+
)
|
| 307 |
+
cache[compile_key] = cute.compile(
|
| 308 |
+
gemm_obj,
|
| 309 |
+
tensor_infos["A"].cute_tensor,
|
| 310 |
+
tensor_infos["B"].cute_tensor,
|
| 311 |
+
tensor_infos["D"].cute_tensor,
|
| 312 |
+
tensor_infos["C"].cute_tensor,
|
| 313 |
+
epi_args,
|
| 314 |
+
scheduler_args,
|
| 315 |
+
varlen_args,
|
| 316 |
+
current_stream,
|
| 317 |
+
)
|
| 318 |
+
cache[compile_key](
|
| 319 |
+
tensor_infos["A"].cute_tensor,
|
| 320 |
+
tensor_infos["B"].cute_tensor,
|
| 321 |
+
tensor_infos["D"].cute_tensor,
|
| 322 |
+
tensor_infos["C"].cute_tensor,
|
| 323 |
+
epi_args,
|
| 324 |
+
scheduler_args,
|
| 325 |
+
varlen_args,
|
| 326 |
+
current_stream,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
gemm_act.compile_cache = {}
|
build/torch-cuda/quack/gemm_wrapper_utils.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
from typing import Optional, Tuple, Dict, Any
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Int32
|
| 10 |
+
from cutlass.cute.runtime import from_dlpack, make_ptr
|
| 11 |
+
|
| 12 |
+
from .cute_dsl_utils import torch2cute_dtype_map
|
| 13 |
+
from .varlen_utils import VarlenArguments
|
| 14 |
+
from .tile_scheduler import TileSchedulerOptions
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class GemmTensorInfo:
|
| 19 |
+
tensor: Optional[Tensor]
|
| 20 |
+
dtype: Optional[Any] = None
|
| 21 |
+
major: Optional[str] = None
|
| 22 |
+
cute_tensor: Optional[cute.Tensor] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GemmWrapperBase:
|
| 26 |
+
@staticmethod
|
| 27 |
+
def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None:
|
| 28 |
+
assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor"
|
| 29 |
+
assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None:
|
| 33 |
+
assert tensor.shape == expected_shape, (
|
| 34 |
+
f"{name} must have shape {expected_shape}, got {tensor.shape}"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str:
|
| 39 |
+
# Tensor is already permuted to (dims[0], dims[1], dims[2])
|
| 40 |
+
# stride(1) == 1 means dims[1] is contiguous (innermost)
|
| 41 |
+
return dims[1] if tensor.stride(1) == 1 else dims[0]
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def create_cute_tensor(
|
| 45 |
+
tensor: Optional[Tensor],
|
| 46 |
+
major: Optional[str],
|
| 47 |
+
dims: Tuple[str, str, str],
|
| 48 |
+
assumed_align: int = 16,
|
| 49 |
+
) -> Optional[cute.Tensor]:
|
| 50 |
+
if tensor is None:
|
| 51 |
+
return None
|
| 52 |
+
# Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1])
|
| 53 |
+
# If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
|
| 54 |
+
leading_dim = 1 if major == dims[1] else 0
|
| 55 |
+
return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
|
| 56 |
+
leading_dim=leading_dim
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def validate_and_prepare_tensors(
|
| 61 |
+
A: Tensor,
|
| 62 |
+
B: Tensor,
|
| 63 |
+
D: Optional[Tensor] = None,
|
| 64 |
+
C: Optional[Tensor] = None,
|
| 65 |
+
additional_tensors: Optional[Dict[str, Tensor]] = None,
|
| 66 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 67 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 68 |
+
A_idx: Optional[Tensor] = None,
|
| 69 |
+
) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
|
| 70 |
+
assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
|
| 71 |
+
"Only one of cu_seqlens_m and cu_seqlens_k can be specified"
|
| 72 |
+
)
|
| 73 |
+
assert B.dtype == A.dtype, "A and B must have the same dtype"
|
| 74 |
+
|
| 75 |
+
# Validate A_idx if provided (for gather_A case)
|
| 76 |
+
gather_A = A_idx is not None
|
| 77 |
+
if gather_A:
|
| 78 |
+
assert cu_seqlens_m is not None or cu_seqlens_k is not None, (
|
| 79 |
+
"gather_A requires either varlen_m or varlen_k"
|
| 80 |
+
)
|
| 81 |
+
assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}"
|
| 82 |
+
assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D"
|
| 83 |
+
|
| 84 |
+
# Determine mode and extract dimensions
|
| 85 |
+
if cu_seqlens_m is not None:
|
| 86 |
+
# varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n)
|
| 87 |
+
assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D"
|
| 88 |
+
assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D"
|
| 89 |
+
|
| 90 |
+
if gather_A:
|
| 91 |
+
# When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M
|
| 92 |
+
total_M = A_idx.shape[0]
|
| 93 |
+
_, K = A.shape
|
| 94 |
+
else:
|
| 95 |
+
total_M, K = A.shape
|
| 96 |
+
|
| 97 |
+
L, N, K_B = B.shape
|
| 98 |
+
assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
|
| 99 |
+
assert cu_seqlens_m.shape == (L + 1,), (
|
| 100 |
+
f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}"
|
| 101 |
+
)
|
| 102 |
+
M = total_M
|
| 103 |
+
dc_shape = (total_M, N)
|
| 104 |
+
dc_ndim = 2
|
| 105 |
+
elif cu_seqlens_k is not None:
|
| 106 |
+
# varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n)
|
| 107 |
+
assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D"
|
| 108 |
+
assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D"
|
| 109 |
+
|
| 110 |
+
if gather_A:
|
| 111 |
+
# When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K
|
| 112 |
+
M, _ = A.shape
|
| 113 |
+
total_K = A_idx.shape[0]
|
| 114 |
+
else:
|
| 115 |
+
M, total_K = A.shape
|
| 116 |
+
|
| 117 |
+
N, K_B = B.shape
|
| 118 |
+
assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}"
|
| 119 |
+
L = cu_seqlens_k.shape[0] - 1
|
| 120 |
+
assert cu_seqlens_k.shape == (L + 1,), (
|
| 121 |
+
f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}"
|
| 122 |
+
)
|
| 123 |
+
K = total_K
|
| 124 |
+
dc_shape = (L, M, N)
|
| 125 |
+
dc_ndim = 3
|
| 126 |
+
else:
|
| 127 |
+
# Normal case - all tensors must be 3D
|
| 128 |
+
GemmWrapperBase.validate_tensor(A, "A", 3)
|
| 129 |
+
GemmWrapperBase.validate_tensor(B, "B", 3)
|
| 130 |
+
L, M, K = A.shape
|
| 131 |
+
_, N, K_B = B.shape
|
| 132 |
+
assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
|
| 133 |
+
GemmWrapperBase.validate_shape(B, (L, N, K), "B")
|
| 134 |
+
dc_shape = (L, M, N)
|
| 135 |
+
dc_ndim = 3
|
| 136 |
+
|
| 137 |
+
# Validate D and C shapes uniformly
|
| 138 |
+
for tensor, name in [(D, "D"), (C, "C")]:
|
| 139 |
+
if tensor is not None:
|
| 140 |
+
assert tensor.dim() == dc_ndim, (
|
| 141 |
+
f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
|
| 142 |
+
)
|
| 143 |
+
assert tensor.shape == dc_shape, (
|
| 144 |
+
f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
tensors = {
|
| 148 |
+
"A": GemmTensorInfo(A),
|
| 149 |
+
"B": GemmTensorInfo(B),
|
| 150 |
+
"D": GemmTensorInfo(D),
|
| 151 |
+
"C": GemmTensorInfo(C),
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
if additional_tensors:
|
| 155 |
+
for name, tensor in additional_tensors.items():
|
| 156 |
+
if tensor is not None:
|
| 157 |
+
assert tensor.dim() == dc_ndim, (
|
| 158 |
+
f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
|
| 159 |
+
)
|
| 160 |
+
assert tensor.shape == dc_shape, (
|
| 161 |
+
f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
|
| 162 |
+
)
|
| 163 |
+
tensors[name] = GemmTensorInfo(tensor)
|
| 164 |
+
|
| 165 |
+
return L, M, K, N, tensors
|
| 166 |
+
|
| 167 |
+
@staticmethod
|
| 168 |
+
def permute_tensors(
|
| 169 |
+
tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False
|
| 170 |
+
) -> None:
|
| 171 |
+
# Determine which tensors need permutation
|
| 172 |
+
if varlen_m:
|
| 173 |
+
# Only B needs permutation (3D tensor)
|
| 174 |
+
tensors_to_permute = ["B"]
|
| 175 |
+
elif varlen_k:
|
| 176 |
+
# Only D and C need permutation (3D tensors)
|
| 177 |
+
tensors_to_permute = ["D", "C"]
|
| 178 |
+
else:
|
| 179 |
+
# All tensors need permutation
|
| 180 |
+
tensors_to_permute = None
|
| 181 |
+
|
| 182 |
+
# Apply permutation from (L, *, *) -> (*, *, L) for selected tensors
|
| 183 |
+
for name, info in tensors.items():
|
| 184 |
+
if info.tensor is not None and info.tensor.ndim == 3:
|
| 185 |
+
if tensors_to_permute is None or name in tensors_to_permute:
|
| 186 |
+
info.tensor = info.tensor.permute(1, 2, 0)
|
| 187 |
+
|
| 188 |
+
@staticmethod
|
| 189 |
+
def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
|
| 190 |
+
for name, info in tensors.items():
|
| 191 |
+
if info.tensor is not None:
|
| 192 |
+
info.dtype = torch2cute_dtype_map[info.tensor.dtype]
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def determine_major_orders(
|
| 196 |
+
tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
|
| 197 |
+
) -> None:
|
| 198 |
+
for name, dims in major_configs.items():
|
| 199 |
+
if name in tensors and tensors[name].tensor is not None:
|
| 200 |
+
tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims)
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def create_cute_tensors(
|
| 204 |
+
tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
|
| 205 |
+
) -> None:
|
| 206 |
+
for name, info in tensors.items():
|
| 207 |
+
if info.tensor is not None and name in major_configs:
|
| 208 |
+
info.cute_tensor = GemmWrapperBase.create_cute_tensor(
|
| 209 |
+
info.tensor, info.major, major_configs[name]
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
@staticmethod
|
| 213 |
+
def create_scheduler_args(
|
| 214 |
+
max_active_clusters: int,
|
| 215 |
+
tile_count_semaphore: Optional[Tensor] = None,
|
| 216 |
+
batch_idx_permute: Optional[Tensor] = None,
|
| 217 |
+
max_swizzle_size: int = 8,
|
| 218 |
+
) -> TileSchedulerOptions:
|
| 219 |
+
return TileSchedulerOptions(
|
| 220 |
+
Int32(max_active_clusters),
|
| 221 |
+
tile_count_semaphore=make_ptr(
|
| 222 |
+
Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
|
| 223 |
+
)
|
| 224 |
+
if tile_count_semaphore is not None
|
| 225 |
+
else None,
|
| 226 |
+
batch_idx_permute=(
|
| 227 |
+
from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
| 228 |
+
)
|
| 229 |
+
if batch_idx_permute is not None
|
| 230 |
+
else None,
|
| 231 |
+
max_swizzle_size=Int32(max_swizzle_size),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
def create_varlen_args(
|
| 236 |
+
cu_seqlens_m: Optional[Tensor],
|
| 237 |
+
cu_seqlens_k: Optional[Tensor],
|
| 238 |
+
A_idx: Optional[Tensor],
|
| 239 |
+
max_active_clusters: int,
|
| 240 |
+
cluster_shape_mnk: Tuple[int, int, int],
|
| 241 |
+
tensors: Dict[str, GemmTensorInfo],
|
| 242 |
+
num_epi_tensormaps: int = 0,
|
| 243 |
+
pingpong: bool = False,
|
| 244 |
+
) -> Optional[Any]:
|
| 245 |
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 246 |
+
return None
|
| 247 |
+
# When varlen_m, we assume persistent=True
|
| 248 |
+
# Grid size depends on num_active_clusters and cluster size
|
| 249 |
+
cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
| 250 |
+
num_blocks = max_active_clusters * cluster_size
|
| 251 |
+
# Calculate number of tensormaps needed
|
| 252 |
+
if cu_seqlens_m is not None:
|
| 253 |
+
# For varlen_m: need tensormaps for D and epilogue tensors
|
| 254 |
+
num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2)
|
| 255 |
+
if tensors["D"].tensor is not None:
|
| 256 |
+
num_tensormaps += 1 if not pingpong else 2 # D tensormap
|
| 257 |
+
else:
|
| 258 |
+
# For varlen_k: need tensormaps for A & B
|
| 259 |
+
num_tensormaps = 2 if A_idx is None else 1
|
| 260 |
+
# Create tensormap buffer (each tensormap is 128 bytes = 16 int64s)
|
| 261 |
+
tensormap_size = 128 // 8 # 16 int64s
|
| 262 |
+
if num_tensormaps > 0:
|
| 263 |
+
device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device
|
| 264 |
+
tensormaps = torch.empty(
|
| 265 |
+
(num_blocks, num_tensormaps, tensormap_size),
|
| 266 |
+
dtype=torch.int64,
|
| 267 |
+
device=device,
|
| 268 |
+
)
|
| 269 |
+
tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic(
|
| 270 |
+
mode=0, stride_order=(0, 1, 2)
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
tensormaps_cute = None
|
| 274 |
+
|
| 275 |
+
return VarlenArguments(
|
| 276 |
+
mCuSeqlensM=(
|
| 277 |
+
from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
| 278 |
+
if cu_seqlens_m is not None
|
| 279 |
+
else None
|
| 280 |
+
),
|
| 281 |
+
mCuSeqlensK=(
|
| 282 |
+
from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
| 283 |
+
if cu_seqlens_k is not None
|
| 284 |
+
else None
|
| 285 |
+
),
|
| 286 |
+
mTensormaps=tensormaps_cute,
|
| 287 |
+
mAIdx=(
|
| 288 |
+
from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
| 289 |
+
if A_idx is not None
|
| 290 |
+
else None
|
| 291 |
+
),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
@staticmethod
|
| 295 |
+
def get_compile_key(
|
| 296 |
+
tensors: Dict[str, GemmTensorInfo],
|
| 297 |
+
activation: Optional[str],
|
| 298 |
+
tile_shape_mn: Tuple[int, int],
|
| 299 |
+
cluster_shape_mnk: Tuple[int, int, int],
|
| 300 |
+
pingpong: bool,
|
| 301 |
+
persistent: bool,
|
| 302 |
+
has_semaphore: bool,
|
| 303 |
+
*args,
|
| 304 |
+
key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"),
|
| 305 |
+
) -> Tuple:
|
| 306 |
+
key_parts = []
|
| 307 |
+
for name in key_tensor_names:
|
| 308 |
+
if name in tensors:
|
| 309 |
+
key_parts.append(tensors[name].dtype)
|
| 310 |
+
key_parts.append(activation)
|
| 311 |
+
key_parts.extend([tile_shape_mn, cluster_shape_mnk])
|
| 312 |
+
for name in key_tensor_names:
|
| 313 |
+
if name in tensors:
|
| 314 |
+
key_parts.append(tensors[name].major)
|
| 315 |
+
key_parts.extend([pingpong, persistent, has_semaphore])
|
| 316 |
+
key_parts.extend(args)
|
| 317 |
+
return tuple(key_parts)
|
build/torch-cuda/quack/layout_utils.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
|
| 7 |
+
from cutlass import Int32, const_expr
|
| 8 |
+
|
| 9 |
+
from .utils import prmt
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
| 13 |
+
"""Transpose the first two dimensions of a tensor on smem."""
|
| 14 |
+
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
| 15 |
+
order = (1, 0, *range(2, cute.rank(a)))
|
| 16 |
+
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
| 20 |
+
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
|
| 24 |
+
shape = (*a.shape[:dim], size, *a.shape[dim:])
|
| 25 |
+
stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
|
| 26 |
+
return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@cute.jit
|
| 30 |
+
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
| 31 |
+
assert t.element_type.width == 16
|
| 32 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
| 33 |
+
t_u32 = cute.recast_tensor(t, Int32)
|
| 34 |
+
|
| 35 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 36 |
+
lane_03 = quad_idx == 0 or quad_idx == 3
|
| 37 |
+
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
| 38 |
+
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
| 39 |
+
# upper_map = [0, 3, 1, 2]
|
| 40 |
+
# lower_map = [1, 2, 0, 3]
|
| 41 |
+
# upper_idx = upper_map[quad_idx]
|
| 42 |
+
# indexing isn't supported so we have to do arithmetic
|
| 43 |
+
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
| 44 |
+
lower_idx = upper_idx ^ 1
|
| 45 |
+
|
| 46 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 47 |
+
width = 4
|
| 48 |
+
mask = cute.arch.WARP_SIZE - width
|
| 49 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 50 |
+
mask_and_clamp = mask << 8 | clamp
|
| 51 |
+
|
| 52 |
+
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
| 53 |
+
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
| 54 |
+
upper0 = upper if lane_03 else lower
|
| 55 |
+
lower0 = lower if lane_03 else upper
|
| 56 |
+
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
| 57 |
+
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
| 58 |
+
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
| 59 |
+
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@cute.jit
|
| 63 |
+
def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
|
| 64 |
+
"""Permute and shuffle within 4 threads to change the layout from
|
| 65 |
+
T0 | T1 | T2 | T3
|
| 66 |
+
a b | c d | e f | g h
|
| 67 |
+
to
|
| 68 |
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
| 69 |
+
a | b | c | d | e | f | g | h
|
| 70 |
+
This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
assert t.element_type.width == 32
|
| 74 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
| 75 |
+
|
| 76 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 77 |
+
# left_map = [0, 2, 1, 3]
|
| 78 |
+
# right_map = [2, 0, 3, 1]
|
| 79 |
+
# indexing isn't supported so we have to do arithmetic
|
| 80 |
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
| 81 |
+
right_idx = left_idx ^ 0b10
|
| 82 |
+
|
| 83 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 84 |
+
width = 4
|
| 85 |
+
mask = cute.arch.WARP_SIZE - width
|
| 86 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 87 |
+
mask_and_clamp = mask << 8 | clamp
|
| 88 |
+
|
| 89 |
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
| 90 |
+
for r in cutlass.range(2, unroll_full=True):
|
| 91 |
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
| 92 |
+
# a b | c d | e f | g h -> a b | c d | f e | h g
|
| 93 |
+
left0 = left if quad_idx < 2 else right
|
| 94 |
+
right0 = right if quad_idx < 2 else left
|
| 95 |
+
# a b | c d | f e | h g -> a b | f d | c e | h g
|
| 96 |
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
| 97 |
+
# a b | f d | c e | h g -> a e | f b | c g | h d
|
| 98 |
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
| 99 |
+
# a e | f b | c g | h d -> a e | b f | c g | d h
|
| 100 |
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
|
| 101 |
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
|
| 102 |
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@cute.jit
|
| 106 |
+
def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
|
| 107 |
+
"""Permute and shuffle within 4 threads to change the layout from
|
| 108 |
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
| 109 |
+
a | b | c | d | e | f | g | h
|
| 110 |
+
to
|
| 111 |
+
T0 | T1 | T2 | T3
|
| 112 |
+
a b | c d | e f | g h
|
| 113 |
+
This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
assert t.element_type.width == 32
|
| 117 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
| 118 |
+
|
| 119 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 120 |
+
# left_map = [0, 2, 1, 3]
|
| 121 |
+
# right_map = [1, 3, 0, 2]
|
| 122 |
+
# indexing isn't supported so we have to do arithmetic
|
| 123 |
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
| 124 |
+
right_idx = left_idx ^ 0b01
|
| 125 |
+
|
| 126 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 127 |
+
width = 4
|
| 128 |
+
mask = cute.arch.WARP_SIZE - width
|
| 129 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 130 |
+
mask_and_clamp = mask << 8 | clamp
|
| 131 |
+
|
| 132 |
+
# This is just the inverse of permute_Cregs_b32_for_stsm
|
| 133 |
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
| 134 |
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
| 135 |
+
for r in cutlass.range(2, unroll_full=True):
|
| 136 |
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
| 137 |
+
# a e | b f | c g | d h -> a e | f b | c g | h d
|
| 138 |
+
left0 = left if quad_idx % 2 == 0 else right
|
| 139 |
+
right0 = right if quad_idx % 2 == 0 else left
|
| 140 |
+
# a e | f b | c g | h d -> a b | f d | c e | h g
|
| 141 |
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
| 142 |
+
# a b | f d | c e | h g -> a b | c d | f e | h g
|
| 143 |
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
| 144 |
+
# a b | c d | f e | h g -> a b | c d | e f | g h
|
| 145 |
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
|
| 146 |
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@cute.jit
|
| 150 |
+
def concat_layout(*layouts: cute.Layout) -> cute.Layout:
|
| 151 |
+
return cute.make_layout(
|
| 152 |
+
tuple(l.shape for l in layouts),
|
| 153 |
+
stride=tuple(l.stride for l in layouts),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
| 158 |
+
"""
|
| 159 |
+
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
| 160 |
+
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
| 161 |
+
"""
|
| 162 |
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
| 163 |
+
acc_layout_mn = cute.make_layout(
|
| 164 |
+
(
|
| 165 |
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
| 166 |
+
(
|
| 167 |
+
acc_layout_col_major.shape[0][0],
|
| 168 |
+
*acc_layout_col_major.shape[0][2:],
|
| 169 |
+
acc_layout_col_major.shape[2],
|
| 170 |
+
), # MMA_N
|
| 171 |
+
*acc_layout_col_major.shape[3:],
|
| 172 |
+
),
|
| 173 |
+
stride=(
|
| 174 |
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
| 175 |
+
(
|
| 176 |
+
acc_layout_col_major.stride[0][0],
|
| 177 |
+
*acc_layout_col_major.stride[0][2:],
|
| 178 |
+
acc_layout_col_major.stride[2],
|
| 179 |
+
), # MMA_N
|
| 180 |
+
*acc_layout_col_major.stride[3:],
|
| 181 |
+
),
|
| 182 |
+
)
|
| 183 |
+
return cute.composition(acc_layout, acc_layout_mn)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
| 187 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor:
|
| 191 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@cute.jit
|
| 195 |
+
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
| 196 |
+
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
| 197 |
+
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
| 198 |
+
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
| 199 |
+
# TODO: Sm90 FP8
|
| 200 |
+
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
| 201 |
+
l = cute.logical_divide(
|
| 202 |
+
acc_layout, ((None, None, 2), None, None)
|
| 203 |
+
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
| 204 |
+
rA_mma_view = cute.make_layout(
|
| 205 |
+
(
|
| 206 |
+
(l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
|
| 207 |
+
l.shape[1],
|
| 208 |
+
(l.shape[0][2][1], l.shape[2]),
|
| 209 |
+
),
|
| 210 |
+
stride=(
|
| 211 |
+
(l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
|
| 212 |
+
l.stride[1],
|
| 213 |
+
(l.stride[0][2][1], l.stride[2]),
|
| 214 |
+
),
|
| 215 |
+
)
|
| 216 |
+
else: # Sm80
|
| 217 |
+
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
|
| 218 |
+
l = cute.logical_divide(acc_layout, (None, None, 2))
|
| 219 |
+
rA_mma_view = cute.make_layout(
|
| 220 |
+
(
|
| 221 |
+
(l.shape[0], l.shape[2][0]),
|
| 222 |
+
l.shape[1],
|
| 223 |
+
l.shape[2][1],
|
| 224 |
+
),
|
| 225 |
+
stride=(
|
| 226 |
+
(l.stride[0], l.stride[2][0]),
|
| 227 |
+
l.stride[1],
|
| 228 |
+
l.stride[2][1],
|
| 229 |
+
),
|
| 230 |
+
)
|
| 231 |
+
return rA_mma_view
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
|
| 235 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def convert_layout_zero_stride(
|
| 239 |
+
input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
|
| 240 |
+
) -> cute.Layout:
|
| 241 |
+
layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
|
| 242 |
+
# Group the modes with non-zero stride in the ref_layout together,
|
| 243 |
+
# and the modes with zero stride together
|
| 244 |
+
layout_flat = cute.flatten(layout)
|
| 245 |
+
ref_layout_flat = cute.flatten(ref_layout)
|
| 246 |
+
nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
|
| 247 |
+
zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
|
| 248 |
+
# There's an edge case when all modes are zero stride
|
| 249 |
+
new_shape = (
|
| 250 |
+
tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
|
| 251 |
+
tuple(layout_flat[i].shape for i in zero_modes),
|
| 252 |
+
)
|
| 253 |
+
new_stride = (
|
| 254 |
+
tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
|
| 255 |
+
tuple(layout_flat[i].stride for i in zero_modes),
|
| 256 |
+
)
|
| 257 |
+
out_layout = cute.make_layout(new_shape, stride=new_stride)
|
| 258 |
+
if const_expr(isinstance(input, cute.Tensor)):
|
| 259 |
+
return cute.make_tensor(input.iterator, out_layout)
|
| 260 |
+
else:
|
| 261 |
+
return out_layout
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def mma_partition_C_vec(
|
| 265 |
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
| 266 |
+
) -> cute.Tensor:
|
| 267 |
+
assert cute.rank(sVec) == 2
|
| 268 |
+
assert sVec.stride[0] == 1
|
| 269 |
+
stage = sVec.shape[1]
|
| 270 |
+
shape = (
|
| 271 |
+
(sVec.shape[0], expand_shape, stage)
|
| 272 |
+
if const_expr(is_colvec)
|
| 273 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 274 |
+
)
|
| 275 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 276 |
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 277 |
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
|
| 278 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def mma_partition_A_vec(
|
| 282 |
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
| 283 |
+
) -> cute.Tensor:
|
| 284 |
+
assert cute.rank(sVec) == 2
|
| 285 |
+
assert sVec.stride[0] == 1
|
| 286 |
+
stage = sVec.shape[1]
|
| 287 |
+
shape = (
|
| 288 |
+
(sVec.shape[0], expand_shape, stage)
|
| 289 |
+
if const_expr(is_colvec)
|
| 290 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 291 |
+
)
|
| 292 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 293 |
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 294 |
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
| 295 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
build/torch-cuda/quack/pipeline.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
from cutlass import Boolean, Int32, const_expr
|
| 8 |
+
from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
|
| 9 |
+
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
|
| 10 |
+
from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
|
| 11 |
+
from cutlass.pipeline import PipelineTmaUmma
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PipelineStateWAdvance(PipelineState):
|
| 15 |
+
@dsl_user_op
|
| 16 |
+
def advance_iters(self, num_iterations: Int32, *, loc=None, ip=None):
|
| 17 |
+
self._count += Int32(num_iterations)
|
| 18 |
+
new_index = self._index + Int32(num_iterations)
|
| 19 |
+
# How many times did we cross the stages boundary
|
| 20 |
+
num_crossings = new_index // self.stages
|
| 21 |
+
self._phase ^= num_crossings
|
| 22 |
+
self._index = new_index % self.stages
|
| 23 |
+
|
| 24 |
+
# This can be overridden by derived classes
|
| 25 |
+
def __new_from_mlir_values__(self, values):
|
| 26 |
+
return PipelineStateWAdvance(
|
| 27 |
+
self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def make_pipeline_state(type: PipelineUserType, stages: int):
|
| 32 |
+
"""
|
| 33 |
+
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
| 34 |
+
"""
|
| 35 |
+
if type is PipelineUserType.Producer:
|
| 36 |
+
return PipelineStateWAdvance(
|
| 37 |
+
stages,
|
| 38 |
+
Int32(0),
|
| 39 |
+
Int32(0),
|
| 40 |
+
Int32(1),
|
| 41 |
+
)
|
| 42 |
+
elif type is PipelineUserType.Consumer:
|
| 43 |
+
return PipelineStateWAdvance(
|
| 44 |
+
stages,
|
| 45 |
+
Int32(0),
|
| 46 |
+
Int32(0),
|
| 47 |
+
Int32(0),
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass(frozen=True)
|
| 54 |
+
class PipelineTmaCpAsync(PipelineTmaAsync):
|
| 55 |
+
"""
|
| 56 |
+
PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def create(
|
| 61 |
+
*,
|
| 62 |
+
num_stages: int,
|
| 63 |
+
producer_group: CooperativeGroup,
|
| 64 |
+
consumer_group: CooperativeGroup,
|
| 65 |
+
tx_count: int,
|
| 66 |
+
barrier_storage: cute.Pointer = None,
|
| 67 |
+
cta_layout_vmnk: Optional[cute.Layout] = None,
|
| 68 |
+
tidx: Optional[Int32] = None,
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
|
| 72 |
+
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
| 73 |
+
:type barrier_storage: cute.Pointer
|
| 74 |
+
:param num_stages: Number of buffer stages for this pipeline
|
| 75 |
+
:type num_stages: Int32
|
| 76 |
+
:param producer_group: CooperativeGroup for the producer agent
|
| 77 |
+
:type producer_group: CooperativeGroup
|
| 78 |
+
:param consumer_group: CooperativeGroup for the consumer agent
|
| 79 |
+
:type consumer_group: CooperativeGroup
|
| 80 |
+
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
| 81 |
+
:type tx_count: int
|
| 82 |
+
:param cta_layout_vmnk: Layout of the cluster shape
|
| 83 |
+
:type cta_layout_vmnk: cute.Layout | None
|
| 84 |
+
:param tidx: thread index to consumer async threads
|
| 85 |
+
:type tidx: Int32 | None
|
| 86 |
+
"""
|
| 87 |
+
if not isinstance(barrier_storage, cute.Pointer):
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
producer_type = PipelineOp.TmaLoad
|
| 93 |
+
consumer_type = PipelineOp.AsyncThread
|
| 94 |
+
|
| 95 |
+
producer = (producer_type, producer_group)
|
| 96 |
+
consumer = (consumer_type, consumer_group)
|
| 97 |
+
|
| 98 |
+
sync_object_full = PipelineAsync._make_sync_object(
|
| 99 |
+
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
| 100 |
+
)
|
| 101 |
+
sync_object_empty = PipelineAsync._make_sync_object(
|
| 102 |
+
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
| 103 |
+
)
|
| 104 |
+
if tidx is None:
|
| 105 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 106 |
+
if cta_layout_vmnk is None:
|
| 107 |
+
cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
| 108 |
+
(
|
| 109 |
+
dst_rank,
|
| 110 |
+
is_signalling_thread,
|
| 111 |
+
) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
|
| 112 |
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
| 113 |
+
dst_rank = None
|
| 114 |
+
else:
|
| 115 |
+
dst_rank = dst_rank
|
| 116 |
+
|
| 117 |
+
producer_mask = None
|
| 118 |
+
|
| 119 |
+
pipeline_init_wait(cta_layout_vmnk)
|
| 120 |
+
|
| 121 |
+
return PipelineTmaCpAsync(
|
| 122 |
+
sync_object_full,
|
| 123 |
+
sync_object_empty,
|
| 124 |
+
num_stages,
|
| 125 |
+
producer_mask,
|
| 126 |
+
dst_rank,
|
| 127 |
+
is_signalling_thread,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
@dsl_user_op
|
| 131 |
+
def producer_acquire(
|
| 132 |
+
self,
|
| 133 |
+
state: PipelineState,
|
| 134 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 135 |
+
is_tma_warp: Optional[Boolean] = True,
|
| 136 |
+
*,
|
| 137 |
+
loc=None,
|
| 138 |
+
ip=None,
|
| 139 |
+
):
|
| 140 |
+
"""
|
| 141 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
|
| 142 |
+
"""
|
| 143 |
+
if_generate(
|
| 144 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 145 |
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 146 |
+
)
|
| 147 |
+
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
| 148 |
+
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
| 149 |
+
if_generate(
|
| 150 |
+
is_tma_warp,
|
| 151 |
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
@dsl_user_op
|
| 155 |
+
def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
|
| 156 |
+
"""
|
| 157 |
+
We need the mbarrier to track the completion of cp.async
|
| 158 |
+
"""
|
| 159 |
+
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class MbarrierArrayWDropCount(MbarrierArray):
|
| 163 |
+
@dsl_user_op
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
barrier_storage: cute.Pointer,
|
| 167 |
+
num_stages: int,
|
| 168 |
+
agent: tuple[PipelineOp, CooperativeGroup],
|
| 169 |
+
tx_count: int = 0,
|
| 170 |
+
drop_count: Optional[Int32] = None,
|
| 171 |
+
*,
|
| 172 |
+
loc=None,
|
| 173 |
+
ip=None,
|
| 174 |
+
) -> None:
|
| 175 |
+
self.barrier_storage = barrier_storage
|
| 176 |
+
self.tx_count = tx_count
|
| 177 |
+
self.num_stages = num_stages
|
| 178 |
+
self.op_type, self.cg = agent
|
| 179 |
+
self.arrive_count = self.cg.size
|
| 180 |
+
self.drop_count = drop_count
|
| 181 |
+
|
| 182 |
+
if self.num_stages <= 0:
|
| 183 |
+
raise ValueError("Error: Mbarrier stage count must be greater than 0.")
|
| 184 |
+
if self.arrive_count <= 0:
|
| 185 |
+
raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
|
| 186 |
+
if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0:
|
| 187 |
+
raise ValueError("Error: Mbarrier tx count must not be less than 0 for TMA ops.")
|
| 188 |
+
|
| 189 |
+
if const_expr(drop_count is not None):
|
| 190 |
+
self.arrive_count = self.arrive_count - drop_count
|
| 191 |
+
|
| 192 |
+
# Store mbarrier base pointer
|
| 193 |
+
self.mbarrier_base = self.barrier_storage
|
| 194 |
+
|
| 195 |
+
# Mbarrier initialization in constructor
|
| 196 |
+
self.mbarrier_init(loc=loc, ip=ip)
|
| 197 |
+
|
| 198 |
+
def __extract_mlir_values__(self):
|
| 199 |
+
return [self.barrier_storage, self.drop_count]
|
| 200 |
+
|
| 201 |
+
def __new_from_mlir_values__(self, values):
|
| 202 |
+
return MbarrierArrayWDropCount(
|
| 203 |
+
values[0], self.num_stages, (self.op_type, self.cg), self.tx_count, values[1]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@dataclass(frozen=True)
|
| 208 |
+
class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
| 209 |
+
"""
|
| 210 |
+
PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
|
| 211 |
+
(e.g. Blackwell mainloops)
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
@staticmethod
|
| 215 |
+
def create(
|
| 216 |
+
*,
|
| 217 |
+
num_stages: int,
|
| 218 |
+
producer_group: CooperativeGroup,
|
| 219 |
+
consumer_group: CooperativeGroup,
|
| 220 |
+
tx_count: int,
|
| 221 |
+
barrier_storage: cute.Pointer = None,
|
| 222 |
+
cta_layout_vmnk: Optional[cute.Layout] = None,
|
| 223 |
+
producer_drop_count: Optional[Int32] = None,
|
| 224 |
+
mcast_mode_mn: tuple[int, int] = (1, 1),
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
|
| 228 |
+
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
| 229 |
+
:type barrier_storage: cute.Pointer
|
| 230 |
+
:param num_stages: Number of buffer stages for this pipeline
|
| 231 |
+
:type num_stages: Int32
|
| 232 |
+
:param producer_group: `CooperativeGroup` for the producer agent
|
| 233 |
+
:type producer_group: CooperativeGroup
|
| 234 |
+
:param consumer_group: `CooperativeGroup` for the consumer agent
|
| 235 |
+
:type consumer_group: CooperativeGroup
|
| 236 |
+
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
| 237 |
+
:type tx_count: int
|
| 238 |
+
:param cta_layout_vmnk: Layout of the cluster shape
|
| 239 |
+
:type cta_layout_vmnk: cute.Layout | None
|
| 240 |
+
:param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
|
| 241 |
+
:type mcast_mode_mn: tuple[int, int], optional
|
| 242 |
+
"""
|
| 243 |
+
if not isinstance(barrier_storage, cute.Pointer):
|
| 244 |
+
raise ValueError(
|
| 245 |
+
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
producer_type = PipelineOp.TmaLoad
|
| 249 |
+
consumer_type = PipelineOp.TCGen05Mma
|
| 250 |
+
|
| 251 |
+
producer = (producer_type, producer_group)
|
| 252 |
+
consumer = (consumer_type, consumer_group)
|
| 253 |
+
|
| 254 |
+
sync_object_full = MbarrierArrayWDropCount(
|
| 255 |
+
barrier_storage.align(min_align=8),
|
| 256 |
+
num_stages,
|
| 257 |
+
producer,
|
| 258 |
+
tx_count,
|
| 259 |
+
drop_count=producer_drop_count,
|
| 260 |
+
)
|
| 261 |
+
sync_object_empty = PipelineTmaUmma._make_sync_object(
|
| 262 |
+
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
| 266 |
+
# No mcast mask if not using clusters
|
| 267 |
+
producer_mask = None
|
| 268 |
+
# All threadblocks are leaders if not using clusters
|
| 269 |
+
is_leader_cta = True
|
| 270 |
+
else:
|
| 271 |
+
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk, mcast_mode_mn)
|
| 272 |
+
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
| 273 |
+
|
| 274 |
+
cta_group = (
|
| 275 |
+
cute.nvgpu.tcgen05.CtaGroup.ONE
|
| 276 |
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
| 277 |
+
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
consumer_mask = producer_mask
|
| 281 |
+
|
| 282 |
+
pipeline_init_wait(cta_layout_vmnk)
|
| 283 |
+
|
| 284 |
+
return PipelineTmaCpAsyncUmma(
|
| 285 |
+
sync_object_full,
|
| 286 |
+
sync_object_empty,
|
| 287 |
+
num_stages,
|
| 288 |
+
producer_mask,
|
| 289 |
+
consumer_mask,
|
| 290 |
+
is_leader_cta,
|
| 291 |
+
cta_group,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
@dsl_user_op
|
| 295 |
+
def producer_acquire(
|
| 296 |
+
self,
|
| 297 |
+
state: PipelineState,
|
| 298 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 299 |
+
is_tma_warp: Optional[Boolean] = True,
|
| 300 |
+
*,
|
| 301 |
+
loc=None,
|
| 302 |
+
ip=None,
|
| 303 |
+
):
|
| 304 |
+
"""
|
| 305 |
+
TMA producer commit conditionally waits on buffer empty and sets the
|
| 306 |
+
transaction barrier for leader threadblocks.
|
| 307 |
+
"""
|
| 308 |
+
if_generate(
|
| 309 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 310 |
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 311 |
+
)
|
| 312 |
+
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
| 313 |
+
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
| 314 |
+
if_generate(
|
| 315 |
+
and_(self.is_leader_cta, is_tma_warp),
|
| 316 |
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
@dsl_user_op
|
| 320 |
+
def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
|
| 321 |
+
"""
|
| 322 |
+
We need the mbarrier to track the completion of cp.async
|
| 323 |
+
"""
|
| 324 |
+
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
|
build/torch-cuda/quack/reduce.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import operator
|
| 5 |
+
from typing import Callable, Optional
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Int32, Int64, Float32, Boolean, const_expr
|
| 10 |
+
|
| 11 |
+
from . import utils as utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@cute.jit
|
| 15 |
+
def block_reduce(
|
| 16 |
+
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
| 17 |
+
) -> cute.Numeric:
|
| 18 |
+
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
| 19 |
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
| 20 |
+
warps_per_row = cute.size(reduction_buffer.shape[1])
|
| 21 |
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
| 22 |
+
if lane_idx == 0:
|
| 23 |
+
reduction_buffer[row_idx, col_idx] = val
|
| 24 |
+
cute.arch.barrier()
|
| 25 |
+
block_reduce_val = init_val
|
| 26 |
+
if lane_idx < warps_per_row:
|
| 27 |
+
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
| 28 |
+
return cute.arch.warp_reduction(block_reduce_val, op)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@cute.jit
|
| 32 |
+
def cluster_reduce(
|
| 33 |
+
val: cute.Numeric,
|
| 34 |
+
op: Callable,
|
| 35 |
+
reduction_buffer: cute.Tensor,
|
| 36 |
+
mbar_ptr: cute.Pointer,
|
| 37 |
+
init_val: cute.Numeric = 0.0,
|
| 38 |
+
phase: Optional[Int32] = None,
|
| 39 |
+
) -> cute.Numeric:
|
| 40 |
+
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
| 41 |
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
| 42 |
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
| 43 |
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
| 44 |
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
| 45 |
+
if warp_idx == 0:
|
| 46 |
+
with cute.arch.elect_one():
|
| 47 |
+
num_warps = rows_per_block * warps_per_row
|
| 48 |
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
| 49 |
+
mbar_ptr,
|
| 50 |
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
| 51 |
+
)
|
| 52 |
+
if lane_idx < cluster_n:
|
| 53 |
+
utils.store_shared_remote(
|
| 54 |
+
val,
|
| 55 |
+
utils.elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
| 56 |
+
mbar_ptr,
|
| 57 |
+
peer_cta_rank_in_cluster=lane_idx,
|
| 58 |
+
)
|
| 59 |
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
| 60 |
+
block_reduce_val = init_val
|
| 61 |
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
| 62 |
+
for i in cutlass.range_constexpr(num_iter):
|
| 63 |
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
| 64 |
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
| 65 |
+
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
| 66 |
+
return cute.arch.warp_reduction(block_reduce_val, op)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@cute.jit
|
| 70 |
+
def block_or_cluster_reduce(
|
| 71 |
+
val: cute.Numeric,
|
| 72 |
+
op: Callable,
|
| 73 |
+
reduction_buffer: cute.Tensor,
|
| 74 |
+
mbar_ptr: Optional[cute.Pointer],
|
| 75 |
+
phase: Optional[Int32] = None,
|
| 76 |
+
init_val: cute.Numeric = 0.0,
|
| 77 |
+
) -> cute.Numeric:
|
| 78 |
+
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
| 79 |
+
if const_expr(mbar_ptr is None):
|
| 80 |
+
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
| 81 |
+
else:
|
| 82 |
+
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@cute.jit
|
| 86 |
+
def row_reduce(
|
| 87 |
+
x: cute.TensorSSA | cute.Numeric,
|
| 88 |
+
op: cute.ReductionOp,
|
| 89 |
+
threads_per_row: cutlass.Constexpr[int],
|
| 90 |
+
reduction_buffer: Optional[cute.Tensor] = None,
|
| 91 |
+
mbar_ptr: Optional[cute.Pointer] = None,
|
| 92 |
+
phase: Optional[Int32] = None,
|
| 93 |
+
init_val: cute.Numeric = 0.0,
|
| 94 |
+
hook_fn: Optional[Callable] = None,
|
| 95 |
+
) -> cute.Numeric:
|
| 96 |
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
| 97 |
+
if const_expr(isinstance(x, cute.TensorSSA)):
|
| 98 |
+
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
| 99 |
+
else:
|
| 100 |
+
val = x
|
| 101 |
+
warp_op = {
|
| 102 |
+
cute.ReductionOp.ADD: operator.add,
|
| 103 |
+
cute.ReductionOp.MAX: cute.arch.fmax if const_expr(x.dtype == Float32) else max,
|
| 104 |
+
cute.ReductionOp.MIN: min,
|
| 105 |
+
cute.ReductionOp.MUL: operator.mul,
|
| 106 |
+
}[op]
|
| 107 |
+
val = cute.arch.warp_reduction(
|
| 108 |
+
val,
|
| 109 |
+
warp_op,
|
| 110 |
+
threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
|
| 111 |
+
)
|
| 112 |
+
if const_expr(hook_fn is not None):
|
| 113 |
+
hook_fn()
|
| 114 |
+
if const_expr(reduction_buffer is not None):
|
| 115 |
+
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
| 116 |
+
assert cluster_n == 1 or mbar_ptr is not None, (
|
| 117 |
+
"mbar_ptr must be provided for cluster reduction"
|
| 118 |
+
)
|
| 119 |
+
if const_expr(warps_per_row > 1 or cluster_n > 1):
|
| 120 |
+
val = block_or_cluster_reduce(
|
| 121 |
+
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
| 122 |
+
)
|
| 123 |
+
return val
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@cute.jit
|
| 127 |
+
def online_softmax_reduce(
|
| 128 |
+
x: cute.TensorSSA,
|
| 129 |
+
threads_per_row: cutlass.Constexpr[int],
|
| 130 |
+
reduction_buffer: Optional[cute.Tensor] = None,
|
| 131 |
+
mbar_ptr: Optional[cute.Pointer] = None,
|
| 132 |
+
hook_fn: Optional[Callable] = None,
|
| 133 |
+
phase: Optional[Int32] = None,
|
| 134 |
+
return_exp_x: bool = False,
|
| 135 |
+
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
| 136 |
+
assert x.dtype == Float32, "x must be of type Float32"
|
| 137 |
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
| 138 |
+
max_x = cute.arch.warp_reduction(
|
| 139 |
+
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
| 140 |
+
cute.arch.fmax,
|
| 141 |
+
threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
|
| 142 |
+
)
|
| 143 |
+
log2_e = math.log2(math.e)
|
| 144 |
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
| 145 |
+
sum_exp_x = cute.arch.warp_reduction(
|
| 146 |
+
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
| 147 |
+
operator.add,
|
| 148 |
+
threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
|
| 149 |
+
)
|
| 150 |
+
if const_expr(hook_fn is not None):
|
| 151 |
+
hook_fn()
|
| 152 |
+
if const_expr(reduction_buffer is not None):
|
| 153 |
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
| 154 |
+
assert cluster_n == 1 or mbar_ptr is not None, (
|
| 155 |
+
"mbar_ptr must be provided for cluster reduction"
|
| 156 |
+
)
|
| 157 |
+
if const_expr(warps_per_row > 1 or cluster_n > 1):
|
| 158 |
+
assert reduction_buffer.element_type == Int64, (
|
| 159 |
+
"reduction_buffer must be of type cute.Int64"
|
| 160 |
+
)
|
| 161 |
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
| 162 |
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
| 163 |
+
if const_expr(mbar_ptr is None):
|
| 164 |
+
if lane_idx == 0:
|
| 165 |
+
reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x)
|
| 166 |
+
cute.arch.barrier()
|
| 167 |
+
max_x_single_warp = -Float32.inf
|
| 168 |
+
sum_exp_x = 0.0
|
| 169 |
+
if lane_idx < warps_per_row:
|
| 170 |
+
max_x_single_warp, sum_exp_x = utils.i64_to_f32x2(
|
| 171 |
+
reduction_buffer[row_idx, lane_idx]
|
| 172 |
+
)
|
| 173 |
+
max_x_final = cute.arch.warp_reduction(max_x_single_warp, cute.arch.fmax)
|
| 174 |
+
sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
|
| 175 |
+
sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add)
|
| 176 |
+
if const_expr(return_exp_x):
|
| 177 |
+
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
| 178 |
+
max_x = max_x_final
|
| 179 |
+
else:
|
| 180 |
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
| 181 |
+
if warp_idx == 0:
|
| 182 |
+
with cute.arch.elect_one():
|
| 183 |
+
num_warps = rows_per_block * warps_per_row
|
| 184 |
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
| 185 |
+
mbar_ptr,
|
| 186 |
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
| 187 |
+
)
|
| 188 |
+
if lane_idx < cluster_n:
|
| 189 |
+
utils.store_shared_remote(
|
| 190 |
+
utils.f32x2_to_i64(max_x, sum_exp_x),
|
| 191 |
+
utils.elem_pointer(
|
| 192 |
+
reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))
|
| 193 |
+
),
|
| 194 |
+
mbar_ptr,
|
| 195 |
+
peer_cta_rank_in_cluster=lane_idx,
|
| 196 |
+
)
|
| 197 |
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
| 198 |
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
| 199 |
+
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
| 200 |
+
max_x_single_warp.fill(-Float32.inf)
|
| 201 |
+
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
| 202 |
+
sum_exp_x_single_warp.fill(0.0)
|
| 203 |
+
for i in cutlass.range_constexpr(num_iter):
|
| 204 |
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
| 205 |
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
| 206 |
+
max_x_single_warp[i], sum_exp_x_single_warp[i] = utils.i64_to_f32x2(
|
| 207 |
+
reduction_buffer[row_idx, idx]
|
| 208 |
+
)
|
| 209 |
+
max_x_final = max_x_single_warp.load().reduce(
|
| 210 |
+
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
| 211 |
+
)
|
| 212 |
+
max_x_final = cute.arch.warp_reduction(max_x_final, cute.arch.fmax)
|
| 213 |
+
sum_exp_x = 0.0
|
| 214 |
+
for i in cutlass.range_constexpr(num_iter):
|
| 215 |
+
sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
|
| 216 |
+
max_x_single_warp[i] - max_x_final, fastmath=True
|
| 217 |
+
)
|
| 218 |
+
sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add)
|
| 219 |
+
if const_expr(return_exp_x):
|
| 220 |
+
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
| 221 |
+
max_x = max_x_final
|
| 222 |
+
return max_x, sum_exp_x, (exp_x if const_expr(return_exp_x) else None)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
@cute.jit
|
| 226 |
+
def sum_swap_shuffle(
|
| 227 |
+
X: cute.Tensor, elem_per_lane: int = 1, subwarp_size: int = 1, warp_size: int = 32
|
| 228 |
+
) -> cute.Tensor:
|
| 229 |
+
"""
|
| 230 |
+
For warp reduction, we use Swap Shuffle
|
| 231 |
+
The normal way to reduction among threads:
|
| 232 |
+
use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads.
|
| 233 |
+
After each step of reduction, a half of threads won't work in the following steps.
|
| 234 |
+
That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case).
|
| 235 |
+
To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors,
|
| 236 |
+
we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads.
|
| 237 |
+
After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step.
|
| 238 |
+
We can recursively do this until the problem size is 1.
|
| 239 |
+
"""
|
| 240 |
+
assert (
|
| 241 |
+
subwarp_size >= 1
|
| 242 |
+
and subwarp_size <= 32
|
| 243 |
+
and subwarp_size == 1 << int(math.log2(subwarp_size))
|
| 244 |
+
)
|
| 245 |
+
assert (
|
| 246 |
+
warp_size <= 32
|
| 247 |
+
and warp_size % subwarp_size == 0
|
| 248 |
+
and warp_size == 1 << int(math.log2(warp_size))
|
| 249 |
+
)
|
| 250 |
+
lane_idx = cute.arch.lane_idx() // subwarp_size
|
| 251 |
+
X = cute.logical_divide(X, cute.make_layout(elem_per_lane)) # (elem_per_lane, M)
|
| 252 |
+
numvec = cute.size(X, mode=[1])
|
| 253 |
+
assert numvec <= 32 // subwarp_size
|
| 254 |
+
# If X has more values than warp_size // subwarp_size, we first do a normal warp reduction
|
| 255 |
+
# to sum up values held by lanes further than size(X) away
|
| 256 |
+
for i in cutlass.range(
|
| 257 |
+
int(math.log2(numvec)), int(math.log2(warp_size // subwarp_size)), unroll_full=True
|
| 258 |
+
):
|
| 259 |
+
for v in cutlass.range(cute.size(X), unroll_full=True):
|
| 260 |
+
shfl_val = cute.arch.shuffle_sync_bfly(X[v], offset=(1 << i) * subwarp_size)
|
| 261 |
+
X[v] = X[v] + shfl_val
|
| 262 |
+
for logm in cutlass.range_constexpr(int(math.log2(cute.size(X, mode=[1]))) - 1, -1, -1):
|
| 263 |
+
m = 1 << logm
|
| 264 |
+
for r in cutlass.range(m, unroll_full=True):
|
| 265 |
+
frg_A = X[None, r]
|
| 266 |
+
frg_B = X[None, r + m]
|
| 267 |
+
# First half of threads swap fragments from the first half of data to the second
|
| 268 |
+
should_swap = not Boolean(lane_idx & m)
|
| 269 |
+
for v in cutlass.range(cute.size(frg_A), unroll_full=True):
|
| 270 |
+
# Step 1: swap
|
| 271 |
+
lower, upper = frg_A[v], frg_B[v]
|
| 272 |
+
frg_A[v] = upper if should_swap else lower
|
| 273 |
+
frg_B[v] = lower if should_swap else upper
|
| 274 |
+
# Step 2: shuffle
|
| 275 |
+
# each half of threads get a half of data from the other half of threads
|
| 276 |
+
shfl_val = cute.arch.shuffle_sync_bfly(frg_A[v], offset=m * subwarp_size)
|
| 277 |
+
# Step 3: reduction
|
| 278 |
+
frg_A[v] = frg_B[v] + shfl_val
|
| 279 |
+
return X[None, 0]
|
build/torch-cuda/quack/reduction_base.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Type, Tuple, Optional
|
| 4 |
+
|
| 5 |
+
import cutlass
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
from cutlass import Int32, Int64, Float32, const_expr
|
| 8 |
+
|
| 9 |
+
from . import copy_utils as copy_utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ReductionBase:
|
| 13 |
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=Float32):
|
| 14 |
+
self.dtype = dtype
|
| 15 |
+
self.N = N
|
| 16 |
+
self.stage = stage
|
| 17 |
+
self.reduction_dtype = reduction_dtype
|
| 18 |
+
|
| 19 |
+
def _threads_per_row(self):
|
| 20 |
+
raise NotImplementedError()
|
| 21 |
+
|
| 22 |
+
def _num_threads(self):
|
| 23 |
+
return 128 if self.N <= 16384 else 256
|
| 24 |
+
|
| 25 |
+
def _set_cluster_n(self):
|
| 26 |
+
self.cluster_n = 1
|
| 27 |
+
|
| 28 |
+
def _get_tiled_copy(self, vecsize: int = 1):
|
| 29 |
+
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
| 30 |
+
threads_per_row = self._threads_per_row()
|
| 31 |
+
num_threads = self._num_threads()
|
| 32 |
+
assert num_threads % cute.arch.WARP_SIZE == 0
|
| 33 |
+
num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
|
| 34 |
+
tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row)
|
| 35 |
+
tiled_copy = copy_utils.tiled_copy_2d(self.dtype, threads_per_row, num_threads, vecsize)
|
| 36 |
+
return tiled_copy, tiler_mn, threads_per_row
|
| 37 |
+
|
| 38 |
+
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
| 39 |
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
| 40 |
+
warps_per_row = (
|
| 41 |
+
num_warps
|
| 42 |
+
if cute.rank(tv_layout.shape[0]) == 1
|
| 43 |
+
else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
| 44 |
+
)
|
| 45 |
+
return cute.make_ordered_layout(
|
| 46 |
+
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
| 47 |
+
order=(1, 0, 2),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def _allocate_reduction_buffer_and_mbar(
|
| 51 |
+
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
|
| 52 |
+
) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
|
| 53 |
+
reduction_buffer = smem.allocate_tensor(
|
| 54 |
+
self.reduction_dtype,
|
| 55 |
+
self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
|
| 56 |
+
byte_alignment=8,
|
| 57 |
+
)
|
| 58 |
+
if const_expr(self.cluster_n > 1):
|
| 59 |
+
mbar_ptr = smem.allocate_array(
|
| 60 |
+
Int64, num_elems=self.stage if not is_persistent else self.stage * 2
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
mbar_ptr = None
|
| 64 |
+
return reduction_buffer, mbar_ptr
|
| 65 |
+
|
| 66 |
+
@cute.jit
|
| 67 |
+
def _initialize_cluster(
|
| 68 |
+
self,
|
| 69 |
+
tidx: Int32,
|
| 70 |
+
mbar_ptr: cute.Pointer,
|
| 71 |
+
num_warps: int,
|
| 72 |
+
is_persistent: bool = False,
|
| 73 |
+
):
|
| 74 |
+
if const_expr(self.cluster_n > 1):
|
| 75 |
+
if tidx < self.stage: # Initialize full barrier
|
| 76 |
+
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
| 77 |
+
if const_expr(is_persistent): # Initialize empty barrier
|
| 78 |
+
cute.arch.mbarrier_init(
|
| 79 |
+
mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
|
| 80 |
+
)
|
| 81 |
+
cute.arch.mbarrier_init_fence()
|
| 82 |
+
# Cluster arrive after barrier init
|
| 83 |
+
cute.arch.cluster_arrive_relaxed()
|
build/torch-cuda/quack/sm100_utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Type, Union
|
| 4 |
+
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
import cutlass.utils.blackwell_helpers as sm100_utils_og
|
| 7 |
+
from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
|
| 8 |
+
from cutlass.cutlass_dsl import Numeric, dsl_user_op
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dsl_user_op
|
| 12 |
+
def make_smem_layout_cpasync_a(
|
| 13 |
+
tiled_mma: cute.TiledMma,
|
| 14 |
+
mma_tiler_mnk: cute.Tile,
|
| 15 |
+
a_dtype: Type[Numeric],
|
| 16 |
+
num_stages: int,
|
| 17 |
+
*,
|
| 18 |
+
loc=None,
|
| 19 |
+
ip=None,
|
| 20 |
+
) -> Union[cute.Layout, cute.ComposedLayout]:
|
| 21 |
+
"""
|
| 22 |
+
:param tiled_mma: The tiled MMA used to partition tensor A
|
| 23 |
+
:type tiled_mma: cute.TiledMma
|
| 24 |
+
:param mma_tiler_mnk: The MMA tile shape
|
| 25 |
+
:type mma_tiler_mnk: cute.cute.Tile
|
| 26 |
+
:param a_dtype: The element type for tensor A
|
| 27 |
+
:type a_dtype: Type[Numeric]
|
| 28 |
+
:param num_stages: The number of pipeline stages for tensor A
|
| 29 |
+
:type num_stages: int
|
| 30 |
+
|
| 31 |
+
:return: SMEM layout for tensor A
|
| 32 |
+
:rtype: Union[cute.Layout, cute.ComposedLayout]
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
|
| 36 |
+
a_smem_shape = tiled_mma.partition_shape_A(
|
| 37 |
+
cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip)
|
| 38 |
+
)
|
| 39 |
+
a_smem_shape_mn_k = (
|
| 40 |
+
cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1],
|
| 41 |
+
cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
|
| 42 |
+
)
|
| 43 |
+
a_smem_layout_atom = sm100_utils_og.make_smem_layout_atom(
|
| 44 |
+
sm100_utils_og.get_smem_layout_atom_ab(
|
| 45 |
+
tiled_mma.op.a_major_mode,
|
| 46 |
+
a_dtype,
|
| 47 |
+
a_smem_shape_mn_k,
|
| 48 |
+
loc=loc,
|
| 49 |
+
ip=ip,
|
| 50 |
+
),
|
| 51 |
+
a_dtype,
|
| 52 |
+
loc=loc,
|
| 53 |
+
ip=ip,
|
| 54 |
+
)
|
| 55 |
+
a_smem_layout_staged = cute.tile_to_shape(
|
| 56 |
+
a_smem_layout_atom,
|
| 57 |
+
cute.append(a_smem_shape_mn_k, num_stages, loc=loc, ip=ip),
|
| 58 |
+
order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
|
| 59 |
+
loc=loc,
|
| 60 |
+
ip=ip,
|
| 61 |
+
)
|
| 62 |
+
return a_smem_layout_staged
|
build/torch-cuda/quack/sm90_utils.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Type, Union, Optional
|
| 4 |
+
|
| 5 |
+
import cutlass
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
| 8 |
+
from cutlass.cute.nvgpu import warpgroup
|
| 9 |
+
from cutlass.cutlass_dsl import Numeric, dsl_user_op
|
| 10 |
+
from cutlass import Float32, Int32, Boolean, const_expr
|
| 11 |
+
from cutlass.utils import LayoutEnum
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dsl_user_op
|
| 15 |
+
def make_smem_layout(
|
| 16 |
+
dtype: Type[Numeric],
|
| 17 |
+
layout: LayoutEnum,
|
| 18 |
+
tile: cute.Tile,
|
| 19 |
+
stage: Optional[int] = None,
|
| 20 |
+
*,
|
| 21 |
+
loc=None,
|
| 22 |
+
ip=None,
|
| 23 |
+
) -> Union[cute.Layout, cute.ComposedLayout]:
|
| 24 |
+
shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip)
|
| 25 |
+
major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
|
| 26 |
+
smem_layout_atom = warpgroup.make_smem_layout_atom(
|
| 27 |
+
sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
|
| 28 |
+
dtype,
|
| 29 |
+
)
|
| 30 |
+
order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
|
| 31 |
+
smem_layout_staged = cute.tile_to_shape(
|
| 32 |
+
smem_layout_atom,
|
| 33 |
+
cute.append(shape, stage) if const_expr(stage is not None) else shape,
|
| 34 |
+
order=order if const_expr(stage is not None) else order[:2],
|
| 35 |
+
)
|
| 36 |
+
return smem_layout_staged
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# For compatibility with blackwell_helpers.py
|
| 40 |
+
make_smem_layout_epi = make_smem_layout
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dsl_user_op
|
| 44 |
+
def partition_for_epilogue(
|
| 45 |
+
cT: cute.Tensor,
|
| 46 |
+
epi_tile: cute.Tile,
|
| 47 |
+
tiled_copy: cute.TiledCopy,
|
| 48 |
+
tidx: Int32,
|
| 49 |
+
reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy
|
| 50 |
+
*,
|
| 51 |
+
loc=None,
|
| 52 |
+
ip=None,
|
| 53 |
+
) -> cute.Tensor:
|
| 54 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 55 |
+
cT_epi = cute.flat_divide(cT, epi_tile)
|
| 56 |
+
# (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
|
| 57 |
+
if const_expr(reference_src):
|
| 58 |
+
return thr_copy.partition_S(cT_epi, loc=loc, ip=ip)
|
| 59 |
+
else:
|
| 60 |
+
return thr_copy.partition_D(cT_epi, loc=loc, ip=ip)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@cute.jit
|
| 64 |
+
def gemm(
|
| 65 |
+
tiled_mma: cute.TiledMma,
|
| 66 |
+
acc: cute.Tensor,
|
| 67 |
+
tCrA: cute.Tensor,
|
| 68 |
+
tCrB: cute.Tensor,
|
| 69 |
+
zero_init: cutlass.Constexpr[bool] = False,
|
| 70 |
+
wg_wait: cutlass.Constexpr[int] = 0,
|
| 71 |
+
# A_in_regs: cutlass.Constexpr[bool] = False,
|
| 72 |
+
swap_AB: cutlass.Constexpr[bool] = False,
|
| 73 |
+
) -> None:
|
| 74 |
+
if const_expr(swap_AB):
|
| 75 |
+
gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
|
| 76 |
+
else:
|
| 77 |
+
warpgroup.fence()
|
| 78 |
+
# We make a new mma_atom since we'll be modifying its attribute (accumulate).
|
| 79 |
+
# Otherwise the compiler complains "operand #0 does not dominate this use"
|
| 80 |
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
| 81 |
+
mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
|
| 82 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
| 83 |
+
cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
| 84 |
+
mma_atom.set(warpgroup.Field.ACCUMULATE, True)
|
| 85 |
+
warpgroup.commit_group()
|
| 86 |
+
if const_expr(wg_wait >= 0):
|
| 87 |
+
warpgroup.wait_group(wg_wait)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def gemm_zero_init(
|
| 91 |
+
tiled_mma: cute.TiledMma,
|
| 92 |
+
shape: cute.Shape,
|
| 93 |
+
tCrA: cute.Tensor,
|
| 94 |
+
tCrB: cute.Tensor,
|
| 95 |
+
A_idx: Optional[Int32] = None,
|
| 96 |
+
B_idx: Optional[Int32] = None,
|
| 97 |
+
wg_wait: int = -1,
|
| 98 |
+
swap_AB: bool = False,
|
| 99 |
+
) -> cute.Tensor:
|
| 100 |
+
if const_expr(swap_AB):
|
| 101 |
+
return gemm_zero_init(
|
| 102 |
+
tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
|
| 106 |
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
| 107 |
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
| 108 |
+
gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
|
| 109 |
+
return acc
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def gemm_w_idx(
|
| 113 |
+
tiled_mma: cute.TiledMma,
|
| 114 |
+
acc: cute.Tensor,
|
| 115 |
+
tCrA: cute.Tensor,
|
| 116 |
+
tCrB: cute.Tensor,
|
| 117 |
+
zero_init: Boolean,
|
| 118 |
+
A_idx: Optional[Int32] = None,
|
| 119 |
+
B_idx: Optional[Int32] = None,
|
| 120 |
+
wg_wait: int = -1,
|
| 121 |
+
swap_AB: bool = False,
|
| 122 |
+
) -> None:
|
| 123 |
+
if const_expr(swap_AB):
|
| 124 |
+
gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
|
| 125 |
+
else:
|
| 126 |
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
| 127 |
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
| 128 |
+
gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def partition_fragment_ABC(
|
| 132 |
+
thr_mma: cute.ThrMma,
|
| 133 |
+
shape_mnk: cute.Shape,
|
| 134 |
+
sA: Optional[cute.Tensor],
|
| 135 |
+
sB: Optional[cute.Tensor],
|
| 136 |
+
swap_AB: bool = False,
|
| 137 |
+
):
|
| 138 |
+
is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
|
| 139 |
+
if const_expr(not swap_AB):
|
| 140 |
+
acc = cute.make_fragment(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
|
| 141 |
+
if const_expr(not is_rs):
|
| 142 |
+
assert sA is not None
|
| 143 |
+
tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
|
| 144 |
+
else:
|
| 145 |
+
tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2])))
|
| 146 |
+
assert sB is not None
|
| 147 |
+
tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
|
| 148 |
+
else:
|
| 149 |
+
acc = cute.make_fragment(thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32)
|
| 150 |
+
if const_expr(not is_rs):
|
| 151 |
+
assert sB is not None
|
| 152 |
+
tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
|
| 153 |
+
else: # B in rmem
|
| 154 |
+
tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2])))
|
| 155 |
+
assert sA is not None
|
| 156 |
+
tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA))
|
| 157 |
+
return acc, tCrA, tCrB
|
build/torch-cuda/quack/sort/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
build/torch-cuda/quack/sort/bitonic_sort.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Int32, Float32, const_expr
|
| 9 |
+
|
| 10 |
+
from .. import utils
|
| 11 |
+
from .utils import compare_and_swap
|
| 12 |
+
from .sorting_networks import optimal_sort
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@cute.jit
|
| 16 |
+
def bitonic_merge(
|
| 17 |
+
arr: cute.Tensor,
|
| 18 |
+
n: Optional[cutlass.Constexpr[int]] = None,
|
| 19 |
+
start: cutlass.Constexpr[int] = 0,
|
| 20 |
+
ascending: cutlass.Constexpr[bool] = True,
|
| 21 |
+
) -> None:
|
| 22 |
+
"""Merge a bitonic sequence into a sorted sequence using iterative approach."""
|
| 23 |
+
if const_expr(n is None):
|
| 24 |
+
n = cute.size(arr.shape)
|
| 25 |
+
if const_expr(n > 1):
|
| 26 |
+
num_levels = int(math.log2(n))
|
| 27 |
+
assert n == 2**num_levels, "n must be a power of 2"
|
| 28 |
+
# This one must be range_constexpr otherwise it's very slow for n = 128
|
| 29 |
+
for level in cutlass.range_constexpr(num_levels):
|
| 30 |
+
length = n >> level # n // (2^level)
|
| 31 |
+
step = length // 2
|
| 32 |
+
for i in cutlass.range(n // length, unroll_full=True):
|
| 33 |
+
start_i = start + i * length
|
| 34 |
+
for j in cutlass.range(step, unroll_full=True):
|
| 35 |
+
compare_and_swap(arr, start_i + j, start_i + j + step, ascending)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@cute.jit
|
| 39 |
+
def bitonic_sort(
|
| 40 |
+
arr: cute.Tensor,
|
| 41 |
+
n: Optional[cutlass.Constexpr[int]] = None,
|
| 42 |
+
start: cutlass.Constexpr[int] = 0,
|
| 43 |
+
ascending: cutlass.Constexpr[bool] = True,
|
| 44 |
+
) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Bitonic sort for small arrays of size N (power of 2, N <= 128).
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
arr: Array to sort
|
| 50 |
+
n: Size of array (must be power of 2 and <= 128)
|
| 51 |
+
start: Starting index (default 0)
|
| 52 |
+
ascending: Sort in ascending order (default True)
|
| 53 |
+
"""
|
| 54 |
+
if const_expr(n is None):
|
| 55 |
+
n = cute.size(arr.shape)
|
| 56 |
+
assert n <= 128
|
| 57 |
+
if const_expr(n > 1):
|
| 58 |
+
if const_expr(n in [2, 4, 8, 16, 32, 64]):
|
| 59 |
+
optimal_sort(arr, n, start, ascending)
|
| 60 |
+
else: # Fall back to bitonic sort
|
| 61 |
+
assert n % 2 == 0
|
| 62 |
+
# Sort first half in ascending order
|
| 63 |
+
bitonic_sort(arr, n // 2, start, True)
|
| 64 |
+
# Sort second half in descending order
|
| 65 |
+
bitonic_sort(arr, n // 2, start + n // 2, False)
|
| 66 |
+
# Merge the whole sequence
|
| 67 |
+
bitonic_merge(arr, n, start, ascending)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@cute.jit
|
| 71 |
+
def bitonic_topk_merge(
|
| 72 |
+
arr0: cute.Tensor,
|
| 73 |
+
arr1: cute.Tensor,
|
| 74 |
+
k: Optional[cutlass.Constexpr[int]] = None,
|
| 75 |
+
start0: cutlass.Constexpr[int] = 0,
|
| 76 |
+
start1: cutlass.Constexpr[int] = 0,
|
| 77 |
+
ascending: cutlass.Constexpr[bool] = False,
|
| 78 |
+
) -> None:
|
| 79 |
+
if const_expr(k is None):
|
| 80 |
+
k = cute.size(arr0.shape)
|
| 81 |
+
if const_expr(arr0.element_type == Float32):
|
| 82 |
+
minmax_fn = utils.fmin if ascending else cute.arch.fmax
|
| 83 |
+
else:
|
| 84 |
+
minmax_fn = min if ascending else max
|
| 85 |
+
# Write the top k elements to the first half of the array
|
| 86 |
+
for i in cutlass.range(k, unroll_full=True):
|
| 87 |
+
arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
|
| 88 |
+
# Now the 1st half is bitonic, we just need to merge it
|
| 89 |
+
bitonic_merge(arr0, k, start0, ascending)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@cute.jit
|
| 93 |
+
def bitonic_topk(
|
| 94 |
+
arr: cute.Tensor,
|
| 95 |
+
k: cutlass.Constexpr[int],
|
| 96 |
+
ascending: cutlass.Constexpr[bool] = False,
|
| 97 |
+
warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
| 98 |
+
) -> cute.Tensor:
|
| 99 |
+
"""
|
| 100 |
+
Bitonic top-k for small arrays of size N (power of 2, N <= 128).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
arr: Array to sort
|
| 104 |
+
k: must be power of 2 and <= 128
|
| 105 |
+
ascending: Sort in ascending order (default False)
|
| 106 |
+
"""
|
| 107 |
+
assert arr.element_type in [Float32, Int32]
|
| 108 |
+
n = cute.size(arr.shape)
|
| 109 |
+
assert k == 1 << int(math.log2(k)), "k must be a power of 2"
|
| 110 |
+
assert n % k == 0, "n must be divisible by k"
|
| 111 |
+
topk_vals = cute.make_fragment(k, arr.element_type)
|
| 112 |
+
for v in cutlass.range(k, unroll_full=True):
|
| 113 |
+
topk_vals[v] = arr[v]
|
| 114 |
+
bitonic_sort(topk_vals, ascending=ascending)
|
| 115 |
+
for i in cutlass.range(1, n // k, unroll_full=True):
|
| 116 |
+
other_vals = cute.make_fragment(k, arr.element_type)
|
| 117 |
+
for v in cutlass.range(k, unroll_full=True):
|
| 118 |
+
other_vals[v] = arr[i * k + v]
|
| 119 |
+
bitonic_sort(other_vals, ascending=ascending)
|
| 120 |
+
# Merge 2 sorted top-k sequences to get a new top-k sequence
|
| 121 |
+
bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
|
| 122 |
+
# TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
|
| 123 |
+
# do duplicate work.
|
| 124 |
+
for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
|
| 125 |
+
other_vals = cute.make_fragment(k, arr.element_type)
|
| 126 |
+
for v in cutlass.range(k, unroll_full=True):
|
| 127 |
+
other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
|
| 128 |
+
bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
|
| 129 |
+
return topk_vals
|
build/torch-cuda/quack/sort/generate_sorting_networks.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate optimized sorting network code from the optimal sorting network data.
|
| 4 |
+
Based on data from: https://bertdobbelaere.github.io/sorting_networks.html
|
| 5 |
+
|
| 6 |
+
This script generates CUTE DSL functions for optimal sorting networks of various sizes.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
from typing import List, Tuple, Dict
|
| 13 |
+
|
| 14 |
+
# Network strings from bertdobbelaere.github.io/sorting_networks.html
|
| 15 |
+
# Copy-paste network strings here, then run initialize_networks() to parse them
|
| 16 |
+
NETWORK_STRINGS = {
|
| 17 |
+
# Size 2: 1 CE, depth 1
|
| 18 |
+
2: """
|
| 19 |
+
[(0,1)]
|
| 20 |
+
""",
|
| 21 |
+
# Size 4: 5 CEs, depth 3
|
| 22 |
+
4: """
|
| 23 |
+
[(0,2),(1,3)]
|
| 24 |
+
[(0,1),(2,3)]
|
| 25 |
+
[(1,2)]
|
| 26 |
+
""",
|
| 27 |
+
# Size 8: 19 CEs, depth 6
|
| 28 |
+
8: """
|
| 29 |
+
[(0,2),(1,3),(4,6),(5,7)]
|
| 30 |
+
[(0,4),(1,5),(2,6),(3,7)]
|
| 31 |
+
[(0,1),(2,3),(4,5),(6,7)]
|
| 32 |
+
[(2,4),(3,5)]
|
| 33 |
+
[(1,4),(3,6)]
|
| 34 |
+
[(1,2),(3,4),(5,6)]
|
| 35 |
+
""",
|
| 36 |
+
# Size 16: 60 CEs, depth 10
|
| 37 |
+
16: """
|
| 38 |
+
[(0,13),(1,12),(2,15),(3,14),(4,8),(5,6),(7,11),(9,10)]
|
| 39 |
+
[(0,5),(1,7),(2,9),(3,4),(6,13),(8,14),(10,15),(11,12)]
|
| 40 |
+
[(0,1),(2,3),(4,5),(6,8),(7,9),(10,11),(12,13),(14,15)]
|
| 41 |
+
[(0,2),(1,3),(4,10),(5,11),(6,7),(8,9),(12,14),(13,15)]
|
| 42 |
+
[(1,2),(3,12),(4,6),(5,7),(8,10),(9,11),(13,14)]
|
| 43 |
+
[(1,4),(2,6),(5,8),(7,10),(9,13),(11,14)]
|
| 44 |
+
[(2,4),(3,6),(9,12),(11,13)]
|
| 45 |
+
[(3,5),(6,8),(7,9),(10,12)]
|
| 46 |
+
[(3,4),(5,6),(7,8),(9,10),(11,12)]
|
| 47 |
+
[(6,7),(8,9)]
|
| 48 |
+
""",
|
| 49 |
+
# Size 32: 185 CEs, depth 14
|
| 50 |
+
32: """
|
| 51 |
+
[(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31)]
|
| 52 |
+
[(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31)]
|
| 53 |
+
[(0,4),(1,5),(2,6),(3,7),(8,12),(9,13),(10,14),(11,15),(16,20),(17,21),(18,22),(19,23),(24,28),(25,29),(26,30),(27,31)]
|
| 54 |
+
[(0,8),(1,9),(2,10),(3,11),(4,12),(5,13),(6,14),(7,15),(16,24),(17,25),(18,26),(19,27),(20,28),(21,29),(22,30),(23,31)]
|
| 55 |
+
[(0,16),(1,8),(2,4),(3,12),(5,10),(6,9),(7,14),(11,13),(15,31),(17,24),(18,20),(19,28),(21,26),(22,25),(23,30),(27,29)]
|
| 56 |
+
[(1,2),(3,5),(4,8),(6,22),(7,11),(9,25),(10,12),(13,14),(17,18),(19,21),(20,24),(23,27),(26,28),(29,30)]
|
| 57 |
+
[(1,17),(2,18),(3,19),(4,20),(5,10),(7,23),(8,24),(11,27),(12,28),(13,29),(14,30),(21,26)]
|
| 58 |
+
[(3,17),(4,16),(5,21),(6,18),(7,9),(8,20),(10,26),(11,23),(13,25),(14,28),(15,27),(22,24)]
|
| 59 |
+
[(1,4),(3,8),(5,16),(7,17),(9,21),(10,22),(11,19),(12,20),(14,24),(15,26),(23,28),(27,30)]
|
| 60 |
+
[(2,5),(7,8),(9,18),(11,17),(12,16),(13,22),(14,20),(15,19),(23,24),(26,29)]
|
| 61 |
+
[(2,4),(6,12),(9,16),(10,11),(13,17),(14,18),(15,22),(19,25),(20,21),(27,29)]
|
| 62 |
+
[(5,6),(8,12),(9,10),(11,13),(14,16),(15,17),(18,20),(19,23),(21,22),(25,26)]
|
| 63 |
+
[(3,5),(6,7),(8,9),(10,12),(11,14),(13,16),(15,18),(17,20),(19,21),(22,23),(24,25),(26,28)]
|
| 64 |
+
[(3,4),(5,6),(7,8),(9,10),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28)]
|
| 65 |
+
""",
|
| 66 |
+
# Size 64: 512 CEs, depth 21
|
| 67 |
+
64: """
|
| 68 |
+
[(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31),(32,34),(33,35),(36,38),(37,39),(40,42),(41,43),(44,46),(45,47),(48,50),(49,51),(52,54),(53,55),(56,58),(57,59),(60,62),(61,63)]
|
| 69 |
+
[(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31),(32,33),(34,35),(36,37),(38,39),(40,41),(42,43),(44,45),(46,47),(48,49),(50,51),(52,53),(54,55),(56,57),(58,59),(60,61),(62,63)]
|
| 70 |
+
[(0,52),(1,2),(3,55),(4,48),(5,6),(7,51),(8,60),(9,10),(11,63),(12,56),(13,14),(15,59),(16,32),(17,18),(19,35),(20,24),(21,22),(23,27),(25,26),(28,44),(29,30),(31,47),(33,34),(36,40),(37,38),(39,43),(41,42),(45,46),(49,50),(53,54),(57,58),(61,62)]
|
| 71 |
+
[(0,20),(1,53),(2,54),(3,23),(4,28),(5,49),(6,50),(7,31),(8,36),(9,61),(10,62),(11,39),(12,16),(13,57),(14,58),(15,19),(17,33),(18,34),(21,25),(22,26),(24,52),(27,55),(29,45),(30,46),(32,56),(35,59),(37,41),(38,42),(40,60),(43,63),(44,48),(47,51)]
|
| 72 |
+
[(0,4),(1,21),(2,22),(3,7),(5,29),(6,30),(8,12),(9,37),(10,38),(11,15),(13,17),(14,18),(16,20),(19,23),(24,32),(25,53),(26,54),(27,35),(28,36),(31,39),(33,57),(34,58),(40,44),(41,61),(42,62),(43,47),(45,49),(46,50),(48,52),(51,55),(56,60),(59,63)]
|
| 73 |
+
[(0,8),(1,5),(2,6),(3,11),(4,12),(7,15),(9,13),(10,14),(16,40),(17,21),(18,22),(19,43),(20,44),(23,47),(24,28),(25,33),(26,34),(27,31),(29,37),(30,38),(32,36),(35,39),(41,45),(42,46),(48,56),(49,53),(50,54),(51,59),(52,60),(55,63),(57,61),(58,62)]
|
| 74 |
+
[(1,9),(2,10),(4,8),(5,13),(6,14),(7,11),(12,48),(15,51),(16,24),(17,41),(18,42),(19,27),(20,28),(21,45),(22,46),(23,31),(25,29),(26,30),(32,40),(33,37),(34,38),(35,43),(36,44),(39,47),(49,57),(50,58),(52,56),(53,61),(54,62),(55,59)]
|
| 75 |
+
[(4,16),(5,9),(6,10),(7,19),(8,24),(11,27),(13,49),(14,50),(17,25),(18,26),(20,32),(21,29),(22,30),(23,35),(28,40),(31,43),(33,41),(34,42),(36,52),(37,45),(38,46),(39,55),(44,56),(47,59),(53,57),(54,58)]
|
| 76 |
+
[(1,4),(5,17),(6,18),(8,16),(9,25),(10,26),(11,19),(12,24),(15,27),(21,33),(22,34),(29,41),(30,42),(36,48),(37,53),(38,54),(39,51),(44,52),(45,57),(46,58),(47,55),(59,62)]
|
| 77 |
+
[(2,8),(9,17),(10,18),(12,20),(13,25),(14,26),(15,23),(24,32),(27,35),(28,36),(31,39),(37,49),(38,50),(40,48),(43,51),(45,53),(46,54),(55,61)]
|
| 78 |
+
[(2,4),(12,16),(13,21),(14,22),(15,19),(20,24),(23,27),(25,33),(26,34),(28,32),(29,37),(30,38),(31,35),(36,40),(39,43),(41,49),(42,50),(44,48),(47,51),(59,61)]
|
| 79 |
+
[(4,16),(5,20),(10,40),(13,17),(14,18),(21,25),(22,26),(23,53),(24,28),(27,31),(29,33),(30,34),(32,36),(35,39),(37,41),(38,42),(43,58),(45,49),(46,50),(47,59)]
|
| 80 |
+
[(3,17),(6,36),(7,21),(8,32),(9,24),(11,41),(13,28),(14,44),(15,45),(18,48),(19,49),(22,52),(25,29),(26,30),(27,57),(31,55),(33,37),(34,38),(35,50),(39,54),(42,56),(46,60)]
|
| 81 |
+
[(6,20),(8,16),(10,24),(11,25),(14,28),(15,29),(17,33),(18,32),(21,37),(22,36),(26,42),(27,41),(30,46),(31,45),(34,48),(35,49),(38,52),(39,53),(43,57),(47,55)]
|
| 82 |
+
[(3,18),(5,8),(6,12),(7,22),(15,21),(17,32),(19,33),(23,37),(26,40),(30,44),(31,46),(41,56),(42,48),(45,60),(51,57),(55,58)]
|
| 83 |
+
[(3,16),(7,20),(11,26),(18,24),(19,25),(22,28),(23,29),(27,33),(30,36),(34,40),(35,41),(37,52),(38,44),(39,45),(43,56),(47,60)]
|
| 84 |
+
[(3,9),(7,13),(10,16),(11,17),(14,20),(15,30),(19,34),(21,36),(23,38),(25,40),(26,32),(27,42),(29,44),(31,37),(33,48),(43,49),(46,52),(47,53),(50,56),(54,60)]
|
| 85 |
+
[(3,8),(7,10),(9,12),(11,18),(13,14),(15,24),(17,22),(19,28),(21,26),(23,25),(27,34),(29,36),(30,32),(31,33),(35,44),(37,42),(38,40),(39,48),(41,46),(45,52),(49,50),(51,54),(53,56),(55,60)]
|
| 86 |
+
[(3,6),(7,12),(11,16),(15,17),(18,20),(19,24),(21,22),(23,30),(25,32),(26,28),(27,29),(31,38),(33,40),(34,36),(35,37),(39,44),(41,42),(43,45),(46,48),(47,52),(51,56),(57,60)]
|
| 87 |
+
[(3,5),(6,8),(7,9),(10,12),(11,13),(14,16),(15,18),(17,20),(19,21),(22,24),(23,26),(25,28),(27,30),(29,32),(31,34),(33,36),(35,38),(37,40),(39,41),(42,44),(43,46),(45,48),(47,49),(50,52),(51,53),(54,56),(55,57),(58,60)]
|
| 88 |
+
[(3,4),(7,8),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28),(29,30),(31,32),(33,34),(35,36),(37,38),(39,40),(41,42),(43,44),(45,46),(47,48),(49,50),(51,52),(55,56),(59,60)]
|
| 89 |
+
""",
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# This will be populated by initialize_networks()
|
| 93 |
+
OPTIMAL_NETWORKS: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] = {}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def parse_network_string(network_str: str) -> List[List[Tuple[int, int]]]:
|
| 97 |
+
"""
|
| 98 |
+
Parse a sorting network string from bertdobbelaere.github.io format.
|
| 99 |
+
|
| 100 |
+
Examples:
|
| 101 |
+
Input: "[(0,2),(1,3)], [(0,1),(2,3)], [(1,2)]"
|
| 102 |
+
Output: [[(0, 2), (1, 3)], [(0, 1), (2, 3)], [(1, 2)]]
|
| 103 |
+
|
| 104 |
+
Input: "[(0,1)], [(1,2)], [(0,1)]"
|
| 105 |
+
Output: [[(0, 1)], [(1, 2)], [(0, 1)]]
|
| 106 |
+
"""
|
| 107 |
+
# Remove whitespace and split by '], ['
|
| 108 |
+
network_str = network_str.strip()
|
| 109 |
+
if not network_str:
|
| 110 |
+
return []
|
| 111 |
+
|
| 112 |
+
# Split into layer strings
|
| 113 |
+
layer_pattern = r"\[((?:\(\d+,\d+\)(?:,\(\d+,\d+\))*)?)\]"
|
| 114 |
+
layers = []
|
| 115 |
+
|
| 116 |
+
for match in re.finditer(layer_pattern, network_str):
|
| 117 |
+
layer_str = match.group(1)
|
| 118 |
+
if not layer_str.strip():
|
| 119 |
+
layers.append([])
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
# Parse comparisons in this layer: (i,j), (k,l), ...
|
| 123 |
+
comparisons = []
|
| 124 |
+
comp_pattern = r"\((\d+),(\d+)\)"
|
| 125 |
+
|
| 126 |
+
for comp_match in re.finditer(comp_pattern, layer_str):
|
| 127 |
+
i, j = int(comp_match.group(1)), int(comp_match.group(2))
|
| 128 |
+
comparisons.append((i, j))
|
| 129 |
+
|
| 130 |
+
layers.append(comparisons)
|
| 131 |
+
|
| 132 |
+
return layers
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def calculate_network_stats(layers: List[List[Tuple[int, int]]]) -> Tuple[int, int, int]:
|
| 136 |
+
"""Calculate depth, total comparisons, and max index from network layers."""
|
| 137 |
+
depth = len(layers)
|
| 138 |
+
total_comparisons = sum(len(layer) for layer in layers)
|
| 139 |
+
|
| 140 |
+
# Find maximum index to determine network size
|
| 141 |
+
max_index = 0
|
| 142 |
+
for layer in layers:
|
| 143 |
+
for i, j in layer:
|
| 144 |
+
max_index = max(max_index, i, j)
|
| 145 |
+
|
| 146 |
+
network_size = max_index + 1 # Since indices are 0-based
|
| 147 |
+
return depth, total_comparisons, network_size
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def add_network_from_string(size: int, network_str: str, description: str = ""):
|
| 151 |
+
"""
|
| 152 |
+
Add a network from a string representation to the OPTIMAL_NETWORKS dictionary.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
size: Size of the network (number of elements)
|
| 156 |
+
network_str: Network string in bertdobbelaere.github.io format
|
| 157 |
+
description: Optional description for debugging
|
| 158 |
+
"""
|
| 159 |
+
try:
|
| 160 |
+
layers = parse_network_string(network_str)
|
| 161 |
+
depth, comparisons, detected_size = calculate_network_stats(layers)
|
| 162 |
+
|
| 163 |
+
if detected_size != size:
|
| 164 |
+
print(f"Warning: Network size mismatch! Expected {size}, detected {detected_size}")
|
| 165 |
+
print(f"Network string: {network_str[:100]}...")
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
OPTIMAL_NETWORKS[size] = (depth, comparisons, layers)
|
| 169 |
+
|
| 170 |
+
if description:
|
| 171 |
+
print(f"Added network for size {size}: {description}")
|
| 172 |
+
print(f" Depth: {depth}, Comparisons: {comparisons}")
|
| 173 |
+
return True
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Error parsing network for size {size}: {e}")
|
| 177 |
+
print(f"Network string: {network_str[:100]}...")
|
| 178 |
+
return False
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def generate_networks_dict(
|
| 182 |
+
networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]]
|
| 183 |
+
) -> str:
|
| 184 |
+
"""Generate the global networks dictionary."""
|
| 185 |
+
lines = ["networks = {"]
|
| 186 |
+
|
| 187 |
+
for size, (depth, num_comparisons, layers) in sorted(networks_data.items()):
|
| 188 |
+
# Format the network with proper indentation and newlines
|
| 189 |
+
network_lines = []
|
| 190 |
+
for i, layer in enumerate(layers):
|
| 191 |
+
if i == 0:
|
| 192 |
+
network_lines.append(f" {layer}")
|
| 193 |
+
else:
|
| 194 |
+
network_lines.append(f",\n {layer}")
|
| 195 |
+
|
| 196 |
+
if len(layers) == 1:
|
| 197 |
+
network_str = f"[{network_lines[0].strip()}]"
|
| 198 |
+
else:
|
| 199 |
+
network_str = "[\n" + "".join(network_lines) + "\n ]"
|
| 200 |
+
|
| 201 |
+
lines.append(f" # Size {size}: {num_comparisons} CEs, depth {depth}")
|
| 202 |
+
lines.append(f" {size}: {network_str},")
|
| 203 |
+
lines.append("")
|
| 204 |
+
|
| 205 |
+
lines.append("}")
|
| 206 |
+
return "\n".join(lines)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def generate_optimal_sort_function() -> str:
|
| 210 |
+
"""Generate the single optimal_sort function that looks up networks by size."""
|
| 211 |
+
return """@cute.jit
|
| 212 |
+
def optimal_sort(
|
| 213 |
+
arr: cute.Tensor,
|
| 214 |
+
n: cutlass.Constexpr[int],
|
| 215 |
+
start: cutlass.Constexpr[int] = 0,
|
| 216 |
+
ascending: cutlass.Constexpr[bool] = True
|
| 217 |
+
) -> None:
|
| 218 |
+
\"\"\"
|
| 219 |
+
Optimal sorting network dispatcher.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
arr: Array to sort
|
| 223 |
+
n: Size of array (must be power of 2 and available in networks)
|
| 224 |
+
start: Starting index (default 0)
|
| 225 |
+
ascending: Sort in ascending order (default True)
|
| 226 |
+
|
| 227 |
+
Source: https://bertdobbelaere.github.io/sorting_networks.html
|
| 228 |
+
\"\"\"
|
| 229 |
+
assert n in networks
|
| 230 |
+
for level in networks[n]:
|
| 231 |
+
for i, j in level:
|
| 232 |
+
compare_and_swap(arr, start + i, start + j, ascending)
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def generate_sorting_networks_file(max_size: int = 64):
|
| 237 |
+
"""Generate a complete sorting networks file with optimal networks up to max_size."""
|
| 238 |
+
|
| 239 |
+
output_file = os.path.join(os.path.dirname(__file__), "sorting_networks.py")
|
| 240 |
+
|
| 241 |
+
# Header
|
| 242 |
+
header = '''# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
| 243 |
+
"""
|
| 244 |
+
Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
|
| 245 |
+
|
| 246 |
+
This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
# fmt: off
|
| 250 |
+
# ruff: noqa
|
| 251 |
+
# isort: skip_file
|
| 252 |
+
|
| 253 |
+
import cutlass
|
| 254 |
+
import cutlass.cute as cute
|
| 255 |
+
|
| 256 |
+
from .utils import compare_and_swap
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
'''
|
| 260 |
+
|
| 261 |
+
# Generate networks dictionary and optimal_sort function
|
| 262 |
+
sizes = [n for n in range(2, max_size + 1) if n in OPTIMAL_NETWORKS]
|
| 263 |
+
networks_dict = generate_networks_dict(OPTIMAL_NETWORKS)
|
| 264 |
+
optimal_sort_func = generate_optimal_sort_function()
|
| 265 |
+
|
| 266 |
+
# Combine everything
|
| 267 |
+
content = header + networks_dict + "\n\n\n" + optimal_sort_func
|
| 268 |
+
|
| 269 |
+
with open(output_file, "w") as f:
|
| 270 |
+
f.write(content)
|
| 271 |
+
|
| 272 |
+
print(f"Generated optimal sorting networks for sizes {sizes}")
|
| 273 |
+
print(f"Output written to: {output_file}")
|
| 274 |
+
return sizes
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def initialize_networks():
|
| 278 |
+
"""Initialize the OPTIMAL_NETWORKS dictionary by parsing NETWORK_STRINGS."""
|
| 279 |
+
global OPTIMAL_NETWORKS
|
| 280 |
+
OPTIMAL_NETWORKS.clear()
|
| 281 |
+
|
| 282 |
+
for size, network_str in NETWORK_STRINGS.items():
|
| 283 |
+
success = add_network_from_string(size, network_str, f"Size {size} optimal network")
|
| 284 |
+
if not success:
|
| 285 |
+
print(f"Warning: Failed to parse network for size {size}")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def main():
|
| 289 |
+
parser = argparse.ArgumentParser(
|
| 290 |
+
description="Generate optimal sorting network code from bertdobbelaere.github.io data"
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--max-size",
|
| 294 |
+
"-m",
|
| 295 |
+
type=int,
|
| 296 |
+
default=64,
|
| 297 |
+
help="Maximum sorting network size to generate (default: 32)",
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--stats", "-s", action="store_true", help="Print statistics about the optimal networks"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
args = parser.parse_args()
|
| 304 |
+
|
| 305 |
+
# Initialize networks from strings
|
| 306 |
+
initialize_networks()
|
| 307 |
+
|
| 308 |
+
if args.stats:
|
| 309 |
+
print("Optimal Sorting Network Statistics:")
|
| 310 |
+
print("Size\tDepth\tComparisons\tLayers")
|
| 311 |
+
print("-" * 35)
|
| 312 |
+
for n in sorted(OPTIMAL_NETWORKS.keys()):
|
| 313 |
+
if n <= args.max_size:
|
| 314 |
+
depth, comparisons, layers = OPTIMAL_NETWORKS[n]
|
| 315 |
+
print(f"{n}\t{depth}\t{comparisons}\t\t{len(layers)}")
|
| 316 |
+
|
| 317 |
+
# Generate the sorting networks file
|
| 318 |
+
sizes = generate_sorting_networks_file(args.max_size)
|
| 319 |
+
|
| 320 |
+
print(f"\nGenerated optimal sorting networks for {len(sizes)} sizes")
|
| 321 |
+
print(f"Total networks: {len(sizes)}")
|
| 322 |
+
print(f"Max network size: {max(sizes)}")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
if __name__ == "__main__":
|
| 326 |
+
main()
|
build/torch-cuda/quack/sort/sorting_networks.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
| 2 |
+
"""
|
| 3 |
+
Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
|
| 4 |
+
|
| 5 |
+
This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# fmt: off
|
| 9 |
+
# ruff: noqa
|
| 10 |
+
# isort: skip_file
|
| 11 |
+
|
| 12 |
+
import cutlass
|
| 13 |
+
import cutlass.cute as cute
|
| 14 |
+
|
| 15 |
+
from .utils import compare_and_swap
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
networks = {
|
| 19 |
+
# Size 2: 1 CEs, depth 1
|
| 20 |
+
2: [[(0, 1)]],
|
| 21 |
+
|
| 22 |
+
# Size 4: 5 CEs, depth 3
|
| 23 |
+
4: [
|
| 24 |
+
[(0, 2), (1, 3)],
|
| 25 |
+
[(0, 1), (2, 3)],
|
| 26 |
+
[(1, 2)]
|
| 27 |
+
],
|
| 28 |
+
|
| 29 |
+
# Size 8: 19 CEs, depth 6
|
| 30 |
+
8: [
|
| 31 |
+
[(0, 2), (1, 3), (4, 6), (5, 7)],
|
| 32 |
+
[(0, 4), (1, 5), (2, 6), (3, 7)],
|
| 33 |
+
[(0, 1), (2, 3), (4, 5), (6, 7)],
|
| 34 |
+
[(2, 4), (3, 5)],
|
| 35 |
+
[(1, 4), (3, 6)],
|
| 36 |
+
[(1, 2), (3, 4), (5, 6)]
|
| 37 |
+
],
|
| 38 |
+
|
| 39 |
+
# Size 16: 60 CEs, depth 10
|
| 40 |
+
16: [
|
| 41 |
+
[(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)],
|
| 42 |
+
[(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)],
|
| 43 |
+
[(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)],
|
| 44 |
+
[(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)],
|
| 45 |
+
[(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)],
|
| 46 |
+
[(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)],
|
| 47 |
+
[(2, 4), (3, 6), (9, 12), (11, 13)],
|
| 48 |
+
[(3, 5), (6, 8), (7, 9), (10, 12)],
|
| 49 |
+
[(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)],
|
| 50 |
+
[(6, 7), (8, 9)]
|
| 51 |
+
],
|
| 52 |
+
|
| 53 |
+
# Size 32: 185 CEs, depth 14
|
| 54 |
+
32: [
|
| 55 |
+
[(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)],
|
| 56 |
+
[(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)],
|
| 57 |
+
[(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)],
|
| 58 |
+
[(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)],
|
| 59 |
+
[(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)],
|
| 60 |
+
[(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)],
|
| 61 |
+
[(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)],
|
| 62 |
+
[(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)],
|
| 63 |
+
[(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)],
|
| 64 |
+
[(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)],
|
| 65 |
+
[(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)],
|
| 66 |
+
[(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)],
|
| 67 |
+
[(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)],
|
| 68 |
+
[(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)]
|
| 69 |
+
],
|
| 70 |
+
|
| 71 |
+
# Size 64: 521 CEs, depth 21
|
| 72 |
+
64: [
|
| 73 |
+
[(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)],
|
| 74 |
+
[(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)],
|
| 75 |
+
[(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)],
|
| 76 |
+
[(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)],
|
| 77 |
+
[(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)],
|
| 78 |
+
[(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)],
|
| 79 |
+
[(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)],
|
| 80 |
+
[(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)],
|
| 81 |
+
[(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)],
|
| 82 |
+
[(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)],
|
| 83 |
+
[(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)],
|
| 84 |
+
[(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)],
|
| 85 |
+
[(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)],
|
| 86 |
+
[(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)],
|
| 87 |
+
[(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)],
|
| 88 |
+
[(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)],
|
| 89 |
+
[(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)],
|
| 90 |
+
[(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)],
|
| 91 |
+
[(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)],
|
| 92 |
+
[(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)],
|
| 93 |
+
[(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)]
|
| 94 |
+
],
|
| 95 |
+
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@cute.jit
|
| 100 |
+
def optimal_sort(
|
| 101 |
+
arr: cute.Tensor,
|
| 102 |
+
n: cutlass.Constexpr[int],
|
| 103 |
+
start: cutlass.Constexpr[int] = 0,
|
| 104 |
+
ascending: cutlass.Constexpr[bool] = True
|
| 105 |
+
) -> None:
|
| 106 |
+
"""
|
| 107 |
+
Optimal sorting network dispatcher.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
arr: Array to sort
|
| 111 |
+
n: Size of array (must be power of 2 and available in networks)
|
| 112 |
+
start: Starting index (default 0)
|
| 113 |
+
ascending: Sort in ascending order (default True)
|
| 114 |
+
|
| 115 |
+
Source: https://bertdobbelaere.github.io/sorting_networks.html
|
| 116 |
+
"""
|
| 117 |
+
assert n in networks
|
| 118 |
+
for level in networks[n]:
|
| 119 |
+
for i, j in level:
|
| 120 |
+
compare_and_swap(arr, start + i, start + j, ascending)
|
build/torch-cuda/quack/sort/utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cutlass.cute as cute
|
| 2 |
+
from cutlass import Float32, const_expr
|
| 3 |
+
|
| 4 |
+
from .. import utils
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@cute.jit
|
| 8 |
+
def compare_and_swap(
|
| 9 |
+
arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
|
| 10 |
+
) -> None:
|
| 11 |
+
"""Compare and swap elements at indices i and j in ascending or descending order."""
|
| 12 |
+
if const_expr(use_selection):
|
| 13 |
+
a, b = arr[i], arr[j]
|
| 14 |
+
if (a > b) ^ (not ascending):
|
| 15 |
+
arr[i] = b
|
| 16 |
+
arr[j] = a
|
| 17 |
+
# if const_expr(ascending):
|
| 18 |
+
# if a > b:
|
| 19 |
+
# arr[i] = b
|
| 20 |
+
# arr[j] = a
|
| 21 |
+
# else:
|
| 22 |
+
# if a < b:
|
| 23 |
+
# arr[i] = b
|
| 24 |
+
# arr[j] = a
|
| 25 |
+
else:
|
| 26 |
+
min_fn = min if const_expr(arr.element_type != Float32) else utils.fmin
|
| 27 |
+
max_fn = max if const_expr(arr.element_type != Float32) else cute.arch.fmax
|
| 28 |
+
if const_expr(ascending):
|
| 29 |
+
arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
|
| 30 |
+
else:
|
| 31 |
+
arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])
|
build/torch-cuda/quack/tensormap_manager.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass.cutlass_dsl import Boolean, const_expr, Int32
|
| 9 |
+
from cutlass.utils import TensorMapUpdateMode, TensorMapManager
|
| 10 |
+
from cutlass._mlir.dialects import llvm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class TensorMapManagerSm90(TensorMapManager):
|
| 15 |
+
"""
|
| 16 |
+
We have to subclass cutlass.utils.TensorMapManager bc it takes in warp_id and only
|
| 17 |
+
perform the operation if warp_id matches the current warp.
|
| 18 |
+
But for Hopper pingpong gemm we want to call it with warp_id 0 and 4.
|
| 19 |
+
So we take in a boolean `is_manager_warp` to determine whether to perform the operation or not.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
@cute.jit
|
| 23 |
+
def init_tensormap_from_atom(
|
| 24 |
+
self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, is_manager_warp: Boolean
|
| 25 |
+
) -> None:
|
| 26 |
+
if is_manager_warp:
|
| 27 |
+
with cute.arch.elect_one():
|
| 28 |
+
cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr)
|
| 29 |
+
cute.arch.sync_warp()
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
@cute.jit
|
| 33 |
+
def update_tensormap(
|
| 34 |
+
self,
|
| 35 |
+
tensor_gmem: Tuple[cute.Tensor, ...],
|
| 36 |
+
tma_copy_atom: Tuple[cute.CopyAtom, ...],
|
| 37 |
+
tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
|
| 38 |
+
is_manager_warp: Boolean,
|
| 39 |
+
tensormap_smem_ptr: Tuple[cute.Pointer, ...],
|
| 40 |
+
) -> None:
|
| 41 |
+
# updates before touching tensormap in global memory
|
| 42 |
+
if is_manager_warp:
|
| 43 |
+
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
| 44 |
+
for copy_atom, tensor, smem_ptr in zip(
|
| 45 |
+
tma_copy_atom, tensor_gmem, tensormap_smem_ptr
|
| 46 |
+
):
|
| 47 |
+
cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, smem_ptr)
|
| 48 |
+
# wait until it's safe to update tensormap in global memory
|
| 49 |
+
with cute.arch.elect_one():
|
| 50 |
+
cute.arch.cp_async_bulk_commit_group()
|
| 51 |
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
| 52 |
+
cute.arch.sync_warp()
|
| 53 |
+
# updates to tensormap in global memory
|
| 54 |
+
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
| 55 |
+
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
|
| 56 |
+
cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
|
| 57 |
+
else:
|
| 58 |
+
for copy_atom, tensor, gmem_ptr in zip(
|
| 59 |
+
tma_copy_atom, tensor_gmem, tensormap_gmem_ptr
|
| 60 |
+
):
|
| 61 |
+
cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, gmem_ptr)
|
| 62 |
+
cute.arch.sync_warp()
|
| 63 |
+
cute.nvgpu.cpasync.fence_tma_desc_release()
|
| 64 |
+
|
| 65 |
+
@cute.jit
|
| 66 |
+
def update_tensormap_shape(
|
| 67 |
+
self,
|
| 68 |
+
tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
|
| 69 |
+
is_manager_warp: Boolean,
|
| 70 |
+
tensormap_smem_ptr: Tuple[cute.Pointer, ...],
|
| 71 |
+
shapes: Tuple[Int32, ...],
|
| 72 |
+
orders: cutlass.Constexpr[Tuple[int, ...]],
|
| 73 |
+
) -> None:
|
| 74 |
+
# updates before touching tensormap in global memory
|
| 75 |
+
if is_manager_warp:
|
| 76 |
+
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
| 77 |
+
for smem_ptr, shape, order in zip(tensormap_smem_ptr, shapes, orders):
|
| 78 |
+
smem_ptr_i32 = smem_ptr.toint().ir_value()
|
| 79 |
+
llvm.inline_asm(
|
| 80 |
+
None,
|
| 81 |
+
[smem_ptr_i32, Int32(shape).ir_value(), Int32(order).ir_value()],
|
| 82 |
+
"{\n\t"
|
| 83 |
+
".reg .b64 smem_ptr_i64;\n\t"
|
| 84 |
+
"cvt.u64.u32 smem_ptr_i64, $0;\n\t"
|
| 85 |
+
f"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [smem_ptr_i64], {order}, $1;\n\t"
|
| 86 |
+
"}\n",
|
| 87 |
+
"r,r",
|
| 88 |
+
has_side_effects=True,
|
| 89 |
+
is_align_stack=False,
|
| 90 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 91 |
+
)
|
| 92 |
+
# wait until it's safe to update tensormap in global memory
|
| 93 |
+
with cute.arch.elect_one():
|
| 94 |
+
cute.arch.cp_async_bulk_commit_group()
|
| 95 |
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
| 96 |
+
cute.arch.sync_warp()
|
| 97 |
+
# updates to tensormap in global memory
|
| 98 |
+
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
| 99 |
+
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
|
| 100 |
+
cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
|
| 101 |
+
else:
|
| 102 |
+
assert len(shapes) == len(orders) == len(tensormap_gmem_ptr)
|
| 103 |
+
for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders):
|
| 104 |
+
gmem_ptr_i64 = gmem_ptr.toint().ir_value()
|
| 105 |
+
llvm.inline_asm(
|
| 106 |
+
None,
|
| 107 |
+
[gmem_ptr_i64, Int32(shape).ir_value(), Int32(order).ir_value()],
|
| 108 |
+
f"tensormap.replace.tile.global_dim.global.b1024.b32 [$0], {order}, $1;",
|
| 109 |
+
"l,r",
|
| 110 |
+
has_side_effects=True,
|
| 111 |
+
is_align_stack=False,
|
| 112 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 113 |
+
)
|
| 114 |
+
cute.arch.sync_warp()
|
| 115 |
+
cute.nvgpu.cpasync.fence_tma_desc_release()
|