drbh commited on
Commit ·
69b4990
unverified ·
0
Parent(s):
Migrated from kernels-community/finegrained-fp8
Browse files- .gitattributes +35 -0
- build/torch-cuda/__init__.py +32 -0
- build/torch-cuda/_ops.py +8 -0
- build/torch-cuda/act_quant.py +73 -0
- build/torch-cuda/batched.py +390 -0
- build/torch-cuda/finegrained_fp8/__init__.py +26 -0
- build/torch-cuda/grouped.py +477 -0
- build/torch-cuda/matmul.py +411 -0
- build/torch-cuda/metadata.json +8 -0
- build/torch-cuda/utils.py +13 -0
- build/torch-rocm/__init__.py +32 -0
- build/torch-rocm/_ops.py +8 -0
- build/torch-rocm/act_quant.py +73 -0
- build/torch-rocm/batched.py +390 -0
- build/torch-rocm/finegrained_fp8/__init__.py +26 -0
- build/torch-rocm/grouped.py +477 -0
- build/torch-rocm/matmul.py +411 -0
- build/torch-rocm/metadata.json +8 -0
- build/torch-rocm/utils.py +13 -0
- build/torch-xpu/__init__.py +32 -0
- build/torch-xpu/_ops.py +8 -0
- build/torch-xpu/act_quant.py +73 -0
- build/torch-xpu/batched.py +390 -0
- build/torch-xpu/finegrained_fp8/__init__.py +26 -0
- build/torch-xpu/grouped.py +477 -0
- build/torch-xpu/matmul.py +411 -0
- build/torch-xpu/metadata.json +8 -0
- build/torch-xpu/utils.py +13 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
build/torch-cuda/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .act_quant import fp8_act_quant
|
| 2 |
+
from .batched import (
|
| 3 |
+
w8a8_fp8_matmul_batched,
|
| 4 |
+
w8a8_block_fp8_matmul_batched,
|
| 5 |
+
w8a8_tensor_fp8_matmul_batched,
|
| 6 |
+
)
|
| 7 |
+
from .grouped import (
|
| 8 |
+
w8a8_fp8_matmul_grouped,
|
| 9 |
+
w8a8_block_fp8_matmul_grouped,
|
| 10 |
+
w8a8_tensor_fp8_matmul_grouped,
|
| 11 |
+
)
|
| 12 |
+
from .matmul import (
|
| 13 |
+
w8a8_fp8_matmul,
|
| 14 |
+
w8a8_block_fp8_matmul,
|
| 15 |
+
w8a8_tensor_fp8_matmul,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"fp8_act_quant",
|
| 20 |
+
# Single matmul
|
| 21 |
+
"w8a8_fp8_matmul",
|
| 22 |
+
"w8a8_block_fp8_matmul",
|
| 23 |
+
"w8a8_tensor_fp8_matmul",
|
| 24 |
+
# Batched matmul
|
| 25 |
+
"w8a8_fp8_matmul_batched",
|
| 26 |
+
"w8a8_block_fp8_matmul_batched",
|
| 27 |
+
"w8a8_tensor_fp8_matmul_batched",
|
| 28 |
+
# Grouped matmul
|
| 29 |
+
"w8a8_fp8_matmul_grouped",
|
| 30 |
+
"w8a8_block_fp8_matmul_grouped",
|
| 31 |
+
"w8a8_tensor_fp8_matmul_grouped",
|
| 32 |
+
]
|
build/torch-cuda/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._finegrained_fp8_75cbe1b
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_finegrained_fp8_75cbe1b::{op_name}"
|
build/torch-cuda/act_quant.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch.library import triton_op, wrap_triton
|
| 19 |
+
|
| 20 |
+
from .utils import device_context
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_FP8_DTYPE = torch.float8_e4m3fn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
| 27 |
+
@triton.jit
|
| 28 |
+
def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
| 29 |
+
pid = tl.program_id(axis=0)
|
| 30 |
+
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
+
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
|
| 33 |
+
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
+
tl.store(y_ptr + offs, y)
|
| 35 |
+
tl.store(s_ptr + pid, s)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@triton_op("finegrained_fp8::fp8_act_quant", mutates_args=())
|
| 39 |
+
def _fp8_act_quant(
|
| 40 |
+
x: torch.Tensor, block_size: int = 128
|
| 41 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
assert x.is_contiguous()
|
| 43 |
+
assert x.shape[-1] % block_size == 0
|
| 44 |
+
y = torch.empty_like(x, dtype=_FP8_DTYPE)
|
| 45 |
+
grid = (triton.cdiv(x.numel(), block_size),)
|
| 46 |
+
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
| 47 |
+
|
| 48 |
+
with device_context(x.device):
|
| 49 |
+
wrap_triton(_fp8_act_quant_kernel)[grid](x, y, s, BLOCK_SIZE=block_size)
|
| 50 |
+
|
| 51 |
+
return y, s
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def fp8_act_quant(
|
| 55 |
+
x: torch.Tensor, block_size: int = 128
|
| 56 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 57 |
+
"""Quantize activations to FP8 with per-block dynamic scaling.
|
| 58 |
+
|
| 59 |
+
Splits the last dimension of ``x`` into blocks of ``block_size`` elements,
|
| 60 |
+
computes ``scale = max(|x_block|) / 448`` per block, and quantizes to
|
| 61 |
+
``float8_e4m3fn``.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
x: Input tensor in bf16/fp16/fp32. Last dimension must be divisible by
|
| 65 |
+
``block_size`` and the tensor must be contiguous.
|
| 66 |
+
block_size: Number of elements per quantization block (default: 128).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
A tuple ``(quantized, scales)`` where ``quantized`` has dtype
|
| 70 |
+
``float8_e4m3fn`` with the same shape as ``x``, and ``scales`` has
|
| 71 |
+
shape ``(*x.shape[:-1], x.shape[-1] // block_size)`` in float32.
|
| 72 |
+
"""
|
| 73 |
+
return torch.ops.finegrained_fp8.fp8_act_quant(x, block_size)
|
build/torch-cuda/batched.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from .act_quant import fp8_act_quant
|
| 19 |
+
from torch.library import triton_op, wrap_triton
|
| 20 |
+
|
| 21 |
+
from .utils import device_context
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4, 5]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_batched_kernel(
|
| 34 |
+
A, # (S, K) raw BF16/FP16 activations
|
| 35 |
+
B, # (E, N, K) FP8 weight matrices
|
| 36 |
+
C, # (S, N) output
|
| 37 |
+
Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
|
| 38 |
+
ExpertIds, # (S,) — which expert each batch element routes to
|
| 39 |
+
# Shape
|
| 40 |
+
S,
|
| 41 |
+
N,
|
| 42 |
+
K,
|
| 43 |
+
stride_am,
|
| 44 |
+
stride_ak,
|
| 45 |
+
stride_be,
|
| 46 |
+
stride_bk,
|
| 47 |
+
stride_bn,
|
| 48 |
+
stride_cm,
|
| 49 |
+
stride_cn,
|
| 50 |
+
stride_bs_e,
|
| 51 |
+
stride_bs_k,
|
| 52 |
+
stride_bs_n,
|
| 53 |
+
# Meta-parameters
|
| 54 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 55 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 56 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 57 |
+
):
|
| 58 |
+
"""Block-scale batched FP8 expert matmul kernel.
|
| 59 |
+
|
| 60 |
+
Each program handles one routed token row and one N-tile, looks up the
|
| 61 |
+
owning expert from ``ExpertIds``, and applies fused activation quantization.
|
| 62 |
+
"""
|
| 63 |
+
batch_id = tl.program_id(axis=0)
|
| 64 |
+
pid_n = tl.program_id(axis=1)
|
| 65 |
+
|
| 66 |
+
# Cast expert_id to int64 to prevent int32 overflow when computing
|
| 67 |
+
# expert_id * stride_Eb (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
|
| 68 |
+
# 3072×3072 FP8 weights).
|
| 69 |
+
expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
|
| 70 |
+
|
| 71 |
+
A = A + batch_id * stride_am
|
| 72 |
+
B = B + expert_id * stride_be
|
| 73 |
+
C = C + batch_id * stride_cm
|
| 74 |
+
Bs = Bs + expert_id * stride_bs_e
|
| 75 |
+
|
| 76 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 77 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 78 |
+
a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
|
| 79 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 80 |
+
|
| 81 |
+
bs_ptrs = Bs + pid_n * stride_bs_n
|
| 82 |
+
|
| 83 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 84 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 85 |
+
# ---- fused fp8_act_quant ----
|
| 86 |
+
a_raw = tl.load(a_ptrs).to(tl.float32)
|
| 87 |
+
a_s = tl.max(tl.abs(a_raw)) / 448.0
|
| 88 |
+
a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv)
|
| 89 |
+
# ---- matmul ----
|
| 90 |
+
b = tl.load(b_ptrs)
|
| 91 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 92 |
+
accumulator += tl.dot(a, b) * a_s * b_s[None, :]
|
| 93 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 94 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 95 |
+
|
| 96 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 97 |
+
c = accumulator.to(tl.bfloat16)
|
| 98 |
+
elif C.dtype.element_ty == tl.float16:
|
| 99 |
+
c = accumulator.to(tl.float16)
|
| 100 |
+
else:
|
| 101 |
+
c = accumulator.to(tl.float32)
|
| 102 |
+
|
| 103 |
+
offs_cm = tl.arange(0, BLOCK_SIZE_M)
|
| 104 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 105 |
+
c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
|
| 106 |
+
tl.store(c_ptrs, c)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@triton.autotune(
|
| 110 |
+
configs=[
|
| 111 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 112 |
+
for w in [2, 4, 8, 16]
|
| 113 |
+
for s in [2, 3, 4, 5]
|
| 114 |
+
],
|
| 115 |
+
key=["N", "K"],
|
| 116 |
+
)
|
| 117 |
+
@triton.jit
|
| 118 |
+
def w8a8_tensor_fp8_matmul_batched_kernel(
|
| 119 |
+
A, # (S, K) pre-quantized FP8 activations
|
| 120 |
+
B, # (E, N, K) FP8 weight matrices
|
| 121 |
+
C, # (S, N) output
|
| 122 |
+
As, # (S, 1) per-tensor activation scales
|
| 123 |
+
Bs, # (E, 1, 1) per-tensor weight scales
|
| 124 |
+
ExpertIds,
|
| 125 |
+
S,
|
| 126 |
+
N,
|
| 127 |
+
K,
|
| 128 |
+
stride_am,
|
| 129 |
+
stride_ak,
|
| 130 |
+
stride_be,
|
| 131 |
+
stride_bk,
|
| 132 |
+
stride_bn,
|
| 133 |
+
stride_cm,
|
| 134 |
+
stride_cn,
|
| 135 |
+
stride_as_m,
|
| 136 |
+
stride_bs_e,
|
| 137 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 138 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 139 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 140 |
+
):
|
| 141 |
+
"""Tensor-scale batched FP8 expert matmul kernel.
|
| 142 |
+
|
| 143 |
+
Activations are already quantized; the kernel applies per-token activation
|
| 144 |
+
scales and per-expert tensor weight scales.
|
| 145 |
+
"""
|
| 146 |
+
batch_id = tl.program_id(axis=0)
|
| 147 |
+
pid_n = tl.program_id(axis=1)
|
| 148 |
+
|
| 149 |
+
expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
|
| 150 |
+
|
| 151 |
+
A = A + batch_id * stride_am
|
| 152 |
+
B = B + expert_id * stride_be
|
| 153 |
+
C = C + batch_id * stride_cm
|
| 154 |
+
Bs = Bs + expert_id * stride_bs_e
|
| 155 |
+
|
| 156 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 157 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 158 |
+
a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
|
| 159 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 160 |
+
|
| 161 |
+
b_s = tl.load(Bs)
|
| 162 |
+
a_s = tl.load(As + batch_id * stride_as_m)
|
| 163 |
+
|
| 164 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 165 |
+
for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 166 |
+
a = tl.load(a_ptrs)
|
| 167 |
+
b = tl.load(b_ptrs)
|
| 168 |
+
accumulator += tl.dot(a, b)
|
| 169 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 170 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 171 |
+
|
| 172 |
+
accumulator = accumulator * a_s * b_s
|
| 173 |
+
|
| 174 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 175 |
+
c = accumulator.to(tl.bfloat16)
|
| 176 |
+
elif C.dtype.element_ty == tl.float16:
|
| 177 |
+
c = accumulator.to(tl.float16)
|
| 178 |
+
else:
|
| 179 |
+
c = accumulator.to(tl.float32)
|
| 180 |
+
|
| 181 |
+
offs_cm = tl.arange(0, BLOCK_SIZE_M)
|
| 182 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 183 |
+
c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
|
| 184 |
+
tl.store(c_ptrs, c)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul_batched", mutates_args=())
|
| 188 |
+
def _w8a8_block_fp8_matmul_batched(
|
| 189 |
+
A: torch.Tensor,
|
| 190 |
+
B: torch.Tensor,
|
| 191 |
+
Bs: torch.Tensor,
|
| 192 |
+
expert_ids: torch.Tensor,
|
| 193 |
+
block_size: list[int],
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""Block-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
|
| 196 |
+
|
| 197 |
+
A: (S, K) raw bf16/fp16 activations
|
| 198 |
+
B: (E, N, K) FP8 expert weights
|
| 199 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 200 |
+
"""
|
| 201 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 202 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 203 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 204 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 205 |
+
assert A.shape[1] == B.shape[2], (
|
| 206 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
S, K = A.shape
|
| 210 |
+
E, N, _ = B.shape
|
| 211 |
+
|
| 212 |
+
assert len(block_size) == 2, (
|
| 213 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 214 |
+
)
|
| 215 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 216 |
+
# MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
|
| 217 |
+
assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
|
| 218 |
+
assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
|
| 219 |
+
assert Bs.ndim == 3, (
|
| 220 |
+
f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 221 |
+
)
|
| 222 |
+
assert Bs.shape == (E, N // block_n, K // block_k), (
|
| 223 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
C = A.new_empty(S, N)
|
| 227 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 228 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 229 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 230 |
+
# both accuracy and performance for small M (decode).
|
| 231 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 232 |
+
grid = (S, triton.cdiv(N, block_n))
|
| 233 |
+
with device_context(A.device):
|
| 234 |
+
wrap_triton(w8a8_block_fp8_matmul_batched_kernel)[grid](
|
| 235 |
+
A,
|
| 236 |
+
B,
|
| 237 |
+
C,
|
| 238 |
+
Bs,
|
| 239 |
+
expert_ids,
|
| 240 |
+
S,
|
| 241 |
+
N,
|
| 242 |
+
K,
|
| 243 |
+
A.stride(0),
|
| 244 |
+
A.stride(1),
|
| 245 |
+
B.stride(0),
|
| 246 |
+
B.stride(2),
|
| 247 |
+
B.stride(1),
|
| 248 |
+
C.stride(0),
|
| 249 |
+
C.stride(1),
|
| 250 |
+
Bs.stride(0),
|
| 251 |
+
Bs.stride(2),
|
| 252 |
+
Bs.stride(1),
|
| 253 |
+
BLOCK_SIZE_N=block_n,
|
| 254 |
+
BLOCK_SIZE_K=block_k,
|
| 255 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
return C
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_batched", mutates_args=())
|
| 262 |
+
def _w8a8_tensor_fp8_matmul_batched(
|
| 263 |
+
A: torch.Tensor,
|
| 264 |
+
B: torch.Tensor,
|
| 265 |
+
Bs: torch.Tensor,
|
| 266 |
+
expert_ids: torch.Tensor,
|
| 267 |
+
) -> torch.Tensor:
|
| 268 |
+
"""Tensor-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
|
| 269 |
+
|
| 270 |
+
A: (S, K) raw bf16/fp16 activations
|
| 271 |
+
B: (E, N, K) FP8 expert weights
|
| 272 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 273 |
+
"""
|
| 274 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 275 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 276 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 277 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 278 |
+
assert A.shape[1] == B.shape[2], (
|
| 279 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
S, K = A.shape
|
| 283 |
+
E, N, _ = B.shape
|
| 284 |
+
|
| 285 |
+
# Normalize Bs to (E, 1, 1)
|
| 286 |
+
if Bs.ndim == 1:
|
| 287 |
+
assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
|
| 288 |
+
Bs = Bs.reshape(E, 1, 1)
|
| 289 |
+
else:
|
| 290 |
+
assert Bs.shape == (E, 1, 1), (
|
| 291 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
BLOCK_SIZE_N = 128
|
| 295 |
+
BLOCK_SIZE_K = 128
|
| 296 |
+
C = A.new_empty(S, N)
|
| 297 |
+
qA, As = fp8_act_quant(A, K)
|
| 298 |
+
grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
|
| 299 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 300 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 301 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 302 |
+
# both accuracy and performance for small M (decode).
|
| 303 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 304 |
+
grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
|
| 305 |
+
with device_context(A.device):
|
| 306 |
+
wrap_triton(w8a8_tensor_fp8_matmul_batched_kernel)[grid](
|
| 307 |
+
qA,
|
| 308 |
+
B,
|
| 309 |
+
C,
|
| 310 |
+
As,
|
| 311 |
+
Bs,
|
| 312 |
+
expert_ids,
|
| 313 |
+
S,
|
| 314 |
+
N,
|
| 315 |
+
K,
|
| 316 |
+
qA.stride(0),
|
| 317 |
+
qA.stride(1),
|
| 318 |
+
B.stride(0),
|
| 319 |
+
B.stride(2),
|
| 320 |
+
B.stride(1),
|
| 321 |
+
C.stride(0),
|
| 322 |
+
C.stride(1),
|
| 323 |
+
As.stride(0),
|
| 324 |
+
Bs.stride(0),
|
| 325 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 326 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 327 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return C
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def w8a8_block_fp8_matmul_batched(
|
| 334 |
+
A: torch.Tensor,
|
| 335 |
+
B: torch.Tensor,
|
| 336 |
+
Bs: torch.Tensor,
|
| 337 |
+
expert_ids: torch.Tensor,
|
| 338 |
+
block_size: list[int],
|
| 339 |
+
) -> torch.Tensor:
|
| 340 |
+
"""Block-scale batched FP8 matmul with fused activation quantization.
|
| 341 |
+
|
| 342 |
+
A: (S, K) raw activations, bf16/fp16/fp32
|
| 343 |
+
B: (E, N, K) FP8 expert weights
|
| 344 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 345 |
+
"""
|
| 346 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_batched(
|
| 347 |
+
A, B, Bs, expert_ids, block_size
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def w8a8_tensor_fp8_matmul_batched(
|
| 352 |
+
A: torch.Tensor,
|
| 353 |
+
B: torch.Tensor,
|
| 354 |
+
Bs: torch.Tensor,
|
| 355 |
+
expert_ids: torch.Tensor,
|
| 356 |
+
) -> torch.Tensor:
|
| 357 |
+
"""Tensor-scale batched FP8 matmul with fused activation quantization.
|
| 358 |
+
|
| 359 |
+
A: (S, K) raw activations, bf16/fp16/fp32
|
| 360 |
+
B: (E, N, K) FP8 expert weights
|
| 361 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 362 |
+
"""
|
| 363 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_batched(
|
| 364 |
+
A, B, Bs, expert_ids
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def w8a8_fp8_matmul_batched(
|
| 369 |
+
A: torch.Tensor,
|
| 370 |
+
B: torch.Tensor,
|
| 371 |
+
Bs: torch.Tensor,
|
| 372 |
+
expert_ids: torch.Tensor,
|
| 373 |
+
block_size: list[int] | None,
|
| 374 |
+
) -> torch.Tensor:
|
| 375 |
+
"""Unified batched W8A8 FP8 matmul dispatcher.
|
| 376 |
+
|
| 377 |
+
Dispatch rules:
|
| 378 |
+
- tensor mode when ``block_size is None``
|
| 379 |
+
- tensor mode when ``block_size == [N, K]``
|
| 380 |
+
- otherwise block mode
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
Output tensor ``[S, N]`` in the same dtype as ``A``.
|
| 384 |
+
"""
|
| 385 |
+
if block_size is None or (
|
| 386 |
+
block_size[0] == B.size(1) and block_size[1] == B.size(2)
|
| 387 |
+
):
|
| 388 |
+
return w8a8_tensor_fp8_matmul_batched(A, B, Bs, expert_ids)
|
| 389 |
+
|
| 390 |
+
return w8a8_block_fp8_matmul_batched(A, B, Bs, expert_ids, block_size)
|
build/torch-cuda/finegrained_fp8/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import importlib.util
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from types import ModuleType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-cuda/grouped.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from .act_quant import fp8_act_quant
|
| 19 |
+
from torch.library import triton_op, wrap_triton
|
| 20 |
+
|
| 21 |
+
from .utils import device_context
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4, 5]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_grouped_kernel(
|
| 34 |
+
A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
|
| 35 |
+
B, # (E, N, K) FP8 weight matrices
|
| 36 |
+
C, # (S, N) output
|
| 37 |
+
Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
|
| 38 |
+
Offsets, # (E,) int32 — cumulative row-end per expert
|
| 39 |
+
TileOffsets, # (E,) int32 — cumulative tile-end per expert
|
| 40 |
+
# Shape
|
| 41 |
+
S,
|
| 42 |
+
N,
|
| 43 |
+
K,
|
| 44 |
+
# Strides
|
| 45 |
+
stride_am,
|
| 46 |
+
stride_ak,
|
| 47 |
+
stride_be,
|
| 48 |
+
stride_bk,
|
| 49 |
+
stride_bn,
|
| 50 |
+
stride_cm,
|
| 51 |
+
stride_cn,
|
| 52 |
+
stride_bs_e,
|
| 53 |
+
stride_bs_k,
|
| 54 |
+
stride_bs_n,
|
| 55 |
+
# Meta-parameters
|
| 56 |
+
NUM_EXPERTS: tl.constexpr,
|
| 57 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 58 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 59 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 60 |
+
NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
|
| 61 |
+
):
|
| 62 |
+
"""Block-scale grouped FP8 expert matmul kernel.
|
| 63 |
+
|
| 64 |
+
Tokens are assumed sorted by expert. The kernel maps each M-tile to its
|
| 65 |
+
owning expert via ``TileOffsets`` and applies fused activation quantization.
|
| 66 |
+
"""
|
| 67 |
+
pid_m = tl.program_id(axis=0)
|
| 68 |
+
pid_n = tl.program_id(axis=1)
|
| 69 |
+
|
| 70 |
+
# Exit early for programs beyond the actual tile count.
|
| 71 |
+
total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
|
| 72 |
+
if pid_m >= total_tiles:
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
# Binary search in TileOffsets to find the owning expert.
|
| 76 |
+
# Finds the smallest e such that TileOffsets[e] > pid_m (upper_bound semantics),
|
| 77 |
+
# which is the expert whose tile range contains pid_m.
|
| 78 |
+
# O(log2(NUM_EXPERTS)) loads instead of the O(NUM_EXPERTS) linear scan.
|
| 79 |
+
# NUM_EXPERTS_BIT_LENGTH is ceil(log2(E))+1 for powers-of-two, giving one
|
| 80 |
+
# harmless extra iteration when lo==hi; it's a compile-time constant so the
|
| 81 |
+
# loop is fully unrolled by the compiler.
|
| 82 |
+
lo = 0
|
| 83 |
+
hi = NUM_EXPERTS
|
| 84 |
+
for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
|
| 85 |
+
mid = (lo + hi) >> 1
|
| 86 |
+
mid_val = tl.load(TileOffsets + mid)
|
| 87 |
+
is_left = mid_val <= pid_m
|
| 88 |
+
lo = tl.where(is_left, mid + 1, lo)
|
| 89 |
+
hi = tl.where(is_left, hi, mid)
|
| 90 |
+
|
| 91 |
+
# Cast expert_id to int64 to prevent int32 overflow when computing
|
| 92 |
+
# expert_id * stride_be (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
|
| 93 |
+
# 3072×3072 FP8 weights).
|
| 94 |
+
expert_id = lo.to(tl.int64)
|
| 95 |
+
|
| 96 |
+
prev_eid = tl.maximum(expert_id - 1, 0)
|
| 97 |
+
expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
|
| 98 |
+
expert_end = tl.load(Offsets + expert_id)
|
| 99 |
+
M_expert = expert_end - expert_start
|
| 100 |
+
|
| 101 |
+
expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
|
| 102 |
+
local_tile = pid_m - expert_tile_start
|
| 103 |
+
m_off = local_tile * BLOCK_SIZE_M
|
| 104 |
+
|
| 105 |
+
offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
|
| 106 |
+
row_mask = offs_am < M_expert
|
| 107 |
+
offs_global_m = expert_start + offs_am
|
| 108 |
+
|
| 109 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 110 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 111 |
+
|
| 112 |
+
a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 113 |
+
b_ptrs = (
|
| 114 |
+
B
|
| 115 |
+
+ expert_id * stride_be
|
| 116 |
+
+ offs_k[:, None] * stride_bk
|
| 117 |
+
+ offs_bn[None, :] * stride_bn
|
| 118 |
+
)
|
| 119 |
+
bs_ptrs = Bs + expert_id * stride_bs_e + pid_n * stride_bs_n
|
| 120 |
+
|
| 121 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 122 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 123 |
+
# ---- fused fp8_act_quant ----
|
| 124 |
+
a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32)
|
| 125 |
+
a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0
|
| 126 |
+
a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
|
| 127 |
+
# ---- matmul ----
|
| 128 |
+
b = tl.load(b_ptrs)
|
| 129 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 130 |
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
| 131 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 132 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 133 |
+
|
| 134 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 135 |
+
c = accumulator.to(tl.bfloat16)
|
| 136 |
+
elif C.dtype.element_ty == tl.float16:
|
| 137 |
+
c = accumulator.to(tl.float16)
|
| 138 |
+
else:
|
| 139 |
+
c = accumulator.to(tl.float32)
|
| 140 |
+
|
| 141 |
+
c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
|
| 142 |
+
c_mask = row_mask[:, None]
|
| 143 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@triton.autotune(
|
| 147 |
+
configs=[
|
| 148 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 149 |
+
for w in [2, 4, 8, 16]
|
| 150 |
+
for s in [2, 3, 4, 5]
|
| 151 |
+
],
|
| 152 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 153 |
+
)
|
| 154 |
+
@triton.jit
|
| 155 |
+
def w8a8_tensor_fp8_matmul_grouped_kernel(
|
| 156 |
+
A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert idc
|
| 157 |
+
B, # (E, N, K) FP8 weight matrices
|
| 158 |
+
C, # (S, N) output
|
| 159 |
+
As, # (S, 1) activation scales
|
| 160 |
+
Bs, # (E, 1, 1) per-tensor weight scales
|
| 161 |
+
Offsets,
|
| 162 |
+
TileOffsets,
|
| 163 |
+
S,
|
| 164 |
+
N,
|
| 165 |
+
K,
|
| 166 |
+
stride_am,
|
| 167 |
+
stride_ak,
|
| 168 |
+
stride_be,
|
| 169 |
+
stride_bk,
|
| 170 |
+
stride_bn,
|
| 171 |
+
stride_cm,
|
| 172 |
+
stride_cn,
|
| 173 |
+
stride_as_m,
|
| 174 |
+
stride_bs_e,
|
| 175 |
+
NUM_EXPERTS: tl.constexpr,
|
| 176 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 177 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 178 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 179 |
+
NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
|
| 180 |
+
):
|
| 181 |
+
"""Tensor-scale grouped FP8 expert matmul kernel.
|
| 182 |
+
|
| 183 |
+
Uses grouped expert scheduling with pre-quantized activations plus
|
| 184 |
+
per-token activation scales and per-expert tensor weight scales.
|
| 185 |
+
"""
|
| 186 |
+
pid_m = tl.program_id(axis=0)
|
| 187 |
+
pid_n = tl.program_id(axis=1)
|
| 188 |
+
|
| 189 |
+
total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
|
| 190 |
+
if pid_m >= total_tiles:
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
lo = 0
|
| 194 |
+
hi = NUM_EXPERTS
|
| 195 |
+
for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
|
| 196 |
+
mid = (lo + hi) >> 1
|
| 197 |
+
mid_val = tl.load(TileOffsets + mid)
|
| 198 |
+
is_left = mid_val <= pid_m
|
| 199 |
+
lo = tl.where(is_left, mid + 1, lo)
|
| 200 |
+
hi = tl.where(is_left, hi, mid)
|
| 201 |
+
expert_id = lo.to(tl.int64)
|
| 202 |
+
|
| 203 |
+
prev_eid = tl.maximum(expert_id - 1, 0)
|
| 204 |
+
expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
|
| 205 |
+
expert_end = tl.load(Offsets + expert_id)
|
| 206 |
+
M_expert = expert_end - expert_start
|
| 207 |
+
|
| 208 |
+
expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
|
| 209 |
+
local_tile = pid_m - expert_tile_start
|
| 210 |
+
m_off = local_tile * BLOCK_SIZE_M
|
| 211 |
+
|
| 212 |
+
offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
|
| 213 |
+
row_mask = offs_am < M_expert
|
| 214 |
+
offs_global_m = expert_start + offs_am
|
| 215 |
+
|
| 216 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 217 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 218 |
+
|
| 219 |
+
a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 220 |
+
b_ptrs = (
|
| 221 |
+
B
|
| 222 |
+
+ expert_id * stride_be
|
| 223 |
+
+ offs_k[:, None] * stride_bk
|
| 224 |
+
+ offs_bn[None, :] * stride_bn
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
a_s = tl.load(As + offs_global_m * stride_as_m, mask=row_mask, other=0.0)
|
| 228 |
+
b_s = tl.load(Bs + expert_id * stride_bs_e)
|
| 229 |
+
|
| 230 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 231 |
+
for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 232 |
+
a = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0)
|
| 233 |
+
b = tl.load(b_ptrs)
|
| 234 |
+
|
| 235 |
+
accumulator += tl.dot(a, b)
|
| 236 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 237 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 238 |
+
|
| 239 |
+
accumulator = accumulator * a_s[:, None] * b_s
|
| 240 |
+
|
| 241 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 242 |
+
c = accumulator.to(tl.bfloat16)
|
| 243 |
+
elif C.dtype.element_ty == tl.float16:
|
| 244 |
+
c = accumulator.to(tl.float16)
|
| 245 |
+
else:
|
| 246 |
+
c = accumulator.to(tl.float32)
|
| 247 |
+
|
| 248 |
+
c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
|
| 249 |
+
c_mask = row_mask[:, None]
|
| 250 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul_grouped", mutates_args=())
|
| 254 |
+
def _w8a8_block_fp8_matmul_grouped(
|
| 255 |
+
A: torch.Tensor,
|
| 256 |
+
B: torch.Tensor,
|
| 257 |
+
Bs: torch.Tensor,
|
| 258 |
+
offsets: torch.Tensor,
|
| 259 |
+
tokens_per_expert: torch.Tensor,
|
| 260 |
+
block_size: list[int],
|
| 261 |
+
) -> torch.Tensor:
|
| 262 |
+
"""Block-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
|
| 263 |
+
|
| 264 |
+
A: (S, K) raw bf16/fp16 activations, sorted by expert
|
| 265 |
+
B: (E, N, K) FP8 expert weights
|
| 266 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 267 |
+
"""
|
| 268 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 269 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 270 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 271 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 272 |
+
assert A.shape[1] == B.shape[2], (
|
| 273 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
S, K = A.shape
|
| 277 |
+
E, N, _ = B.shape
|
| 278 |
+
|
| 279 |
+
assert len(block_size) == 2, (
|
| 280 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 281 |
+
)
|
| 282 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 283 |
+
# MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
|
| 284 |
+
assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
|
| 285 |
+
assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
|
| 286 |
+
assert Bs.ndim == 3, (
|
| 287 |
+
f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 288 |
+
)
|
| 289 |
+
assert Bs.shape == (E, N // block_n, K // block_k), (
|
| 290 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
C = A.new_empty(S, N)
|
| 294 |
+
# Adaptive BLOCK_SIZE_M: match tile to average tokens per expert.
|
| 295 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 296 |
+
tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
| 297 |
+
tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
|
| 298 |
+
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
|
| 299 |
+
# Programs beyond the real tile count exit immediately via the early-return
|
| 300 |
+
# guard inside the kernel. This is faster than syncing for the exact count
|
| 301 |
+
# and keeps the grid size data-independent (cuda-graph / torch.compile safe).
|
| 302 |
+
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
|
| 303 |
+
grid = (max_M_tiles, triton.cdiv(N, block_n))
|
| 304 |
+
with device_context(A.device):
|
| 305 |
+
wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
|
| 306 |
+
A,
|
| 307 |
+
B,
|
| 308 |
+
C,
|
| 309 |
+
Bs,
|
| 310 |
+
offsets,
|
| 311 |
+
tile_offsets,
|
| 312 |
+
S,
|
| 313 |
+
N,
|
| 314 |
+
K,
|
| 315 |
+
A.stride(0),
|
| 316 |
+
A.stride(1),
|
| 317 |
+
B.stride(0),
|
| 318 |
+
B.stride(2),
|
| 319 |
+
B.stride(1),
|
| 320 |
+
C.stride(0),
|
| 321 |
+
C.stride(1),
|
| 322 |
+
Bs.stride(0),
|
| 323 |
+
Bs.stride(2),
|
| 324 |
+
Bs.stride(1),
|
| 325 |
+
# Meta-parameters
|
| 326 |
+
NUM_EXPERTS=E,
|
| 327 |
+
BLOCK_SIZE_N=block_n,
|
| 328 |
+
BLOCK_SIZE_K=block_k,
|
| 329 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 330 |
+
NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return C
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_grouped", mutates_args=())
|
| 337 |
+
def _w8a8_tensor_fp8_matmul_grouped(
|
| 338 |
+
A: torch.Tensor,
|
| 339 |
+
B: torch.Tensor,
|
| 340 |
+
Bs: torch.Tensor,
|
| 341 |
+
offsets: torch.Tensor,
|
| 342 |
+
tokens_per_expert: torch.Tensor,
|
| 343 |
+
) -> torch.Tensor:
|
| 344 |
+
"""Tensor-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
|
| 345 |
+
|
| 346 |
+
A: (S, K) raw bf16/fp16 activations, sorted by expert
|
| 347 |
+
B: (E, N, K) FP8 expert weights
|
| 348 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 349 |
+
"""
|
| 350 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 351 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 352 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 353 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 354 |
+
assert A.shape[1] == B.shape[2], (
|
| 355 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
S, K = A.shape
|
| 359 |
+
E, N, _ = B.shape
|
| 360 |
+
|
| 361 |
+
# Normalize Bs to (E, 1, 1)
|
| 362 |
+
if Bs.ndim == 1:
|
| 363 |
+
assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
|
| 364 |
+
Bs = Bs.reshape(E, 1, 1)
|
| 365 |
+
else:
|
| 366 |
+
assert Bs.shape == (E, 1, 1), (
|
| 367 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
BLOCK_SIZE_N = 128
|
| 371 |
+
BLOCK_SIZE_K = 128
|
| 372 |
+
C = A.new_empty(S, N)
|
| 373 |
+
qA, As = fp8_act_quant(A, K)
|
| 374 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 375 |
+
tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
| 376 |
+
tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
|
| 377 |
+
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
|
| 378 |
+
# Programs beyond the real tile count exit immediately via the early-return
|
| 379 |
+
# guard inside the kernel. This is faster than syncing for the exact count
|
| 380 |
+
# and keeps the grid size data-independent (cuda-graph / torch.compile safe).
|
| 381 |
+
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
|
| 382 |
+
grid = (max_M_tiles, triton.cdiv(N, BLOCK_SIZE_N))
|
| 383 |
+
with device_context(A.device):
|
| 384 |
+
wrap_triton(w8a8_tensor_fp8_matmul_grouped_kernel)[grid](
|
| 385 |
+
qA,
|
| 386 |
+
B,
|
| 387 |
+
C,
|
| 388 |
+
As,
|
| 389 |
+
Bs,
|
| 390 |
+
offsets,
|
| 391 |
+
tile_offsets,
|
| 392 |
+
S,
|
| 393 |
+
N,
|
| 394 |
+
K,
|
| 395 |
+
qA.stride(0),
|
| 396 |
+
qA.stride(1),
|
| 397 |
+
B.stride(0),
|
| 398 |
+
B.stride(2),
|
| 399 |
+
B.stride(1),
|
| 400 |
+
C.stride(0),
|
| 401 |
+
C.stride(1),
|
| 402 |
+
As.stride(0),
|
| 403 |
+
Bs.stride(0),
|
| 404 |
+
# Meta-parameters
|
| 405 |
+
NUM_EXPERTS=E,
|
| 406 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 407 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 408 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 409 |
+
NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
return C
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def w8a8_block_fp8_matmul_grouped(
|
| 416 |
+
A: torch.Tensor,
|
| 417 |
+
B: torch.Tensor,
|
| 418 |
+
Bs: torch.Tensor,
|
| 419 |
+
offsets: torch.Tensor,
|
| 420 |
+
tokens_per_expert: torch.Tensor,
|
| 421 |
+
block_size: list[int],
|
| 422 |
+
) -> torch.Tensor:
|
| 423 |
+
"""Block-scale grouped FP8 matmul with fused activation quantization.
|
| 424 |
+
|
| 425 |
+
A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
|
| 426 |
+
B: (E, N, K) FP8 expert weights
|
| 427 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 428 |
+
"""
|
| 429 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
|
| 430 |
+
A, B, Bs, offsets, tokens_per_expert, block_size
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def w8a8_tensor_fp8_matmul_grouped(
|
| 435 |
+
A: torch.Tensor,
|
| 436 |
+
B: torch.Tensor,
|
| 437 |
+
Bs: torch.Tensor,
|
| 438 |
+
offsets: torch.Tensor,
|
| 439 |
+
tokens_per_expert: torch.Tensor,
|
| 440 |
+
) -> torch.Tensor:
|
| 441 |
+
"""Tensor-scale grouped FP8 matmul with fused activation quantization.
|
| 442 |
+
|
| 443 |
+
A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
|
| 444 |
+
B: (E, N, K) FP8 expert weights
|
| 445 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 446 |
+
"""
|
| 447 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
|
| 448 |
+
A, B, Bs, offsets, tokens_per_expert
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def w8a8_fp8_matmul_grouped(
|
| 453 |
+
A: torch.Tensor,
|
| 454 |
+
B: torch.Tensor,
|
| 455 |
+
Bs: torch.Tensor,
|
| 456 |
+
offsets: torch.Tensor,
|
| 457 |
+
tokens_per_expert: torch.Tensor,
|
| 458 |
+
block_size: list[int] | None,
|
| 459 |
+
) -> torch.Tensor:
|
| 460 |
+
"""Unified grouped W8A8 FP8 matmul dispatcher.
|
| 461 |
+
|
| 462 |
+
Dispatch rules:
|
| 463 |
+
- tensor mode when ``block_size is None``
|
| 464 |
+
- tensor mode when ``block_size == [N, K]``
|
| 465 |
+
- otherwise block mode
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
|
| 469 |
+
"""
|
| 470 |
+
if block_size is None or (
|
| 471 |
+
block_size[0] == B.size(1) and block_size[1] == B.size(2)
|
| 472 |
+
):
|
| 473 |
+
return w8a8_tensor_fp8_matmul_grouped(A, B, Bs, offsets, tokens_per_expert)
|
| 474 |
+
|
| 475 |
+
return w8a8_block_fp8_matmul_grouped(
|
| 476 |
+
A, B, Bs, offsets, tokens_per_expert, block_size
|
| 477 |
+
)
|
build/torch-cuda/matmul.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch.library import triton_op, wrap_triton
|
| 19 |
+
|
| 20 |
+
from .utils import device_context
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_kernel(
|
| 34 |
+
# Pointers to inputs and output
|
| 35 |
+
A,
|
| 36 |
+
B,
|
| 37 |
+
C,
|
| 38 |
+
As,
|
| 39 |
+
Bs,
|
| 40 |
+
# Shape for matmul
|
| 41 |
+
M,
|
| 42 |
+
N,
|
| 43 |
+
K,
|
| 44 |
+
stride_am,
|
| 45 |
+
stride_ak,
|
| 46 |
+
stride_bk,
|
| 47 |
+
stride_bn,
|
| 48 |
+
stride_cm,
|
| 49 |
+
stride_cn,
|
| 50 |
+
stride_as_m,
|
| 51 |
+
stride_as_k,
|
| 52 |
+
stride_bs_k,
|
| 53 |
+
stride_bs_n,
|
| 54 |
+
# Meta-parameters
|
| 55 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 56 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 57 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 58 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 59 |
+
):
|
| 60 |
+
"""Block-scale FP8 GEMM kernel.
|
| 61 |
+
|
| 62 |
+
Computes ``C = A @ B.T`` with block-wise activation/weight scales.
|
| 63 |
+
Uses a 2D grid with swizzle for L2 cache locality on B tiles.
|
| 64 |
+
"""
|
| 65 |
+
pid_m = tl.program_id(axis=0)
|
| 66 |
+
pid_n = tl.program_id(axis=1)
|
| 67 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 68 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 69 |
+
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
| 70 |
+
|
| 71 |
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
| 72 |
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 73 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 74 |
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
| 75 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 76 |
+
|
| 77 |
+
as_ptrs = As + offs_am * stride_as_m
|
| 78 |
+
offs_bsn = offs_bn // BLOCK_SIZE_N
|
| 79 |
+
bs_ptrs = Bs + offs_bsn * stride_bs_n
|
| 80 |
+
|
| 81 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 82 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 83 |
+
k_remaining = K - k * BLOCK_SIZE_K
|
| 84 |
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
|
| 85 |
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
| 86 |
+
|
| 87 |
+
a_s = tl.load(as_ptrs + k * stride_as_k)
|
| 88 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 89 |
+
|
| 90 |
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
| 91 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 92 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 93 |
+
|
| 94 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 95 |
+
c = accumulator.to(tl.bfloat16)
|
| 96 |
+
elif C.dtype.element_ty == tl.float16:
|
| 97 |
+
c = accumulator.to(tl.float16)
|
| 98 |
+
else:
|
| 99 |
+
c = accumulator.to(tl.float32)
|
| 100 |
+
|
| 101 |
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 102 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 103 |
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
| 104 |
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
| 105 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@triton.autotune(
|
| 109 |
+
configs=[
|
| 110 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 111 |
+
for w in [2, 4, 8, 16]
|
| 112 |
+
for s in [2, 3, 4]
|
| 113 |
+
],
|
| 114 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 115 |
+
)
|
| 116 |
+
@triton.jit
|
| 117 |
+
def w8a8_tensor_fp8_matmul_kernel(
|
| 118 |
+
A,
|
| 119 |
+
B,
|
| 120 |
+
C,
|
| 121 |
+
As,
|
| 122 |
+
Bs,
|
| 123 |
+
M,
|
| 124 |
+
N,
|
| 125 |
+
K,
|
| 126 |
+
stride_am,
|
| 127 |
+
stride_ak,
|
| 128 |
+
stride_bk,
|
| 129 |
+
stride_bn,
|
| 130 |
+
stride_cm,
|
| 131 |
+
stride_cn,
|
| 132 |
+
stride_as_m,
|
| 133 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 134 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 135 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 136 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 137 |
+
):
|
| 138 |
+
"""Tensor-scale FP8 GEMM kernel.
|
| 139 |
+
|
| 140 |
+
Computes ``C = A @ B.T`` with one activation scale per row and one
|
| 141 |
+
weight scale for the full matrix.
|
| 142 |
+
Uses a 2D grid with swizzle for L2 cache locality on B tiles.
|
| 143 |
+
"""
|
| 144 |
+
pid_m = tl.program_id(axis=0)
|
| 145 |
+
pid_n = tl.program_id(axis=1)
|
| 146 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 147 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 148 |
+
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
| 149 |
+
|
| 150 |
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
| 151 |
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 152 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 153 |
+
|
| 154 |
+
a_ptrs = A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 155 |
+
b_ptrs = B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
| 156 |
+
|
| 157 |
+
a_s = tl.load(As + offs_am * stride_as_m)
|
| 158 |
+
b_s = tl.load(Bs)
|
| 159 |
+
|
| 160 |
+
# Accumulate raw dot products, apply scales once after the loop.
|
| 161 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 162 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 163 |
+
k_remaining = K - k * BLOCK_SIZE_K
|
| 164 |
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
|
| 165 |
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
| 166 |
+
accumulator += tl.dot(a, b)
|
| 167 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 168 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 169 |
+
|
| 170 |
+
accumulator = accumulator * a_s[:, None] * b_s
|
| 171 |
+
|
| 172 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 173 |
+
c = accumulator.to(tl.bfloat16)
|
| 174 |
+
elif C.dtype.element_ty == tl.float16:
|
| 175 |
+
c = accumulator.to(tl.float16)
|
| 176 |
+
else:
|
| 177 |
+
c = accumulator.to(tl.float32)
|
| 178 |
+
|
| 179 |
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 180 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 181 |
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
| 182 |
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
| 183 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul", mutates_args=())
|
| 187 |
+
def _w8a8_block_fp8_matmul(
|
| 188 |
+
A: torch.Tensor,
|
| 189 |
+
B: torch.Tensor,
|
| 190 |
+
As: torch.Tensor,
|
| 191 |
+
Bs: torch.Tensor,
|
| 192 |
+
block_size: list[int],
|
| 193 |
+
output_dtype: torch.dtype = torch.float32,
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""Block-scale FP8 matmul: C = A @ B.T with per-block scales.
|
| 196 |
+
|
| 197 |
+
As: (M, K // block_k) — per-token-group activation scales
|
| 198 |
+
Bs: (N // block_n, K // block_k) — per-block weight scales
|
| 199 |
+
"""
|
| 200 |
+
assert len(block_size) == 2, (
|
| 201 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 202 |
+
)
|
| 203 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 204 |
+
|
| 205 |
+
assert A.shape[-1] == B.shape[-1], (
|
| 206 |
+
f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
|
| 207 |
+
)
|
| 208 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 209 |
+
assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
|
| 210 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 211 |
+
|
| 212 |
+
N, K = B.shape
|
| 213 |
+
M = A.numel() // A.shape[-1]
|
| 214 |
+
|
| 215 |
+
assert As.ndim >= 2, f"As must be at least 2D, got ndim={As.ndim}"
|
| 216 |
+
assert As.shape[-1] == triton.cdiv(K, block_k), (
|
| 217 |
+
f"As last dim {As.shape[-1]} != expected {triton.cdiv(K, block_k)} (cdiv(K={K}, block_k={block_k}))"
|
| 218 |
+
)
|
| 219 |
+
assert Bs.ndim == 2, f"Bs must be 2D (N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 220 |
+
assert Bs.shape == (triton.cdiv(N, block_n), triton.cdiv(K, block_k)), (
|
| 221 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({triton.cdiv(N, block_n)}, {triton.cdiv(K, block_k)})"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
BLOCK_SIZE_K = block_k
|
| 225 |
+
BLOCK_SIZE_N = block_n
|
| 226 |
+
C_shape = A.shape[:-1] + (N,)
|
| 227 |
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
| 228 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 229 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 230 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 231 |
+
# both accuracy and performance for small M (decode).
|
| 232 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
|
| 233 |
+
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
|
| 234 |
+
with device_context(A.device):
|
| 235 |
+
wrap_triton(w8a8_block_fp8_matmul_kernel)[grid](
|
| 236 |
+
A,
|
| 237 |
+
B,
|
| 238 |
+
C,
|
| 239 |
+
As,
|
| 240 |
+
Bs,
|
| 241 |
+
M,
|
| 242 |
+
N,
|
| 243 |
+
K,
|
| 244 |
+
A.stride(-2),
|
| 245 |
+
A.stride(-1),
|
| 246 |
+
B.stride(1),
|
| 247 |
+
B.stride(0),
|
| 248 |
+
C.stride(-2),
|
| 249 |
+
C.stride(-1),
|
| 250 |
+
As.stride(-2),
|
| 251 |
+
As.stride(-1),
|
| 252 |
+
Bs.stride(1),
|
| 253 |
+
Bs.stride(0),
|
| 254 |
+
# Meta-parameters
|
| 255 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 256 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 257 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 258 |
+
GROUP_SIZE_M=8,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return C
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul", mutates_args=())
|
| 265 |
+
def _w8a8_tensor_fp8_matmul(
|
| 266 |
+
A: torch.Tensor,
|
| 267 |
+
B: torch.Tensor,
|
| 268 |
+
As: torch.Tensor,
|
| 269 |
+
Bs: torch.Tensor,
|
| 270 |
+
output_dtype: torch.dtype = torch.float32,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
"""Tensor-scale FP8 matmul: C = A @ B.T with per-row / per-tensor scales.
|
| 273 |
+
|
| 274 |
+
As: scalar, (M,), or (M, 1) — per-row activation scales
|
| 275 |
+
Bs: scalar, (1,), or (1, 1) — single weight scale
|
| 276 |
+
"""
|
| 277 |
+
assert A.shape[-1] == B.shape[-1], (
|
| 278 |
+
f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
|
| 279 |
+
)
|
| 280 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 281 |
+
assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
|
| 282 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 283 |
+
|
| 284 |
+
N, K = B.shape
|
| 285 |
+
M = A.numel() // A.shape[-1]
|
| 286 |
+
|
| 287 |
+
# Normalize As to (M,)
|
| 288 |
+
if As.numel() == 1:
|
| 289 |
+
As = As.reshape(1).expand(M).contiguous()
|
| 290 |
+
elif As.ndim == 2:
|
| 291 |
+
As = As.reshape(M)
|
| 292 |
+
assert As.ndim == 1 and As.shape[0] == M, (
|
| 293 |
+
f"As must be scalar, (M,), or (M,1) with M={M}, got {tuple(As.shape)}"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Normalize Bs to (1,)
|
| 297 |
+
assert Bs.numel() == 1, f"Bs must be scalar or (1,), got {tuple(Bs.shape)}"
|
| 298 |
+
Bs = Bs.reshape(1)
|
| 299 |
+
|
| 300 |
+
BLOCK_SIZE_N = 128
|
| 301 |
+
BLOCK_SIZE_K = 128
|
| 302 |
+
C_shape = A.shape[:-1] + (N,)
|
| 303 |
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
| 304 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
|
| 305 |
+
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
|
| 306 |
+
with device_context(A.device):
|
| 307 |
+
wrap_triton(w8a8_tensor_fp8_matmul_kernel)[grid](
|
| 308 |
+
A,
|
| 309 |
+
B,
|
| 310 |
+
C,
|
| 311 |
+
As,
|
| 312 |
+
Bs,
|
| 313 |
+
M,
|
| 314 |
+
N,
|
| 315 |
+
K,
|
| 316 |
+
A.stride(-2),
|
| 317 |
+
A.stride(-1),
|
| 318 |
+
B.stride(1),
|
| 319 |
+
B.stride(0),
|
| 320 |
+
C.stride(-2),
|
| 321 |
+
C.stride(-1),
|
| 322 |
+
As.stride(0),
|
| 323 |
+
# Meta-parameters
|
| 324 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 325 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 326 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 327 |
+
GROUP_SIZE_M=8,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return C
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def w8a8_block_fp8_matmul(
|
| 334 |
+
A: torch.Tensor,
|
| 335 |
+
B: torch.Tensor,
|
| 336 |
+
As: torch.Tensor,
|
| 337 |
+
Bs: torch.Tensor,
|
| 338 |
+
block_size: list[int],
|
| 339 |
+
output_dtype: torch.dtype = torch.float32,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
"""Block-wise W8A8 FP8 matrix multiplication.
|
| 342 |
+
|
| 343 |
+
Computes ``C = A @ B.T`` where both operands are pre-quantized to
|
| 344 |
+
``float8_e4m3fn`` with per-block scales, and accumulates in float32
|
| 345 |
+
before casting to ``output_dtype``.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
|
| 349 |
+
B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
|
| 350 |
+
As: Per-token-group activation scales ``[M, K // block_size[1]]``.
|
| 351 |
+
Bs: Per-block weight scales ``[N // block_size[0], K // block_size[1]]``.
|
| 352 |
+
block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
|
| 353 |
+
output_dtype: dtype of the returned tensor (default: ``torch.float32``).
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 357 |
+
"""
|
| 358 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul(
|
| 359 |
+
A, B, As, Bs, block_size, output_dtype
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def w8a8_tensor_fp8_matmul(
|
| 364 |
+
A: torch.Tensor,
|
| 365 |
+
B: torch.Tensor,
|
| 366 |
+
As: torch.Tensor,
|
| 367 |
+
Bs: torch.Tensor,
|
| 368 |
+
output_dtype: torch.dtype = torch.float32,
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
"""Tensor-scale W8A8 FP8 matrix multiplication.
|
| 371 |
+
|
| 372 |
+
Computes ``C = A @ B.T`` in tensor-scale mode using pre-quantized FP8
|
| 373 |
+
activations/weights and tensor scales.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
|
| 377 |
+
B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
|
| 378 |
+
As: Per-row activation scales ``[M]``.
|
| 379 |
+
Bs: Single weight scale, scalar or ``[1]``.
|
| 380 |
+
output_dtype: dtype of the returned tensor.
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 384 |
+
"""
|
| 385 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def w8a8_fp8_matmul(
|
| 389 |
+
A: torch.Tensor,
|
| 390 |
+
B: torch.Tensor,
|
| 391 |
+
As: torch.Tensor,
|
| 392 |
+
Bs: torch.Tensor,
|
| 393 |
+
block_size: list[int] | None,
|
| 394 |
+
output_dtype: torch.dtype = torch.float32,
|
| 395 |
+
) -> torch.Tensor:
|
| 396 |
+
"""Unified W8A8 FP8 matmul dispatcher.
|
| 397 |
+
|
| 398 |
+
Dispatch rules:
|
| 399 |
+
- tensor mode when ``block_size is None``
|
| 400 |
+
- tensor mode when ``block_size == [N, K]``
|
| 401 |
+
- otherwise block mode
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 405 |
+
"""
|
| 406 |
+
if block_size is None or (
|
| 407 |
+
block_size[0] == B.size(0) and block_size[1] == B.size(1)
|
| 408 |
+
):
|
| 409 |
+
return w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
|
| 410 |
+
|
| 411 |
+
return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
|
build/torch-cuda/metadata.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 1,
|
| 3 |
+
"license": "Apache-2.0",
|
| 4 |
+
"python-depends": [],
|
| 5 |
+
"backend": {
|
| 6 |
+
"type": "cuda"
|
| 7 |
+
}
|
| 8 |
+
}
|
build/torch-cuda/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@contextmanager
|
| 6 |
+
def device_context(device: torch.device):
|
| 7 |
+
"""Context manager that sets the active device for any backend (cuda, xpu, etc.)."""
|
| 8 |
+
backend = getattr(torch, device.type, None)
|
| 9 |
+
if backend is not None and hasattr(backend, "device"):
|
| 10 |
+
with backend.device(device):
|
| 11 |
+
yield
|
| 12 |
+
else:
|
| 13 |
+
yield
|
build/torch-rocm/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .act_quant import fp8_act_quant
|
| 2 |
+
from .batched import (
|
| 3 |
+
w8a8_fp8_matmul_batched,
|
| 4 |
+
w8a8_block_fp8_matmul_batched,
|
| 5 |
+
w8a8_tensor_fp8_matmul_batched,
|
| 6 |
+
)
|
| 7 |
+
from .grouped import (
|
| 8 |
+
w8a8_fp8_matmul_grouped,
|
| 9 |
+
w8a8_block_fp8_matmul_grouped,
|
| 10 |
+
w8a8_tensor_fp8_matmul_grouped,
|
| 11 |
+
)
|
| 12 |
+
from .matmul import (
|
| 13 |
+
w8a8_fp8_matmul,
|
| 14 |
+
w8a8_block_fp8_matmul,
|
| 15 |
+
w8a8_tensor_fp8_matmul,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"fp8_act_quant",
|
| 20 |
+
# Single matmul
|
| 21 |
+
"w8a8_fp8_matmul",
|
| 22 |
+
"w8a8_block_fp8_matmul",
|
| 23 |
+
"w8a8_tensor_fp8_matmul",
|
| 24 |
+
# Batched matmul
|
| 25 |
+
"w8a8_fp8_matmul_batched",
|
| 26 |
+
"w8a8_block_fp8_matmul_batched",
|
| 27 |
+
"w8a8_tensor_fp8_matmul_batched",
|
| 28 |
+
# Grouped matmul
|
| 29 |
+
"w8a8_fp8_matmul_grouped",
|
| 30 |
+
"w8a8_block_fp8_matmul_grouped",
|
| 31 |
+
"w8a8_tensor_fp8_matmul_grouped",
|
| 32 |
+
]
|
build/torch-rocm/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._finegrained_fp8_75cbe1b
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_finegrained_fp8_75cbe1b::{op_name}"
|
build/torch-rocm/act_quant.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch.library import triton_op, wrap_triton
|
| 19 |
+
|
| 20 |
+
from .utils import device_context
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_FP8_DTYPE = torch.float8_e4m3fn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
| 27 |
+
@triton.jit
|
| 28 |
+
def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
| 29 |
+
pid = tl.program_id(axis=0)
|
| 30 |
+
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
+
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
|
| 33 |
+
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
+
tl.store(y_ptr + offs, y)
|
| 35 |
+
tl.store(s_ptr + pid, s)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@triton_op("finegrained_fp8::fp8_act_quant", mutates_args=())
|
| 39 |
+
def _fp8_act_quant(
|
| 40 |
+
x: torch.Tensor, block_size: int = 128
|
| 41 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
assert x.is_contiguous()
|
| 43 |
+
assert x.shape[-1] % block_size == 0
|
| 44 |
+
y = torch.empty_like(x, dtype=_FP8_DTYPE)
|
| 45 |
+
grid = (triton.cdiv(x.numel(), block_size),)
|
| 46 |
+
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
| 47 |
+
|
| 48 |
+
with device_context(x.device):
|
| 49 |
+
wrap_triton(_fp8_act_quant_kernel)[grid](x, y, s, BLOCK_SIZE=block_size)
|
| 50 |
+
|
| 51 |
+
return y, s
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def fp8_act_quant(
|
| 55 |
+
x: torch.Tensor, block_size: int = 128
|
| 56 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 57 |
+
"""Quantize activations to FP8 with per-block dynamic scaling.
|
| 58 |
+
|
| 59 |
+
Splits the last dimension of ``x`` into blocks of ``block_size`` elements,
|
| 60 |
+
computes ``scale = max(|x_block|) / 448`` per block, and quantizes to
|
| 61 |
+
``float8_e4m3fn``.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
x: Input tensor in bf16/fp16/fp32. Last dimension must be divisible by
|
| 65 |
+
``block_size`` and the tensor must be contiguous.
|
| 66 |
+
block_size: Number of elements per quantization block (default: 128).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
A tuple ``(quantized, scales)`` where ``quantized`` has dtype
|
| 70 |
+
``float8_e4m3fn`` with the same shape as ``x``, and ``scales`` has
|
| 71 |
+
shape ``(*x.shape[:-1], x.shape[-1] // block_size)`` in float32.
|
| 72 |
+
"""
|
| 73 |
+
return torch.ops.finegrained_fp8.fp8_act_quant(x, block_size)
|
build/torch-rocm/batched.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from .act_quant import fp8_act_quant
|
| 19 |
+
from torch.library import triton_op, wrap_triton
|
| 20 |
+
|
| 21 |
+
from .utils import device_context
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4, 5]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_batched_kernel(
|
| 34 |
+
A, # (S, K) raw BF16/FP16 activations
|
| 35 |
+
B, # (E, N, K) FP8 weight matrices
|
| 36 |
+
C, # (S, N) output
|
| 37 |
+
Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
|
| 38 |
+
ExpertIds, # (S,) — which expert each batch element routes to
|
| 39 |
+
# Shape
|
| 40 |
+
S,
|
| 41 |
+
N,
|
| 42 |
+
K,
|
| 43 |
+
stride_am,
|
| 44 |
+
stride_ak,
|
| 45 |
+
stride_be,
|
| 46 |
+
stride_bk,
|
| 47 |
+
stride_bn,
|
| 48 |
+
stride_cm,
|
| 49 |
+
stride_cn,
|
| 50 |
+
stride_bs_e,
|
| 51 |
+
stride_bs_k,
|
| 52 |
+
stride_bs_n,
|
| 53 |
+
# Meta-parameters
|
| 54 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 55 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 56 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 57 |
+
):
|
| 58 |
+
"""Block-scale batched FP8 expert matmul kernel.
|
| 59 |
+
|
| 60 |
+
Each program handles one routed token row and one N-tile, looks up the
|
| 61 |
+
owning expert from ``ExpertIds``, and applies fused activation quantization.
|
| 62 |
+
"""
|
| 63 |
+
batch_id = tl.program_id(axis=0)
|
| 64 |
+
pid_n = tl.program_id(axis=1)
|
| 65 |
+
|
| 66 |
+
# Cast expert_id to int64 to prevent int32 overflow when computing
|
| 67 |
+
# expert_id * stride_Eb (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
|
| 68 |
+
# 3072×3072 FP8 weights).
|
| 69 |
+
expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
|
| 70 |
+
|
| 71 |
+
A = A + batch_id * stride_am
|
| 72 |
+
B = B + expert_id * stride_be
|
| 73 |
+
C = C + batch_id * stride_cm
|
| 74 |
+
Bs = Bs + expert_id * stride_bs_e
|
| 75 |
+
|
| 76 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 77 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 78 |
+
a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
|
| 79 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 80 |
+
|
| 81 |
+
bs_ptrs = Bs + pid_n * stride_bs_n
|
| 82 |
+
|
| 83 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 84 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 85 |
+
# ---- fused fp8_act_quant ----
|
| 86 |
+
a_raw = tl.load(a_ptrs).to(tl.float32)
|
| 87 |
+
a_s = tl.max(tl.abs(a_raw)) / 448.0
|
| 88 |
+
a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv)
|
| 89 |
+
# ---- matmul ----
|
| 90 |
+
b = tl.load(b_ptrs)
|
| 91 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 92 |
+
accumulator += tl.dot(a, b) * a_s * b_s[None, :]
|
| 93 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 94 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 95 |
+
|
| 96 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 97 |
+
c = accumulator.to(tl.bfloat16)
|
| 98 |
+
elif C.dtype.element_ty == tl.float16:
|
| 99 |
+
c = accumulator.to(tl.float16)
|
| 100 |
+
else:
|
| 101 |
+
c = accumulator.to(tl.float32)
|
| 102 |
+
|
| 103 |
+
offs_cm = tl.arange(0, BLOCK_SIZE_M)
|
| 104 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 105 |
+
c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
|
| 106 |
+
tl.store(c_ptrs, c)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@triton.autotune(
|
| 110 |
+
configs=[
|
| 111 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 112 |
+
for w in [2, 4, 8, 16]
|
| 113 |
+
for s in [2, 3, 4, 5]
|
| 114 |
+
],
|
| 115 |
+
key=["N", "K"],
|
| 116 |
+
)
|
| 117 |
+
@triton.jit
|
| 118 |
+
def w8a8_tensor_fp8_matmul_batched_kernel(
|
| 119 |
+
A, # (S, K) pre-quantized FP8 activations
|
| 120 |
+
B, # (E, N, K) FP8 weight matrices
|
| 121 |
+
C, # (S, N) output
|
| 122 |
+
As, # (S, 1) per-tensor activation scales
|
| 123 |
+
Bs, # (E, 1, 1) per-tensor weight scales
|
| 124 |
+
ExpertIds,
|
| 125 |
+
S,
|
| 126 |
+
N,
|
| 127 |
+
K,
|
| 128 |
+
stride_am,
|
| 129 |
+
stride_ak,
|
| 130 |
+
stride_be,
|
| 131 |
+
stride_bk,
|
| 132 |
+
stride_bn,
|
| 133 |
+
stride_cm,
|
| 134 |
+
stride_cn,
|
| 135 |
+
stride_as_m,
|
| 136 |
+
stride_bs_e,
|
| 137 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 138 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 139 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 140 |
+
):
|
| 141 |
+
"""Tensor-scale batched FP8 expert matmul kernel.
|
| 142 |
+
|
| 143 |
+
Activations are already quantized; the kernel applies per-token activation
|
| 144 |
+
scales and per-expert tensor weight scales.
|
| 145 |
+
"""
|
| 146 |
+
batch_id = tl.program_id(axis=0)
|
| 147 |
+
pid_n = tl.program_id(axis=1)
|
| 148 |
+
|
| 149 |
+
expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
|
| 150 |
+
|
| 151 |
+
A = A + batch_id * stride_am
|
| 152 |
+
B = B + expert_id * stride_be
|
| 153 |
+
C = C + batch_id * stride_cm
|
| 154 |
+
Bs = Bs + expert_id * stride_bs_e
|
| 155 |
+
|
| 156 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 157 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 158 |
+
a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
|
| 159 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 160 |
+
|
| 161 |
+
b_s = tl.load(Bs)
|
| 162 |
+
a_s = tl.load(As + batch_id * stride_as_m)
|
| 163 |
+
|
| 164 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 165 |
+
for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 166 |
+
a = tl.load(a_ptrs)
|
| 167 |
+
b = tl.load(b_ptrs)
|
| 168 |
+
accumulator += tl.dot(a, b)
|
| 169 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 170 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 171 |
+
|
| 172 |
+
accumulator = accumulator * a_s * b_s
|
| 173 |
+
|
| 174 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 175 |
+
c = accumulator.to(tl.bfloat16)
|
| 176 |
+
elif C.dtype.element_ty == tl.float16:
|
| 177 |
+
c = accumulator.to(tl.float16)
|
| 178 |
+
else:
|
| 179 |
+
c = accumulator.to(tl.float32)
|
| 180 |
+
|
| 181 |
+
offs_cm = tl.arange(0, BLOCK_SIZE_M)
|
| 182 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 183 |
+
c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
|
| 184 |
+
tl.store(c_ptrs, c)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul_batched", mutates_args=())
|
| 188 |
+
def _w8a8_block_fp8_matmul_batched(
|
| 189 |
+
A: torch.Tensor,
|
| 190 |
+
B: torch.Tensor,
|
| 191 |
+
Bs: torch.Tensor,
|
| 192 |
+
expert_ids: torch.Tensor,
|
| 193 |
+
block_size: list[int],
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""Block-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
|
| 196 |
+
|
| 197 |
+
A: (S, K) raw bf16/fp16 activations
|
| 198 |
+
B: (E, N, K) FP8 expert weights
|
| 199 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 200 |
+
"""
|
| 201 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 202 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 203 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 204 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 205 |
+
assert A.shape[1] == B.shape[2], (
|
| 206 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
S, K = A.shape
|
| 210 |
+
E, N, _ = B.shape
|
| 211 |
+
|
| 212 |
+
assert len(block_size) == 2, (
|
| 213 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 214 |
+
)
|
| 215 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 216 |
+
# MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
|
| 217 |
+
assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
|
| 218 |
+
assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
|
| 219 |
+
assert Bs.ndim == 3, (
|
| 220 |
+
f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 221 |
+
)
|
| 222 |
+
assert Bs.shape == (E, N // block_n, K // block_k), (
|
| 223 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
C = A.new_empty(S, N)
|
| 227 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 228 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 229 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 230 |
+
# both accuracy and performance for small M (decode).
|
| 231 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 232 |
+
grid = (S, triton.cdiv(N, block_n))
|
| 233 |
+
with device_context(A.device):
|
| 234 |
+
wrap_triton(w8a8_block_fp8_matmul_batched_kernel)[grid](
|
| 235 |
+
A,
|
| 236 |
+
B,
|
| 237 |
+
C,
|
| 238 |
+
Bs,
|
| 239 |
+
expert_ids,
|
| 240 |
+
S,
|
| 241 |
+
N,
|
| 242 |
+
K,
|
| 243 |
+
A.stride(0),
|
| 244 |
+
A.stride(1),
|
| 245 |
+
B.stride(0),
|
| 246 |
+
B.stride(2),
|
| 247 |
+
B.stride(1),
|
| 248 |
+
C.stride(0),
|
| 249 |
+
C.stride(1),
|
| 250 |
+
Bs.stride(0),
|
| 251 |
+
Bs.stride(2),
|
| 252 |
+
Bs.stride(1),
|
| 253 |
+
BLOCK_SIZE_N=block_n,
|
| 254 |
+
BLOCK_SIZE_K=block_k,
|
| 255 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
return C
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_batched", mutates_args=())
|
| 262 |
+
def _w8a8_tensor_fp8_matmul_batched(
|
| 263 |
+
A: torch.Tensor,
|
| 264 |
+
B: torch.Tensor,
|
| 265 |
+
Bs: torch.Tensor,
|
| 266 |
+
expert_ids: torch.Tensor,
|
| 267 |
+
) -> torch.Tensor:
|
| 268 |
+
"""Tensor-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
|
| 269 |
+
|
| 270 |
+
A: (S, K) raw bf16/fp16 activations
|
| 271 |
+
B: (E, N, K) FP8 expert weights
|
| 272 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 273 |
+
"""
|
| 274 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 275 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 276 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 277 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 278 |
+
assert A.shape[1] == B.shape[2], (
|
| 279 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
S, K = A.shape
|
| 283 |
+
E, N, _ = B.shape
|
| 284 |
+
|
| 285 |
+
# Normalize Bs to (E, 1, 1)
|
| 286 |
+
if Bs.ndim == 1:
|
| 287 |
+
assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
|
| 288 |
+
Bs = Bs.reshape(E, 1, 1)
|
| 289 |
+
else:
|
| 290 |
+
assert Bs.shape == (E, 1, 1), (
|
| 291 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
BLOCK_SIZE_N = 128
|
| 295 |
+
BLOCK_SIZE_K = 128
|
| 296 |
+
C = A.new_empty(S, N)
|
| 297 |
+
qA, As = fp8_act_quant(A, K)
|
| 298 |
+
grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
|
| 299 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 300 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 301 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 302 |
+
# both accuracy and performance for small M (decode).
|
| 303 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 304 |
+
grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
|
| 305 |
+
with device_context(A.device):
|
| 306 |
+
wrap_triton(w8a8_tensor_fp8_matmul_batched_kernel)[grid](
|
| 307 |
+
qA,
|
| 308 |
+
B,
|
| 309 |
+
C,
|
| 310 |
+
As,
|
| 311 |
+
Bs,
|
| 312 |
+
expert_ids,
|
| 313 |
+
S,
|
| 314 |
+
N,
|
| 315 |
+
K,
|
| 316 |
+
qA.stride(0),
|
| 317 |
+
qA.stride(1),
|
| 318 |
+
B.stride(0),
|
| 319 |
+
B.stride(2),
|
| 320 |
+
B.stride(1),
|
| 321 |
+
C.stride(0),
|
| 322 |
+
C.stride(1),
|
| 323 |
+
As.stride(0),
|
| 324 |
+
Bs.stride(0),
|
| 325 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 326 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 327 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return C
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def w8a8_block_fp8_matmul_batched(
|
| 334 |
+
A: torch.Tensor,
|
| 335 |
+
B: torch.Tensor,
|
| 336 |
+
Bs: torch.Tensor,
|
| 337 |
+
expert_ids: torch.Tensor,
|
| 338 |
+
block_size: list[int],
|
| 339 |
+
) -> torch.Tensor:
|
| 340 |
+
"""Block-scale batched FP8 matmul with fused activation quantization.
|
| 341 |
+
|
| 342 |
+
A: (S, K) raw activations, bf16/fp16/fp32
|
| 343 |
+
B: (E, N, K) FP8 expert weights
|
| 344 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 345 |
+
"""
|
| 346 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_batched(
|
| 347 |
+
A, B, Bs, expert_ids, block_size
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def w8a8_tensor_fp8_matmul_batched(
|
| 352 |
+
A: torch.Tensor,
|
| 353 |
+
B: torch.Tensor,
|
| 354 |
+
Bs: torch.Tensor,
|
| 355 |
+
expert_ids: torch.Tensor,
|
| 356 |
+
) -> torch.Tensor:
|
| 357 |
+
"""Tensor-scale batched FP8 matmul with fused activation quantization.
|
| 358 |
+
|
| 359 |
+
A: (S, K) raw activations, bf16/fp16/fp32
|
| 360 |
+
B: (E, N, K) FP8 expert weights
|
| 361 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 362 |
+
"""
|
| 363 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_batched(
|
| 364 |
+
A, B, Bs, expert_ids
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def w8a8_fp8_matmul_batched(
|
| 369 |
+
A: torch.Tensor,
|
| 370 |
+
B: torch.Tensor,
|
| 371 |
+
Bs: torch.Tensor,
|
| 372 |
+
expert_ids: torch.Tensor,
|
| 373 |
+
block_size: list[int] | None,
|
| 374 |
+
) -> torch.Tensor:
|
| 375 |
+
"""Unified batched W8A8 FP8 matmul dispatcher.
|
| 376 |
+
|
| 377 |
+
Dispatch rules:
|
| 378 |
+
- tensor mode when ``block_size is None``
|
| 379 |
+
- tensor mode when ``block_size == [N, K]``
|
| 380 |
+
- otherwise block mode
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
Output tensor ``[S, N]`` in the same dtype as ``A``.
|
| 384 |
+
"""
|
| 385 |
+
if block_size is None or (
|
| 386 |
+
block_size[0] == B.size(1) and block_size[1] == B.size(2)
|
| 387 |
+
):
|
| 388 |
+
return w8a8_tensor_fp8_matmul_batched(A, B, Bs, expert_ids)
|
| 389 |
+
|
| 390 |
+
return w8a8_block_fp8_matmul_batched(A, B, Bs, expert_ids, block_size)
|
build/torch-rocm/finegrained_fp8/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import importlib.util
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from types import ModuleType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-rocm/grouped.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from .act_quant import fp8_act_quant
|
| 19 |
+
from torch.library import triton_op, wrap_triton
|
| 20 |
+
|
| 21 |
+
from .utils import device_context
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4, 5]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_grouped_kernel(
|
| 34 |
+
A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
|
| 35 |
+
B, # (E, N, K) FP8 weight matrices
|
| 36 |
+
C, # (S, N) output
|
| 37 |
+
Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
|
| 38 |
+
Offsets, # (E,) int32 — cumulative row-end per expert
|
| 39 |
+
TileOffsets, # (E,) int32 — cumulative tile-end per expert
|
| 40 |
+
# Shape
|
| 41 |
+
S,
|
| 42 |
+
N,
|
| 43 |
+
K,
|
| 44 |
+
# Strides
|
| 45 |
+
stride_am,
|
| 46 |
+
stride_ak,
|
| 47 |
+
stride_be,
|
| 48 |
+
stride_bk,
|
| 49 |
+
stride_bn,
|
| 50 |
+
stride_cm,
|
| 51 |
+
stride_cn,
|
| 52 |
+
stride_bs_e,
|
| 53 |
+
stride_bs_k,
|
| 54 |
+
stride_bs_n,
|
| 55 |
+
# Meta-parameters
|
| 56 |
+
NUM_EXPERTS: tl.constexpr,
|
| 57 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 58 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 59 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 60 |
+
NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
|
| 61 |
+
):
|
| 62 |
+
"""Block-scale grouped FP8 expert matmul kernel.
|
| 63 |
+
|
| 64 |
+
Tokens are assumed sorted by expert. The kernel maps each M-tile to its
|
| 65 |
+
owning expert via ``TileOffsets`` and applies fused activation quantization.
|
| 66 |
+
"""
|
| 67 |
+
pid_m = tl.program_id(axis=0)
|
| 68 |
+
pid_n = tl.program_id(axis=1)
|
| 69 |
+
|
| 70 |
+
# Exit early for programs beyond the actual tile count.
|
| 71 |
+
total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
|
| 72 |
+
if pid_m >= total_tiles:
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
# Binary search in TileOffsets to find the owning expert.
|
| 76 |
+
# Finds the smallest e such that TileOffsets[e] > pid_m (upper_bound semantics),
|
| 77 |
+
# which is the expert whose tile range contains pid_m.
|
| 78 |
+
# O(log2(NUM_EXPERTS)) loads instead of the O(NUM_EXPERTS) linear scan.
|
| 79 |
+
# NUM_EXPERTS_BIT_LENGTH is ceil(log2(E))+1 for powers-of-two, giving one
|
| 80 |
+
# harmless extra iteration when lo==hi; it's a compile-time constant so the
|
| 81 |
+
# loop is fully unrolled by the compiler.
|
| 82 |
+
lo = 0
|
| 83 |
+
hi = NUM_EXPERTS
|
| 84 |
+
for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
|
| 85 |
+
mid = (lo + hi) >> 1
|
| 86 |
+
mid_val = tl.load(TileOffsets + mid)
|
| 87 |
+
is_left = mid_val <= pid_m
|
| 88 |
+
lo = tl.where(is_left, mid + 1, lo)
|
| 89 |
+
hi = tl.where(is_left, hi, mid)
|
| 90 |
+
|
| 91 |
+
# Cast expert_id to int64 to prevent int32 overflow when computing
|
| 92 |
+
# expert_id * stride_be (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
|
| 93 |
+
# 3072×3072 FP8 weights).
|
| 94 |
+
expert_id = lo.to(tl.int64)
|
| 95 |
+
|
| 96 |
+
prev_eid = tl.maximum(expert_id - 1, 0)
|
| 97 |
+
expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
|
| 98 |
+
expert_end = tl.load(Offsets + expert_id)
|
| 99 |
+
M_expert = expert_end - expert_start
|
| 100 |
+
|
| 101 |
+
expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
|
| 102 |
+
local_tile = pid_m - expert_tile_start
|
| 103 |
+
m_off = local_tile * BLOCK_SIZE_M
|
| 104 |
+
|
| 105 |
+
offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
|
| 106 |
+
row_mask = offs_am < M_expert
|
| 107 |
+
offs_global_m = expert_start + offs_am
|
| 108 |
+
|
| 109 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 110 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 111 |
+
|
| 112 |
+
a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 113 |
+
b_ptrs = (
|
| 114 |
+
B
|
| 115 |
+
+ expert_id * stride_be
|
| 116 |
+
+ offs_k[:, None] * stride_bk
|
| 117 |
+
+ offs_bn[None, :] * stride_bn
|
| 118 |
+
)
|
| 119 |
+
bs_ptrs = Bs + expert_id * stride_bs_e + pid_n * stride_bs_n
|
| 120 |
+
|
| 121 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 122 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 123 |
+
# ---- fused fp8_act_quant ----
|
| 124 |
+
a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32)
|
| 125 |
+
a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0
|
| 126 |
+
a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
|
| 127 |
+
# ---- matmul ----
|
| 128 |
+
b = tl.load(b_ptrs)
|
| 129 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 130 |
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
| 131 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 132 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 133 |
+
|
| 134 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 135 |
+
c = accumulator.to(tl.bfloat16)
|
| 136 |
+
elif C.dtype.element_ty == tl.float16:
|
| 137 |
+
c = accumulator.to(tl.float16)
|
| 138 |
+
else:
|
| 139 |
+
c = accumulator.to(tl.float32)
|
| 140 |
+
|
| 141 |
+
c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
|
| 142 |
+
c_mask = row_mask[:, None]
|
| 143 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@triton.autotune(
|
| 147 |
+
configs=[
|
| 148 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 149 |
+
for w in [2, 4, 8, 16]
|
| 150 |
+
for s in [2, 3, 4, 5]
|
| 151 |
+
],
|
| 152 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 153 |
+
)
|
| 154 |
+
@triton.jit
|
| 155 |
+
def w8a8_tensor_fp8_matmul_grouped_kernel(
|
| 156 |
+
A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert idc
|
| 157 |
+
B, # (E, N, K) FP8 weight matrices
|
| 158 |
+
C, # (S, N) output
|
| 159 |
+
As, # (S, 1) activation scales
|
| 160 |
+
Bs, # (E, 1, 1) per-tensor weight scales
|
| 161 |
+
Offsets,
|
| 162 |
+
TileOffsets,
|
| 163 |
+
S,
|
| 164 |
+
N,
|
| 165 |
+
K,
|
| 166 |
+
stride_am,
|
| 167 |
+
stride_ak,
|
| 168 |
+
stride_be,
|
| 169 |
+
stride_bk,
|
| 170 |
+
stride_bn,
|
| 171 |
+
stride_cm,
|
| 172 |
+
stride_cn,
|
| 173 |
+
stride_as_m,
|
| 174 |
+
stride_bs_e,
|
| 175 |
+
NUM_EXPERTS: tl.constexpr,
|
| 176 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 177 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 178 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 179 |
+
NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
|
| 180 |
+
):
|
| 181 |
+
"""Tensor-scale grouped FP8 expert matmul kernel.
|
| 182 |
+
|
| 183 |
+
Uses grouped expert scheduling with pre-quantized activations plus
|
| 184 |
+
per-token activation scales and per-expert tensor weight scales.
|
| 185 |
+
"""
|
| 186 |
+
pid_m = tl.program_id(axis=0)
|
| 187 |
+
pid_n = tl.program_id(axis=1)
|
| 188 |
+
|
| 189 |
+
total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
|
| 190 |
+
if pid_m >= total_tiles:
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
lo = 0
|
| 194 |
+
hi = NUM_EXPERTS
|
| 195 |
+
for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
|
| 196 |
+
mid = (lo + hi) >> 1
|
| 197 |
+
mid_val = tl.load(TileOffsets + mid)
|
| 198 |
+
is_left = mid_val <= pid_m
|
| 199 |
+
lo = tl.where(is_left, mid + 1, lo)
|
| 200 |
+
hi = tl.where(is_left, hi, mid)
|
| 201 |
+
expert_id = lo.to(tl.int64)
|
| 202 |
+
|
| 203 |
+
prev_eid = tl.maximum(expert_id - 1, 0)
|
| 204 |
+
expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
|
| 205 |
+
expert_end = tl.load(Offsets + expert_id)
|
| 206 |
+
M_expert = expert_end - expert_start
|
| 207 |
+
|
| 208 |
+
expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
|
| 209 |
+
local_tile = pid_m - expert_tile_start
|
| 210 |
+
m_off = local_tile * BLOCK_SIZE_M
|
| 211 |
+
|
| 212 |
+
offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
|
| 213 |
+
row_mask = offs_am < M_expert
|
| 214 |
+
offs_global_m = expert_start + offs_am
|
| 215 |
+
|
| 216 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 217 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 218 |
+
|
| 219 |
+
a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 220 |
+
b_ptrs = (
|
| 221 |
+
B
|
| 222 |
+
+ expert_id * stride_be
|
| 223 |
+
+ offs_k[:, None] * stride_bk
|
| 224 |
+
+ offs_bn[None, :] * stride_bn
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
a_s = tl.load(As + offs_global_m * stride_as_m, mask=row_mask, other=0.0)
|
| 228 |
+
b_s = tl.load(Bs + expert_id * stride_bs_e)
|
| 229 |
+
|
| 230 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 231 |
+
for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 232 |
+
a = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0)
|
| 233 |
+
b = tl.load(b_ptrs)
|
| 234 |
+
|
| 235 |
+
accumulator += tl.dot(a, b)
|
| 236 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 237 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 238 |
+
|
| 239 |
+
accumulator = accumulator * a_s[:, None] * b_s
|
| 240 |
+
|
| 241 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 242 |
+
c = accumulator.to(tl.bfloat16)
|
| 243 |
+
elif C.dtype.element_ty == tl.float16:
|
| 244 |
+
c = accumulator.to(tl.float16)
|
| 245 |
+
else:
|
| 246 |
+
c = accumulator.to(tl.float32)
|
| 247 |
+
|
| 248 |
+
c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
|
| 249 |
+
c_mask = row_mask[:, None]
|
| 250 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul_grouped", mutates_args=())
|
| 254 |
+
def _w8a8_block_fp8_matmul_grouped(
|
| 255 |
+
A: torch.Tensor,
|
| 256 |
+
B: torch.Tensor,
|
| 257 |
+
Bs: torch.Tensor,
|
| 258 |
+
offsets: torch.Tensor,
|
| 259 |
+
tokens_per_expert: torch.Tensor,
|
| 260 |
+
block_size: list[int],
|
| 261 |
+
) -> torch.Tensor:
|
| 262 |
+
"""Block-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
|
| 263 |
+
|
| 264 |
+
A: (S, K) raw bf16/fp16 activations, sorted by expert
|
| 265 |
+
B: (E, N, K) FP8 expert weights
|
| 266 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 267 |
+
"""
|
| 268 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 269 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 270 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 271 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 272 |
+
assert A.shape[1] == B.shape[2], (
|
| 273 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
S, K = A.shape
|
| 277 |
+
E, N, _ = B.shape
|
| 278 |
+
|
| 279 |
+
assert len(block_size) == 2, (
|
| 280 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 281 |
+
)
|
| 282 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 283 |
+
# MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
|
| 284 |
+
assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
|
| 285 |
+
assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
|
| 286 |
+
assert Bs.ndim == 3, (
|
| 287 |
+
f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 288 |
+
)
|
| 289 |
+
assert Bs.shape == (E, N // block_n, K // block_k), (
|
| 290 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
C = A.new_empty(S, N)
|
| 294 |
+
# Adaptive BLOCK_SIZE_M: match tile to average tokens per expert.
|
| 295 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 296 |
+
tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
| 297 |
+
tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
|
| 298 |
+
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
|
| 299 |
+
# Programs beyond the real tile count exit immediately via the early-return
|
| 300 |
+
# guard inside the kernel. This is faster than syncing for the exact count
|
| 301 |
+
# and keeps the grid size data-independent (cuda-graph / torch.compile safe).
|
| 302 |
+
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
|
| 303 |
+
grid = (max_M_tiles, triton.cdiv(N, block_n))
|
| 304 |
+
with device_context(A.device):
|
| 305 |
+
wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
|
| 306 |
+
A,
|
| 307 |
+
B,
|
| 308 |
+
C,
|
| 309 |
+
Bs,
|
| 310 |
+
offsets,
|
| 311 |
+
tile_offsets,
|
| 312 |
+
S,
|
| 313 |
+
N,
|
| 314 |
+
K,
|
| 315 |
+
A.stride(0),
|
| 316 |
+
A.stride(1),
|
| 317 |
+
B.stride(0),
|
| 318 |
+
B.stride(2),
|
| 319 |
+
B.stride(1),
|
| 320 |
+
C.stride(0),
|
| 321 |
+
C.stride(1),
|
| 322 |
+
Bs.stride(0),
|
| 323 |
+
Bs.stride(2),
|
| 324 |
+
Bs.stride(1),
|
| 325 |
+
# Meta-parameters
|
| 326 |
+
NUM_EXPERTS=E,
|
| 327 |
+
BLOCK_SIZE_N=block_n,
|
| 328 |
+
BLOCK_SIZE_K=block_k,
|
| 329 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 330 |
+
NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return C
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_grouped", mutates_args=())
|
| 337 |
+
def _w8a8_tensor_fp8_matmul_grouped(
|
| 338 |
+
A: torch.Tensor,
|
| 339 |
+
B: torch.Tensor,
|
| 340 |
+
Bs: torch.Tensor,
|
| 341 |
+
offsets: torch.Tensor,
|
| 342 |
+
tokens_per_expert: torch.Tensor,
|
| 343 |
+
) -> torch.Tensor:
|
| 344 |
+
"""Tensor-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
|
| 345 |
+
|
| 346 |
+
A: (S, K) raw bf16/fp16 activations, sorted by expert
|
| 347 |
+
B: (E, N, K) FP8 expert weights
|
| 348 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 349 |
+
"""
|
| 350 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 351 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 352 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 353 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 354 |
+
assert A.shape[1] == B.shape[2], (
|
| 355 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
S, K = A.shape
|
| 359 |
+
E, N, _ = B.shape
|
| 360 |
+
|
| 361 |
+
# Normalize Bs to (E, 1, 1)
|
| 362 |
+
if Bs.ndim == 1:
|
| 363 |
+
assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
|
| 364 |
+
Bs = Bs.reshape(E, 1, 1)
|
| 365 |
+
else:
|
| 366 |
+
assert Bs.shape == (E, 1, 1), (
|
| 367 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
BLOCK_SIZE_N = 128
|
| 371 |
+
BLOCK_SIZE_K = 128
|
| 372 |
+
C = A.new_empty(S, N)
|
| 373 |
+
qA, As = fp8_act_quant(A, K)
|
| 374 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 375 |
+
tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
| 376 |
+
tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
|
| 377 |
+
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
|
| 378 |
+
# Programs beyond the real tile count exit immediately via the early-return
|
| 379 |
+
# guard inside the kernel. This is faster than syncing for the exact count
|
| 380 |
+
# and keeps the grid size data-independent (cuda-graph / torch.compile safe).
|
| 381 |
+
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
|
| 382 |
+
grid = (max_M_tiles, triton.cdiv(N, BLOCK_SIZE_N))
|
| 383 |
+
with device_context(A.device):
|
| 384 |
+
wrap_triton(w8a8_tensor_fp8_matmul_grouped_kernel)[grid](
|
| 385 |
+
qA,
|
| 386 |
+
B,
|
| 387 |
+
C,
|
| 388 |
+
As,
|
| 389 |
+
Bs,
|
| 390 |
+
offsets,
|
| 391 |
+
tile_offsets,
|
| 392 |
+
S,
|
| 393 |
+
N,
|
| 394 |
+
K,
|
| 395 |
+
qA.stride(0),
|
| 396 |
+
qA.stride(1),
|
| 397 |
+
B.stride(0),
|
| 398 |
+
B.stride(2),
|
| 399 |
+
B.stride(1),
|
| 400 |
+
C.stride(0),
|
| 401 |
+
C.stride(1),
|
| 402 |
+
As.stride(0),
|
| 403 |
+
Bs.stride(0),
|
| 404 |
+
# Meta-parameters
|
| 405 |
+
NUM_EXPERTS=E,
|
| 406 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 407 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 408 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 409 |
+
NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
return C
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def w8a8_block_fp8_matmul_grouped(
|
| 416 |
+
A: torch.Tensor,
|
| 417 |
+
B: torch.Tensor,
|
| 418 |
+
Bs: torch.Tensor,
|
| 419 |
+
offsets: torch.Tensor,
|
| 420 |
+
tokens_per_expert: torch.Tensor,
|
| 421 |
+
block_size: list[int],
|
| 422 |
+
) -> torch.Tensor:
|
| 423 |
+
"""Block-scale grouped FP8 matmul with fused activation quantization.
|
| 424 |
+
|
| 425 |
+
A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
|
| 426 |
+
B: (E, N, K) FP8 expert weights
|
| 427 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 428 |
+
"""
|
| 429 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
|
| 430 |
+
A, B, Bs, offsets, tokens_per_expert, block_size
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def w8a8_tensor_fp8_matmul_grouped(
|
| 435 |
+
A: torch.Tensor,
|
| 436 |
+
B: torch.Tensor,
|
| 437 |
+
Bs: torch.Tensor,
|
| 438 |
+
offsets: torch.Tensor,
|
| 439 |
+
tokens_per_expert: torch.Tensor,
|
| 440 |
+
) -> torch.Tensor:
|
| 441 |
+
"""Tensor-scale grouped FP8 matmul with fused activation quantization.
|
| 442 |
+
|
| 443 |
+
A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
|
| 444 |
+
B: (E, N, K) FP8 expert weights
|
| 445 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 446 |
+
"""
|
| 447 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
|
| 448 |
+
A, B, Bs, offsets, tokens_per_expert
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def w8a8_fp8_matmul_grouped(
|
| 453 |
+
A: torch.Tensor,
|
| 454 |
+
B: torch.Tensor,
|
| 455 |
+
Bs: torch.Tensor,
|
| 456 |
+
offsets: torch.Tensor,
|
| 457 |
+
tokens_per_expert: torch.Tensor,
|
| 458 |
+
block_size: list[int] | None,
|
| 459 |
+
) -> torch.Tensor:
|
| 460 |
+
"""Unified grouped W8A8 FP8 matmul dispatcher.
|
| 461 |
+
|
| 462 |
+
Dispatch rules:
|
| 463 |
+
- tensor mode when ``block_size is None``
|
| 464 |
+
- tensor mode when ``block_size == [N, K]``
|
| 465 |
+
- otherwise block mode
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
|
| 469 |
+
"""
|
| 470 |
+
if block_size is None or (
|
| 471 |
+
block_size[0] == B.size(1) and block_size[1] == B.size(2)
|
| 472 |
+
):
|
| 473 |
+
return w8a8_tensor_fp8_matmul_grouped(A, B, Bs, offsets, tokens_per_expert)
|
| 474 |
+
|
| 475 |
+
return w8a8_block_fp8_matmul_grouped(
|
| 476 |
+
A, B, Bs, offsets, tokens_per_expert, block_size
|
| 477 |
+
)
|
build/torch-rocm/matmul.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch.library import triton_op, wrap_triton
|
| 19 |
+
|
| 20 |
+
from .utils import device_context
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_kernel(
|
| 34 |
+
# Pointers to inputs and output
|
| 35 |
+
A,
|
| 36 |
+
B,
|
| 37 |
+
C,
|
| 38 |
+
As,
|
| 39 |
+
Bs,
|
| 40 |
+
# Shape for matmul
|
| 41 |
+
M,
|
| 42 |
+
N,
|
| 43 |
+
K,
|
| 44 |
+
stride_am,
|
| 45 |
+
stride_ak,
|
| 46 |
+
stride_bk,
|
| 47 |
+
stride_bn,
|
| 48 |
+
stride_cm,
|
| 49 |
+
stride_cn,
|
| 50 |
+
stride_as_m,
|
| 51 |
+
stride_as_k,
|
| 52 |
+
stride_bs_k,
|
| 53 |
+
stride_bs_n,
|
| 54 |
+
# Meta-parameters
|
| 55 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 56 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 57 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 58 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 59 |
+
):
|
| 60 |
+
"""Block-scale FP8 GEMM kernel.
|
| 61 |
+
|
| 62 |
+
Computes ``C = A @ B.T`` with block-wise activation/weight scales.
|
| 63 |
+
Uses a 2D grid with swizzle for L2 cache locality on B tiles.
|
| 64 |
+
"""
|
| 65 |
+
pid_m = tl.program_id(axis=0)
|
| 66 |
+
pid_n = tl.program_id(axis=1)
|
| 67 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 68 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 69 |
+
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
| 70 |
+
|
| 71 |
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
| 72 |
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 73 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 74 |
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
| 75 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 76 |
+
|
| 77 |
+
as_ptrs = As + offs_am * stride_as_m
|
| 78 |
+
offs_bsn = offs_bn // BLOCK_SIZE_N
|
| 79 |
+
bs_ptrs = Bs + offs_bsn * stride_bs_n
|
| 80 |
+
|
| 81 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 82 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 83 |
+
k_remaining = K - k * BLOCK_SIZE_K
|
| 84 |
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
|
| 85 |
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
| 86 |
+
|
| 87 |
+
a_s = tl.load(as_ptrs + k * stride_as_k)
|
| 88 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 89 |
+
|
| 90 |
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
| 91 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 92 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 93 |
+
|
| 94 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 95 |
+
c = accumulator.to(tl.bfloat16)
|
| 96 |
+
elif C.dtype.element_ty == tl.float16:
|
| 97 |
+
c = accumulator.to(tl.float16)
|
| 98 |
+
else:
|
| 99 |
+
c = accumulator.to(tl.float32)
|
| 100 |
+
|
| 101 |
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 102 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 103 |
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
| 104 |
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
| 105 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@triton.autotune(
|
| 109 |
+
configs=[
|
| 110 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 111 |
+
for w in [2, 4, 8, 16]
|
| 112 |
+
for s in [2, 3, 4]
|
| 113 |
+
],
|
| 114 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 115 |
+
)
|
| 116 |
+
@triton.jit
|
| 117 |
+
def w8a8_tensor_fp8_matmul_kernel(
|
| 118 |
+
A,
|
| 119 |
+
B,
|
| 120 |
+
C,
|
| 121 |
+
As,
|
| 122 |
+
Bs,
|
| 123 |
+
M,
|
| 124 |
+
N,
|
| 125 |
+
K,
|
| 126 |
+
stride_am,
|
| 127 |
+
stride_ak,
|
| 128 |
+
stride_bk,
|
| 129 |
+
stride_bn,
|
| 130 |
+
stride_cm,
|
| 131 |
+
stride_cn,
|
| 132 |
+
stride_as_m,
|
| 133 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 134 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 135 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 136 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 137 |
+
):
|
| 138 |
+
"""Tensor-scale FP8 GEMM kernel.
|
| 139 |
+
|
| 140 |
+
Computes ``C = A @ B.T`` with one activation scale per row and one
|
| 141 |
+
weight scale for the full matrix.
|
| 142 |
+
Uses a 2D grid with swizzle for L2 cache locality on B tiles.
|
| 143 |
+
"""
|
| 144 |
+
pid_m = tl.program_id(axis=0)
|
| 145 |
+
pid_n = tl.program_id(axis=1)
|
| 146 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 147 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 148 |
+
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
| 149 |
+
|
| 150 |
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
| 151 |
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 152 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 153 |
+
|
| 154 |
+
a_ptrs = A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 155 |
+
b_ptrs = B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
| 156 |
+
|
| 157 |
+
a_s = tl.load(As + offs_am * stride_as_m)
|
| 158 |
+
b_s = tl.load(Bs)
|
| 159 |
+
|
| 160 |
+
# Accumulate raw dot products, apply scales once after the loop.
|
| 161 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 162 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 163 |
+
k_remaining = K - k * BLOCK_SIZE_K
|
| 164 |
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
|
| 165 |
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
| 166 |
+
accumulator += tl.dot(a, b)
|
| 167 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 168 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 169 |
+
|
| 170 |
+
accumulator = accumulator * a_s[:, None] * b_s
|
| 171 |
+
|
| 172 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 173 |
+
c = accumulator.to(tl.bfloat16)
|
| 174 |
+
elif C.dtype.element_ty == tl.float16:
|
| 175 |
+
c = accumulator.to(tl.float16)
|
| 176 |
+
else:
|
| 177 |
+
c = accumulator.to(tl.float32)
|
| 178 |
+
|
| 179 |
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 180 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 181 |
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
| 182 |
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
| 183 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul", mutates_args=())
|
| 187 |
+
def _w8a8_block_fp8_matmul(
|
| 188 |
+
A: torch.Tensor,
|
| 189 |
+
B: torch.Tensor,
|
| 190 |
+
As: torch.Tensor,
|
| 191 |
+
Bs: torch.Tensor,
|
| 192 |
+
block_size: list[int],
|
| 193 |
+
output_dtype: torch.dtype = torch.float32,
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""Block-scale FP8 matmul: C = A @ B.T with per-block scales.
|
| 196 |
+
|
| 197 |
+
As: (M, K // block_k) — per-token-group activation scales
|
| 198 |
+
Bs: (N // block_n, K // block_k) — per-block weight scales
|
| 199 |
+
"""
|
| 200 |
+
assert len(block_size) == 2, (
|
| 201 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 202 |
+
)
|
| 203 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 204 |
+
|
| 205 |
+
assert A.shape[-1] == B.shape[-1], (
|
| 206 |
+
f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
|
| 207 |
+
)
|
| 208 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 209 |
+
assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
|
| 210 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 211 |
+
|
| 212 |
+
N, K = B.shape
|
| 213 |
+
M = A.numel() // A.shape[-1]
|
| 214 |
+
|
| 215 |
+
assert As.ndim >= 2, f"As must be at least 2D, got ndim={As.ndim}"
|
| 216 |
+
assert As.shape[-1] == triton.cdiv(K, block_k), (
|
| 217 |
+
f"As last dim {As.shape[-1]} != expected {triton.cdiv(K, block_k)} (cdiv(K={K}, block_k={block_k}))"
|
| 218 |
+
)
|
| 219 |
+
assert Bs.ndim == 2, f"Bs must be 2D (N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 220 |
+
assert Bs.shape == (triton.cdiv(N, block_n), triton.cdiv(K, block_k)), (
|
| 221 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({triton.cdiv(N, block_n)}, {triton.cdiv(K, block_k)})"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
BLOCK_SIZE_K = block_k
|
| 225 |
+
BLOCK_SIZE_N = block_n
|
| 226 |
+
C_shape = A.shape[:-1] + (N,)
|
| 227 |
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
| 228 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 229 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 230 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 231 |
+
# both accuracy and performance for small M (decode).
|
| 232 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
|
| 233 |
+
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
|
| 234 |
+
with device_context(A.device):
|
| 235 |
+
wrap_triton(w8a8_block_fp8_matmul_kernel)[grid](
|
| 236 |
+
A,
|
| 237 |
+
B,
|
| 238 |
+
C,
|
| 239 |
+
As,
|
| 240 |
+
Bs,
|
| 241 |
+
M,
|
| 242 |
+
N,
|
| 243 |
+
K,
|
| 244 |
+
A.stride(-2),
|
| 245 |
+
A.stride(-1),
|
| 246 |
+
B.stride(1),
|
| 247 |
+
B.stride(0),
|
| 248 |
+
C.stride(-2),
|
| 249 |
+
C.stride(-1),
|
| 250 |
+
As.stride(-2),
|
| 251 |
+
As.stride(-1),
|
| 252 |
+
Bs.stride(1),
|
| 253 |
+
Bs.stride(0),
|
| 254 |
+
# Meta-parameters
|
| 255 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 256 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 257 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 258 |
+
GROUP_SIZE_M=8,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return C
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul", mutates_args=())
|
| 265 |
+
def _w8a8_tensor_fp8_matmul(
|
| 266 |
+
A: torch.Tensor,
|
| 267 |
+
B: torch.Tensor,
|
| 268 |
+
As: torch.Tensor,
|
| 269 |
+
Bs: torch.Tensor,
|
| 270 |
+
output_dtype: torch.dtype = torch.float32,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
"""Tensor-scale FP8 matmul: C = A @ B.T with per-row / per-tensor scales.
|
| 273 |
+
|
| 274 |
+
As: scalar, (M,), or (M, 1) — per-row activation scales
|
| 275 |
+
Bs: scalar, (1,), or (1, 1) — single weight scale
|
| 276 |
+
"""
|
| 277 |
+
assert A.shape[-1] == B.shape[-1], (
|
| 278 |
+
f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
|
| 279 |
+
)
|
| 280 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 281 |
+
assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
|
| 282 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 283 |
+
|
| 284 |
+
N, K = B.shape
|
| 285 |
+
M = A.numel() // A.shape[-1]
|
| 286 |
+
|
| 287 |
+
# Normalize As to (M,)
|
| 288 |
+
if As.numel() == 1:
|
| 289 |
+
As = As.reshape(1).expand(M).contiguous()
|
| 290 |
+
elif As.ndim == 2:
|
| 291 |
+
As = As.reshape(M)
|
| 292 |
+
assert As.ndim == 1 and As.shape[0] == M, (
|
| 293 |
+
f"As must be scalar, (M,), or (M,1) with M={M}, got {tuple(As.shape)}"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Normalize Bs to (1,)
|
| 297 |
+
assert Bs.numel() == 1, f"Bs must be scalar or (1,), got {tuple(Bs.shape)}"
|
| 298 |
+
Bs = Bs.reshape(1)
|
| 299 |
+
|
| 300 |
+
BLOCK_SIZE_N = 128
|
| 301 |
+
BLOCK_SIZE_K = 128
|
| 302 |
+
C_shape = A.shape[:-1] + (N,)
|
| 303 |
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
| 304 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
|
| 305 |
+
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
|
| 306 |
+
with device_context(A.device):
|
| 307 |
+
wrap_triton(w8a8_tensor_fp8_matmul_kernel)[grid](
|
| 308 |
+
A,
|
| 309 |
+
B,
|
| 310 |
+
C,
|
| 311 |
+
As,
|
| 312 |
+
Bs,
|
| 313 |
+
M,
|
| 314 |
+
N,
|
| 315 |
+
K,
|
| 316 |
+
A.stride(-2),
|
| 317 |
+
A.stride(-1),
|
| 318 |
+
B.stride(1),
|
| 319 |
+
B.stride(0),
|
| 320 |
+
C.stride(-2),
|
| 321 |
+
C.stride(-1),
|
| 322 |
+
As.stride(0),
|
| 323 |
+
# Meta-parameters
|
| 324 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 325 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 326 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 327 |
+
GROUP_SIZE_M=8,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return C
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def w8a8_block_fp8_matmul(
|
| 334 |
+
A: torch.Tensor,
|
| 335 |
+
B: torch.Tensor,
|
| 336 |
+
As: torch.Tensor,
|
| 337 |
+
Bs: torch.Tensor,
|
| 338 |
+
block_size: list[int],
|
| 339 |
+
output_dtype: torch.dtype = torch.float32,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
"""Block-wise W8A8 FP8 matrix multiplication.
|
| 342 |
+
|
| 343 |
+
Computes ``C = A @ B.T`` where both operands are pre-quantized to
|
| 344 |
+
``float8_e4m3fn`` with per-block scales, and accumulates in float32
|
| 345 |
+
before casting to ``output_dtype``.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
|
| 349 |
+
B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
|
| 350 |
+
As: Per-token-group activation scales ``[M, K // block_size[1]]``.
|
| 351 |
+
Bs: Per-block weight scales ``[N // block_size[0], K // block_size[1]]``.
|
| 352 |
+
block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
|
| 353 |
+
output_dtype: dtype of the returned tensor (default: ``torch.float32``).
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 357 |
+
"""
|
| 358 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul(
|
| 359 |
+
A, B, As, Bs, block_size, output_dtype
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def w8a8_tensor_fp8_matmul(
|
| 364 |
+
A: torch.Tensor,
|
| 365 |
+
B: torch.Tensor,
|
| 366 |
+
As: torch.Tensor,
|
| 367 |
+
Bs: torch.Tensor,
|
| 368 |
+
output_dtype: torch.dtype = torch.float32,
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
"""Tensor-scale W8A8 FP8 matrix multiplication.
|
| 371 |
+
|
| 372 |
+
Computes ``C = A @ B.T`` in tensor-scale mode using pre-quantized FP8
|
| 373 |
+
activations/weights and tensor scales.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
|
| 377 |
+
B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
|
| 378 |
+
As: Per-row activation scales ``[M]``.
|
| 379 |
+
Bs: Single weight scale, scalar or ``[1]``.
|
| 380 |
+
output_dtype: dtype of the returned tensor.
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 384 |
+
"""
|
| 385 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def w8a8_fp8_matmul(
|
| 389 |
+
A: torch.Tensor,
|
| 390 |
+
B: torch.Tensor,
|
| 391 |
+
As: torch.Tensor,
|
| 392 |
+
Bs: torch.Tensor,
|
| 393 |
+
block_size: list[int] | None,
|
| 394 |
+
output_dtype: torch.dtype = torch.float32,
|
| 395 |
+
) -> torch.Tensor:
|
| 396 |
+
"""Unified W8A8 FP8 matmul dispatcher.
|
| 397 |
+
|
| 398 |
+
Dispatch rules:
|
| 399 |
+
- tensor mode when ``block_size is None``
|
| 400 |
+
- tensor mode when ``block_size == [N, K]``
|
| 401 |
+
- otherwise block mode
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 405 |
+
"""
|
| 406 |
+
if block_size is None or (
|
| 407 |
+
block_size[0] == B.size(0) and block_size[1] == B.size(1)
|
| 408 |
+
):
|
| 409 |
+
return w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
|
| 410 |
+
|
| 411 |
+
return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
|
build/torch-rocm/metadata.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 1,
|
| 3 |
+
"license": "Apache-2.0",
|
| 4 |
+
"python-depends": [],
|
| 5 |
+
"backend": {
|
| 6 |
+
"type": "rocm"
|
| 7 |
+
}
|
| 8 |
+
}
|
build/torch-rocm/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@contextmanager
|
| 6 |
+
def device_context(device: torch.device):
|
| 7 |
+
"""Context manager that sets the active device for any backend (cuda, xpu, etc.)."""
|
| 8 |
+
backend = getattr(torch, device.type, None)
|
| 9 |
+
if backend is not None and hasattr(backend, "device"):
|
| 10 |
+
with backend.device(device):
|
| 11 |
+
yield
|
| 12 |
+
else:
|
| 13 |
+
yield
|
build/torch-xpu/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .act_quant import fp8_act_quant
|
| 2 |
+
from .batched import (
|
| 3 |
+
w8a8_fp8_matmul_batched,
|
| 4 |
+
w8a8_block_fp8_matmul_batched,
|
| 5 |
+
w8a8_tensor_fp8_matmul_batched,
|
| 6 |
+
)
|
| 7 |
+
from .grouped import (
|
| 8 |
+
w8a8_fp8_matmul_grouped,
|
| 9 |
+
w8a8_block_fp8_matmul_grouped,
|
| 10 |
+
w8a8_tensor_fp8_matmul_grouped,
|
| 11 |
+
)
|
| 12 |
+
from .matmul import (
|
| 13 |
+
w8a8_fp8_matmul,
|
| 14 |
+
w8a8_block_fp8_matmul,
|
| 15 |
+
w8a8_tensor_fp8_matmul,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"fp8_act_quant",
|
| 20 |
+
# Single matmul
|
| 21 |
+
"w8a8_fp8_matmul",
|
| 22 |
+
"w8a8_block_fp8_matmul",
|
| 23 |
+
"w8a8_tensor_fp8_matmul",
|
| 24 |
+
# Batched matmul
|
| 25 |
+
"w8a8_fp8_matmul_batched",
|
| 26 |
+
"w8a8_block_fp8_matmul_batched",
|
| 27 |
+
"w8a8_tensor_fp8_matmul_batched",
|
| 28 |
+
# Grouped matmul
|
| 29 |
+
"w8a8_fp8_matmul_grouped",
|
| 30 |
+
"w8a8_block_fp8_matmul_grouped",
|
| 31 |
+
"w8a8_tensor_fp8_matmul_grouped",
|
| 32 |
+
]
|
build/torch-xpu/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._finegrained_fp8_75cbe1b
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_finegrained_fp8_75cbe1b::{op_name}"
|
build/torch-xpu/act_quant.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch.library import triton_op, wrap_triton
|
| 19 |
+
|
| 20 |
+
from .utils import device_context
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_FP8_DTYPE = torch.float8_e4m3fn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
| 27 |
+
@triton.jit
|
| 28 |
+
def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
| 29 |
+
pid = tl.program_id(axis=0)
|
| 30 |
+
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
+
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
|
| 33 |
+
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
+
tl.store(y_ptr + offs, y)
|
| 35 |
+
tl.store(s_ptr + pid, s)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@triton_op("finegrained_fp8::fp8_act_quant", mutates_args=())
|
| 39 |
+
def _fp8_act_quant(
|
| 40 |
+
x: torch.Tensor, block_size: int = 128
|
| 41 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
assert x.is_contiguous()
|
| 43 |
+
assert x.shape[-1] % block_size == 0
|
| 44 |
+
y = torch.empty_like(x, dtype=_FP8_DTYPE)
|
| 45 |
+
grid = (triton.cdiv(x.numel(), block_size),)
|
| 46 |
+
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
| 47 |
+
|
| 48 |
+
with device_context(x.device):
|
| 49 |
+
wrap_triton(_fp8_act_quant_kernel)[grid](x, y, s, BLOCK_SIZE=block_size)
|
| 50 |
+
|
| 51 |
+
return y, s
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def fp8_act_quant(
|
| 55 |
+
x: torch.Tensor, block_size: int = 128
|
| 56 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 57 |
+
"""Quantize activations to FP8 with per-block dynamic scaling.
|
| 58 |
+
|
| 59 |
+
Splits the last dimension of ``x`` into blocks of ``block_size`` elements,
|
| 60 |
+
computes ``scale = max(|x_block|) / 448`` per block, and quantizes to
|
| 61 |
+
``float8_e4m3fn``.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
x: Input tensor in bf16/fp16/fp32. Last dimension must be divisible by
|
| 65 |
+
``block_size`` and the tensor must be contiguous.
|
| 66 |
+
block_size: Number of elements per quantization block (default: 128).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
A tuple ``(quantized, scales)`` where ``quantized`` has dtype
|
| 70 |
+
``float8_e4m3fn`` with the same shape as ``x``, and ``scales`` has
|
| 71 |
+
shape ``(*x.shape[:-1], x.shape[-1] // block_size)`` in float32.
|
| 72 |
+
"""
|
| 73 |
+
return torch.ops.finegrained_fp8.fp8_act_quant(x, block_size)
|
build/torch-xpu/batched.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from .act_quant import fp8_act_quant
|
| 19 |
+
from torch.library import triton_op, wrap_triton
|
| 20 |
+
|
| 21 |
+
from .utils import device_context
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4, 5]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_batched_kernel(
|
| 34 |
+
A, # (S, K) raw BF16/FP16 activations
|
| 35 |
+
B, # (E, N, K) FP8 weight matrices
|
| 36 |
+
C, # (S, N) output
|
| 37 |
+
Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
|
| 38 |
+
ExpertIds, # (S,) — which expert each batch element routes to
|
| 39 |
+
# Shape
|
| 40 |
+
S,
|
| 41 |
+
N,
|
| 42 |
+
K,
|
| 43 |
+
stride_am,
|
| 44 |
+
stride_ak,
|
| 45 |
+
stride_be,
|
| 46 |
+
stride_bk,
|
| 47 |
+
stride_bn,
|
| 48 |
+
stride_cm,
|
| 49 |
+
stride_cn,
|
| 50 |
+
stride_bs_e,
|
| 51 |
+
stride_bs_k,
|
| 52 |
+
stride_bs_n,
|
| 53 |
+
# Meta-parameters
|
| 54 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 55 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 56 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 57 |
+
):
|
| 58 |
+
"""Block-scale batched FP8 expert matmul kernel.
|
| 59 |
+
|
| 60 |
+
Each program handles one routed token row and one N-tile, looks up the
|
| 61 |
+
owning expert from ``ExpertIds``, and applies fused activation quantization.
|
| 62 |
+
"""
|
| 63 |
+
batch_id = tl.program_id(axis=0)
|
| 64 |
+
pid_n = tl.program_id(axis=1)
|
| 65 |
+
|
| 66 |
+
# Cast expert_id to int64 to prevent int32 overflow when computing
|
| 67 |
+
# expert_id * stride_Eb (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
|
| 68 |
+
# 3072×3072 FP8 weights).
|
| 69 |
+
expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
|
| 70 |
+
|
| 71 |
+
A = A + batch_id * stride_am
|
| 72 |
+
B = B + expert_id * stride_be
|
| 73 |
+
C = C + batch_id * stride_cm
|
| 74 |
+
Bs = Bs + expert_id * stride_bs_e
|
| 75 |
+
|
| 76 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 77 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 78 |
+
a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
|
| 79 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 80 |
+
|
| 81 |
+
bs_ptrs = Bs + pid_n * stride_bs_n
|
| 82 |
+
|
| 83 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 84 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 85 |
+
# ---- fused fp8_act_quant ----
|
| 86 |
+
a_raw = tl.load(a_ptrs).to(tl.float32)
|
| 87 |
+
a_s = tl.max(tl.abs(a_raw)) / 448.0
|
| 88 |
+
a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv)
|
| 89 |
+
# ---- matmul ----
|
| 90 |
+
b = tl.load(b_ptrs)
|
| 91 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 92 |
+
accumulator += tl.dot(a, b) * a_s * b_s[None, :]
|
| 93 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 94 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 95 |
+
|
| 96 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 97 |
+
c = accumulator.to(tl.bfloat16)
|
| 98 |
+
elif C.dtype.element_ty == tl.float16:
|
| 99 |
+
c = accumulator.to(tl.float16)
|
| 100 |
+
else:
|
| 101 |
+
c = accumulator.to(tl.float32)
|
| 102 |
+
|
| 103 |
+
offs_cm = tl.arange(0, BLOCK_SIZE_M)
|
| 104 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 105 |
+
c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
|
| 106 |
+
tl.store(c_ptrs, c)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@triton.autotune(
|
| 110 |
+
configs=[
|
| 111 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 112 |
+
for w in [2, 4, 8, 16]
|
| 113 |
+
for s in [2, 3, 4, 5]
|
| 114 |
+
],
|
| 115 |
+
key=["N", "K"],
|
| 116 |
+
)
|
| 117 |
+
@triton.jit
|
| 118 |
+
def w8a8_tensor_fp8_matmul_batched_kernel(
|
| 119 |
+
A, # (S, K) pre-quantized FP8 activations
|
| 120 |
+
B, # (E, N, K) FP8 weight matrices
|
| 121 |
+
C, # (S, N) output
|
| 122 |
+
As, # (S, 1) per-tensor activation scales
|
| 123 |
+
Bs, # (E, 1, 1) per-tensor weight scales
|
| 124 |
+
ExpertIds,
|
| 125 |
+
S,
|
| 126 |
+
N,
|
| 127 |
+
K,
|
| 128 |
+
stride_am,
|
| 129 |
+
stride_ak,
|
| 130 |
+
stride_be,
|
| 131 |
+
stride_bk,
|
| 132 |
+
stride_bn,
|
| 133 |
+
stride_cm,
|
| 134 |
+
stride_cn,
|
| 135 |
+
stride_as_m,
|
| 136 |
+
stride_bs_e,
|
| 137 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 138 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 139 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 140 |
+
):
|
| 141 |
+
"""Tensor-scale batched FP8 expert matmul kernel.
|
| 142 |
+
|
| 143 |
+
Activations are already quantized; the kernel applies per-token activation
|
| 144 |
+
scales and per-expert tensor weight scales.
|
| 145 |
+
"""
|
| 146 |
+
batch_id = tl.program_id(axis=0)
|
| 147 |
+
pid_n = tl.program_id(axis=1)
|
| 148 |
+
|
| 149 |
+
expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
|
| 150 |
+
|
| 151 |
+
A = A + batch_id * stride_am
|
| 152 |
+
B = B + expert_id * stride_be
|
| 153 |
+
C = C + batch_id * stride_cm
|
| 154 |
+
Bs = Bs + expert_id * stride_bs_e
|
| 155 |
+
|
| 156 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 157 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 158 |
+
a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
|
| 159 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 160 |
+
|
| 161 |
+
b_s = tl.load(Bs)
|
| 162 |
+
a_s = tl.load(As + batch_id * stride_as_m)
|
| 163 |
+
|
| 164 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 165 |
+
for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 166 |
+
a = tl.load(a_ptrs)
|
| 167 |
+
b = tl.load(b_ptrs)
|
| 168 |
+
accumulator += tl.dot(a, b)
|
| 169 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 170 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 171 |
+
|
| 172 |
+
accumulator = accumulator * a_s * b_s
|
| 173 |
+
|
| 174 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 175 |
+
c = accumulator.to(tl.bfloat16)
|
| 176 |
+
elif C.dtype.element_ty == tl.float16:
|
| 177 |
+
c = accumulator.to(tl.float16)
|
| 178 |
+
else:
|
| 179 |
+
c = accumulator.to(tl.float32)
|
| 180 |
+
|
| 181 |
+
offs_cm = tl.arange(0, BLOCK_SIZE_M)
|
| 182 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 183 |
+
c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
|
| 184 |
+
tl.store(c_ptrs, c)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul_batched", mutates_args=())
|
| 188 |
+
def _w8a8_block_fp8_matmul_batched(
|
| 189 |
+
A: torch.Tensor,
|
| 190 |
+
B: torch.Tensor,
|
| 191 |
+
Bs: torch.Tensor,
|
| 192 |
+
expert_ids: torch.Tensor,
|
| 193 |
+
block_size: list[int],
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""Block-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
|
| 196 |
+
|
| 197 |
+
A: (S, K) raw bf16/fp16 activations
|
| 198 |
+
B: (E, N, K) FP8 expert weights
|
| 199 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 200 |
+
"""
|
| 201 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 202 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 203 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 204 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 205 |
+
assert A.shape[1] == B.shape[2], (
|
| 206 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
S, K = A.shape
|
| 210 |
+
E, N, _ = B.shape
|
| 211 |
+
|
| 212 |
+
assert len(block_size) == 2, (
|
| 213 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 214 |
+
)
|
| 215 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 216 |
+
# MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
|
| 217 |
+
assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
|
| 218 |
+
assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
|
| 219 |
+
assert Bs.ndim == 3, (
|
| 220 |
+
f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 221 |
+
)
|
| 222 |
+
assert Bs.shape == (E, N // block_n, K // block_k), (
|
| 223 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
C = A.new_empty(S, N)
|
| 227 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 228 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 229 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 230 |
+
# both accuracy and performance for small M (decode).
|
| 231 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 232 |
+
grid = (S, triton.cdiv(N, block_n))
|
| 233 |
+
with device_context(A.device):
|
| 234 |
+
wrap_triton(w8a8_block_fp8_matmul_batched_kernel)[grid](
|
| 235 |
+
A,
|
| 236 |
+
B,
|
| 237 |
+
C,
|
| 238 |
+
Bs,
|
| 239 |
+
expert_ids,
|
| 240 |
+
S,
|
| 241 |
+
N,
|
| 242 |
+
K,
|
| 243 |
+
A.stride(0),
|
| 244 |
+
A.stride(1),
|
| 245 |
+
B.stride(0),
|
| 246 |
+
B.stride(2),
|
| 247 |
+
B.stride(1),
|
| 248 |
+
C.stride(0),
|
| 249 |
+
C.stride(1),
|
| 250 |
+
Bs.stride(0),
|
| 251 |
+
Bs.stride(2),
|
| 252 |
+
Bs.stride(1),
|
| 253 |
+
BLOCK_SIZE_N=block_n,
|
| 254 |
+
BLOCK_SIZE_K=block_k,
|
| 255 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
return C
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_batched", mutates_args=())
|
| 262 |
+
def _w8a8_tensor_fp8_matmul_batched(
|
| 263 |
+
A: torch.Tensor,
|
| 264 |
+
B: torch.Tensor,
|
| 265 |
+
Bs: torch.Tensor,
|
| 266 |
+
expert_ids: torch.Tensor,
|
| 267 |
+
) -> torch.Tensor:
|
| 268 |
+
"""Tensor-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
|
| 269 |
+
|
| 270 |
+
A: (S, K) raw bf16/fp16 activations
|
| 271 |
+
B: (E, N, K) FP8 expert weights
|
| 272 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 273 |
+
"""
|
| 274 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 275 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 276 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 277 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 278 |
+
assert A.shape[1] == B.shape[2], (
|
| 279 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
S, K = A.shape
|
| 283 |
+
E, N, _ = B.shape
|
| 284 |
+
|
| 285 |
+
# Normalize Bs to (E, 1, 1)
|
| 286 |
+
if Bs.ndim == 1:
|
| 287 |
+
assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
|
| 288 |
+
Bs = Bs.reshape(E, 1, 1)
|
| 289 |
+
else:
|
| 290 |
+
assert Bs.shape == (E, 1, 1), (
|
| 291 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
BLOCK_SIZE_N = 128
|
| 295 |
+
BLOCK_SIZE_K = 128
|
| 296 |
+
C = A.new_empty(S, N)
|
| 297 |
+
qA, As = fp8_act_quant(A, K)
|
| 298 |
+
grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
|
| 299 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 300 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 301 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 302 |
+
# both accuracy and performance for small M (decode).
|
| 303 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 304 |
+
grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
|
| 305 |
+
with device_context(A.device):
|
| 306 |
+
wrap_triton(w8a8_tensor_fp8_matmul_batched_kernel)[grid](
|
| 307 |
+
qA,
|
| 308 |
+
B,
|
| 309 |
+
C,
|
| 310 |
+
As,
|
| 311 |
+
Bs,
|
| 312 |
+
expert_ids,
|
| 313 |
+
S,
|
| 314 |
+
N,
|
| 315 |
+
K,
|
| 316 |
+
qA.stride(0),
|
| 317 |
+
qA.stride(1),
|
| 318 |
+
B.stride(0),
|
| 319 |
+
B.stride(2),
|
| 320 |
+
B.stride(1),
|
| 321 |
+
C.stride(0),
|
| 322 |
+
C.stride(1),
|
| 323 |
+
As.stride(0),
|
| 324 |
+
Bs.stride(0),
|
| 325 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 326 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 327 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return C
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def w8a8_block_fp8_matmul_batched(
|
| 334 |
+
A: torch.Tensor,
|
| 335 |
+
B: torch.Tensor,
|
| 336 |
+
Bs: torch.Tensor,
|
| 337 |
+
expert_ids: torch.Tensor,
|
| 338 |
+
block_size: list[int],
|
| 339 |
+
) -> torch.Tensor:
|
| 340 |
+
"""Block-scale batched FP8 matmul with fused activation quantization.
|
| 341 |
+
|
| 342 |
+
A: (S, K) raw activations, bf16/fp16/fp32
|
| 343 |
+
B: (E, N, K) FP8 expert weights
|
| 344 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 345 |
+
"""
|
| 346 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_batched(
|
| 347 |
+
A, B, Bs, expert_ids, block_size
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def w8a8_tensor_fp8_matmul_batched(
|
| 352 |
+
A: torch.Tensor,
|
| 353 |
+
B: torch.Tensor,
|
| 354 |
+
Bs: torch.Tensor,
|
| 355 |
+
expert_ids: torch.Tensor,
|
| 356 |
+
) -> torch.Tensor:
|
| 357 |
+
"""Tensor-scale batched FP8 matmul with fused activation quantization.
|
| 358 |
+
|
| 359 |
+
A: (S, K) raw activations, bf16/fp16/fp32
|
| 360 |
+
B: (E, N, K) FP8 expert weights
|
| 361 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 362 |
+
"""
|
| 363 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_batched(
|
| 364 |
+
A, B, Bs, expert_ids
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def w8a8_fp8_matmul_batched(
|
| 369 |
+
A: torch.Tensor,
|
| 370 |
+
B: torch.Tensor,
|
| 371 |
+
Bs: torch.Tensor,
|
| 372 |
+
expert_ids: torch.Tensor,
|
| 373 |
+
block_size: list[int] | None,
|
| 374 |
+
) -> torch.Tensor:
|
| 375 |
+
"""Unified batched W8A8 FP8 matmul dispatcher.
|
| 376 |
+
|
| 377 |
+
Dispatch rules:
|
| 378 |
+
- tensor mode when ``block_size is None``
|
| 379 |
+
- tensor mode when ``block_size == [N, K]``
|
| 380 |
+
- otherwise block mode
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
Output tensor ``[S, N]`` in the same dtype as ``A``.
|
| 384 |
+
"""
|
| 385 |
+
if block_size is None or (
|
| 386 |
+
block_size[0] == B.size(1) and block_size[1] == B.size(2)
|
| 387 |
+
):
|
| 388 |
+
return w8a8_tensor_fp8_matmul_batched(A, B, Bs, expert_ids)
|
| 389 |
+
|
| 390 |
+
return w8a8_block_fp8_matmul_batched(A, B, Bs, expert_ids, block_size)
|
build/torch-xpu/finegrained_fp8/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import importlib.util
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from types import ModuleType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-xpu/grouped.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from .act_quant import fp8_act_quant
|
| 19 |
+
from torch.library import triton_op, wrap_triton
|
| 20 |
+
|
| 21 |
+
from .utils import device_context
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4, 5]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_grouped_kernel(
|
| 34 |
+
A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
|
| 35 |
+
B, # (E, N, K) FP8 weight matrices
|
| 36 |
+
C, # (S, N) output
|
| 37 |
+
Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
|
| 38 |
+
Offsets, # (E,) int32 — cumulative row-end per expert
|
| 39 |
+
TileOffsets, # (E,) int32 — cumulative tile-end per expert
|
| 40 |
+
# Shape
|
| 41 |
+
S,
|
| 42 |
+
N,
|
| 43 |
+
K,
|
| 44 |
+
# Strides
|
| 45 |
+
stride_am,
|
| 46 |
+
stride_ak,
|
| 47 |
+
stride_be,
|
| 48 |
+
stride_bk,
|
| 49 |
+
stride_bn,
|
| 50 |
+
stride_cm,
|
| 51 |
+
stride_cn,
|
| 52 |
+
stride_bs_e,
|
| 53 |
+
stride_bs_k,
|
| 54 |
+
stride_bs_n,
|
| 55 |
+
# Meta-parameters
|
| 56 |
+
NUM_EXPERTS: tl.constexpr,
|
| 57 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 58 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 59 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 60 |
+
NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
|
| 61 |
+
):
|
| 62 |
+
"""Block-scale grouped FP8 expert matmul kernel.
|
| 63 |
+
|
| 64 |
+
Tokens are assumed sorted by expert. The kernel maps each M-tile to its
|
| 65 |
+
owning expert via ``TileOffsets`` and applies fused activation quantization.
|
| 66 |
+
"""
|
| 67 |
+
pid_m = tl.program_id(axis=0)
|
| 68 |
+
pid_n = tl.program_id(axis=1)
|
| 69 |
+
|
| 70 |
+
# Exit early for programs beyond the actual tile count.
|
| 71 |
+
total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
|
| 72 |
+
if pid_m >= total_tiles:
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
# Binary search in TileOffsets to find the owning expert.
|
| 76 |
+
# Finds the smallest e such that TileOffsets[e] > pid_m (upper_bound semantics),
|
| 77 |
+
# which is the expert whose tile range contains pid_m.
|
| 78 |
+
# O(log2(NUM_EXPERTS)) loads instead of the O(NUM_EXPERTS) linear scan.
|
| 79 |
+
# NUM_EXPERTS_BIT_LENGTH is ceil(log2(E))+1 for powers-of-two, giving one
|
| 80 |
+
# harmless extra iteration when lo==hi; it's a compile-time constant so the
|
| 81 |
+
# loop is fully unrolled by the compiler.
|
| 82 |
+
lo = 0
|
| 83 |
+
hi = NUM_EXPERTS
|
| 84 |
+
for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
|
| 85 |
+
mid = (lo + hi) >> 1
|
| 86 |
+
mid_val = tl.load(TileOffsets + mid)
|
| 87 |
+
is_left = mid_val <= pid_m
|
| 88 |
+
lo = tl.where(is_left, mid + 1, lo)
|
| 89 |
+
hi = tl.where(is_left, hi, mid)
|
| 90 |
+
|
| 91 |
+
# Cast expert_id to int64 to prevent int32 overflow when computing
|
| 92 |
+
# expert_id * stride_be (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
|
| 93 |
+
# 3072×3072 FP8 weights).
|
| 94 |
+
expert_id = lo.to(tl.int64)
|
| 95 |
+
|
| 96 |
+
prev_eid = tl.maximum(expert_id - 1, 0)
|
| 97 |
+
expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
|
| 98 |
+
expert_end = tl.load(Offsets + expert_id)
|
| 99 |
+
M_expert = expert_end - expert_start
|
| 100 |
+
|
| 101 |
+
expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
|
| 102 |
+
local_tile = pid_m - expert_tile_start
|
| 103 |
+
m_off = local_tile * BLOCK_SIZE_M
|
| 104 |
+
|
| 105 |
+
offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
|
| 106 |
+
row_mask = offs_am < M_expert
|
| 107 |
+
offs_global_m = expert_start + offs_am
|
| 108 |
+
|
| 109 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 110 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 111 |
+
|
| 112 |
+
a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 113 |
+
b_ptrs = (
|
| 114 |
+
B
|
| 115 |
+
+ expert_id * stride_be
|
| 116 |
+
+ offs_k[:, None] * stride_bk
|
| 117 |
+
+ offs_bn[None, :] * stride_bn
|
| 118 |
+
)
|
| 119 |
+
bs_ptrs = Bs + expert_id * stride_bs_e + pid_n * stride_bs_n
|
| 120 |
+
|
| 121 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 122 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 123 |
+
# ---- fused fp8_act_quant ----
|
| 124 |
+
a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32)
|
| 125 |
+
a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0
|
| 126 |
+
a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
|
| 127 |
+
# ---- matmul ----
|
| 128 |
+
b = tl.load(b_ptrs)
|
| 129 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 130 |
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
| 131 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 132 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 133 |
+
|
| 134 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 135 |
+
c = accumulator.to(tl.bfloat16)
|
| 136 |
+
elif C.dtype.element_ty == tl.float16:
|
| 137 |
+
c = accumulator.to(tl.float16)
|
| 138 |
+
else:
|
| 139 |
+
c = accumulator.to(tl.float32)
|
| 140 |
+
|
| 141 |
+
c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
|
| 142 |
+
c_mask = row_mask[:, None]
|
| 143 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@triton.autotune(
|
| 147 |
+
configs=[
|
| 148 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 149 |
+
for w in [2, 4, 8, 16]
|
| 150 |
+
for s in [2, 3, 4, 5]
|
| 151 |
+
],
|
| 152 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 153 |
+
)
|
| 154 |
+
@triton.jit
|
| 155 |
+
def w8a8_tensor_fp8_matmul_grouped_kernel(
|
| 156 |
+
A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert idc
|
| 157 |
+
B, # (E, N, K) FP8 weight matrices
|
| 158 |
+
C, # (S, N) output
|
| 159 |
+
As, # (S, 1) activation scales
|
| 160 |
+
Bs, # (E, 1, 1) per-tensor weight scales
|
| 161 |
+
Offsets,
|
| 162 |
+
TileOffsets,
|
| 163 |
+
S,
|
| 164 |
+
N,
|
| 165 |
+
K,
|
| 166 |
+
stride_am,
|
| 167 |
+
stride_ak,
|
| 168 |
+
stride_be,
|
| 169 |
+
stride_bk,
|
| 170 |
+
stride_bn,
|
| 171 |
+
stride_cm,
|
| 172 |
+
stride_cn,
|
| 173 |
+
stride_as_m,
|
| 174 |
+
stride_bs_e,
|
| 175 |
+
NUM_EXPERTS: tl.constexpr,
|
| 176 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 177 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 178 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 179 |
+
NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
|
| 180 |
+
):
|
| 181 |
+
"""Tensor-scale grouped FP8 expert matmul kernel.
|
| 182 |
+
|
| 183 |
+
Uses grouped expert scheduling with pre-quantized activations plus
|
| 184 |
+
per-token activation scales and per-expert tensor weight scales.
|
| 185 |
+
"""
|
| 186 |
+
pid_m = tl.program_id(axis=0)
|
| 187 |
+
pid_n = tl.program_id(axis=1)
|
| 188 |
+
|
| 189 |
+
total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
|
| 190 |
+
if pid_m >= total_tiles:
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
lo = 0
|
| 194 |
+
hi = NUM_EXPERTS
|
| 195 |
+
for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
|
| 196 |
+
mid = (lo + hi) >> 1
|
| 197 |
+
mid_val = tl.load(TileOffsets + mid)
|
| 198 |
+
is_left = mid_val <= pid_m
|
| 199 |
+
lo = tl.where(is_left, mid + 1, lo)
|
| 200 |
+
hi = tl.where(is_left, hi, mid)
|
| 201 |
+
expert_id = lo.to(tl.int64)
|
| 202 |
+
|
| 203 |
+
prev_eid = tl.maximum(expert_id - 1, 0)
|
| 204 |
+
expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
|
| 205 |
+
expert_end = tl.load(Offsets + expert_id)
|
| 206 |
+
M_expert = expert_end - expert_start
|
| 207 |
+
|
| 208 |
+
expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
|
| 209 |
+
local_tile = pid_m - expert_tile_start
|
| 210 |
+
m_off = local_tile * BLOCK_SIZE_M
|
| 211 |
+
|
| 212 |
+
offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
|
| 213 |
+
row_mask = offs_am < M_expert
|
| 214 |
+
offs_global_m = expert_start + offs_am
|
| 215 |
+
|
| 216 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 217 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 218 |
+
|
| 219 |
+
a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 220 |
+
b_ptrs = (
|
| 221 |
+
B
|
| 222 |
+
+ expert_id * stride_be
|
| 223 |
+
+ offs_k[:, None] * stride_bk
|
| 224 |
+
+ offs_bn[None, :] * stride_bn
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
a_s = tl.load(As + offs_global_m * stride_as_m, mask=row_mask, other=0.0)
|
| 228 |
+
b_s = tl.load(Bs + expert_id * stride_bs_e)
|
| 229 |
+
|
| 230 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 231 |
+
for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 232 |
+
a = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0)
|
| 233 |
+
b = tl.load(b_ptrs)
|
| 234 |
+
|
| 235 |
+
accumulator += tl.dot(a, b)
|
| 236 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 237 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 238 |
+
|
| 239 |
+
accumulator = accumulator * a_s[:, None] * b_s
|
| 240 |
+
|
| 241 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 242 |
+
c = accumulator.to(tl.bfloat16)
|
| 243 |
+
elif C.dtype.element_ty == tl.float16:
|
| 244 |
+
c = accumulator.to(tl.float16)
|
| 245 |
+
else:
|
| 246 |
+
c = accumulator.to(tl.float32)
|
| 247 |
+
|
| 248 |
+
c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
|
| 249 |
+
c_mask = row_mask[:, None]
|
| 250 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul_grouped", mutates_args=())
|
| 254 |
+
def _w8a8_block_fp8_matmul_grouped(
|
| 255 |
+
A: torch.Tensor,
|
| 256 |
+
B: torch.Tensor,
|
| 257 |
+
Bs: torch.Tensor,
|
| 258 |
+
offsets: torch.Tensor,
|
| 259 |
+
tokens_per_expert: torch.Tensor,
|
| 260 |
+
block_size: list[int],
|
| 261 |
+
) -> torch.Tensor:
|
| 262 |
+
"""Block-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
|
| 263 |
+
|
| 264 |
+
A: (S, K) raw bf16/fp16 activations, sorted by expert
|
| 265 |
+
B: (E, N, K) FP8 expert weights
|
| 266 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 267 |
+
"""
|
| 268 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 269 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 270 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 271 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 272 |
+
assert A.shape[1] == B.shape[2], (
|
| 273 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
S, K = A.shape
|
| 277 |
+
E, N, _ = B.shape
|
| 278 |
+
|
| 279 |
+
assert len(block_size) == 2, (
|
| 280 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 281 |
+
)
|
| 282 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 283 |
+
# MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
|
| 284 |
+
assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
|
| 285 |
+
assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
|
| 286 |
+
assert Bs.ndim == 3, (
|
| 287 |
+
f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 288 |
+
)
|
| 289 |
+
assert Bs.shape == (E, N // block_n, K // block_k), (
|
| 290 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
C = A.new_empty(S, N)
|
| 294 |
+
# Adaptive BLOCK_SIZE_M: match tile to average tokens per expert.
|
| 295 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 296 |
+
tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
| 297 |
+
tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
|
| 298 |
+
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
|
| 299 |
+
# Programs beyond the real tile count exit immediately via the early-return
|
| 300 |
+
# guard inside the kernel. This is faster than syncing for the exact count
|
| 301 |
+
# and keeps the grid size data-independent (cuda-graph / torch.compile safe).
|
| 302 |
+
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
|
| 303 |
+
grid = (max_M_tiles, triton.cdiv(N, block_n))
|
| 304 |
+
with device_context(A.device):
|
| 305 |
+
wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
|
| 306 |
+
A,
|
| 307 |
+
B,
|
| 308 |
+
C,
|
| 309 |
+
Bs,
|
| 310 |
+
offsets,
|
| 311 |
+
tile_offsets,
|
| 312 |
+
S,
|
| 313 |
+
N,
|
| 314 |
+
K,
|
| 315 |
+
A.stride(0),
|
| 316 |
+
A.stride(1),
|
| 317 |
+
B.stride(0),
|
| 318 |
+
B.stride(2),
|
| 319 |
+
B.stride(1),
|
| 320 |
+
C.stride(0),
|
| 321 |
+
C.stride(1),
|
| 322 |
+
Bs.stride(0),
|
| 323 |
+
Bs.stride(2),
|
| 324 |
+
Bs.stride(1),
|
| 325 |
+
# Meta-parameters
|
| 326 |
+
NUM_EXPERTS=E,
|
| 327 |
+
BLOCK_SIZE_N=block_n,
|
| 328 |
+
BLOCK_SIZE_K=block_k,
|
| 329 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 330 |
+
NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return C
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_grouped", mutates_args=())
|
| 337 |
+
def _w8a8_tensor_fp8_matmul_grouped(
|
| 338 |
+
A: torch.Tensor,
|
| 339 |
+
B: torch.Tensor,
|
| 340 |
+
Bs: torch.Tensor,
|
| 341 |
+
offsets: torch.Tensor,
|
| 342 |
+
tokens_per_expert: torch.Tensor,
|
| 343 |
+
) -> torch.Tensor:
|
| 344 |
+
"""Tensor-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
|
| 345 |
+
|
| 346 |
+
A: (S, K) raw bf16/fp16 activations, sorted by expert
|
| 347 |
+
B: (E, N, K) FP8 expert weights
|
| 348 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 349 |
+
"""
|
| 350 |
+
assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
|
| 351 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 352 |
+
assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
|
| 353 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 354 |
+
assert A.shape[1] == B.shape[2], (
|
| 355 |
+
f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
S, K = A.shape
|
| 359 |
+
E, N, _ = B.shape
|
| 360 |
+
|
| 361 |
+
# Normalize Bs to (E, 1, 1)
|
| 362 |
+
if Bs.ndim == 1:
|
| 363 |
+
assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
|
| 364 |
+
Bs = Bs.reshape(E, 1, 1)
|
| 365 |
+
else:
|
| 366 |
+
assert Bs.shape == (E, 1, 1), (
|
| 367 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
BLOCK_SIZE_N = 128
|
| 371 |
+
BLOCK_SIZE_K = 128
|
| 372 |
+
C = A.new_empty(S, N)
|
| 373 |
+
qA, As = fp8_act_quant(A, K)
|
| 374 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
|
| 375 |
+
tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
| 376 |
+
tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
|
| 377 |
+
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
|
| 378 |
+
# Programs beyond the real tile count exit immediately via the early-return
|
| 379 |
+
# guard inside the kernel. This is faster than syncing for the exact count
|
| 380 |
+
# and keeps the grid size data-independent (cuda-graph / torch.compile safe).
|
| 381 |
+
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
|
| 382 |
+
grid = (max_M_tiles, triton.cdiv(N, BLOCK_SIZE_N))
|
| 383 |
+
with device_context(A.device):
|
| 384 |
+
wrap_triton(w8a8_tensor_fp8_matmul_grouped_kernel)[grid](
|
| 385 |
+
qA,
|
| 386 |
+
B,
|
| 387 |
+
C,
|
| 388 |
+
As,
|
| 389 |
+
Bs,
|
| 390 |
+
offsets,
|
| 391 |
+
tile_offsets,
|
| 392 |
+
S,
|
| 393 |
+
N,
|
| 394 |
+
K,
|
| 395 |
+
qA.stride(0),
|
| 396 |
+
qA.stride(1),
|
| 397 |
+
B.stride(0),
|
| 398 |
+
B.stride(2),
|
| 399 |
+
B.stride(1),
|
| 400 |
+
C.stride(0),
|
| 401 |
+
C.stride(1),
|
| 402 |
+
As.stride(0),
|
| 403 |
+
Bs.stride(0),
|
| 404 |
+
# Meta-parameters
|
| 405 |
+
NUM_EXPERTS=E,
|
| 406 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 407 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 408 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 409 |
+
NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
return C
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def w8a8_block_fp8_matmul_grouped(
|
| 416 |
+
A: torch.Tensor,
|
| 417 |
+
B: torch.Tensor,
|
| 418 |
+
Bs: torch.Tensor,
|
| 419 |
+
offsets: torch.Tensor,
|
| 420 |
+
tokens_per_expert: torch.Tensor,
|
| 421 |
+
block_size: list[int],
|
| 422 |
+
) -> torch.Tensor:
|
| 423 |
+
"""Block-scale grouped FP8 matmul with fused activation quantization.
|
| 424 |
+
|
| 425 |
+
A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
|
| 426 |
+
B: (E, N, K) FP8 expert weights
|
| 427 |
+
Bs: (E, N // block_n, K // block_k) per-block weight scales
|
| 428 |
+
"""
|
| 429 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
|
| 430 |
+
A, B, Bs, offsets, tokens_per_expert, block_size
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def w8a8_tensor_fp8_matmul_grouped(
|
| 435 |
+
A: torch.Tensor,
|
| 436 |
+
B: torch.Tensor,
|
| 437 |
+
Bs: torch.Tensor,
|
| 438 |
+
offsets: torch.Tensor,
|
| 439 |
+
tokens_per_expert: torch.Tensor,
|
| 440 |
+
) -> torch.Tensor:
|
| 441 |
+
"""Tensor-scale grouped FP8 matmul with fused activation quantization.
|
| 442 |
+
|
| 443 |
+
A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
|
| 444 |
+
B: (E, N, K) FP8 expert weights
|
| 445 |
+
Bs: (E,) or (E, 1, 1) per-expert weight scales
|
| 446 |
+
"""
|
| 447 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
|
| 448 |
+
A, B, Bs, offsets, tokens_per_expert
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def w8a8_fp8_matmul_grouped(
|
| 453 |
+
A: torch.Tensor,
|
| 454 |
+
B: torch.Tensor,
|
| 455 |
+
Bs: torch.Tensor,
|
| 456 |
+
offsets: torch.Tensor,
|
| 457 |
+
tokens_per_expert: torch.Tensor,
|
| 458 |
+
block_size: list[int] | None,
|
| 459 |
+
) -> torch.Tensor:
|
| 460 |
+
"""Unified grouped W8A8 FP8 matmul dispatcher.
|
| 461 |
+
|
| 462 |
+
Dispatch rules:
|
| 463 |
+
- tensor mode when ``block_size is None``
|
| 464 |
+
- tensor mode when ``block_size == [N, K]``
|
| 465 |
+
- otherwise block mode
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
|
| 469 |
+
"""
|
| 470 |
+
if block_size is None or (
|
| 471 |
+
block_size[0] == B.size(1) and block_size[1] == B.size(2)
|
| 472 |
+
):
|
| 473 |
+
return w8a8_tensor_fp8_matmul_grouped(A, B, Bs, offsets, tokens_per_expert)
|
| 474 |
+
|
| 475 |
+
return w8a8_block_fp8_matmul_grouped(
|
| 476 |
+
A, B, Bs, offsets, tokens_per_expert, block_size
|
| 477 |
+
)
|
build/torch-xpu/matmul.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch.library import triton_op, wrap_triton
|
| 19 |
+
|
| 20 |
+
from .utils import device_context
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 27 |
+
for w in [2, 4, 8, 16]
|
| 28 |
+
for s in [2, 3, 4]
|
| 29 |
+
],
|
| 30 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def w8a8_block_fp8_matmul_kernel(
|
| 34 |
+
# Pointers to inputs and output
|
| 35 |
+
A,
|
| 36 |
+
B,
|
| 37 |
+
C,
|
| 38 |
+
As,
|
| 39 |
+
Bs,
|
| 40 |
+
# Shape for matmul
|
| 41 |
+
M,
|
| 42 |
+
N,
|
| 43 |
+
K,
|
| 44 |
+
stride_am,
|
| 45 |
+
stride_ak,
|
| 46 |
+
stride_bk,
|
| 47 |
+
stride_bn,
|
| 48 |
+
stride_cm,
|
| 49 |
+
stride_cn,
|
| 50 |
+
stride_as_m,
|
| 51 |
+
stride_as_k,
|
| 52 |
+
stride_bs_k,
|
| 53 |
+
stride_bs_n,
|
| 54 |
+
# Meta-parameters
|
| 55 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 56 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 57 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 58 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 59 |
+
):
|
| 60 |
+
"""Block-scale FP8 GEMM kernel.
|
| 61 |
+
|
| 62 |
+
Computes ``C = A @ B.T`` with block-wise activation/weight scales.
|
| 63 |
+
Uses a 2D grid with swizzle for L2 cache locality on B tiles.
|
| 64 |
+
"""
|
| 65 |
+
pid_m = tl.program_id(axis=0)
|
| 66 |
+
pid_n = tl.program_id(axis=1)
|
| 67 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 68 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 69 |
+
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
| 70 |
+
|
| 71 |
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
| 72 |
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 73 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 74 |
+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
| 75 |
+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 76 |
+
|
| 77 |
+
as_ptrs = As + offs_am * stride_as_m
|
| 78 |
+
offs_bsn = offs_bn // BLOCK_SIZE_N
|
| 79 |
+
bs_ptrs = Bs + offs_bsn * stride_bs_n
|
| 80 |
+
|
| 81 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 82 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 83 |
+
k_remaining = K - k * BLOCK_SIZE_K
|
| 84 |
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
|
| 85 |
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
| 86 |
+
|
| 87 |
+
a_s = tl.load(as_ptrs + k * stride_as_k)
|
| 88 |
+
b_s = tl.load(bs_ptrs + k * stride_bs_k)
|
| 89 |
+
|
| 90 |
+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
| 91 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 92 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 93 |
+
|
| 94 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 95 |
+
c = accumulator.to(tl.bfloat16)
|
| 96 |
+
elif C.dtype.element_ty == tl.float16:
|
| 97 |
+
c = accumulator.to(tl.float16)
|
| 98 |
+
else:
|
| 99 |
+
c = accumulator.to(tl.float32)
|
| 100 |
+
|
| 101 |
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 102 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 103 |
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
| 104 |
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
| 105 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@triton.autotune(
|
| 109 |
+
configs=[
|
| 110 |
+
triton.Config({}, num_warps=w, num_stages=s)
|
| 111 |
+
for w in [2, 4, 8, 16]
|
| 112 |
+
for s in [2, 3, 4]
|
| 113 |
+
],
|
| 114 |
+
key=["N", "K", "BLOCK_SIZE_M"],
|
| 115 |
+
)
|
| 116 |
+
@triton.jit
|
| 117 |
+
def w8a8_tensor_fp8_matmul_kernel(
|
| 118 |
+
A,
|
| 119 |
+
B,
|
| 120 |
+
C,
|
| 121 |
+
As,
|
| 122 |
+
Bs,
|
| 123 |
+
M,
|
| 124 |
+
N,
|
| 125 |
+
K,
|
| 126 |
+
stride_am,
|
| 127 |
+
stride_ak,
|
| 128 |
+
stride_bk,
|
| 129 |
+
stride_bn,
|
| 130 |
+
stride_cm,
|
| 131 |
+
stride_cn,
|
| 132 |
+
stride_as_m,
|
| 133 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 134 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 135 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 136 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 137 |
+
):
|
| 138 |
+
"""Tensor-scale FP8 GEMM kernel.
|
| 139 |
+
|
| 140 |
+
Computes ``C = A @ B.T`` with one activation scale per row and one
|
| 141 |
+
weight scale for the full matrix.
|
| 142 |
+
Uses a 2D grid with swizzle for L2 cache locality on B tiles.
|
| 143 |
+
"""
|
| 144 |
+
pid_m = tl.program_id(axis=0)
|
| 145 |
+
pid_n = tl.program_id(axis=1)
|
| 146 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 147 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 148 |
+
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
| 149 |
+
|
| 150 |
+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
| 151 |
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 152 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 153 |
+
|
| 154 |
+
a_ptrs = A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 155 |
+
b_ptrs = B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
| 156 |
+
|
| 157 |
+
a_s = tl.load(As + offs_am * stride_as_m)
|
| 158 |
+
b_s = tl.load(Bs)
|
| 159 |
+
|
| 160 |
+
# Accumulate raw dot products, apply scales once after the loop.
|
| 161 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 162 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 163 |
+
k_remaining = K - k * BLOCK_SIZE_K
|
| 164 |
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
|
| 165 |
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
| 166 |
+
accumulator += tl.dot(a, b)
|
| 167 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 168 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 169 |
+
|
| 170 |
+
accumulator = accumulator * a_s[:, None] * b_s
|
| 171 |
+
|
| 172 |
+
if C.dtype.element_ty == tl.bfloat16:
|
| 173 |
+
c = accumulator.to(tl.bfloat16)
|
| 174 |
+
elif C.dtype.element_ty == tl.float16:
|
| 175 |
+
c = accumulator.to(tl.float16)
|
| 176 |
+
else:
|
| 177 |
+
c = accumulator.to(tl.float32)
|
| 178 |
+
|
| 179 |
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 180 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 181 |
+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
| 182 |
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
| 183 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@triton_op("finegrained_fp8::w8a8_block_fp8_matmul", mutates_args=())
|
| 187 |
+
def _w8a8_block_fp8_matmul(
|
| 188 |
+
A: torch.Tensor,
|
| 189 |
+
B: torch.Tensor,
|
| 190 |
+
As: torch.Tensor,
|
| 191 |
+
Bs: torch.Tensor,
|
| 192 |
+
block_size: list[int],
|
| 193 |
+
output_dtype: torch.dtype = torch.float32,
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""Block-scale FP8 matmul: C = A @ B.T with per-block scales.
|
| 196 |
+
|
| 197 |
+
As: (M, K // block_k) — per-token-group activation scales
|
| 198 |
+
Bs: (N // block_n, K // block_k) — per-block weight scales
|
| 199 |
+
"""
|
| 200 |
+
assert len(block_size) == 2, (
|
| 201 |
+
f"block_size must be [block_n, block_k], got {block_size}"
|
| 202 |
+
)
|
| 203 |
+
block_n, block_k = block_size[0], block_size[1]
|
| 204 |
+
|
| 205 |
+
assert A.shape[-1] == B.shape[-1], (
|
| 206 |
+
f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
|
| 207 |
+
)
|
| 208 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 209 |
+
assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
|
| 210 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 211 |
+
|
| 212 |
+
N, K = B.shape
|
| 213 |
+
M = A.numel() // A.shape[-1]
|
| 214 |
+
|
| 215 |
+
assert As.ndim >= 2, f"As must be at least 2D, got ndim={As.ndim}"
|
| 216 |
+
assert As.shape[-1] == triton.cdiv(K, block_k), (
|
| 217 |
+
f"As last dim {As.shape[-1]} != expected {triton.cdiv(K, block_k)} (cdiv(K={K}, block_k={block_k}))"
|
| 218 |
+
)
|
| 219 |
+
assert Bs.ndim == 2, f"Bs must be 2D (N//block_n, K//block_k), got ndim={Bs.ndim}"
|
| 220 |
+
assert Bs.shape == (triton.cdiv(N, block_n), triton.cdiv(K, block_k)), (
|
| 221 |
+
f"Bs shape {tuple(Bs.shape)} != expected ({triton.cdiv(N, block_n)}, {triton.cdiv(K, block_k)})"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
BLOCK_SIZE_K = block_k
|
| 225 |
+
BLOCK_SIZE_N = block_n
|
| 226 |
+
C_shape = A.shape[:-1] + (N,)
|
| 227 |
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
| 228 |
+
# Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
|
| 229 |
+
# Matches the WGMMA tile to the actual row count — smaller tiles use less
|
| 230 |
+
# register pressure and a better-matched FP8 WGMMA instruction, improving
|
| 231 |
+
# both accuracy and performance for small M (decode).
|
| 232 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
|
| 233 |
+
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
|
| 234 |
+
with device_context(A.device):
|
| 235 |
+
wrap_triton(w8a8_block_fp8_matmul_kernel)[grid](
|
| 236 |
+
A,
|
| 237 |
+
B,
|
| 238 |
+
C,
|
| 239 |
+
As,
|
| 240 |
+
Bs,
|
| 241 |
+
M,
|
| 242 |
+
N,
|
| 243 |
+
K,
|
| 244 |
+
A.stride(-2),
|
| 245 |
+
A.stride(-1),
|
| 246 |
+
B.stride(1),
|
| 247 |
+
B.stride(0),
|
| 248 |
+
C.stride(-2),
|
| 249 |
+
C.stride(-1),
|
| 250 |
+
As.stride(-2),
|
| 251 |
+
As.stride(-1),
|
| 252 |
+
Bs.stride(1),
|
| 253 |
+
Bs.stride(0),
|
| 254 |
+
# Meta-parameters
|
| 255 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 256 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 257 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 258 |
+
GROUP_SIZE_M=8,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return C
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul", mutates_args=())
|
| 265 |
+
def _w8a8_tensor_fp8_matmul(
|
| 266 |
+
A: torch.Tensor,
|
| 267 |
+
B: torch.Tensor,
|
| 268 |
+
As: torch.Tensor,
|
| 269 |
+
Bs: torch.Tensor,
|
| 270 |
+
output_dtype: torch.dtype = torch.float32,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
"""Tensor-scale FP8 matmul: C = A @ B.T with per-row / per-tensor scales.
|
| 273 |
+
|
| 274 |
+
As: scalar, (M,), or (M, 1) — per-row activation scales
|
| 275 |
+
Bs: scalar, (1,), or (1, 1) — single weight scale
|
| 276 |
+
"""
|
| 277 |
+
assert A.shape[-1] == B.shape[-1], (
|
| 278 |
+
f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
|
| 279 |
+
)
|
| 280 |
+
assert A.is_contiguous(), "A must be contiguous"
|
| 281 |
+
assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
|
| 282 |
+
assert B.is_contiguous(), "B must be contiguous"
|
| 283 |
+
|
| 284 |
+
N, K = B.shape
|
| 285 |
+
M = A.numel() // A.shape[-1]
|
| 286 |
+
|
| 287 |
+
# Normalize As to (M,)
|
| 288 |
+
if As.numel() == 1:
|
| 289 |
+
As = As.reshape(1).expand(M).contiguous()
|
| 290 |
+
elif As.ndim == 2:
|
| 291 |
+
As = As.reshape(M)
|
| 292 |
+
assert As.ndim == 1 and As.shape[0] == M, (
|
| 293 |
+
f"As must be scalar, (M,), or (M,1) with M={M}, got {tuple(As.shape)}"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Normalize Bs to (1,)
|
| 297 |
+
assert Bs.numel() == 1, f"Bs must be scalar or (1,), got {tuple(Bs.shape)}"
|
| 298 |
+
Bs = Bs.reshape(1)
|
| 299 |
+
|
| 300 |
+
BLOCK_SIZE_N = 128
|
| 301 |
+
BLOCK_SIZE_K = 128
|
| 302 |
+
C_shape = A.shape[:-1] + (N,)
|
| 303 |
+
C = A.new_empty(C_shape, dtype=output_dtype)
|
| 304 |
+
BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
|
| 305 |
+
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
|
| 306 |
+
with device_context(A.device):
|
| 307 |
+
wrap_triton(w8a8_tensor_fp8_matmul_kernel)[grid](
|
| 308 |
+
A,
|
| 309 |
+
B,
|
| 310 |
+
C,
|
| 311 |
+
As,
|
| 312 |
+
Bs,
|
| 313 |
+
M,
|
| 314 |
+
N,
|
| 315 |
+
K,
|
| 316 |
+
A.stride(-2),
|
| 317 |
+
A.stride(-1),
|
| 318 |
+
B.stride(1),
|
| 319 |
+
B.stride(0),
|
| 320 |
+
C.stride(-2),
|
| 321 |
+
C.stride(-1),
|
| 322 |
+
As.stride(0),
|
| 323 |
+
# Meta-parameters
|
| 324 |
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
| 325 |
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
| 326 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
| 327 |
+
GROUP_SIZE_M=8,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return C
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def w8a8_block_fp8_matmul(
|
| 334 |
+
A: torch.Tensor,
|
| 335 |
+
B: torch.Tensor,
|
| 336 |
+
As: torch.Tensor,
|
| 337 |
+
Bs: torch.Tensor,
|
| 338 |
+
block_size: list[int],
|
| 339 |
+
output_dtype: torch.dtype = torch.float32,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
"""Block-wise W8A8 FP8 matrix multiplication.
|
| 342 |
+
|
| 343 |
+
Computes ``C = A @ B.T`` where both operands are pre-quantized to
|
| 344 |
+
``float8_e4m3fn`` with per-block scales, and accumulates in float32
|
| 345 |
+
before casting to ``output_dtype``.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
|
| 349 |
+
B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
|
| 350 |
+
As: Per-token-group activation scales ``[M, K // block_size[1]]``.
|
| 351 |
+
Bs: Per-block weight scales ``[N // block_size[0], K // block_size[1]]``.
|
| 352 |
+
block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
|
| 353 |
+
output_dtype: dtype of the returned tensor (default: ``torch.float32``).
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 357 |
+
"""
|
| 358 |
+
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul(
|
| 359 |
+
A, B, As, Bs, block_size, output_dtype
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def w8a8_tensor_fp8_matmul(
|
| 364 |
+
A: torch.Tensor,
|
| 365 |
+
B: torch.Tensor,
|
| 366 |
+
As: torch.Tensor,
|
| 367 |
+
Bs: torch.Tensor,
|
| 368 |
+
output_dtype: torch.dtype = torch.float32,
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
"""Tensor-scale W8A8 FP8 matrix multiplication.
|
| 371 |
+
|
| 372 |
+
Computes ``C = A @ B.T`` in tensor-scale mode using pre-quantized FP8
|
| 373 |
+
activations/weights and tensor scales.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
|
| 377 |
+
B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
|
| 378 |
+
As: Per-row activation scales ``[M]``.
|
| 379 |
+
Bs: Single weight scale, scalar or ``[1]``.
|
| 380 |
+
output_dtype: dtype of the returned tensor.
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 384 |
+
"""
|
| 385 |
+
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def w8a8_fp8_matmul(
|
| 389 |
+
A: torch.Tensor,
|
| 390 |
+
B: torch.Tensor,
|
| 391 |
+
As: torch.Tensor,
|
| 392 |
+
Bs: torch.Tensor,
|
| 393 |
+
block_size: list[int] | None,
|
| 394 |
+
output_dtype: torch.dtype = torch.float32,
|
| 395 |
+
) -> torch.Tensor:
|
| 396 |
+
"""Unified W8A8 FP8 matmul dispatcher.
|
| 397 |
+
|
| 398 |
+
Dispatch rules:
|
| 399 |
+
- tensor mode when ``block_size is None``
|
| 400 |
+
- tensor mode when ``block_size == [N, K]``
|
| 401 |
+
- otherwise block mode
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Output tensor ``[M, N]`` in ``output_dtype``.
|
| 405 |
+
"""
|
| 406 |
+
if block_size is None or (
|
| 407 |
+
block_size[0] == B.size(0) and block_size[1] == B.size(1)
|
| 408 |
+
):
|
| 409 |
+
return w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
|
| 410 |
+
|
| 411 |
+
return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
|
build/torch-xpu/metadata.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 1,
|
| 3 |
+
"license": "Apache-2.0",
|
| 4 |
+
"python-depends": [],
|
| 5 |
+
"backend": {
|
| 6 |
+
"type": "xpu"
|
| 7 |
+
}
|
| 8 |
+
}
|
build/torch-xpu/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@contextmanager
|
| 6 |
+
def device_context(device: torch.device):
|
| 7 |
+
"""Context manager that sets the active device for any backend (cuda, xpu, etc.)."""
|
| 8 |
+
backend = getattr(torch, device.type, None)
|
| 9 |
+
if backend is not None and hasattr(backend, "device"):
|
| 10 |
+
with backend.device(device):
|
| 11 |
+
yield
|
| 12 |
+
else:
|
| 13 |
+
yield
|