drbh commited on
Commit
69b4990
·
unverified ·
0 Parent(s):

Migrated from kernels-community/finegrained-fp8

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
build/torch-cuda/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .act_quant import fp8_act_quant
2
+ from .batched import (
3
+ w8a8_fp8_matmul_batched,
4
+ w8a8_block_fp8_matmul_batched,
5
+ w8a8_tensor_fp8_matmul_batched,
6
+ )
7
+ from .grouped import (
8
+ w8a8_fp8_matmul_grouped,
9
+ w8a8_block_fp8_matmul_grouped,
10
+ w8a8_tensor_fp8_matmul_grouped,
11
+ )
12
+ from .matmul import (
13
+ w8a8_fp8_matmul,
14
+ w8a8_block_fp8_matmul,
15
+ w8a8_tensor_fp8_matmul,
16
+ )
17
+
18
+ __all__ = [
19
+ "fp8_act_quant",
20
+ # Single matmul
21
+ "w8a8_fp8_matmul",
22
+ "w8a8_block_fp8_matmul",
23
+ "w8a8_tensor_fp8_matmul",
24
+ # Batched matmul
25
+ "w8a8_fp8_matmul_batched",
26
+ "w8a8_block_fp8_matmul_batched",
27
+ "w8a8_tensor_fp8_matmul_batched",
28
+ # Grouped matmul
29
+ "w8a8_fp8_matmul_grouped",
30
+ "w8a8_block_fp8_matmul_grouped",
31
+ "w8a8_tensor_fp8_matmul_grouped",
32
+ ]
build/torch-cuda/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._finegrained_fp8_75cbe1b
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_finegrained_fp8_75cbe1b::{op_name}"
build/torch-cuda/act_quant.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from torch.library import triton_op, wrap_triton
19
+
20
+ from .utils import device_context
21
+
22
+
23
+ _FP8_DTYPE = torch.float8_e4m3fn
24
+
25
+
26
+ # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
27
+ @triton.jit
28
+ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
+ pid = tl.program_id(axis=0)
30
+ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
+ x = tl.load(x_ptr + offs).to(tl.float32)
32
+ s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
33
+ y = (x / s).to(y_ptr.dtype.element_ty)
34
+ tl.store(y_ptr + offs, y)
35
+ tl.store(s_ptr + pid, s)
36
+
37
+
38
+ @triton_op("finegrained_fp8::fp8_act_quant", mutates_args=())
39
+ def _fp8_act_quant(
40
+ x: torch.Tensor, block_size: int = 128
41
+ ) -> tuple[torch.Tensor, torch.Tensor]:
42
+ assert x.is_contiguous()
43
+ assert x.shape[-1] % block_size == 0
44
+ y = torch.empty_like(x, dtype=_FP8_DTYPE)
45
+ grid = (triton.cdiv(x.numel(), block_size),)
46
+ s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
47
+
48
+ with device_context(x.device):
49
+ wrap_triton(_fp8_act_quant_kernel)[grid](x, y, s, BLOCK_SIZE=block_size)
50
+
51
+ return y, s
52
+
53
+
54
+ def fp8_act_quant(
55
+ x: torch.Tensor, block_size: int = 128
56
+ ) -> tuple[torch.Tensor, torch.Tensor]:
57
+ """Quantize activations to FP8 with per-block dynamic scaling.
58
+
59
+ Splits the last dimension of ``x`` into blocks of ``block_size`` elements,
60
+ computes ``scale = max(|x_block|) / 448`` per block, and quantizes to
61
+ ``float8_e4m3fn``.
62
+
63
+ Args:
64
+ x: Input tensor in bf16/fp16/fp32. Last dimension must be divisible by
65
+ ``block_size`` and the tensor must be contiguous.
66
+ block_size: Number of elements per quantization block (default: 128).
67
+
68
+ Returns:
69
+ A tuple ``(quantized, scales)`` where ``quantized`` has dtype
70
+ ``float8_e4m3fn`` with the same shape as ``x``, and ``scales`` has
71
+ shape ``(*x.shape[:-1], x.shape[-1] // block_size)`` in float32.
72
+ """
73
+ return torch.ops.finegrained_fp8.fp8_act_quant(x, block_size)
build/torch-cuda/batched.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from .act_quant import fp8_act_quant
19
+ from torch.library import triton_op, wrap_triton
20
+
21
+ from .utils import device_context
22
+
23
+
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4, 5]
29
+ ],
30
+ key=["N", "K"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_batched_kernel(
34
+ A, # (S, K) raw BF16/FP16 activations
35
+ B, # (E, N, K) FP8 weight matrices
36
+ C, # (S, N) output
37
+ Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
38
+ ExpertIds, # (S,) — which expert each batch element routes to
39
+ # Shape
40
+ S,
41
+ N,
42
+ K,
43
+ stride_am,
44
+ stride_ak,
45
+ stride_be,
46
+ stride_bk,
47
+ stride_bn,
48
+ stride_cm,
49
+ stride_cn,
50
+ stride_bs_e,
51
+ stride_bs_k,
52
+ stride_bs_n,
53
+ # Meta-parameters
54
+ BLOCK_SIZE_N: tl.constexpr,
55
+ BLOCK_SIZE_K: tl.constexpr,
56
+ BLOCK_SIZE_M: tl.constexpr,
57
+ ):
58
+ """Block-scale batched FP8 expert matmul kernel.
59
+
60
+ Each program handles one routed token row and one N-tile, looks up the
61
+ owning expert from ``ExpertIds``, and applies fused activation quantization.
62
+ """
63
+ batch_id = tl.program_id(axis=0)
64
+ pid_n = tl.program_id(axis=1)
65
+
66
+ # Cast expert_id to int64 to prevent int32 overflow when computing
67
+ # expert_id * stride_Eb (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
68
+ # 3072×3072 FP8 weights).
69
+ expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
70
+
71
+ A = A + batch_id * stride_am
72
+ B = B + expert_id * stride_be
73
+ C = C + batch_id * stride_cm
74
+ Bs = Bs + expert_id * stride_bs_e
75
+
76
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
77
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
78
+ a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
79
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
80
+
81
+ bs_ptrs = Bs + pid_n * stride_bs_n
82
+
83
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
84
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
85
+ # ---- fused fp8_act_quant ----
86
+ a_raw = tl.load(a_ptrs).to(tl.float32)
87
+ a_s = tl.max(tl.abs(a_raw)) / 448.0
88
+ a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv)
89
+ # ---- matmul ----
90
+ b = tl.load(b_ptrs)
91
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
92
+ accumulator += tl.dot(a, b) * a_s * b_s[None, :]
93
+ a_ptrs += BLOCK_SIZE_K * stride_ak
94
+ b_ptrs += BLOCK_SIZE_K * stride_bk
95
+
96
+ if C.dtype.element_ty == tl.bfloat16:
97
+ c = accumulator.to(tl.bfloat16)
98
+ elif C.dtype.element_ty == tl.float16:
99
+ c = accumulator.to(tl.float16)
100
+ else:
101
+ c = accumulator.to(tl.float32)
102
+
103
+ offs_cm = tl.arange(0, BLOCK_SIZE_M)
104
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
105
+ c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
106
+ tl.store(c_ptrs, c)
107
+
108
+
109
+ @triton.autotune(
110
+ configs=[
111
+ triton.Config({}, num_warps=w, num_stages=s)
112
+ for w in [2, 4, 8, 16]
113
+ for s in [2, 3, 4, 5]
114
+ ],
115
+ key=["N", "K"],
116
+ )
117
+ @triton.jit
118
+ def w8a8_tensor_fp8_matmul_batched_kernel(
119
+ A, # (S, K) pre-quantized FP8 activations
120
+ B, # (E, N, K) FP8 weight matrices
121
+ C, # (S, N) output
122
+ As, # (S, 1) per-tensor activation scales
123
+ Bs, # (E, 1, 1) per-tensor weight scales
124
+ ExpertIds,
125
+ S,
126
+ N,
127
+ K,
128
+ stride_am,
129
+ stride_ak,
130
+ stride_be,
131
+ stride_bk,
132
+ stride_bn,
133
+ stride_cm,
134
+ stride_cn,
135
+ stride_as_m,
136
+ stride_bs_e,
137
+ BLOCK_SIZE_N: tl.constexpr,
138
+ BLOCK_SIZE_K: tl.constexpr,
139
+ BLOCK_SIZE_M: tl.constexpr,
140
+ ):
141
+ """Tensor-scale batched FP8 expert matmul kernel.
142
+
143
+ Activations are already quantized; the kernel applies per-token activation
144
+ scales and per-expert tensor weight scales.
145
+ """
146
+ batch_id = tl.program_id(axis=0)
147
+ pid_n = tl.program_id(axis=1)
148
+
149
+ expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
150
+
151
+ A = A + batch_id * stride_am
152
+ B = B + expert_id * stride_be
153
+ C = C + batch_id * stride_cm
154
+ Bs = Bs + expert_id * stride_bs_e
155
+
156
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
157
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
158
+ a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
159
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
160
+
161
+ b_s = tl.load(Bs)
162
+ a_s = tl.load(As + batch_id * stride_as_m)
163
+
164
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
165
+ for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
166
+ a = tl.load(a_ptrs)
167
+ b = tl.load(b_ptrs)
168
+ accumulator += tl.dot(a, b)
169
+ a_ptrs += BLOCK_SIZE_K * stride_ak
170
+ b_ptrs += BLOCK_SIZE_K * stride_bk
171
+
172
+ accumulator = accumulator * a_s * b_s
173
+
174
+ if C.dtype.element_ty == tl.bfloat16:
175
+ c = accumulator.to(tl.bfloat16)
176
+ elif C.dtype.element_ty == tl.float16:
177
+ c = accumulator.to(tl.float16)
178
+ else:
179
+ c = accumulator.to(tl.float32)
180
+
181
+ offs_cm = tl.arange(0, BLOCK_SIZE_M)
182
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
183
+ c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
184
+ tl.store(c_ptrs, c)
185
+
186
+
187
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul_batched", mutates_args=())
188
+ def _w8a8_block_fp8_matmul_batched(
189
+ A: torch.Tensor,
190
+ B: torch.Tensor,
191
+ Bs: torch.Tensor,
192
+ expert_ids: torch.Tensor,
193
+ block_size: list[int],
194
+ ) -> torch.Tensor:
195
+ """Block-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
196
+
197
+ A: (S, K) raw bf16/fp16 activations
198
+ B: (E, N, K) FP8 expert weights
199
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
200
+ """
201
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
202
+ assert A.is_contiguous(), "A must be contiguous"
203
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
204
+ assert B.is_contiguous(), "B must be contiguous"
205
+ assert A.shape[1] == B.shape[2], (
206
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
207
+ )
208
+
209
+ S, K = A.shape
210
+ E, N, _ = B.shape
211
+
212
+ assert len(block_size) == 2, (
213
+ f"block_size must be [block_n, block_k], got {block_size}"
214
+ )
215
+ block_n, block_k = block_size[0], block_size[1]
216
+ # MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
217
+ assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
218
+ assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
219
+ assert Bs.ndim == 3, (
220
+ f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
221
+ )
222
+ assert Bs.shape == (E, N // block_n, K // block_k), (
223
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
224
+ )
225
+
226
+ C = A.new_empty(S, N)
227
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
228
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
229
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
230
+ # both accuracy and performance for small M (decode).
231
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
232
+ grid = (S, triton.cdiv(N, block_n))
233
+ with device_context(A.device):
234
+ wrap_triton(w8a8_block_fp8_matmul_batched_kernel)[grid](
235
+ A,
236
+ B,
237
+ C,
238
+ Bs,
239
+ expert_ids,
240
+ S,
241
+ N,
242
+ K,
243
+ A.stride(0),
244
+ A.stride(1),
245
+ B.stride(0),
246
+ B.stride(2),
247
+ B.stride(1),
248
+ C.stride(0),
249
+ C.stride(1),
250
+ Bs.stride(0),
251
+ Bs.stride(2),
252
+ Bs.stride(1),
253
+ BLOCK_SIZE_N=block_n,
254
+ BLOCK_SIZE_K=block_k,
255
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
256
+ )
257
+
258
+ return C
259
+
260
+
261
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_batched", mutates_args=())
262
+ def _w8a8_tensor_fp8_matmul_batched(
263
+ A: torch.Tensor,
264
+ B: torch.Tensor,
265
+ Bs: torch.Tensor,
266
+ expert_ids: torch.Tensor,
267
+ ) -> torch.Tensor:
268
+ """Tensor-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
269
+
270
+ A: (S, K) raw bf16/fp16 activations
271
+ B: (E, N, K) FP8 expert weights
272
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
273
+ """
274
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
275
+ assert A.is_contiguous(), "A must be contiguous"
276
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
277
+ assert B.is_contiguous(), "B must be contiguous"
278
+ assert A.shape[1] == B.shape[2], (
279
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
280
+ )
281
+
282
+ S, K = A.shape
283
+ E, N, _ = B.shape
284
+
285
+ # Normalize Bs to (E, 1, 1)
286
+ if Bs.ndim == 1:
287
+ assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
288
+ Bs = Bs.reshape(E, 1, 1)
289
+ else:
290
+ assert Bs.shape == (E, 1, 1), (
291
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
292
+ )
293
+
294
+ BLOCK_SIZE_N = 128
295
+ BLOCK_SIZE_K = 128
296
+ C = A.new_empty(S, N)
297
+ qA, As = fp8_act_quant(A, K)
298
+ grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
299
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
300
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
301
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
302
+ # both accuracy and performance for small M (decode).
303
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
304
+ grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
305
+ with device_context(A.device):
306
+ wrap_triton(w8a8_tensor_fp8_matmul_batched_kernel)[grid](
307
+ qA,
308
+ B,
309
+ C,
310
+ As,
311
+ Bs,
312
+ expert_ids,
313
+ S,
314
+ N,
315
+ K,
316
+ qA.stride(0),
317
+ qA.stride(1),
318
+ B.stride(0),
319
+ B.stride(2),
320
+ B.stride(1),
321
+ C.stride(0),
322
+ C.stride(1),
323
+ As.stride(0),
324
+ Bs.stride(0),
325
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
326
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
327
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
328
+ )
329
+
330
+ return C
331
+
332
+
333
+ def w8a8_block_fp8_matmul_batched(
334
+ A: torch.Tensor,
335
+ B: torch.Tensor,
336
+ Bs: torch.Tensor,
337
+ expert_ids: torch.Tensor,
338
+ block_size: list[int],
339
+ ) -> torch.Tensor:
340
+ """Block-scale batched FP8 matmul with fused activation quantization.
341
+
342
+ A: (S, K) raw activations, bf16/fp16/fp32
343
+ B: (E, N, K) FP8 expert weights
344
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
345
+ """
346
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_batched(
347
+ A, B, Bs, expert_ids, block_size
348
+ )
349
+
350
+
351
+ def w8a8_tensor_fp8_matmul_batched(
352
+ A: torch.Tensor,
353
+ B: torch.Tensor,
354
+ Bs: torch.Tensor,
355
+ expert_ids: torch.Tensor,
356
+ ) -> torch.Tensor:
357
+ """Tensor-scale batched FP8 matmul with fused activation quantization.
358
+
359
+ A: (S, K) raw activations, bf16/fp16/fp32
360
+ B: (E, N, K) FP8 expert weights
361
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
362
+ """
363
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_batched(
364
+ A, B, Bs, expert_ids
365
+ )
366
+
367
+
368
+ def w8a8_fp8_matmul_batched(
369
+ A: torch.Tensor,
370
+ B: torch.Tensor,
371
+ Bs: torch.Tensor,
372
+ expert_ids: torch.Tensor,
373
+ block_size: list[int] | None,
374
+ ) -> torch.Tensor:
375
+ """Unified batched W8A8 FP8 matmul dispatcher.
376
+
377
+ Dispatch rules:
378
+ - tensor mode when ``block_size is None``
379
+ - tensor mode when ``block_size == [N, K]``
380
+ - otherwise block mode
381
+
382
+ Returns:
383
+ Output tensor ``[S, N]`` in the same dtype as ``A``.
384
+ """
385
+ if block_size is None or (
386
+ block_size[0] == B.size(1) and block_size[1] == B.size(2)
387
+ ):
388
+ return w8a8_tensor_fp8_matmul_batched(A, B, Bs, expert_ids)
389
+
390
+ return w8a8_block_fp8_matmul_batched(A, B, Bs, expert_ids, block_size)
build/torch-cuda/finegrained_fp8/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-cuda/grouped.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from .act_quant import fp8_act_quant
19
+ from torch.library import triton_op, wrap_triton
20
+
21
+ from .utils import device_context
22
+
23
+
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4, 5]
29
+ ],
30
+ key=["N", "K", "BLOCK_SIZE_M"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_grouped_kernel(
34
+ A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
35
+ B, # (E, N, K) FP8 weight matrices
36
+ C, # (S, N) output
37
+ Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
38
+ Offsets, # (E,) int32 — cumulative row-end per expert
39
+ TileOffsets, # (E,) int32 — cumulative tile-end per expert
40
+ # Shape
41
+ S,
42
+ N,
43
+ K,
44
+ # Strides
45
+ stride_am,
46
+ stride_ak,
47
+ stride_be,
48
+ stride_bk,
49
+ stride_bn,
50
+ stride_cm,
51
+ stride_cn,
52
+ stride_bs_e,
53
+ stride_bs_k,
54
+ stride_bs_n,
55
+ # Meta-parameters
56
+ NUM_EXPERTS: tl.constexpr,
57
+ BLOCK_SIZE_N: tl.constexpr,
58
+ BLOCK_SIZE_K: tl.constexpr,
59
+ BLOCK_SIZE_M: tl.constexpr,
60
+ NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
61
+ ):
62
+ """Block-scale grouped FP8 expert matmul kernel.
63
+
64
+ Tokens are assumed sorted by expert. The kernel maps each M-tile to its
65
+ owning expert via ``TileOffsets`` and applies fused activation quantization.
66
+ """
67
+ pid_m = tl.program_id(axis=0)
68
+ pid_n = tl.program_id(axis=1)
69
+
70
+ # Exit early for programs beyond the actual tile count.
71
+ total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
72
+ if pid_m >= total_tiles:
73
+ return
74
+
75
+ # Binary search in TileOffsets to find the owning expert.
76
+ # Finds the smallest e such that TileOffsets[e] > pid_m (upper_bound semantics),
77
+ # which is the expert whose tile range contains pid_m.
78
+ # O(log2(NUM_EXPERTS)) loads instead of the O(NUM_EXPERTS) linear scan.
79
+ # NUM_EXPERTS_BIT_LENGTH is ceil(log2(E))+1 for powers-of-two, giving one
80
+ # harmless extra iteration when lo==hi; it's a compile-time constant so the
81
+ # loop is fully unrolled by the compiler.
82
+ lo = 0
83
+ hi = NUM_EXPERTS
84
+ for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
85
+ mid = (lo + hi) >> 1
86
+ mid_val = tl.load(TileOffsets + mid)
87
+ is_left = mid_val <= pid_m
88
+ lo = tl.where(is_left, mid + 1, lo)
89
+ hi = tl.where(is_left, hi, mid)
90
+
91
+ # Cast expert_id to int64 to prevent int32 overflow when computing
92
+ # expert_id * stride_be (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
93
+ # 3072×3072 FP8 weights).
94
+ expert_id = lo.to(tl.int64)
95
+
96
+ prev_eid = tl.maximum(expert_id - 1, 0)
97
+ expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
98
+ expert_end = tl.load(Offsets + expert_id)
99
+ M_expert = expert_end - expert_start
100
+
101
+ expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
102
+ local_tile = pid_m - expert_tile_start
103
+ m_off = local_tile * BLOCK_SIZE_M
104
+
105
+ offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
106
+ row_mask = offs_am < M_expert
107
+ offs_global_m = expert_start + offs_am
108
+
109
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
110
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
111
+
112
+ a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
113
+ b_ptrs = (
114
+ B
115
+ + expert_id * stride_be
116
+ + offs_k[:, None] * stride_bk
117
+ + offs_bn[None, :] * stride_bn
118
+ )
119
+ bs_ptrs = Bs + expert_id * stride_bs_e + pid_n * stride_bs_n
120
+
121
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
122
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
123
+ # ---- fused fp8_act_quant ----
124
+ a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32)
125
+ a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0
126
+ a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
127
+ # ---- matmul ----
128
+ b = tl.load(b_ptrs)
129
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
130
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
131
+ a_ptrs += BLOCK_SIZE_K * stride_ak
132
+ b_ptrs += BLOCK_SIZE_K * stride_bk
133
+
134
+ if C.dtype.element_ty == tl.bfloat16:
135
+ c = accumulator.to(tl.bfloat16)
136
+ elif C.dtype.element_ty == tl.float16:
137
+ c = accumulator.to(tl.float16)
138
+ else:
139
+ c = accumulator.to(tl.float32)
140
+
141
+ c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
142
+ c_mask = row_mask[:, None]
143
+ tl.store(c_ptrs, c, mask=c_mask)
144
+
145
+
146
+ @triton.autotune(
147
+ configs=[
148
+ triton.Config({}, num_warps=w, num_stages=s)
149
+ for w in [2, 4, 8, 16]
150
+ for s in [2, 3, 4, 5]
151
+ ],
152
+ key=["N", "K", "BLOCK_SIZE_M"],
153
+ )
154
+ @triton.jit
155
+ def w8a8_tensor_fp8_matmul_grouped_kernel(
156
+ A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert idc
157
+ B, # (E, N, K) FP8 weight matrices
158
+ C, # (S, N) output
159
+ As, # (S, 1) activation scales
160
+ Bs, # (E, 1, 1) per-tensor weight scales
161
+ Offsets,
162
+ TileOffsets,
163
+ S,
164
+ N,
165
+ K,
166
+ stride_am,
167
+ stride_ak,
168
+ stride_be,
169
+ stride_bk,
170
+ stride_bn,
171
+ stride_cm,
172
+ stride_cn,
173
+ stride_as_m,
174
+ stride_bs_e,
175
+ NUM_EXPERTS: tl.constexpr,
176
+ BLOCK_SIZE_N: tl.constexpr,
177
+ BLOCK_SIZE_K: tl.constexpr,
178
+ BLOCK_SIZE_M: tl.constexpr,
179
+ NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
180
+ ):
181
+ """Tensor-scale grouped FP8 expert matmul kernel.
182
+
183
+ Uses grouped expert scheduling with pre-quantized activations plus
184
+ per-token activation scales and per-expert tensor weight scales.
185
+ """
186
+ pid_m = tl.program_id(axis=0)
187
+ pid_n = tl.program_id(axis=1)
188
+
189
+ total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
190
+ if pid_m >= total_tiles:
191
+ return
192
+
193
+ lo = 0
194
+ hi = NUM_EXPERTS
195
+ for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
196
+ mid = (lo + hi) >> 1
197
+ mid_val = tl.load(TileOffsets + mid)
198
+ is_left = mid_val <= pid_m
199
+ lo = tl.where(is_left, mid + 1, lo)
200
+ hi = tl.where(is_left, hi, mid)
201
+ expert_id = lo.to(tl.int64)
202
+
203
+ prev_eid = tl.maximum(expert_id - 1, 0)
204
+ expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
205
+ expert_end = tl.load(Offsets + expert_id)
206
+ M_expert = expert_end - expert_start
207
+
208
+ expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
209
+ local_tile = pid_m - expert_tile_start
210
+ m_off = local_tile * BLOCK_SIZE_M
211
+
212
+ offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
213
+ row_mask = offs_am < M_expert
214
+ offs_global_m = expert_start + offs_am
215
+
216
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
218
+
219
+ a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
220
+ b_ptrs = (
221
+ B
222
+ + expert_id * stride_be
223
+ + offs_k[:, None] * stride_bk
224
+ + offs_bn[None, :] * stride_bn
225
+ )
226
+
227
+ a_s = tl.load(As + offs_global_m * stride_as_m, mask=row_mask, other=0.0)
228
+ b_s = tl.load(Bs + expert_id * stride_bs_e)
229
+
230
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
231
+ for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
232
+ a = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0)
233
+ b = tl.load(b_ptrs)
234
+
235
+ accumulator += tl.dot(a, b)
236
+ a_ptrs += BLOCK_SIZE_K * stride_ak
237
+ b_ptrs += BLOCK_SIZE_K * stride_bk
238
+
239
+ accumulator = accumulator * a_s[:, None] * b_s
240
+
241
+ if C.dtype.element_ty == tl.bfloat16:
242
+ c = accumulator.to(tl.bfloat16)
243
+ elif C.dtype.element_ty == tl.float16:
244
+ c = accumulator.to(tl.float16)
245
+ else:
246
+ c = accumulator.to(tl.float32)
247
+
248
+ c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
249
+ c_mask = row_mask[:, None]
250
+ tl.store(c_ptrs, c, mask=c_mask)
251
+
252
+
253
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul_grouped", mutates_args=())
254
+ def _w8a8_block_fp8_matmul_grouped(
255
+ A: torch.Tensor,
256
+ B: torch.Tensor,
257
+ Bs: torch.Tensor,
258
+ offsets: torch.Tensor,
259
+ tokens_per_expert: torch.Tensor,
260
+ block_size: list[int],
261
+ ) -> torch.Tensor:
262
+ """Block-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
263
+
264
+ A: (S, K) raw bf16/fp16 activations, sorted by expert
265
+ B: (E, N, K) FP8 expert weights
266
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
267
+ """
268
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
269
+ assert A.is_contiguous(), "A must be contiguous"
270
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
271
+ assert B.is_contiguous(), "B must be contiguous"
272
+ assert A.shape[1] == B.shape[2], (
273
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
274
+ )
275
+
276
+ S, K = A.shape
277
+ E, N, _ = B.shape
278
+
279
+ assert len(block_size) == 2, (
280
+ f"block_size must be [block_n, block_k], got {block_size}"
281
+ )
282
+ block_n, block_k = block_size[0], block_size[1]
283
+ # MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
284
+ assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
285
+ assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
286
+ assert Bs.ndim == 3, (
287
+ f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
288
+ )
289
+ assert Bs.shape == (E, N // block_n, K // block_k), (
290
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
291
+ )
292
+
293
+ C = A.new_empty(S, N)
294
+ # Adaptive BLOCK_SIZE_M: match tile to average tokens per expert.
295
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
296
+ tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
297
+ tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
298
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
299
+ # Programs beyond the real tile count exit immediately via the early-return
300
+ # guard inside the kernel. This is faster than syncing for the exact count
301
+ # and keeps the grid size data-independent (cuda-graph / torch.compile safe).
302
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
303
+ grid = (max_M_tiles, triton.cdiv(N, block_n))
304
+ with device_context(A.device):
305
+ wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
306
+ A,
307
+ B,
308
+ C,
309
+ Bs,
310
+ offsets,
311
+ tile_offsets,
312
+ S,
313
+ N,
314
+ K,
315
+ A.stride(0),
316
+ A.stride(1),
317
+ B.stride(0),
318
+ B.stride(2),
319
+ B.stride(1),
320
+ C.stride(0),
321
+ C.stride(1),
322
+ Bs.stride(0),
323
+ Bs.stride(2),
324
+ Bs.stride(1),
325
+ # Meta-parameters
326
+ NUM_EXPERTS=E,
327
+ BLOCK_SIZE_N=block_n,
328
+ BLOCK_SIZE_K=block_k,
329
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
330
+ NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
331
+ )
332
+
333
+ return C
334
+
335
+
336
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_grouped", mutates_args=())
337
+ def _w8a8_tensor_fp8_matmul_grouped(
338
+ A: torch.Tensor,
339
+ B: torch.Tensor,
340
+ Bs: torch.Tensor,
341
+ offsets: torch.Tensor,
342
+ tokens_per_expert: torch.Tensor,
343
+ ) -> torch.Tensor:
344
+ """Tensor-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
345
+
346
+ A: (S, K) raw bf16/fp16 activations, sorted by expert
347
+ B: (E, N, K) FP8 expert weights
348
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
349
+ """
350
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
351
+ assert A.is_contiguous(), "A must be contiguous"
352
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
353
+ assert B.is_contiguous(), "B must be contiguous"
354
+ assert A.shape[1] == B.shape[2], (
355
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
356
+ )
357
+
358
+ S, K = A.shape
359
+ E, N, _ = B.shape
360
+
361
+ # Normalize Bs to (E, 1, 1)
362
+ if Bs.ndim == 1:
363
+ assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
364
+ Bs = Bs.reshape(E, 1, 1)
365
+ else:
366
+ assert Bs.shape == (E, 1, 1), (
367
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
368
+ )
369
+
370
+ BLOCK_SIZE_N = 128
371
+ BLOCK_SIZE_K = 128
372
+ C = A.new_empty(S, N)
373
+ qA, As = fp8_act_quant(A, K)
374
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
375
+ tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
376
+ tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
377
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
378
+ # Programs beyond the real tile count exit immediately via the early-return
379
+ # guard inside the kernel. This is faster than syncing for the exact count
380
+ # and keeps the grid size data-independent (cuda-graph / torch.compile safe).
381
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
382
+ grid = (max_M_tiles, triton.cdiv(N, BLOCK_SIZE_N))
383
+ with device_context(A.device):
384
+ wrap_triton(w8a8_tensor_fp8_matmul_grouped_kernel)[grid](
385
+ qA,
386
+ B,
387
+ C,
388
+ As,
389
+ Bs,
390
+ offsets,
391
+ tile_offsets,
392
+ S,
393
+ N,
394
+ K,
395
+ qA.stride(0),
396
+ qA.stride(1),
397
+ B.stride(0),
398
+ B.stride(2),
399
+ B.stride(1),
400
+ C.stride(0),
401
+ C.stride(1),
402
+ As.stride(0),
403
+ Bs.stride(0),
404
+ # Meta-parameters
405
+ NUM_EXPERTS=E,
406
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
407
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
408
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
409
+ NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
410
+ )
411
+
412
+ return C
413
+
414
+
415
+ def w8a8_block_fp8_matmul_grouped(
416
+ A: torch.Tensor,
417
+ B: torch.Tensor,
418
+ Bs: torch.Tensor,
419
+ offsets: torch.Tensor,
420
+ tokens_per_expert: torch.Tensor,
421
+ block_size: list[int],
422
+ ) -> torch.Tensor:
423
+ """Block-scale grouped FP8 matmul with fused activation quantization.
424
+
425
+ A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
426
+ B: (E, N, K) FP8 expert weights
427
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
428
+ """
429
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
430
+ A, B, Bs, offsets, tokens_per_expert, block_size
431
+ )
432
+
433
+
434
+ def w8a8_tensor_fp8_matmul_grouped(
435
+ A: torch.Tensor,
436
+ B: torch.Tensor,
437
+ Bs: torch.Tensor,
438
+ offsets: torch.Tensor,
439
+ tokens_per_expert: torch.Tensor,
440
+ ) -> torch.Tensor:
441
+ """Tensor-scale grouped FP8 matmul with fused activation quantization.
442
+
443
+ A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
444
+ B: (E, N, K) FP8 expert weights
445
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
446
+ """
447
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
448
+ A, B, Bs, offsets, tokens_per_expert
449
+ )
450
+
451
+
452
+ def w8a8_fp8_matmul_grouped(
453
+ A: torch.Tensor,
454
+ B: torch.Tensor,
455
+ Bs: torch.Tensor,
456
+ offsets: torch.Tensor,
457
+ tokens_per_expert: torch.Tensor,
458
+ block_size: list[int] | None,
459
+ ) -> torch.Tensor:
460
+ """Unified grouped W8A8 FP8 matmul dispatcher.
461
+
462
+ Dispatch rules:
463
+ - tensor mode when ``block_size is None``
464
+ - tensor mode when ``block_size == [N, K]``
465
+ - otherwise block mode
466
+
467
+ Returns:
468
+ Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
469
+ """
470
+ if block_size is None or (
471
+ block_size[0] == B.size(1) and block_size[1] == B.size(2)
472
+ ):
473
+ return w8a8_tensor_fp8_matmul_grouped(A, B, Bs, offsets, tokens_per_expert)
474
+
475
+ return w8a8_block_fp8_matmul_grouped(
476
+ A, B, Bs, offsets, tokens_per_expert, block_size
477
+ )
build/torch-cuda/matmul.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from torch.library import triton_op, wrap_triton
19
+
20
+ from .utils import device_context
21
+
22
+
23
+ # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4]
29
+ ],
30
+ key=["N", "K", "BLOCK_SIZE_M"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_kernel(
34
+ # Pointers to inputs and output
35
+ A,
36
+ B,
37
+ C,
38
+ As,
39
+ Bs,
40
+ # Shape for matmul
41
+ M,
42
+ N,
43
+ K,
44
+ stride_am,
45
+ stride_ak,
46
+ stride_bk,
47
+ stride_bn,
48
+ stride_cm,
49
+ stride_cn,
50
+ stride_as_m,
51
+ stride_as_k,
52
+ stride_bs_k,
53
+ stride_bs_n,
54
+ # Meta-parameters
55
+ BLOCK_SIZE_M: tl.constexpr,
56
+ BLOCK_SIZE_N: tl.constexpr,
57
+ BLOCK_SIZE_K: tl.constexpr,
58
+ GROUP_SIZE_M: tl.constexpr,
59
+ ):
60
+ """Block-scale FP8 GEMM kernel.
61
+
62
+ Computes ``C = A @ B.T`` with block-wise activation/weight scales.
63
+ Uses a 2D grid with swizzle for L2 cache locality on B tiles.
64
+ """
65
+ pid_m = tl.program_id(axis=0)
66
+ pid_n = tl.program_id(axis=1)
67
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
68
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
69
+ pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
70
+
71
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
72
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
73
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
74
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
75
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
76
+
77
+ as_ptrs = As + offs_am * stride_as_m
78
+ offs_bsn = offs_bn // BLOCK_SIZE_N
79
+ bs_ptrs = Bs + offs_bsn * stride_bs_n
80
+
81
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
82
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
83
+ k_remaining = K - k * BLOCK_SIZE_K
84
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
85
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
86
+
87
+ a_s = tl.load(as_ptrs + k * stride_as_k)
88
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
89
+
90
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
91
+ a_ptrs += BLOCK_SIZE_K * stride_ak
92
+ b_ptrs += BLOCK_SIZE_K * stride_bk
93
+
94
+ if C.dtype.element_ty == tl.bfloat16:
95
+ c = accumulator.to(tl.bfloat16)
96
+ elif C.dtype.element_ty == tl.float16:
97
+ c = accumulator.to(tl.float16)
98
+ else:
99
+ c = accumulator.to(tl.float32)
100
+
101
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
102
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
103
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
104
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
105
+ tl.store(c_ptrs, c, mask=c_mask)
106
+
107
+
108
+ @triton.autotune(
109
+ configs=[
110
+ triton.Config({}, num_warps=w, num_stages=s)
111
+ for w in [2, 4, 8, 16]
112
+ for s in [2, 3, 4]
113
+ ],
114
+ key=["N", "K", "BLOCK_SIZE_M"],
115
+ )
116
+ @triton.jit
117
+ def w8a8_tensor_fp8_matmul_kernel(
118
+ A,
119
+ B,
120
+ C,
121
+ As,
122
+ Bs,
123
+ M,
124
+ N,
125
+ K,
126
+ stride_am,
127
+ stride_ak,
128
+ stride_bk,
129
+ stride_bn,
130
+ stride_cm,
131
+ stride_cn,
132
+ stride_as_m,
133
+ BLOCK_SIZE_M: tl.constexpr,
134
+ BLOCK_SIZE_N: tl.constexpr,
135
+ BLOCK_SIZE_K: tl.constexpr,
136
+ GROUP_SIZE_M: tl.constexpr,
137
+ ):
138
+ """Tensor-scale FP8 GEMM kernel.
139
+
140
+ Computes ``C = A @ B.T`` with one activation scale per row and one
141
+ weight scale for the full matrix.
142
+ Uses a 2D grid with swizzle for L2 cache locality on B tiles.
143
+ """
144
+ pid_m = tl.program_id(axis=0)
145
+ pid_n = tl.program_id(axis=1)
146
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
147
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
148
+ pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
149
+
150
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
151
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
152
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
153
+
154
+ a_ptrs = A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
155
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
156
+
157
+ a_s = tl.load(As + offs_am * stride_as_m)
158
+ b_s = tl.load(Bs)
159
+
160
+ # Accumulate raw dot products, apply scales once after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ k_remaining = K - k * BLOCK_SIZE_K
164
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
165
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
166
+ accumulator += tl.dot(a, b)
167
+ a_ptrs += BLOCK_SIZE_K * stride_ak
168
+ b_ptrs += BLOCK_SIZE_K * stride_bk
169
+
170
+ accumulator = accumulator * a_s[:, None] * b_s
171
+
172
+ if C.dtype.element_ty == tl.bfloat16:
173
+ c = accumulator.to(tl.bfloat16)
174
+ elif C.dtype.element_ty == tl.float16:
175
+ c = accumulator.to(tl.float16)
176
+ else:
177
+ c = accumulator.to(tl.float32)
178
+
179
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
180
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
181
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
182
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
183
+ tl.store(c_ptrs, c, mask=c_mask)
184
+
185
+
186
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul", mutates_args=())
187
+ def _w8a8_block_fp8_matmul(
188
+ A: torch.Tensor,
189
+ B: torch.Tensor,
190
+ As: torch.Tensor,
191
+ Bs: torch.Tensor,
192
+ block_size: list[int],
193
+ output_dtype: torch.dtype = torch.float32,
194
+ ) -> torch.Tensor:
195
+ """Block-scale FP8 matmul: C = A @ B.T with per-block scales.
196
+
197
+ As: (M, K // block_k) — per-token-group activation scales
198
+ Bs: (N // block_n, K // block_k) — per-block weight scales
199
+ """
200
+ assert len(block_size) == 2, (
201
+ f"block_size must be [block_n, block_k], got {block_size}"
202
+ )
203
+ block_n, block_k = block_size[0], block_size[1]
204
+
205
+ assert A.shape[-1] == B.shape[-1], (
206
+ f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
207
+ )
208
+ assert A.is_contiguous(), "A must be contiguous"
209
+ assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
210
+ assert B.is_contiguous(), "B must be contiguous"
211
+
212
+ N, K = B.shape
213
+ M = A.numel() // A.shape[-1]
214
+
215
+ assert As.ndim >= 2, f"As must be at least 2D, got ndim={As.ndim}"
216
+ assert As.shape[-1] == triton.cdiv(K, block_k), (
217
+ f"As last dim {As.shape[-1]} != expected {triton.cdiv(K, block_k)} (cdiv(K={K}, block_k={block_k}))"
218
+ )
219
+ assert Bs.ndim == 2, f"Bs must be 2D (N//block_n, K//block_k), got ndim={Bs.ndim}"
220
+ assert Bs.shape == (triton.cdiv(N, block_n), triton.cdiv(K, block_k)), (
221
+ f"Bs shape {tuple(Bs.shape)} != expected ({triton.cdiv(N, block_n)}, {triton.cdiv(K, block_k)})"
222
+ )
223
+
224
+ BLOCK_SIZE_K = block_k
225
+ BLOCK_SIZE_N = block_n
226
+ C_shape = A.shape[:-1] + (N,)
227
+ C = A.new_empty(C_shape, dtype=output_dtype)
228
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
229
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
230
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
231
+ # both accuracy and performance for small M (decode).
232
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
233
+ grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
234
+ with device_context(A.device):
235
+ wrap_triton(w8a8_block_fp8_matmul_kernel)[grid](
236
+ A,
237
+ B,
238
+ C,
239
+ As,
240
+ Bs,
241
+ M,
242
+ N,
243
+ K,
244
+ A.stride(-2),
245
+ A.stride(-1),
246
+ B.stride(1),
247
+ B.stride(0),
248
+ C.stride(-2),
249
+ C.stride(-1),
250
+ As.stride(-2),
251
+ As.stride(-1),
252
+ Bs.stride(1),
253
+ Bs.stride(0),
254
+ # Meta-parameters
255
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
256
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
257
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
258
+ GROUP_SIZE_M=8,
259
+ )
260
+
261
+ return C
262
+
263
+
264
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul", mutates_args=())
265
+ def _w8a8_tensor_fp8_matmul(
266
+ A: torch.Tensor,
267
+ B: torch.Tensor,
268
+ As: torch.Tensor,
269
+ Bs: torch.Tensor,
270
+ output_dtype: torch.dtype = torch.float32,
271
+ ) -> torch.Tensor:
272
+ """Tensor-scale FP8 matmul: C = A @ B.T with per-row / per-tensor scales.
273
+
274
+ As: scalar, (M,), or (M, 1) — per-row activation scales
275
+ Bs: scalar, (1,), or (1, 1) — single weight scale
276
+ """
277
+ assert A.shape[-1] == B.shape[-1], (
278
+ f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
279
+ )
280
+ assert A.is_contiguous(), "A must be contiguous"
281
+ assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
282
+ assert B.is_contiguous(), "B must be contiguous"
283
+
284
+ N, K = B.shape
285
+ M = A.numel() // A.shape[-1]
286
+
287
+ # Normalize As to (M,)
288
+ if As.numel() == 1:
289
+ As = As.reshape(1).expand(M).contiguous()
290
+ elif As.ndim == 2:
291
+ As = As.reshape(M)
292
+ assert As.ndim == 1 and As.shape[0] == M, (
293
+ f"As must be scalar, (M,), or (M,1) with M={M}, got {tuple(As.shape)}"
294
+ )
295
+
296
+ # Normalize Bs to (1,)
297
+ assert Bs.numel() == 1, f"Bs must be scalar or (1,), got {tuple(Bs.shape)}"
298
+ Bs = Bs.reshape(1)
299
+
300
+ BLOCK_SIZE_N = 128
301
+ BLOCK_SIZE_K = 128
302
+ C_shape = A.shape[:-1] + (N,)
303
+ C = A.new_empty(C_shape, dtype=output_dtype)
304
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
305
+ grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
306
+ with device_context(A.device):
307
+ wrap_triton(w8a8_tensor_fp8_matmul_kernel)[grid](
308
+ A,
309
+ B,
310
+ C,
311
+ As,
312
+ Bs,
313
+ M,
314
+ N,
315
+ K,
316
+ A.stride(-2),
317
+ A.stride(-1),
318
+ B.stride(1),
319
+ B.stride(0),
320
+ C.stride(-2),
321
+ C.stride(-1),
322
+ As.stride(0),
323
+ # Meta-parameters
324
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
325
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
326
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
327
+ GROUP_SIZE_M=8,
328
+ )
329
+
330
+ return C
331
+
332
+
333
+ def w8a8_block_fp8_matmul(
334
+ A: torch.Tensor,
335
+ B: torch.Tensor,
336
+ As: torch.Tensor,
337
+ Bs: torch.Tensor,
338
+ block_size: list[int],
339
+ output_dtype: torch.dtype = torch.float32,
340
+ ) -> torch.Tensor:
341
+ """Block-wise W8A8 FP8 matrix multiplication.
342
+
343
+ Computes ``C = A @ B.T`` where both operands are pre-quantized to
344
+ ``float8_e4m3fn`` with per-block scales, and accumulates in float32
345
+ before casting to ``output_dtype``.
346
+
347
+ Args:
348
+ A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
349
+ B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
350
+ As: Per-token-group activation scales ``[M, K // block_size[1]]``.
351
+ Bs: Per-block weight scales ``[N // block_size[0], K // block_size[1]]``.
352
+ block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
353
+ output_dtype: dtype of the returned tensor (default: ``torch.float32``).
354
+
355
+ Returns:
356
+ Output tensor ``[M, N]`` in ``output_dtype``.
357
+ """
358
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul(
359
+ A, B, As, Bs, block_size, output_dtype
360
+ )
361
+
362
+
363
+ def w8a8_tensor_fp8_matmul(
364
+ A: torch.Tensor,
365
+ B: torch.Tensor,
366
+ As: torch.Tensor,
367
+ Bs: torch.Tensor,
368
+ output_dtype: torch.dtype = torch.float32,
369
+ ) -> torch.Tensor:
370
+ """Tensor-scale W8A8 FP8 matrix multiplication.
371
+
372
+ Computes ``C = A @ B.T`` in tensor-scale mode using pre-quantized FP8
373
+ activations/weights and tensor scales.
374
+
375
+ Args:
376
+ A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
377
+ B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
378
+ As: Per-row activation scales ``[M]``.
379
+ Bs: Single weight scale, scalar or ``[1]``.
380
+ output_dtype: dtype of the returned tensor.
381
+
382
+ Returns:
383
+ Output tensor ``[M, N]`` in ``output_dtype``.
384
+ """
385
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
386
+
387
+
388
+ def w8a8_fp8_matmul(
389
+ A: torch.Tensor,
390
+ B: torch.Tensor,
391
+ As: torch.Tensor,
392
+ Bs: torch.Tensor,
393
+ block_size: list[int] | None,
394
+ output_dtype: torch.dtype = torch.float32,
395
+ ) -> torch.Tensor:
396
+ """Unified W8A8 FP8 matmul dispatcher.
397
+
398
+ Dispatch rules:
399
+ - tensor mode when ``block_size is None``
400
+ - tensor mode when ``block_size == [N, K]``
401
+ - otherwise block mode
402
+
403
+ Returns:
404
+ Output tensor ``[M, N]`` in ``output_dtype``.
405
+ """
406
+ if block_size is None or (
407
+ block_size[0] == B.size(0) and block_size[1] == B.size(1)
408
+ ):
409
+ return w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
410
+
411
+ return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
build/torch-cuda/metadata.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda"
7
+ }
8
+ }
build/torch-cuda/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+
4
+
5
+ @contextmanager
6
+ def device_context(device: torch.device):
7
+ """Context manager that sets the active device for any backend (cuda, xpu, etc.)."""
8
+ backend = getattr(torch, device.type, None)
9
+ if backend is not None and hasattr(backend, "device"):
10
+ with backend.device(device):
11
+ yield
12
+ else:
13
+ yield
build/torch-rocm/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .act_quant import fp8_act_quant
2
+ from .batched import (
3
+ w8a8_fp8_matmul_batched,
4
+ w8a8_block_fp8_matmul_batched,
5
+ w8a8_tensor_fp8_matmul_batched,
6
+ )
7
+ from .grouped import (
8
+ w8a8_fp8_matmul_grouped,
9
+ w8a8_block_fp8_matmul_grouped,
10
+ w8a8_tensor_fp8_matmul_grouped,
11
+ )
12
+ from .matmul import (
13
+ w8a8_fp8_matmul,
14
+ w8a8_block_fp8_matmul,
15
+ w8a8_tensor_fp8_matmul,
16
+ )
17
+
18
+ __all__ = [
19
+ "fp8_act_quant",
20
+ # Single matmul
21
+ "w8a8_fp8_matmul",
22
+ "w8a8_block_fp8_matmul",
23
+ "w8a8_tensor_fp8_matmul",
24
+ # Batched matmul
25
+ "w8a8_fp8_matmul_batched",
26
+ "w8a8_block_fp8_matmul_batched",
27
+ "w8a8_tensor_fp8_matmul_batched",
28
+ # Grouped matmul
29
+ "w8a8_fp8_matmul_grouped",
30
+ "w8a8_block_fp8_matmul_grouped",
31
+ "w8a8_tensor_fp8_matmul_grouped",
32
+ ]
build/torch-rocm/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._finegrained_fp8_75cbe1b
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_finegrained_fp8_75cbe1b::{op_name}"
build/torch-rocm/act_quant.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from torch.library import triton_op, wrap_triton
19
+
20
+ from .utils import device_context
21
+
22
+
23
+ _FP8_DTYPE = torch.float8_e4m3fn
24
+
25
+
26
+ # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
27
+ @triton.jit
28
+ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
+ pid = tl.program_id(axis=0)
30
+ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
+ x = tl.load(x_ptr + offs).to(tl.float32)
32
+ s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
33
+ y = (x / s).to(y_ptr.dtype.element_ty)
34
+ tl.store(y_ptr + offs, y)
35
+ tl.store(s_ptr + pid, s)
36
+
37
+
38
+ @triton_op("finegrained_fp8::fp8_act_quant", mutates_args=())
39
+ def _fp8_act_quant(
40
+ x: torch.Tensor, block_size: int = 128
41
+ ) -> tuple[torch.Tensor, torch.Tensor]:
42
+ assert x.is_contiguous()
43
+ assert x.shape[-1] % block_size == 0
44
+ y = torch.empty_like(x, dtype=_FP8_DTYPE)
45
+ grid = (triton.cdiv(x.numel(), block_size),)
46
+ s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
47
+
48
+ with device_context(x.device):
49
+ wrap_triton(_fp8_act_quant_kernel)[grid](x, y, s, BLOCK_SIZE=block_size)
50
+
51
+ return y, s
52
+
53
+
54
+ def fp8_act_quant(
55
+ x: torch.Tensor, block_size: int = 128
56
+ ) -> tuple[torch.Tensor, torch.Tensor]:
57
+ """Quantize activations to FP8 with per-block dynamic scaling.
58
+
59
+ Splits the last dimension of ``x`` into blocks of ``block_size`` elements,
60
+ computes ``scale = max(|x_block|) / 448`` per block, and quantizes to
61
+ ``float8_e4m3fn``.
62
+
63
+ Args:
64
+ x: Input tensor in bf16/fp16/fp32. Last dimension must be divisible by
65
+ ``block_size`` and the tensor must be contiguous.
66
+ block_size: Number of elements per quantization block (default: 128).
67
+
68
+ Returns:
69
+ A tuple ``(quantized, scales)`` where ``quantized`` has dtype
70
+ ``float8_e4m3fn`` with the same shape as ``x``, and ``scales`` has
71
+ shape ``(*x.shape[:-1], x.shape[-1] // block_size)`` in float32.
72
+ """
73
+ return torch.ops.finegrained_fp8.fp8_act_quant(x, block_size)
build/torch-rocm/batched.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from .act_quant import fp8_act_quant
19
+ from torch.library import triton_op, wrap_triton
20
+
21
+ from .utils import device_context
22
+
23
+
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4, 5]
29
+ ],
30
+ key=["N", "K"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_batched_kernel(
34
+ A, # (S, K) raw BF16/FP16 activations
35
+ B, # (E, N, K) FP8 weight matrices
36
+ C, # (S, N) output
37
+ Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
38
+ ExpertIds, # (S,) — which expert each batch element routes to
39
+ # Shape
40
+ S,
41
+ N,
42
+ K,
43
+ stride_am,
44
+ stride_ak,
45
+ stride_be,
46
+ stride_bk,
47
+ stride_bn,
48
+ stride_cm,
49
+ stride_cn,
50
+ stride_bs_e,
51
+ stride_bs_k,
52
+ stride_bs_n,
53
+ # Meta-parameters
54
+ BLOCK_SIZE_N: tl.constexpr,
55
+ BLOCK_SIZE_K: tl.constexpr,
56
+ BLOCK_SIZE_M: tl.constexpr,
57
+ ):
58
+ """Block-scale batched FP8 expert matmul kernel.
59
+
60
+ Each program handles one routed token row and one N-tile, looks up the
61
+ owning expert from ``ExpertIds``, and applies fused activation quantization.
62
+ """
63
+ batch_id = tl.program_id(axis=0)
64
+ pid_n = tl.program_id(axis=1)
65
+
66
+ # Cast expert_id to int64 to prevent int32 overflow when computing
67
+ # expert_id * stride_Eb (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
68
+ # 3072×3072 FP8 weights).
69
+ expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
70
+
71
+ A = A + batch_id * stride_am
72
+ B = B + expert_id * stride_be
73
+ C = C + batch_id * stride_cm
74
+ Bs = Bs + expert_id * stride_bs_e
75
+
76
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
77
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
78
+ a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
79
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
80
+
81
+ bs_ptrs = Bs + pid_n * stride_bs_n
82
+
83
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
84
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
85
+ # ---- fused fp8_act_quant ----
86
+ a_raw = tl.load(a_ptrs).to(tl.float32)
87
+ a_s = tl.max(tl.abs(a_raw)) / 448.0
88
+ a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv)
89
+ # ---- matmul ----
90
+ b = tl.load(b_ptrs)
91
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
92
+ accumulator += tl.dot(a, b) * a_s * b_s[None, :]
93
+ a_ptrs += BLOCK_SIZE_K * stride_ak
94
+ b_ptrs += BLOCK_SIZE_K * stride_bk
95
+
96
+ if C.dtype.element_ty == tl.bfloat16:
97
+ c = accumulator.to(tl.bfloat16)
98
+ elif C.dtype.element_ty == tl.float16:
99
+ c = accumulator.to(tl.float16)
100
+ else:
101
+ c = accumulator.to(tl.float32)
102
+
103
+ offs_cm = tl.arange(0, BLOCK_SIZE_M)
104
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
105
+ c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
106
+ tl.store(c_ptrs, c)
107
+
108
+
109
+ @triton.autotune(
110
+ configs=[
111
+ triton.Config({}, num_warps=w, num_stages=s)
112
+ for w in [2, 4, 8, 16]
113
+ for s in [2, 3, 4, 5]
114
+ ],
115
+ key=["N", "K"],
116
+ )
117
+ @triton.jit
118
+ def w8a8_tensor_fp8_matmul_batched_kernel(
119
+ A, # (S, K) pre-quantized FP8 activations
120
+ B, # (E, N, K) FP8 weight matrices
121
+ C, # (S, N) output
122
+ As, # (S, 1) per-tensor activation scales
123
+ Bs, # (E, 1, 1) per-tensor weight scales
124
+ ExpertIds,
125
+ S,
126
+ N,
127
+ K,
128
+ stride_am,
129
+ stride_ak,
130
+ stride_be,
131
+ stride_bk,
132
+ stride_bn,
133
+ stride_cm,
134
+ stride_cn,
135
+ stride_as_m,
136
+ stride_bs_e,
137
+ BLOCK_SIZE_N: tl.constexpr,
138
+ BLOCK_SIZE_K: tl.constexpr,
139
+ BLOCK_SIZE_M: tl.constexpr,
140
+ ):
141
+ """Tensor-scale batched FP8 expert matmul kernel.
142
+
143
+ Activations are already quantized; the kernel applies per-token activation
144
+ scales and per-expert tensor weight scales.
145
+ """
146
+ batch_id = tl.program_id(axis=0)
147
+ pid_n = tl.program_id(axis=1)
148
+
149
+ expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
150
+
151
+ A = A + batch_id * stride_am
152
+ B = B + expert_id * stride_be
153
+ C = C + batch_id * stride_cm
154
+ Bs = Bs + expert_id * stride_bs_e
155
+
156
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
157
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
158
+ a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
159
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
160
+
161
+ b_s = tl.load(Bs)
162
+ a_s = tl.load(As + batch_id * stride_as_m)
163
+
164
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
165
+ for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
166
+ a = tl.load(a_ptrs)
167
+ b = tl.load(b_ptrs)
168
+ accumulator += tl.dot(a, b)
169
+ a_ptrs += BLOCK_SIZE_K * stride_ak
170
+ b_ptrs += BLOCK_SIZE_K * stride_bk
171
+
172
+ accumulator = accumulator * a_s * b_s
173
+
174
+ if C.dtype.element_ty == tl.bfloat16:
175
+ c = accumulator.to(tl.bfloat16)
176
+ elif C.dtype.element_ty == tl.float16:
177
+ c = accumulator.to(tl.float16)
178
+ else:
179
+ c = accumulator.to(tl.float32)
180
+
181
+ offs_cm = tl.arange(0, BLOCK_SIZE_M)
182
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
183
+ c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
184
+ tl.store(c_ptrs, c)
185
+
186
+
187
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul_batched", mutates_args=())
188
+ def _w8a8_block_fp8_matmul_batched(
189
+ A: torch.Tensor,
190
+ B: torch.Tensor,
191
+ Bs: torch.Tensor,
192
+ expert_ids: torch.Tensor,
193
+ block_size: list[int],
194
+ ) -> torch.Tensor:
195
+ """Block-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
196
+
197
+ A: (S, K) raw bf16/fp16 activations
198
+ B: (E, N, K) FP8 expert weights
199
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
200
+ """
201
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
202
+ assert A.is_contiguous(), "A must be contiguous"
203
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
204
+ assert B.is_contiguous(), "B must be contiguous"
205
+ assert A.shape[1] == B.shape[2], (
206
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
207
+ )
208
+
209
+ S, K = A.shape
210
+ E, N, _ = B.shape
211
+
212
+ assert len(block_size) == 2, (
213
+ f"block_size must be [block_n, block_k], got {block_size}"
214
+ )
215
+ block_n, block_k = block_size[0], block_size[1]
216
+ # MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
217
+ assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
218
+ assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
219
+ assert Bs.ndim == 3, (
220
+ f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
221
+ )
222
+ assert Bs.shape == (E, N // block_n, K // block_k), (
223
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
224
+ )
225
+
226
+ C = A.new_empty(S, N)
227
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
228
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
229
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
230
+ # both accuracy and performance for small M (decode).
231
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
232
+ grid = (S, triton.cdiv(N, block_n))
233
+ with device_context(A.device):
234
+ wrap_triton(w8a8_block_fp8_matmul_batched_kernel)[grid](
235
+ A,
236
+ B,
237
+ C,
238
+ Bs,
239
+ expert_ids,
240
+ S,
241
+ N,
242
+ K,
243
+ A.stride(0),
244
+ A.stride(1),
245
+ B.stride(0),
246
+ B.stride(2),
247
+ B.stride(1),
248
+ C.stride(0),
249
+ C.stride(1),
250
+ Bs.stride(0),
251
+ Bs.stride(2),
252
+ Bs.stride(1),
253
+ BLOCK_SIZE_N=block_n,
254
+ BLOCK_SIZE_K=block_k,
255
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
256
+ )
257
+
258
+ return C
259
+
260
+
261
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_batched", mutates_args=())
262
+ def _w8a8_tensor_fp8_matmul_batched(
263
+ A: torch.Tensor,
264
+ B: torch.Tensor,
265
+ Bs: torch.Tensor,
266
+ expert_ids: torch.Tensor,
267
+ ) -> torch.Tensor:
268
+ """Tensor-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
269
+
270
+ A: (S, K) raw bf16/fp16 activations
271
+ B: (E, N, K) FP8 expert weights
272
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
273
+ """
274
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
275
+ assert A.is_contiguous(), "A must be contiguous"
276
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
277
+ assert B.is_contiguous(), "B must be contiguous"
278
+ assert A.shape[1] == B.shape[2], (
279
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
280
+ )
281
+
282
+ S, K = A.shape
283
+ E, N, _ = B.shape
284
+
285
+ # Normalize Bs to (E, 1, 1)
286
+ if Bs.ndim == 1:
287
+ assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
288
+ Bs = Bs.reshape(E, 1, 1)
289
+ else:
290
+ assert Bs.shape == (E, 1, 1), (
291
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
292
+ )
293
+
294
+ BLOCK_SIZE_N = 128
295
+ BLOCK_SIZE_K = 128
296
+ C = A.new_empty(S, N)
297
+ qA, As = fp8_act_quant(A, K)
298
+ grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
299
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
300
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
301
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
302
+ # both accuracy and performance for small M (decode).
303
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
304
+ grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
305
+ with device_context(A.device):
306
+ wrap_triton(w8a8_tensor_fp8_matmul_batched_kernel)[grid](
307
+ qA,
308
+ B,
309
+ C,
310
+ As,
311
+ Bs,
312
+ expert_ids,
313
+ S,
314
+ N,
315
+ K,
316
+ qA.stride(0),
317
+ qA.stride(1),
318
+ B.stride(0),
319
+ B.stride(2),
320
+ B.stride(1),
321
+ C.stride(0),
322
+ C.stride(1),
323
+ As.stride(0),
324
+ Bs.stride(0),
325
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
326
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
327
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
328
+ )
329
+
330
+ return C
331
+
332
+
333
+ def w8a8_block_fp8_matmul_batched(
334
+ A: torch.Tensor,
335
+ B: torch.Tensor,
336
+ Bs: torch.Tensor,
337
+ expert_ids: torch.Tensor,
338
+ block_size: list[int],
339
+ ) -> torch.Tensor:
340
+ """Block-scale batched FP8 matmul with fused activation quantization.
341
+
342
+ A: (S, K) raw activations, bf16/fp16/fp32
343
+ B: (E, N, K) FP8 expert weights
344
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
345
+ """
346
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_batched(
347
+ A, B, Bs, expert_ids, block_size
348
+ )
349
+
350
+
351
+ def w8a8_tensor_fp8_matmul_batched(
352
+ A: torch.Tensor,
353
+ B: torch.Tensor,
354
+ Bs: torch.Tensor,
355
+ expert_ids: torch.Tensor,
356
+ ) -> torch.Tensor:
357
+ """Tensor-scale batched FP8 matmul with fused activation quantization.
358
+
359
+ A: (S, K) raw activations, bf16/fp16/fp32
360
+ B: (E, N, K) FP8 expert weights
361
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
362
+ """
363
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_batched(
364
+ A, B, Bs, expert_ids
365
+ )
366
+
367
+
368
+ def w8a8_fp8_matmul_batched(
369
+ A: torch.Tensor,
370
+ B: torch.Tensor,
371
+ Bs: torch.Tensor,
372
+ expert_ids: torch.Tensor,
373
+ block_size: list[int] | None,
374
+ ) -> torch.Tensor:
375
+ """Unified batched W8A8 FP8 matmul dispatcher.
376
+
377
+ Dispatch rules:
378
+ - tensor mode when ``block_size is None``
379
+ - tensor mode when ``block_size == [N, K]``
380
+ - otherwise block mode
381
+
382
+ Returns:
383
+ Output tensor ``[S, N]`` in the same dtype as ``A``.
384
+ """
385
+ if block_size is None or (
386
+ block_size[0] == B.size(1) and block_size[1] == B.size(2)
387
+ ):
388
+ return w8a8_tensor_fp8_matmul_batched(A, B, Bs, expert_ids)
389
+
390
+ return w8a8_block_fp8_matmul_batched(A, B, Bs, expert_ids, block_size)
build/torch-rocm/finegrained_fp8/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-rocm/grouped.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from .act_quant import fp8_act_quant
19
+ from torch.library import triton_op, wrap_triton
20
+
21
+ from .utils import device_context
22
+
23
+
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4, 5]
29
+ ],
30
+ key=["N", "K", "BLOCK_SIZE_M"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_grouped_kernel(
34
+ A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
35
+ B, # (E, N, K) FP8 weight matrices
36
+ C, # (S, N) output
37
+ Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
38
+ Offsets, # (E,) int32 — cumulative row-end per expert
39
+ TileOffsets, # (E,) int32 — cumulative tile-end per expert
40
+ # Shape
41
+ S,
42
+ N,
43
+ K,
44
+ # Strides
45
+ stride_am,
46
+ stride_ak,
47
+ stride_be,
48
+ stride_bk,
49
+ stride_bn,
50
+ stride_cm,
51
+ stride_cn,
52
+ stride_bs_e,
53
+ stride_bs_k,
54
+ stride_bs_n,
55
+ # Meta-parameters
56
+ NUM_EXPERTS: tl.constexpr,
57
+ BLOCK_SIZE_N: tl.constexpr,
58
+ BLOCK_SIZE_K: tl.constexpr,
59
+ BLOCK_SIZE_M: tl.constexpr,
60
+ NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
61
+ ):
62
+ """Block-scale grouped FP8 expert matmul kernel.
63
+
64
+ Tokens are assumed sorted by expert. The kernel maps each M-tile to its
65
+ owning expert via ``TileOffsets`` and applies fused activation quantization.
66
+ """
67
+ pid_m = tl.program_id(axis=0)
68
+ pid_n = tl.program_id(axis=1)
69
+
70
+ # Exit early for programs beyond the actual tile count.
71
+ total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
72
+ if pid_m >= total_tiles:
73
+ return
74
+
75
+ # Binary search in TileOffsets to find the owning expert.
76
+ # Finds the smallest e such that TileOffsets[e] > pid_m (upper_bound semantics),
77
+ # which is the expert whose tile range contains pid_m.
78
+ # O(log2(NUM_EXPERTS)) loads instead of the O(NUM_EXPERTS) linear scan.
79
+ # NUM_EXPERTS_BIT_LENGTH is ceil(log2(E))+1 for powers-of-two, giving one
80
+ # harmless extra iteration when lo==hi; it's a compile-time constant so the
81
+ # loop is fully unrolled by the compiler.
82
+ lo = 0
83
+ hi = NUM_EXPERTS
84
+ for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
85
+ mid = (lo + hi) >> 1
86
+ mid_val = tl.load(TileOffsets + mid)
87
+ is_left = mid_val <= pid_m
88
+ lo = tl.where(is_left, mid + 1, lo)
89
+ hi = tl.where(is_left, hi, mid)
90
+
91
+ # Cast expert_id to int64 to prevent int32 overflow when computing
92
+ # expert_id * stride_be (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
93
+ # 3072×3072 FP8 weights).
94
+ expert_id = lo.to(tl.int64)
95
+
96
+ prev_eid = tl.maximum(expert_id - 1, 0)
97
+ expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
98
+ expert_end = tl.load(Offsets + expert_id)
99
+ M_expert = expert_end - expert_start
100
+
101
+ expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
102
+ local_tile = pid_m - expert_tile_start
103
+ m_off = local_tile * BLOCK_SIZE_M
104
+
105
+ offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
106
+ row_mask = offs_am < M_expert
107
+ offs_global_m = expert_start + offs_am
108
+
109
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
110
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
111
+
112
+ a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
113
+ b_ptrs = (
114
+ B
115
+ + expert_id * stride_be
116
+ + offs_k[:, None] * stride_bk
117
+ + offs_bn[None, :] * stride_bn
118
+ )
119
+ bs_ptrs = Bs + expert_id * stride_bs_e + pid_n * stride_bs_n
120
+
121
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
122
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
123
+ # ---- fused fp8_act_quant ----
124
+ a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32)
125
+ a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0
126
+ a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
127
+ # ---- matmul ----
128
+ b = tl.load(b_ptrs)
129
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
130
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
131
+ a_ptrs += BLOCK_SIZE_K * stride_ak
132
+ b_ptrs += BLOCK_SIZE_K * stride_bk
133
+
134
+ if C.dtype.element_ty == tl.bfloat16:
135
+ c = accumulator.to(tl.bfloat16)
136
+ elif C.dtype.element_ty == tl.float16:
137
+ c = accumulator.to(tl.float16)
138
+ else:
139
+ c = accumulator.to(tl.float32)
140
+
141
+ c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
142
+ c_mask = row_mask[:, None]
143
+ tl.store(c_ptrs, c, mask=c_mask)
144
+
145
+
146
+ @triton.autotune(
147
+ configs=[
148
+ triton.Config({}, num_warps=w, num_stages=s)
149
+ for w in [2, 4, 8, 16]
150
+ for s in [2, 3, 4, 5]
151
+ ],
152
+ key=["N", "K", "BLOCK_SIZE_M"],
153
+ )
154
+ @triton.jit
155
+ def w8a8_tensor_fp8_matmul_grouped_kernel(
156
+ A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert idc
157
+ B, # (E, N, K) FP8 weight matrices
158
+ C, # (S, N) output
159
+ As, # (S, 1) activation scales
160
+ Bs, # (E, 1, 1) per-tensor weight scales
161
+ Offsets,
162
+ TileOffsets,
163
+ S,
164
+ N,
165
+ K,
166
+ stride_am,
167
+ stride_ak,
168
+ stride_be,
169
+ stride_bk,
170
+ stride_bn,
171
+ stride_cm,
172
+ stride_cn,
173
+ stride_as_m,
174
+ stride_bs_e,
175
+ NUM_EXPERTS: tl.constexpr,
176
+ BLOCK_SIZE_N: tl.constexpr,
177
+ BLOCK_SIZE_K: tl.constexpr,
178
+ BLOCK_SIZE_M: tl.constexpr,
179
+ NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
180
+ ):
181
+ """Tensor-scale grouped FP8 expert matmul kernel.
182
+
183
+ Uses grouped expert scheduling with pre-quantized activations plus
184
+ per-token activation scales and per-expert tensor weight scales.
185
+ """
186
+ pid_m = tl.program_id(axis=0)
187
+ pid_n = tl.program_id(axis=1)
188
+
189
+ total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
190
+ if pid_m >= total_tiles:
191
+ return
192
+
193
+ lo = 0
194
+ hi = NUM_EXPERTS
195
+ for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
196
+ mid = (lo + hi) >> 1
197
+ mid_val = tl.load(TileOffsets + mid)
198
+ is_left = mid_val <= pid_m
199
+ lo = tl.where(is_left, mid + 1, lo)
200
+ hi = tl.where(is_left, hi, mid)
201
+ expert_id = lo.to(tl.int64)
202
+
203
+ prev_eid = tl.maximum(expert_id - 1, 0)
204
+ expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
205
+ expert_end = tl.load(Offsets + expert_id)
206
+ M_expert = expert_end - expert_start
207
+
208
+ expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
209
+ local_tile = pid_m - expert_tile_start
210
+ m_off = local_tile * BLOCK_SIZE_M
211
+
212
+ offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
213
+ row_mask = offs_am < M_expert
214
+ offs_global_m = expert_start + offs_am
215
+
216
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
218
+
219
+ a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
220
+ b_ptrs = (
221
+ B
222
+ + expert_id * stride_be
223
+ + offs_k[:, None] * stride_bk
224
+ + offs_bn[None, :] * stride_bn
225
+ )
226
+
227
+ a_s = tl.load(As + offs_global_m * stride_as_m, mask=row_mask, other=0.0)
228
+ b_s = tl.load(Bs + expert_id * stride_bs_e)
229
+
230
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
231
+ for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
232
+ a = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0)
233
+ b = tl.load(b_ptrs)
234
+
235
+ accumulator += tl.dot(a, b)
236
+ a_ptrs += BLOCK_SIZE_K * stride_ak
237
+ b_ptrs += BLOCK_SIZE_K * stride_bk
238
+
239
+ accumulator = accumulator * a_s[:, None] * b_s
240
+
241
+ if C.dtype.element_ty == tl.bfloat16:
242
+ c = accumulator.to(tl.bfloat16)
243
+ elif C.dtype.element_ty == tl.float16:
244
+ c = accumulator.to(tl.float16)
245
+ else:
246
+ c = accumulator.to(tl.float32)
247
+
248
+ c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
249
+ c_mask = row_mask[:, None]
250
+ tl.store(c_ptrs, c, mask=c_mask)
251
+
252
+
253
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul_grouped", mutates_args=())
254
+ def _w8a8_block_fp8_matmul_grouped(
255
+ A: torch.Tensor,
256
+ B: torch.Tensor,
257
+ Bs: torch.Tensor,
258
+ offsets: torch.Tensor,
259
+ tokens_per_expert: torch.Tensor,
260
+ block_size: list[int],
261
+ ) -> torch.Tensor:
262
+ """Block-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
263
+
264
+ A: (S, K) raw bf16/fp16 activations, sorted by expert
265
+ B: (E, N, K) FP8 expert weights
266
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
267
+ """
268
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
269
+ assert A.is_contiguous(), "A must be contiguous"
270
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
271
+ assert B.is_contiguous(), "B must be contiguous"
272
+ assert A.shape[1] == B.shape[2], (
273
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
274
+ )
275
+
276
+ S, K = A.shape
277
+ E, N, _ = B.shape
278
+
279
+ assert len(block_size) == 2, (
280
+ f"block_size must be [block_n, block_k], got {block_size}"
281
+ )
282
+ block_n, block_k = block_size[0], block_size[1]
283
+ # MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
284
+ assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
285
+ assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
286
+ assert Bs.ndim == 3, (
287
+ f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
288
+ )
289
+ assert Bs.shape == (E, N // block_n, K // block_k), (
290
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
291
+ )
292
+
293
+ C = A.new_empty(S, N)
294
+ # Adaptive BLOCK_SIZE_M: match tile to average tokens per expert.
295
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
296
+ tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
297
+ tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
298
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
299
+ # Programs beyond the real tile count exit immediately via the early-return
300
+ # guard inside the kernel. This is faster than syncing for the exact count
301
+ # and keeps the grid size data-independent (cuda-graph / torch.compile safe).
302
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
303
+ grid = (max_M_tiles, triton.cdiv(N, block_n))
304
+ with device_context(A.device):
305
+ wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
306
+ A,
307
+ B,
308
+ C,
309
+ Bs,
310
+ offsets,
311
+ tile_offsets,
312
+ S,
313
+ N,
314
+ K,
315
+ A.stride(0),
316
+ A.stride(1),
317
+ B.stride(0),
318
+ B.stride(2),
319
+ B.stride(1),
320
+ C.stride(0),
321
+ C.stride(1),
322
+ Bs.stride(0),
323
+ Bs.stride(2),
324
+ Bs.stride(1),
325
+ # Meta-parameters
326
+ NUM_EXPERTS=E,
327
+ BLOCK_SIZE_N=block_n,
328
+ BLOCK_SIZE_K=block_k,
329
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
330
+ NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
331
+ )
332
+
333
+ return C
334
+
335
+
336
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_grouped", mutates_args=())
337
+ def _w8a8_tensor_fp8_matmul_grouped(
338
+ A: torch.Tensor,
339
+ B: torch.Tensor,
340
+ Bs: torch.Tensor,
341
+ offsets: torch.Tensor,
342
+ tokens_per_expert: torch.Tensor,
343
+ ) -> torch.Tensor:
344
+ """Tensor-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
345
+
346
+ A: (S, K) raw bf16/fp16 activations, sorted by expert
347
+ B: (E, N, K) FP8 expert weights
348
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
349
+ """
350
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
351
+ assert A.is_contiguous(), "A must be contiguous"
352
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
353
+ assert B.is_contiguous(), "B must be contiguous"
354
+ assert A.shape[1] == B.shape[2], (
355
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
356
+ )
357
+
358
+ S, K = A.shape
359
+ E, N, _ = B.shape
360
+
361
+ # Normalize Bs to (E, 1, 1)
362
+ if Bs.ndim == 1:
363
+ assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
364
+ Bs = Bs.reshape(E, 1, 1)
365
+ else:
366
+ assert Bs.shape == (E, 1, 1), (
367
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
368
+ )
369
+
370
+ BLOCK_SIZE_N = 128
371
+ BLOCK_SIZE_K = 128
372
+ C = A.new_empty(S, N)
373
+ qA, As = fp8_act_quant(A, K)
374
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
375
+ tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
376
+ tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
377
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
378
+ # Programs beyond the real tile count exit immediately via the early-return
379
+ # guard inside the kernel. This is faster than syncing for the exact count
380
+ # and keeps the grid size data-independent (cuda-graph / torch.compile safe).
381
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
382
+ grid = (max_M_tiles, triton.cdiv(N, BLOCK_SIZE_N))
383
+ with device_context(A.device):
384
+ wrap_triton(w8a8_tensor_fp8_matmul_grouped_kernel)[grid](
385
+ qA,
386
+ B,
387
+ C,
388
+ As,
389
+ Bs,
390
+ offsets,
391
+ tile_offsets,
392
+ S,
393
+ N,
394
+ K,
395
+ qA.stride(0),
396
+ qA.stride(1),
397
+ B.stride(0),
398
+ B.stride(2),
399
+ B.stride(1),
400
+ C.stride(0),
401
+ C.stride(1),
402
+ As.stride(0),
403
+ Bs.stride(0),
404
+ # Meta-parameters
405
+ NUM_EXPERTS=E,
406
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
407
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
408
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
409
+ NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
410
+ )
411
+
412
+ return C
413
+
414
+
415
+ def w8a8_block_fp8_matmul_grouped(
416
+ A: torch.Tensor,
417
+ B: torch.Tensor,
418
+ Bs: torch.Tensor,
419
+ offsets: torch.Tensor,
420
+ tokens_per_expert: torch.Tensor,
421
+ block_size: list[int],
422
+ ) -> torch.Tensor:
423
+ """Block-scale grouped FP8 matmul with fused activation quantization.
424
+
425
+ A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
426
+ B: (E, N, K) FP8 expert weights
427
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
428
+ """
429
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
430
+ A, B, Bs, offsets, tokens_per_expert, block_size
431
+ )
432
+
433
+
434
+ def w8a8_tensor_fp8_matmul_grouped(
435
+ A: torch.Tensor,
436
+ B: torch.Tensor,
437
+ Bs: torch.Tensor,
438
+ offsets: torch.Tensor,
439
+ tokens_per_expert: torch.Tensor,
440
+ ) -> torch.Tensor:
441
+ """Tensor-scale grouped FP8 matmul with fused activation quantization.
442
+
443
+ A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
444
+ B: (E, N, K) FP8 expert weights
445
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
446
+ """
447
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
448
+ A, B, Bs, offsets, tokens_per_expert
449
+ )
450
+
451
+
452
+ def w8a8_fp8_matmul_grouped(
453
+ A: torch.Tensor,
454
+ B: torch.Tensor,
455
+ Bs: torch.Tensor,
456
+ offsets: torch.Tensor,
457
+ tokens_per_expert: torch.Tensor,
458
+ block_size: list[int] | None,
459
+ ) -> torch.Tensor:
460
+ """Unified grouped W8A8 FP8 matmul dispatcher.
461
+
462
+ Dispatch rules:
463
+ - tensor mode when ``block_size is None``
464
+ - tensor mode when ``block_size == [N, K]``
465
+ - otherwise block mode
466
+
467
+ Returns:
468
+ Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
469
+ """
470
+ if block_size is None or (
471
+ block_size[0] == B.size(1) and block_size[1] == B.size(2)
472
+ ):
473
+ return w8a8_tensor_fp8_matmul_grouped(A, B, Bs, offsets, tokens_per_expert)
474
+
475
+ return w8a8_block_fp8_matmul_grouped(
476
+ A, B, Bs, offsets, tokens_per_expert, block_size
477
+ )
build/torch-rocm/matmul.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from torch.library import triton_op, wrap_triton
19
+
20
+ from .utils import device_context
21
+
22
+
23
+ # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4]
29
+ ],
30
+ key=["N", "K", "BLOCK_SIZE_M"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_kernel(
34
+ # Pointers to inputs and output
35
+ A,
36
+ B,
37
+ C,
38
+ As,
39
+ Bs,
40
+ # Shape for matmul
41
+ M,
42
+ N,
43
+ K,
44
+ stride_am,
45
+ stride_ak,
46
+ stride_bk,
47
+ stride_bn,
48
+ stride_cm,
49
+ stride_cn,
50
+ stride_as_m,
51
+ stride_as_k,
52
+ stride_bs_k,
53
+ stride_bs_n,
54
+ # Meta-parameters
55
+ BLOCK_SIZE_M: tl.constexpr,
56
+ BLOCK_SIZE_N: tl.constexpr,
57
+ BLOCK_SIZE_K: tl.constexpr,
58
+ GROUP_SIZE_M: tl.constexpr,
59
+ ):
60
+ """Block-scale FP8 GEMM kernel.
61
+
62
+ Computes ``C = A @ B.T`` with block-wise activation/weight scales.
63
+ Uses a 2D grid with swizzle for L2 cache locality on B tiles.
64
+ """
65
+ pid_m = tl.program_id(axis=0)
66
+ pid_n = tl.program_id(axis=1)
67
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
68
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
69
+ pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
70
+
71
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
72
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
73
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
74
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
75
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
76
+
77
+ as_ptrs = As + offs_am * stride_as_m
78
+ offs_bsn = offs_bn // BLOCK_SIZE_N
79
+ bs_ptrs = Bs + offs_bsn * stride_bs_n
80
+
81
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
82
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
83
+ k_remaining = K - k * BLOCK_SIZE_K
84
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
85
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
86
+
87
+ a_s = tl.load(as_ptrs + k * stride_as_k)
88
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
89
+
90
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
91
+ a_ptrs += BLOCK_SIZE_K * stride_ak
92
+ b_ptrs += BLOCK_SIZE_K * stride_bk
93
+
94
+ if C.dtype.element_ty == tl.bfloat16:
95
+ c = accumulator.to(tl.bfloat16)
96
+ elif C.dtype.element_ty == tl.float16:
97
+ c = accumulator.to(tl.float16)
98
+ else:
99
+ c = accumulator.to(tl.float32)
100
+
101
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
102
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
103
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
104
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
105
+ tl.store(c_ptrs, c, mask=c_mask)
106
+
107
+
108
+ @triton.autotune(
109
+ configs=[
110
+ triton.Config({}, num_warps=w, num_stages=s)
111
+ for w in [2, 4, 8, 16]
112
+ for s in [2, 3, 4]
113
+ ],
114
+ key=["N", "K", "BLOCK_SIZE_M"],
115
+ )
116
+ @triton.jit
117
+ def w8a8_tensor_fp8_matmul_kernel(
118
+ A,
119
+ B,
120
+ C,
121
+ As,
122
+ Bs,
123
+ M,
124
+ N,
125
+ K,
126
+ stride_am,
127
+ stride_ak,
128
+ stride_bk,
129
+ stride_bn,
130
+ stride_cm,
131
+ stride_cn,
132
+ stride_as_m,
133
+ BLOCK_SIZE_M: tl.constexpr,
134
+ BLOCK_SIZE_N: tl.constexpr,
135
+ BLOCK_SIZE_K: tl.constexpr,
136
+ GROUP_SIZE_M: tl.constexpr,
137
+ ):
138
+ """Tensor-scale FP8 GEMM kernel.
139
+
140
+ Computes ``C = A @ B.T`` with one activation scale per row and one
141
+ weight scale for the full matrix.
142
+ Uses a 2D grid with swizzle for L2 cache locality on B tiles.
143
+ """
144
+ pid_m = tl.program_id(axis=0)
145
+ pid_n = tl.program_id(axis=1)
146
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
147
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
148
+ pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
149
+
150
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
151
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
152
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
153
+
154
+ a_ptrs = A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
155
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
156
+
157
+ a_s = tl.load(As + offs_am * stride_as_m)
158
+ b_s = tl.load(Bs)
159
+
160
+ # Accumulate raw dot products, apply scales once after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ k_remaining = K - k * BLOCK_SIZE_K
164
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
165
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
166
+ accumulator += tl.dot(a, b)
167
+ a_ptrs += BLOCK_SIZE_K * stride_ak
168
+ b_ptrs += BLOCK_SIZE_K * stride_bk
169
+
170
+ accumulator = accumulator * a_s[:, None] * b_s
171
+
172
+ if C.dtype.element_ty == tl.bfloat16:
173
+ c = accumulator.to(tl.bfloat16)
174
+ elif C.dtype.element_ty == tl.float16:
175
+ c = accumulator.to(tl.float16)
176
+ else:
177
+ c = accumulator.to(tl.float32)
178
+
179
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
180
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
181
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
182
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
183
+ tl.store(c_ptrs, c, mask=c_mask)
184
+
185
+
186
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul", mutates_args=())
187
+ def _w8a8_block_fp8_matmul(
188
+ A: torch.Tensor,
189
+ B: torch.Tensor,
190
+ As: torch.Tensor,
191
+ Bs: torch.Tensor,
192
+ block_size: list[int],
193
+ output_dtype: torch.dtype = torch.float32,
194
+ ) -> torch.Tensor:
195
+ """Block-scale FP8 matmul: C = A @ B.T with per-block scales.
196
+
197
+ As: (M, K // block_k) — per-token-group activation scales
198
+ Bs: (N // block_n, K // block_k) — per-block weight scales
199
+ """
200
+ assert len(block_size) == 2, (
201
+ f"block_size must be [block_n, block_k], got {block_size}"
202
+ )
203
+ block_n, block_k = block_size[0], block_size[1]
204
+
205
+ assert A.shape[-1] == B.shape[-1], (
206
+ f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
207
+ )
208
+ assert A.is_contiguous(), "A must be contiguous"
209
+ assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
210
+ assert B.is_contiguous(), "B must be contiguous"
211
+
212
+ N, K = B.shape
213
+ M = A.numel() // A.shape[-1]
214
+
215
+ assert As.ndim >= 2, f"As must be at least 2D, got ndim={As.ndim}"
216
+ assert As.shape[-1] == triton.cdiv(K, block_k), (
217
+ f"As last dim {As.shape[-1]} != expected {triton.cdiv(K, block_k)} (cdiv(K={K}, block_k={block_k}))"
218
+ )
219
+ assert Bs.ndim == 2, f"Bs must be 2D (N//block_n, K//block_k), got ndim={Bs.ndim}"
220
+ assert Bs.shape == (triton.cdiv(N, block_n), triton.cdiv(K, block_k)), (
221
+ f"Bs shape {tuple(Bs.shape)} != expected ({triton.cdiv(N, block_n)}, {triton.cdiv(K, block_k)})"
222
+ )
223
+
224
+ BLOCK_SIZE_K = block_k
225
+ BLOCK_SIZE_N = block_n
226
+ C_shape = A.shape[:-1] + (N,)
227
+ C = A.new_empty(C_shape, dtype=output_dtype)
228
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
229
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
230
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
231
+ # both accuracy and performance for small M (decode).
232
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
233
+ grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
234
+ with device_context(A.device):
235
+ wrap_triton(w8a8_block_fp8_matmul_kernel)[grid](
236
+ A,
237
+ B,
238
+ C,
239
+ As,
240
+ Bs,
241
+ M,
242
+ N,
243
+ K,
244
+ A.stride(-2),
245
+ A.stride(-1),
246
+ B.stride(1),
247
+ B.stride(0),
248
+ C.stride(-2),
249
+ C.stride(-1),
250
+ As.stride(-2),
251
+ As.stride(-1),
252
+ Bs.stride(1),
253
+ Bs.stride(0),
254
+ # Meta-parameters
255
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
256
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
257
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
258
+ GROUP_SIZE_M=8,
259
+ )
260
+
261
+ return C
262
+
263
+
264
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul", mutates_args=())
265
+ def _w8a8_tensor_fp8_matmul(
266
+ A: torch.Tensor,
267
+ B: torch.Tensor,
268
+ As: torch.Tensor,
269
+ Bs: torch.Tensor,
270
+ output_dtype: torch.dtype = torch.float32,
271
+ ) -> torch.Tensor:
272
+ """Tensor-scale FP8 matmul: C = A @ B.T with per-row / per-tensor scales.
273
+
274
+ As: scalar, (M,), or (M, 1) — per-row activation scales
275
+ Bs: scalar, (1,), or (1, 1) — single weight scale
276
+ """
277
+ assert A.shape[-1] == B.shape[-1], (
278
+ f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
279
+ )
280
+ assert A.is_contiguous(), "A must be contiguous"
281
+ assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
282
+ assert B.is_contiguous(), "B must be contiguous"
283
+
284
+ N, K = B.shape
285
+ M = A.numel() // A.shape[-1]
286
+
287
+ # Normalize As to (M,)
288
+ if As.numel() == 1:
289
+ As = As.reshape(1).expand(M).contiguous()
290
+ elif As.ndim == 2:
291
+ As = As.reshape(M)
292
+ assert As.ndim == 1 and As.shape[0] == M, (
293
+ f"As must be scalar, (M,), or (M,1) with M={M}, got {tuple(As.shape)}"
294
+ )
295
+
296
+ # Normalize Bs to (1,)
297
+ assert Bs.numel() == 1, f"Bs must be scalar or (1,), got {tuple(Bs.shape)}"
298
+ Bs = Bs.reshape(1)
299
+
300
+ BLOCK_SIZE_N = 128
301
+ BLOCK_SIZE_K = 128
302
+ C_shape = A.shape[:-1] + (N,)
303
+ C = A.new_empty(C_shape, dtype=output_dtype)
304
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
305
+ grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
306
+ with device_context(A.device):
307
+ wrap_triton(w8a8_tensor_fp8_matmul_kernel)[grid](
308
+ A,
309
+ B,
310
+ C,
311
+ As,
312
+ Bs,
313
+ M,
314
+ N,
315
+ K,
316
+ A.stride(-2),
317
+ A.stride(-1),
318
+ B.stride(1),
319
+ B.stride(0),
320
+ C.stride(-2),
321
+ C.stride(-1),
322
+ As.stride(0),
323
+ # Meta-parameters
324
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
325
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
326
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
327
+ GROUP_SIZE_M=8,
328
+ )
329
+
330
+ return C
331
+
332
+
333
+ def w8a8_block_fp8_matmul(
334
+ A: torch.Tensor,
335
+ B: torch.Tensor,
336
+ As: torch.Tensor,
337
+ Bs: torch.Tensor,
338
+ block_size: list[int],
339
+ output_dtype: torch.dtype = torch.float32,
340
+ ) -> torch.Tensor:
341
+ """Block-wise W8A8 FP8 matrix multiplication.
342
+
343
+ Computes ``C = A @ B.T`` where both operands are pre-quantized to
344
+ ``float8_e4m3fn`` with per-block scales, and accumulates in float32
345
+ before casting to ``output_dtype``.
346
+
347
+ Args:
348
+ A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
349
+ B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
350
+ As: Per-token-group activation scales ``[M, K // block_size[1]]``.
351
+ Bs: Per-block weight scales ``[N // block_size[0], K // block_size[1]]``.
352
+ block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
353
+ output_dtype: dtype of the returned tensor (default: ``torch.float32``).
354
+
355
+ Returns:
356
+ Output tensor ``[M, N]`` in ``output_dtype``.
357
+ """
358
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul(
359
+ A, B, As, Bs, block_size, output_dtype
360
+ )
361
+
362
+
363
+ def w8a8_tensor_fp8_matmul(
364
+ A: torch.Tensor,
365
+ B: torch.Tensor,
366
+ As: torch.Tensor,
367
+ Bs: torch.Tensor,
368
+ output_dtype: torch.dtype = torch.float32,
369
+ ) -> torch.Tensor:
370
+ """Tensor-scale W8A8 FP8 matrix multiplication.
371
+
372
+ Computes ``C = A @ B.T`` in tensor-scale mode using pre-quantized FP8
373
+ activations/weights and tensor scales.
374
+
375
+ Args:
376
+ A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
377
+ B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
378
+ As: Per-row activation scales ``[M]``.
379
+ Bs: Single weight scale, scalar or ``[1]``.
380
+ output_dtype: dtype of the returned tensor.
381
+
382
+ Returns:
383
+ Output tensor ``[M, N]`` in ``output_dtype``.
384
+ """
385
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
386
+
387
+
388
+ def w8a8_fp8_matmul(
389
+ A: torch.Tensor,
390
+ B: torch.Tensor,
391
+ As: torch.Tensor,
392
+ Bs: torch.Tensor,
393
+ block_size: list[int] | None,
394
+ output_dtype: torch.dtype = torch.float32,
395
+ ) -> torch.Tensor:
396
+ """Unified W8A8 FP8 matmul dispatcher.
397
+
398
+ Dispatch rules:
399
+ - tensor mode when ``block_size is None``
400
+ - tensor mode when ``block_size == [N, K]``
401
+ - otherwise block mode
402
+
403
+ Returns:
404
+ Output tensor ``[M, N]`` in ``output_dtype``.
405
+ """
406
+ if block_size is None or (
407
+ block_size[0] == B.size(0) and block_size[1] == B.size(1)
408
+ ):
409
+ return w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
410
+
411
+ return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
build/torch-rocm/metadata.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "rocm"
7
+ }
8
+ }
build/torch-rocm/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+
4
+
5
+ @contextmanager
6
+ def device_context(device: torch.device):
7
+ """Context manager that sets the active device for any backend (cuda, xpu, etc.)."""
8
+ backend = getattr(torch, device.type, None)
9
+ if backend is not None and hasattr(backend, "device"):
10
+ with backend.device(device):
11
+ yield
12
+ else:
13
+ yield
build/torch-xpu/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .act_quant import fp8_act_quant
2
+ from .batched import (
3
+ w8a8_fp8_matmul_batched,
4
+ w8a8_block_fp8_matmul_batched,
5
+ w8a8_tensor_fp8_matmul_batched,
6
+ )
7
+ from .grouped import (
8
+ w8a8_fp8_matmul_grouped,
9
+ w8a8_block_fp8_matmul_grouped,
10
+ w8a8_tensor_fp8_matmul_grouped,
11
+ )
12
+ from .matmul import (
13
+ w8a8_fp8_matmul,
14
+ w8a8_block_fp8_matmul,
15
+ w8a8_tensor_fp8_matmul,
16
+ )
17
+
18
+ __all__ = [
19
+ "fp8_act_quant",
20
+ # Single matmul
21
+ "w8a8_fp8_matmul",
22
+ "w8a8_block_fp8_matmul",
23
+ "w8a8_tensor_fp8_matmul",
24
+ # Batched matmul
25
+ "w8a8_fp8_matmul_batched",
26
+ "w8a8_block_fp8_matmul_batched",
27
+ "w8a8_tensor_fp8_matmul_batched",
28
+ # Grouped matmul
29
+ "w8a8_fp8_matmul_grouped",
30
+ "w8a8_block_fp8_matmul_grouped",
31
+ "w8a8_tensor_fp8_matmul_grouped",
32
+ ]
build/torch-xpu/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._finegrained_fp8_75cbe1b
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_finegrained_fp8_75cbe1b::{op_name}"
build/torch-xpu/act_quant.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from torch.library import triton_op, wrap_triton
19
+
20
+ from .utils import device_context
21
+
22
+
23
+ _FP8_DTYPE = torch.float8_e4m3fn
24
+
25
+
26
+ # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
27
+ @triton.jit
28
+ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
+ pid = tl.program_id(axis=0)
30
+ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
+ x = tl.load(x_ptr + offs).to(tl.float32)
32
+ s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
33
+ y = (x / s).to(y_ptr.dtype.element_ty)
34
+ tl.store(y_ptr + offs, y)
35
+ tl.store(s_ptr + pid, s)
36
+
37
+
38
+ @triton_op("finegrained_fp8::fp8_act_quant", mutates_args=())
39
+ def _fp8_act_quant(
40
+ x: torch.Tensor, block_size: int = 128
41
+ ) -> tuple[torch.Tensor, torch.Tensor]:
42
+ assert x.is_contiguous()
43
+ assert x.shape[-1] % block_size == 0
44
+ y = torch.empty_like(x, dtype=_FP8_DTYPE)
45
+ grid = (triton.cdiv(x.numel(), block_size),)
46
+ s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
47
+
48
+ with device_context(x.device):
49
+ wrap_triton(_fp8_act_quant_kernel)[grid](x, y, s, BLOCK_SIZE=block_size)
50
+
51
+ return y, s
52
+
53
+
54
+ def fp8_act_quant(
55
+ x: torch.Tensor, block_size: int = 128
56
+ ) -> tuple[torch.Tensor, torch.Tensor]:
57
+ """Quantize activations to FP8 with per-block dynamic scaling.
58
+
59
+ Splits the last dimension of ``x`` into blocks of ``block_size`` elements,
60
+ computes ``scale = max(|x_block|) / 448`` per block, and quantizes to
61
+ ``float8_e4m3fn``.
62
+
63
+ Args:
64
+ x: Input tensor in bf16/fp16/fp32. Last dimension must be divisible by
65
+ ``block_size`` and the tensor must be contiguous.
66
+ block_size: Number of elements per quantization block (default: 128).
67
+
68
+ Returns:
69
+ A tuple ``(quantized, scales)`` where ``quantized`` has dtype
70
+ ``float8_e4m3fn`` with the same shape as ``x``, and ``scales`` has
71
+ shape ``(*x.shape[:-1], x.shape[-1] // block_size)`` in float32.
72
+ """
73
+ return torch.ops.finegrained_fp8.fp8_act_quant(x, block_size)
build/torch-xpu/batched.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from .act_quant import fp8_act_quant
19
+ from torch.library import triton_op, wrap_triton
20
+
21
+ from .utils import device_context
22
+
23
+
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4, 5]
29
+ ],
30
+ key=["N", "K"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_batched_kernel(
34
+ A, # (S, K) raw BF16/FP16 activations
35
+ B, # (E, N, K) FP8 weight matrices
36
+ C, # (S, N) output
37
+ Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
38
+ ExpertIds, # (S,) — which expert each batch element routes to
39
+ # Shape
40
+ S,
41
+ N,
42
+ K,
43
+ stride_am,
44
+ stride_ak,
45
+ stride_be,
46
+ stride_bk,
47
+ stride_bn,
48
+ stride_cm,
49
+ stride_cn,
50
+ stride_bs_e,
51
+ stride_bs_k,
52
+ stride_bs_n,
53
+ # Meta-parameters
54
+ BLOCK_SIZE_N: tl.constexpr,
55
+ BLOCK_SIZE_K: tl.constexpr,
56
+ BLOCK_SIZE_M: tl.constexpr,
57
+ ):
58
+ """Block-scale batched FP8 expert matmul kernel.
59
+
60
+ Each program handles one routed token row and one N-tile, looks up the
61
+ owning expert from ``ExpertIds``, and applies fused activation quantization.
62
+ """
63
+ batch_id = tl.program_id(axis=0)
64
+ pid_n = tl.program_id(axis=1)
65
+
66
+ # Cast expert_id to int64 to prevent int32 overflow when computing
67
+ # expert_id * stride_Eb (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
68
+ # 3072×3072 FP8 weights).
69
+ expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
70
+
71
+ A = A + batch_id * stride_am
72
+ B = B + expert_id * stride_be
73
+ C = C + batch_id * stride_cm
74
+ Bs = Bs + expert_id * stride_bs_e
75
+
76
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
77
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
78
+ a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
79
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
80
+
81
+ bs_ptrs = Bs + pid_n * stride_bs_n
82
+
83
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
84
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
85
+ # ---- fused fp8_act_quant ----
86
+ a_raw = tl.load(a_ptrs).to(tl.float32)
87
+ a_s = tl.max(tl.abs(a_raw)) / 448.0
88
+ a = (a_raw / tl.maximum(a_s, 1e-12)).to(tl.float8e4nv)
89
+ # ---- matmul ----
90
+ b = tl.load(b_ptrs)
91
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
92
+ accumulator += tl.dot(a, b) * a_s * b_s[None, :]
93
+ a_ptrs += BLOCK_SIZE_K * stride_ak
94
+ b_ptrs += BLOCK_SIZE_K * stride_bk
95
+
96
+ if C.dtype.element_ty == tl.bfloat16:
97
+ c = accumulator.to(tl.bfloat16)
98
+ elif C.dtype.element_ty == tl.float16:
99
+ c = accumulator.to(tl.float16)
100
+ else:
101
+ c = accumulator.to(tl.float32)
102
+
103
+ offs_cm = tl.arange(0, BLOCK_SIZE_M)
104
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
105
+ c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
106
+ tl.store(c_ptrs, c)
107
+
108
+
109
+ @triton.autotune(
110
+ configs=[
111
+ triton.Config({}, num_warps=w, num_stages=s)
112
+ for w in [2, 4, 8, 16]
113
+ for s in [2, 3, 4, 5]
114
+ ],
115
+ key=["N", "K"],
116
+ )
117
+ @triton.jit
118
+ def w8a8_tensor_fp8_matmul_batched_kernel(
119
+ A, # (S, K) pre-quantized FP8 activations
120
+ B, # (E, N, K) FP8 weight matrices
121
+ C, # (S, N) output
122
+ As, # (S, 1) per-tensor activation scales
123
+ Bs, # (E, 1, 1) per-tensor weight scales
124
+ ExpertIds,
125
+ S,
126
+ N,
127
+ K,
128
+ stride_am,
129
+ stride_ak,
130
+ stride_be,
131
+ stride_bk,
132
+ stride_bn,
133
+ stride_cm,
134
+ stride_cn,
135
+ stride_as_m,
136
+ stride_bs_e,
137
+ BLOCK_SIZE_N: tl.constexpr,
138
+ BLOCK_SIZE_K: tl.constexpr,
139
+ BLOCK_SIZE_M: tl.constexpr,
140
+ ):
141
+ """Tensor-scale batched FP8 expert matmul kernel.
142
+
143
+ Activations are already quantized; the kernel applies per-token activation
144
+ scales and per-expert tensor weight scales.
145
+ """
146
+ batch_id = tl.program_id(axis=0)
147
+ pid_n = tl.program_id(axis=1)
148
+
149
+ expert_id = tl.load(ExpertIds + batch_id).to(tl.int64)
150
+
151
+ A = A + batch_id * stride_am
152
+ B = B + expert_id * stride_be
153
+ C = C + batch_id * stride_cm
154
+ Bs = Bs + expert_id * stride_bs_e
155
+
156
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
157
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
158
+ a_ptrs = A + tl.arange(0, BLOCK_SIZE_M)[:, None] * 0 + offs_k[None, :] * stride_ak
159
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
160
+
161
+ b_s = tl.load(Bs)
162
+ a_s = tl.load(As + batch_id * stride_as_m)
163
+
164
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
165
+ for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
166
+ a = tl.load(a_ptrs)
167
+ b = tl.load(b_ptrs)
168
+ accumulator += tl.dot(a, b)
169
+ a_ptrs += BLOCK_SIZE_K * stride_ak
170
+ b_ptrs += BLOCK_SIZE_K * stride_bk
171
+
172
+ accumulator = accumulator * a_s * b_s
173
+
174
+ if C.dtype.element_ty == tl.bfloat16:
175
+ c = accumulator.to(tl.bfloat16)
176
+ elif C.dtype.element_ty == tl.float16:
177
+ c = accumulator.to(tl.float16)
178
+ else:
179
+ c = accumulator.to(tl.float32)
180
+
181
+ offs_cm = tl.arange(0, BLOCK_SIZE_M)
182
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
183
+ c_ptrs = C + offs_cm[:, None] * 0 + stride_cn * offs_cn[None, :]
184
+ tl.store(c_ptrs, c)
185
+
186
+
187
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul_batched", mutates_args=())
188
+ def _w8a8_block_fp8_matmul_batched(
189
+ A: torch.Tensor,
190
+ B: torch.Tensor,
191
+ Bs: torch.Tensor,
192
+ expert_ids: torch.Tensor,
193
+ block_size: list[int],
194
+ ) -> torch.Tensor:
195
+ """Block-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
196
+
197
+ A: (S, K) raw bf16/fp16 activations
198
+ B: (E, N, K) FP8 expert weights
199
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
200
+ """
201
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
202
+ assert A.is_contiguous(), "A must be contiguous"
203
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
204
+ assert B.is_contiguous(), "B must be contiguous"
205
+ assert A.shape[1] == B.shape[2], (
206
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
207
+ )
208
+
209
+ S, K = A.shape
210
+ E, N, _ = B.shape
211
+
212
+ assert len(block_size) == 2, (
213
+ f"block_size must be [block_n, block_k], got {block_size}"
214
+ )
215
+ block_n, block_k = block_size[0], block_size[1]
216
+ # MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
217
+ assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
218
+ assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
219
+ assert Bs.ndim == 3, (
220
+ f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
221
+ )
222
+ assert Bs.shape == (E, N // block_n, K // block_k), (
223
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
224
+ )
225
+
226
+ C = A.new_empty(S, N)
227
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
228
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
229
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
230
+ # both accuracy and performance for small M (decode).
231
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
232
+ grid = (S, triton.cdiv(N, block_n))
233
+ with device_context(A.device):
234
+ wrap_triton(w8a8_block_fp8_matmul_batched_kernel)[grid](
235
+ A,
236
+ B,
237
+ C,
238
+ Bs,
239
+ expert_ids,
240
+ S,
241
+ N,
242
+ K,
243
+ A.stride(0),
244
+ A.stride(1),
245
+ B.stride(0),
246
+ B.stride(2),
247
+ B.stride(1),
248
+ C.stride(0),
249
+ C.stride(1),
250
+ Bs.stride(0),
251
+ Bs.stride(2),
252
+ Bs.stride(1),
253
+ BLOCK_SIZE_N=block_n,
254
+ BLOCK_SIZE_K=block_k,
255
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
256
+ )
257
+
258
+ return C
259
+
260
+
261
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_batched", mutates_args=())
262
+ def _w8a8_tensor_fp8_matmul_batched(
263
+ A: torch.Tensor,
264
+ B: torch.Tensor,
265
+ Bs: torch.Tensor,
266
+ expert_ids: torch.Tensor,
267
+ ) -> torch.Tensor:
268
+ """Tensor-scale batched FP8 matmul: C[s] = A[s] @ B[expert_ids[s]].T, with fused act quant.
269
+
270
+ A: (S, K) raw bf16/fp16 activations
271
+ B: (E, N, K) FP8 expert weights
272
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
273
+ """
274
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
275
+ assert A.is_contiguous(), "A must be contiguous"
276
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
277
+ assert B.is_contiguous(), "B must be contiguous"
278
+ assert A.shape[1] == B.shape[2], (
279
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
280
+ )
281
+
282
+ S, K = A.shape
283
+ E, N, _ = B.shape
284
+
285
+ # Normalize Bs to (E, 1, 1)
286
+ if Bs.ndim == 1:
287
+ assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
288
+ Bs = Bs.reshape(E, 1, 1)
289
+ else:
290
+ assert Bs.shape == (E, 1, 1), (
291
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
292
+ )
293
+
294
+ BLOCK_SIZE_N = 128
295
+ BLOCK_SIZE_K = 128
296
+ C = A.new_empty(S, N)
297
+ qA, As = fp8_act_quant(A, K)
298
+ grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
299
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
300
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
301
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
302
+ # both accuracy and performance for small M (decode).
303
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
304
+ grid = (S, triton.cdiv(N, BLOCK_SIZE_N))
305
+ with device_context(A.device):
306
+ wrap_triton(w8a8_tensor_fp8_matmul_batched_kernel)[grid](
307
+ qA,
308
+ B,
309
+ C,
310
+ As,
311
+ Bs,
312
+ expert_ids,
313
+ S,
314
+ N,
315
+ K,
316
+ qA.stride(0),
317
+ qA.stride(1),
318
+ B.stride(0),
319
+ B.stride(2),
320
+ B.stride(1),
321
+ C.stride(0),
322
+ C.stride(1),
323
+ As.stride(0),
324
+ Bs.stride(0),
325
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
326
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
327
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
328
+ )
329
+
330
+ return C
331
+
332
+
333
+ def w8a8_block_fp8_matmul_batched(
334
+ A: torch.Tensor,
335
+ B: torch.Tensor,
336
+ Bs: torch.Tensor,
337
+ expert_ids: torch.Tensor,
338
+ block_size: list[int],
339
+ ) -> torch.Tensor:
340
+ """Block-scale batched FP8 matmul with fused activation quantization.
341
+
342
+ A: (S, K) raw activations, bf16/fp16/fp32
343
+ B: (E, N, K) FP8 expert weights
344
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
345
+ """
346
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_batched(
347
+ A, B, Bs, expert_ids, block_size
348
+ )
349
+
350
+
351
+ def w8a8_tensor_fp8_matmul_batched(
352
+ A: torch.Tensor,
353
+ B: torch.Tensor,
354
+ Bs: torch.Tensor,
355
+ expert_ids: torch.Tensor,
356
+ ) -> torch.Tensor:
357
+ """Tensor-scale batched FP8 matmul with fused activation quantization.
358
+
359
+ A: (S, K) raw activations, bf16/fp16/fp32
360
+ B: (E, N, K) FP8 expert weights
361
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
362
+ """
363
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_batched(
364
+ A, B, Bs, expert_ids
365
+ )
366
+
367
+
368
+ def w8a8_fp8_matmul_batched(
369
+ A: torch.Tensor,
370
+ B: torch.Tensor,
371
+ Bs: torch.Tensor,
372
+ expert_ids: torch.Tensor,
373
+ block_size: list[int] | None,
374
+ ) -> torch.Tensor:
375
+ """Unified batched W8A8 FP8 matmul dispatcher.
376
+
377
+ Dispatch rules:
378
+ - tensor mode when ``block_size is None``
379
+ - tensor mode when ``block_size == [N, K]``
380
+ - otherwise block mode
381
+
382
+ Returns:
383
+ Output tensor ``[S, N]`` in the same dtype as ``A``.
384
+ """
385
+ if block_size is None or (
386
+ block_size[0] == B.size(1) and block_size[1] == B.size(2)
387
+ ):
388
+ return w8a8_tensor_fp8_matmul_batched(A, B, Bs, expert_ids)
389
+
390
+ return w8a8_block_fp8_matmul_batched(A, B, Bs, expert_ids, block_size)
build/torch-xpu/finegrained_fp8/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-xpu/grouped.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from .act_quant import fp8_act_quant
19
+ from torch.library import triton_op, wrap_triton
20
+
21
+ from .utils import device_context
22
+
23
+
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4, 5]
29
+ ],
30
+ key=["N", "K", "BLOCK_SIZE_M"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_grouped_kernel(
34
+ A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
35
+ B, # (E, N, K) FP8 weight matrices
36
+ C, # (S, N) output
37
+ Bs, # (E, N // BLOCK_SIZE_N, K // BLOCK_SIZE_K) weight scales
38
+ Offsets, # (E,) int32 — cumulative row-end per expert
39
+ TileOffsets, # (E,) int32 — cumulative tile-end per expert
40
+ # Shape
41
+ S,
42
+ N,
43
+ K,
44
+ # Strides
45
+ stride_am,
46
+ stride_ak,
47
+ stride_be,
48
+ stride_bk,
49
+ stride_bn,
50
+ stride_cm,
51
+ stride_cn,
52
+ stride_bs_e,
53
+ stride_bs_k,
54
+ stride_bs_n,
55
+ # Meta-parameters
56
+ NUM_EXPERTS: tl.constexpr,
57
+ BLOCK_SIZE_N: tl.constexpr,
58
+ BLOCK_SIZE_K: tl.constexpr,
59
+ BLOCK_SIZE_M: tl.constexpr,
60
+ NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
61
+ ):
62
+ """Block-scale grouped FP8 expert matmul kernel.
63
+
64
+ Tokens are assumed sorted by expert. The kernel maps each M-tile to its
65
+ owning expert via ``TileOffsets`` and applies fused activation quantization.
66
+ """
67
+ pid_m = tl.program_id(axis=0)
68
+ pid_n = tl.program_id(axis=1)
69
+
70
+ # Exit early for programs beyond the actual tile count.
71
+ total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
72
+ if pid_m >= total_tiles:
73
+ return
74
+
75
+ # Binary search in TileOffsets to find the owning expert.
76
+ # Finds the smallest e such that TileOffsets[e] > pid_m (upper_bound semantics),
77
+ # which is the expert whose tile range contains pid_m.
78
+ # O(log2(NUM_EXPERTS)) loads instead of the O(NUM_EXPERTS) linear scan.
79
+ # NUM_EXPERTS_BIT_LENGTH is ceil(log2(E))+1 for powers-of-two, giving one
80
+ # harmless extra iteration when lo==hi; it's a compile-time constant so the
81
+ # loop is fully unrolled by the compiler.
82
+ lo = 0
83
+ hi = NUM_EXPERTS
84
+ for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
85
+ mid = (lo + hi) >> 1
86
+ mid_val = tl.load(TileOffsets + mid)
87
+ is_left = mid_val <= pid_m
88
+ lo = tl.where(is_left, mid + 1, lo)
89
+ hi = tl.where(is_left, hi, mid)
90
+
91
+ # Cast expert_id to int64 to prevent int32 overflow when computing
92
+ # expert_id * stride_be (e.g. 255 * 9_437_184 > 2^31 for 256 experts of
93
+ # 3072×3072 FP8 weights).
94
+ expert_id = lo.to(tl.int64)
95
+
96
+ prev_eid = tl.maximum(expert_id - 1, 0)
97
+ expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
98
+ expert_end = tl.load(Offsets + expert_id)
99
+ M_expert = expert_end - expert_start
100
+
101
+ expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
102
+ local_tile = pid_m - expert_tile_start
103
+ m_off = local_tile * BLOCK_SIZE_M
104
+
105
+ offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
106
+ row_mask = offs_am < M_expert
107
+ offs_global_m = expert_start + offs_am
108
+
109
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
110
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
111
+
112
+ a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
113
+ b_ptrs = (
114
+ B
115
+ + expert_id * stride_be
116
+ + offs_k[:, None] * stride_bk
117
+ + offs_bn[None, :] * stride_bn
118
+ )
119
+ bs_ptrs = Bs + expert_id * stride_bs_e + pid_n * stride_bs_n
120
+
121
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
122
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
123
+ # ---- fused fp8_act_quant ----
124
+ a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32)
125
+ a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0
126
+ a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
127
+ # ---- matmul ----
128
+ b = tl.load(b_ptrs)
129
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
130
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
131
+ a_ptrs += BLOCK_SIZE_K * stride_ak
132
+ b_ptrs += BLOCK_SIZE_K * stride_bk
133
+
134
+ if C.dtype.element_ty == tl.bfloat16:
135
+ c = accumulator.to(tl.bfloat16)
136
+ elif C.dtype.element_ty == tl.float16:
137
+ c = accumulator.to(tl.float16)
138
+ else:
139
+ c = accumulator.to(tl.float32)
140
+
141
+ c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
142
+ c_mask = row_mask[:, None]
143
+ tl.store(c_ptrs, c, mask=c_mask)
144
+
145
+
146
+ @triton.autotune(
147
+ configs=[
148
+ triton.Config({}, num_warps=w, num_stages=s)
149
+ for w in [2, 4, 8, 16]
150
+ for s in [2, 3, 4, 5]
151
+ ],
152
+ key=["N", "K", "BLOCK_SIZE_M"],
153
+ )
154
+ @triton.jit
155
+ def w8a8_tensor_fp8_matmul_grouped_kernel(
156
+ A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert idc
157
+ B, # (E, N, K) FP8 weight matrices
158
+ C, # (S, N) output
159
+ As, # (S, 1) activation scales
160
+ Bs, # (E, 1, 1) per-tensor weight scales
161
+ Offsets,
162
+ TileOffsets,
163
+ S,
164
+ N,
165
+ K,
166
+ stride_am,
167
+ stride_ak,
168
+ stride_be,
169
+ stride_bk,
170
+ stride_bn,
171
+ stride_cm,
172
+ stride_cn,
173
+ stride_as_m,
174
+ stride_bs_e,
175
+ NUM_EXPERTS: tl.constexpr,
176
+ BLOCK_SIZE_N: tl.constexpr,
177
+ BLOCK_SIZE_K: tl.constexpr,
178
+ BLOCK_SIZE_M: tl.constexpr,
179
+ NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
180
+ ):
181
+ """Tensor-scale grouped FP8 expert matmul kernel.
182
+
183
+ Uses grouped expert scheduling with pre-quantized activations plus
184
+ per-token activation scales and per-expert tensor weight scales.
185
+ """
186
+ pid_m = tl.program_id(axis=0)
187
+ pid_n = tl.program_id(axis=1)
188
+
189
+ total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1)
190
+ if pid_m >= total_tiles:
191
+ return
192
+
193
+ lo = 0
194
+ hi = NUM_EXPERTS
195
+ for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
196
+ mid = (lo + hi) >> 1
197
+ mid_val = tl.load(TileOffsets + mid)
198
+ is_left = mid_val <= pid_m
199
+ lo = tl.where(is_left, mid + 1, lo)
200
+ hi = tl.where(is_left, hi, mid)
201
+ expert_id = lo.to(tl.int64)
202
+
203
+ prev_eid = tl.maximum(expert_id - 1, 0)
204
+ expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid))
205
+ expert_end = tl.load(Offsets + expert_id)
206
+ M_expert = expert_end - expert_start
207
+
208
+ expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid))
209
+ local_tile = pid_m - expert_tile_start
210
+ m_off = local_tile * BLOCK_SIZE_M
211
+
212
+ offs_am = m_off + tl.arange(0, BLOCK_SIZE_M)
213
+ row_mask = offs_am < M_expert
214
+ offs_global_m = expert_start + offs_am
215
+
216
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
218
+
219
+ a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak
220
+ b_ptrs = (
221
+ B
222
+ + expert_id * stride_be
223
+ + offs_k[:, None] * stride_bk
224
+ + offs_bn[None, :] * stride_bn
225
+ )
226
+
227
+ a_s = tl.load(As + offs_global_m * stride_as_m, mask=row_mask, other=0.0)
228
+ b_s = tl.load(Bs + expert_id * stride_bs_e)
229
+
230
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
231
+ for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
232
+ a = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0)
233
+ b = tl.load(b_ptrs)
234
+
235
+ accumulator += tl.dot(a, b)
236
+ a_ptrs += BLOCK_SIZE_K * stride_ak
237
+ b_ptrs += BLOCK_SIZE_K * stride_bk
238
+
239
+ accumulator = accumulator * a_s[:, None] * b_s
240
+
241
+ if C.dtype.element_ty == tl.bfloat16:
242
+ c = accumulator.to(tl.bfloat16)
243
+ elif C.dtype.element_ty == tl.float16:
244
+ c = accumulator.to(tl.float16)
245
+ else:
246
+ c = accumulator.to(tl.float32)
247
+
248
+ c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :]
249
+ c_mask = row_mask[:, None]
250
+ tl.store(c_ptrs, c, mask=c_mask)
251
+
252
+
253
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul_grouped", mutates_args=())
254
+ def _w8a8_block_fp8_matmul_grouped(
255
+ A: torch.Tensor,
256
+ B: torch.Tensor,
257
+ Bs: torch.Tensor,
258
+ offsets: torch.Tensor,
259
+ tokens_per_expert: torch.Tensor,
260
+ block_size: list[int],
261
+ ) -> torch.Tensor:
262
+ """Block-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
263
+
264
+ A: (S, K) raw bf16/fp16 activations, sorted by expert
265
+ B: (E, N, K) FP8 expert weights
266
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
267
+ """
268
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
269
+ assert A.is_contiguous(), "A must be contiguous"
270
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
271
+ assert B.is_contiguous(), "B must be contiguous"
272
+ assert A.shape[1] == B.shape[2], (
273
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
274
+ )
275
+
276
+ S, K = A.shape
277
+ E, N, _ = B.shape
278
+
279
+ assert len(block_size) == 2, (
280
+ f"block_size must be [block_n, block_k], got {block_size}"
281
+ )
282
+ block_n, block_k = block_size[0], block_size[1]
283
+ # MoE expert dimensions must be block-aligned; non-aligned N/K is not supported.
284
+ assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
285
+ assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
286
+ assert Bs.ndim == 3, (
287
+ f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}"
288
+ )
289
+ assert Bs.shape == (E, N // block_n, K // block_k), (
290
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})"
291
+ )
292
+
293
+ C = A.new_empty(S, N)
294
+ # Adaptive BLOCK_SIZE_M: match tile to average tokens per expert.
295
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
296
+ tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
297
+ tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
298
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
299
+ # Programs beyond the real tile count exit immediately via the early-return
300
+ # guard inside the kernel. This is faster than syncing for the exact count
301
+ # and keeps the grid size data-independent (cuda-graph / torch.compile safe).
302
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
303
+ grid = (max_M_tiles, triton.cdiv(N, block_n))
304
+ with device_context(A.device):
305
+ wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
306
+ A,
307
+ B,
308
+ C,
309
+ Bs,
310
+ offsets,
311
+ tile_offsets,
312
+ S,
313
+ N,
314
+ K,
315
+ A.stride(0),
316
+ A.stride(1),
317
+ B.stride(0),
318
+ B.stride(2),
319
+ B.stride(1),
320
+ C.stride(0),
321
+ C.stride(1),
322
+ Bs.stride(0),
323
+ Bs.stride(2),
324
+ Bs.stride(1),
325
+ # Meta-parameters
326
+ NUM_EXPERTS=E,
327
+ BLOCK_SIZE_N=block_n,
328
+ BLOCK_SIZE_K=block_k,
329
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
330
+ NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
331
+ )
332
+
333
+ return C
334
+
335
+
336
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_grouped", mutates_args=())
337
+ def _w8a8_tensor_fp8_matmul_grouped(
338
+ A: torch.Tensor,
339
+ B: torch.Tensor,
340
+ Bs: torch.Tensor,
341
+ offsets: torch.Tensor,
342
+ tokens_per_expert: torch.Tensor,
343
+ ) -> torch.Tensor:
344
+ """Tensor-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant.
345
+
346
+ A: (S, K) raw bf16/fp16 activations, sorted by expert
347
+ B: (E, N, K) FP8 expert weights
348
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
349
+ """
350
+ assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}"
351
+ assert A.is_contiguous(), "A must be contiguous"
352
+ assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}"
353
+ assert B.is_contiguous(), "B must be contiguous"
354
+ assert A.shape[1] == B.shape[2], (
355
+ f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}"
356
+ )
357
+
358
+ S, K = A.shape
359
+ E, N, _ = B.shape
360
+
361
+ # Normalize Bs to (E, 1, 1)
362
+ if Bs.ndim == 1:
363
+ assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)"
364
+ Bs = Bs.reshape(E, 1, 1)
365
+ else:
366
+ assert Bs.shape == (E, 1, 1), (
367
+ f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)"
368
+ )
369
+
370
+ BLOCK_SIZE_N = 128
371
+ BLOCK_SIZE_K = 128
372
+ C = A.new_empty(S, N)
373
+ qA, As = fp8_act_quant(A, K)
374
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
375
+ tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
376
+ tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
377
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
378
+ # Programs beyond the real tile count exit immediately via the early-return
379
+ # guard inside the kernel. This is faster than syncing for the exact count
380
+ # and keeps the grid size data-independent (cuda-graph / torch.compile safe).
381
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
382
+ grid = (max_M_tiles, triton.cdiv(N, BLOCK_SIZE_N))
383
+ with device_context(A.device):
384
+ wrap_triton(w8a8_tensor_fp8_matmul_grouped_kernel)[grid](
385
+ qA,
386
+ B,
387
+ C,
388
+ As,
389
+ Bs,
390
+ offsets,
391
+ tile_offsets,
392
+ S,
393
+ N,
394
+ K,
395
+ qA.stride(0),
396
+ qA.stride(1),
397
+ B.stride(0),
398
+ B.stride(2),
399
+ B.stride(1),
400
+ C.stride(0),
401
+ C.stride(1),
402
+ As.stride(0),
403
+ Bs.stride(0),
404
+ # Meta-parameters
405
+ NUM_EXPERTS=E,
406
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
407
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
408
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
409
+ NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
410
+ )
411
+
412
+ return C
413
+
414
+
415
+ def w8a8_block_fp8_matmul_grouped(
416
+ A: torch.Tensor,
417
+ B: torch.Tensor,
418
+ Bs: torch.Tensor,
419
+ offsets: torch.Tensor,
420
+ tokens_per_expert: torch.Tensor,
421
+ block_size: list[int],
422
+ ) -> torch.Tensor:
423
+ """Block-scale grouped FP8 matmul with fused activation quantization.
424
+
425
+ A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
426
+ B: (E, N, K) FP8 expert weights
427
+ Bs: (E, N // block_n, K // block_k) per-block weight scales
428
+ """
429
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
430
+ A, B, Bs, offsets, tokens_per_expert, block_size
431
+ )
432
+
433
+
434
+ def w8a8_tensor_fp8_matmul_grouped(
435
+ A: torch.Tensor,
436
+ B: torch.Tensor,
437
+ Bs: torch.Tensor,
438
+ offsets: torch.Tensor,
439
+ tokens_per_expert: torch.Tensor,
440
+ ) -> torch.Tensor:
441
+ """Tensor-scale grouped FP8 matmul with fused activation quantization.
442
+
443
+ A: (S, K) raw activations sorted by expert, bf16/fp16/fp32
444
+ B: (E, N, K) FP8 expert weights
445
+ Bs: (E,) or (E, 1, 1) per-expert weight scales
446
+ """
447
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
448
+ A, B, Bs, offsets, tokens_per_expert
449
+ )
450
+
451
+
452
+ def w8a8_fp8_matmul_grouped(
453
+ A: torch.Tensor,
454
+ B: torch.Tensor,
455
+ Bs: torch.Tensor,
456
+ offsets: torch.Tensor,
457
+ tokens_per_expert: torch.Tensor,
458
+ block_size: list[int] | None,
459
+ ) -> torch.Tensor:
460
+ """Unified grouped W8A8 FP8 matmul dispatcher.
461
+
462
+ Dispatch rules:
463
+ - tensor mode when ``block_size is None``
464
+ - tensor mode when ``block_size == [N, K]``
465
+ - otherwise block mode
466
+
467
+ Returns:
468
+ Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
469
+ """
470
+ if block_size is None or (
471
+ block_size[0] == B.size(1) and block_size[1] == B.size(2)
472
+ ):
473
+ return w8a8_tensor_fp8_matmul_grouped(A, B, Bs, offsets, tokens_per_expert)
474
+
475
+ return w8a8_block_fp8_matmul_grouped(
476
+ A, B, Bs, offsets, tokens_per_expert, block_size
477
+ )
build/torch-xpu/matmul.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import triton
17
+ import triton.language as tl
18
+ from torch.library import triton_op, wrap_triton
19
+
20
+ from .utils import device_context
21
+
22
+
23
+ # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=w, num_stages=s)
27
+ for w in [2, 4, 8, 16]
28
+ for s in [2, 3, 4]
29
+ ],
30
+ key=["N", "K", "BLOCK_SIZE_M"],
31
+ )
32
+ @triton.jit
33
+ def w8a8_block_fp8_matmul_kernel(
34
+ # Pointers to inputs and output
35
+ A,
36
+ B,
37
+ C,
38
+ As,
39
+ Bs,
40
+ # Shape for matmul
41
+ M,
42
+ N,
43
+ K,
44
+ stride_am,
45
+ stride_ak,
46
+ stride_bk,
47
+ stride_bn,
48
+ stride_cm,
49
+ stride_cn,
50
+ stride_as_m,
51
+ stride_as_k,
52
+ stride_bs_k,
53
+ stride_bs_n,
54
+ # Meta-parameters
55
+ BLOCK_SIZE_M: tl.constexpr,
56
+ BLOCK_SIZE_N: tl.constexpr,
57
+ BLOCK_SIZE_K: tl.constexpr,
58
+ GROUP_SIZE_M: tl.constexpr,
59
+ ):
60
+ """Block-scale FP8 GEMM kernel.
61
+
62
+ Computes ``C = A @ B.T`` with block-wise activation/weight scales.
63
+ Uses a 2D grid with swizzle for L2 cache locality on B tiles.
64
+ """
65
+ pid_m = tl.program_id(axis=0)
66
+ pid_n = tl.program_id(axis=1)
67
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
68
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
69
+ pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
70
+
71
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
72
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
73
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
74
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
75
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
76
+
77
+ as_ptrs = As + offs_am * stride_as_m
78
+ offs_bsn = offs_bn // BLOCK_SIZE_N
79
+ bs_ptrs = Bs + offs_bsn * stride_bs_n
80
+
81
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
82
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
83
+ k_remaining = K - k * BLOCK_SIZE_K
84
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
85
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
86
+
87
+ a_s = tl.load(as_ptrs + k * stride_as_k)
88
+ b_s = tl.load(bs_ptrs + k * stride_bs_k)
89
+
90
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
91
+ a_ptrs += BLOCK_SIZE_K * stride_ak
92
+ b_ptrs += BLOCK_SIZE_K * stride_bk
93
+
94
+ if C.dtype.element_ty == tl.bfloat16:
95
+ c = accumulator.to(tl.bfloat16)
96
+ elif C.dtype.element_ty == tl.float16:
97
+ c = accumulator.to(tl.float16)
98
+ else:
99
+ c = accumulator.to(tl.float32)
100
+
101
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
102
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
103
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
104
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
105
+ tl.store(c_ptrs, c, mask=c_mask)
106
+
107
+
108
+ @triton.autotune(
109
+ configs=[
110
+ triton.Config({}, num_warps=w, num_stages=s)
111
+ for w in [2, 4, 8, 16]
112
+ for s in [2, 3, 4]
113
+ ],
114
+ key=["N", "K", "BLOCK_SIZE_M"],
115
+ )
116
+ @triton.jit
117
+ def w8a8_tensor_fp8_matmul_kernel(
118
+ A,
119
+ B,
120
+ C,
121
+ As,
122
+ Bs,
123
+ M,
124
+ N,
125
+ K,
126
+ stride_am,
127
+ stride_ak,
128
+ stride_bk,
129
+ stride_bn,
130
+ stride_cm,
131
+ stride_cn,
132
+ stride_as_m,
133
+ BLOCK_SIZE_M: tl.constexpr,
134
+ BLOCK_SIZE_N: tl.constexpr,
135
+ BLOCK_SIZE_K: tl.constexpr,
136
+ GROUP_SIZE_M: tl.constexpr,
137
+ ):
138
+ """Tensor-scale FP8 GEMM kernel.
139
+
140
+ Computes ``C = A @ B.T`` with one activation scale per row and one
141
+ weight scale for the full matrix.
142
+ Uses a 2D grid with swizzle for L2 cache locality on B tiles.
143
+ """
144
+ pid_m = tl.program_id(axis=0)
145
+ pid_n = tl.program_id(axis=1)
146
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
147
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
148
+ pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
149
+
150
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
151
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
152
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
153
+
154
+ a_ptrs = A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
155
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
156
+
157
+ a_s = tl.load(As + offs_am * stride_as_m)
158
+ b_s = tl.load(Bs)
159
+
160
+ # Accumulate raw dot products, apply scales once after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ k_remaining = K - k * BLOCK_SIZE_K
164
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
165
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
166
+ accumulator += tl.dot(a, b)
167
+ a_ptrs += BLOCK_SIZE_K * stride_ak
168
+ b_ptrs += BLOCK_SIZE_K * stride_bk
169
+
170
+ accumulator = accumulator * a_s[:, None] * b_s
171
+
172
+ if C.dtype.element_ty == tl.bfloat16:
173
+ c = accumulator.to(tl.bfloat16)
174
+ elif C.dtype.element_ty == tl.float16:
175
+ c = accumulator.to(tl.float16)
176
+ else:
177
+ c = accumulator.to(tl.float32)
178
+
179
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
180
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
181
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
182
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
183
+ tl.store(c_ptrs, c, mask=c_mask)
184
+
185
+
186
+ @triton_op("finegrained_fp8::w8a8_block_fp8_matmul", mutates_args=())
187
+ def _w8a8_block_fp8_matmul(
188
+ A: torch.Tensor,
189
+ B: torch.Tensor,
190
+ As: torch.Tensor,
191
+ Bs: torch.Tensor,
192
+ block_size: list[int],
193
+ output_dtype: torch.dtype = torch.float32,
194
+ ) -> torch.Tensor:
195
+ """Block-scale FP8 matmul: C = A @ B.T with per-block scales.
196
+
197
+ As: (M, K // block_k) — per-token-group activation scales
198
+ Bs: (N // block_n, K // block_k) — per-block weight scales
199
+ """
200
+ assert len(block_size) == 2, (
201
+ f"block_size must be [block_n, block_k], got {block_size}"
202
+ )
203
+ block_n, block_k = block_size[0], block_size[1]
204
+
205
+ assert A.shape[-1] == B.shape[-1], (
206
+ f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
207
+ )
208
+ assert A.is_contiguous(), "A must be contiguous"
209
+ assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
210
+ assert B.is_contiguous(), "B must be contiguous"
211
+
212
+ N, K = B.shape
213
+ M = A.numel() // A.shape[-1]
214
+
215
+ assert As.ndim >= 2, f"As must be at least 2D, got ndim={As.ndim}"
216
+ assert As.shape[-1] == triton.cdiv(K, block_k), (
217
+ f"As last dim {As.shape[-1]} != expected {triton.cdiv(K, block_k)} (cdiv(K={K}, block_k={block_k}))"
218
+ )
219
+ assert Bs.ndim == 2, f"Bs must be 2D (N//block_n, K//block_k), got ndim={Bs.ndim}"
220
+ assert Bs.shape == (triton.cdiv(N, block_n), triton.cdiv(K, block_k)), (
221
+ f"Bs shape {tuple(Bs.shape)} != expected ({triton.cdiv(N, block_n)}, {triton.cdiv(K, block_k)})"
222
+ )
223
+
224
+ BLOCK_SIZE_K = block_k
225
+ BLOCK_SIZE_N = block_n
226
+ C_shape = A.shape[:-1] + (N,)
227
+ C = A.new_empty(C_shape, dtype=output_dtype)
228
+ # Adaptive BLOCK_SIZE_M: smallest power-of-2 >= M, floored at 16, capped at 128.
229
+ # Matches the WGMMA tile to the actual row count — smaller tiles use less
230
+ # register pressure and a better-matched FP8 WGMMA instruction, improving
231
+ # both accuracy and performance for small M (decode).
232
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
233
+ grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
234
+ with device_context(A.device):
235
+ wrap_triton(w8a8_block_fp8_matmul_kernel)[grid](
236
+ A,
237
+ B,
238
+ C,
239
+ As,
240
+ Bs,
241
+ M,
242
+ N,
243
+ K,
244
+ A.stride(-2),
245
+ A.stride(-1),
246
+ B.stride(1),
247
+ B.stride(0),
248
+ C.stride(-2),
249
+ C.stride(-1),
250
+ As.stride(-2),
251
+ As.stride(-1),
252
+ Bs.stride(1),
253
+ Bs.stride(0),
254
+ # Meta-parameters
255
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
256
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
257
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
258
+ GROUP_SIZE_M=8,
259
+ )
260
+
261
+ return C
262
+
263
+
264
+ @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul", mutates_args=())
265
+ def _w8a8_tensor_fp8_matmul(
266
+ A: torch.Tensor,
267
+ B: torch.Tensor,
268
+ As: torch.Tensor,
269
+ Bs: torch.Tensor,
270
+ output_dtype: torch.dtype = torch.float32,
271
+ ) -> torch.Tensor:
272
+ """Tensor-scale FP8 matmul: C = A @ B.T with per-row / per-tensor scales.
273
+
274
+ As: scalar, (M,), or (M, 1) — per-row activation scales
275
+ Bs: scalar, (1,), or (1, 1) — single weight scale
276
+ """
277
+ assert A.shape[-1] == B.shape[-1], (
278
+ f"K mismatch: A has K={A.shape[-1]}, B has K={B.shape[-1]}"
279
+ )
280
+ assert A.is_contiguous(), "A must be contiguous"
281
+ assert B.ndim == 2, f"B must be 2D (N, K), got ndim={B.ndim}"
282
+ assert B.is_contiguous(), "B must be contiguous"
283
+
284
+ N, K = B.shape
285
+ M = A.numel() // A.shape[-1]
286
+
287
+ # Normalize As to (M,)
288
+ if As.numel() == 1:
289
+ As = As.reshape(1).expand(M).contiguous()
290
+ elif As.ndim == 2:
291
+ As = As.reshape(M)
292
+ assert As.ndim == 1 and As.shape[0] == M, (
293
+ f"As must be scalar, (M,), or (M,1) with M={M}, got {tuple(As.shape)}"
294
+ )
295
+
296
+ # Normalize Bs to (1,)
297
+ assert Bs.numel() == 1, f"Bs must be scalar or (1,), got {tuple(Bs.shape)}"
298
+ Bs = Bs.reshape(1)
299
+
300
+ BLOCK_SIZE_N = 128
301
+ BLOCK_SIZE_K = 128
302
+ C_shape = A.shape[:-1] + (N,)
303
+ C = A.new_empty(C_shape, dtype=output_dtype)
304
+ BLOCK_SIZE_M = min(max(triton.next_power_of_2(M), 16), 128)
305
+ grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
306
+ with device_context(A.device):
307
+ wrap_triton(w8a8_tensor_fp8_matmul_kernel)[grid](
308
+ A,
309
+ B,
310
+ C,
311
+ As,
312
+ Bs,
313
+ M,
314
+ N,
315
+ K,
316
+ A.stride(-2),
317
+ A.stride(-1),
318
+ B.stride(1),
319
+ B.stride(0),
320
+ C.stride(-2),
321
+ C.stride(-1),
322
+ As.stride(0),
323
+ # Meta-parameters
324
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
325
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
326
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
327
+ GROUP_SIZE_M=8,
328
+ )
329
+
330
+ return C
331
+
332
+
333
+ def w8a8_block_fp8_matmul(
334
+ A: torch.Tensor,
335
+ B: torch.Tensor,
336
+ As: torch.Tensor,
337
+ Bs: torch.Tensor,
338
+ block_size: list[int],
339
+ output_dtype: torch.dtype = torch.float32,
340
+ ) -> torch.Tensor:
341
+ """Block-wise W8A8 FP8 matrix multiplication.
342
+
343
+ Computes ``C = A @ B.T`` where both operands are pre-quantized to
344
+ ``float8_e4m3fn`` with per-block scales, and accumulates in float32
345
+ before casting to ``output_dtype``.
346
+
347
+ Args:
348
+ A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
349
+ B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
350
+ As: Per-token-group activation scales ``[M, K // block_size[1]]``.
351
+ Bs: Per-block weight scales ``[N // block_size[0], K // block_size[1]]``.
352
+ block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
353
+ output_dtype: dtype of the returned tensor (default: ``torch.float32``).
354
+
355
+ Returns:
356
+ Output tensor ``[M, N]`` in ``output_dtype``.
357
+ """
358
+ return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul(
359
+ A, B, As, Bs, block_size, output_dtype
360
+ )
361
+
362
+
363
+ def w8a8_tensor_fp8_matmul(
364
+ A: torch.Tensor,
365
+ B: torch.Tensor,
366
+ As: torch.Tensor,
367
+ Bs: torch.Tensor,
368
+ output_dtype: torch.dtype = torch.float32,
369
+ ) -> torch.Tensor:
370
+ """Tensor-scale W8A8 FP8 matrix multiplication.
371
+
372
+ Computes ``C = A @ B.T`` in tensor-scale mode using pre-quantized FP8
373
+ activations/weights and tensor scales.
374
+
375
+ Args:
376
+ A: Quantized activation tensor ``[M, K]`` in ``float8_e4m3fn``.
377
+ B: Quantized weight tensor ``[N, K]`` in ``float8_e4m3fn``.
378
+ As: Per-row activation scales ``[M]``.
379
+ Bs: Single weight scale, scalar or ``[1]``.
380
+ output_dtype: dtype of the returned tensor.
381
+
382
+ Returns:
383
+ Output tensor ``[M, N]`` in ``output_dtype``.
384
+ """
385
+ return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
386
+
387
+
388
+ def w8a8_fp8_matmul(
389
+ A: torch.Tensor,
390
+ B: torch.Tensor,
391
+ As: torch.Tensor,
392
+ Bs: torch.Tensor,
393
+ block_size: list[int] | None,
394
+ output_dtype: torch.dtype = torch.float32,
395
+ ) -> torch.Tensor:
396
+ """Unified W8A8 FP8 matmul dispatcher.
397
+
398
+ Dispatch rules:
399
+ - tensor mode when ``block_size is None``
400
+ - tensor mode when ``block_size == [N, K]``
401
+ - otherwise block mode
402
+
403
+ Returns:
404
+ Output tensor ``[M, N]`` in ``output_dtype``.
405
+ """
406
+ if block_size is None or (
407
+ block_size[0] == B.size(0) and block_size[1] == B.size(1)
408
+ ):
409
+ return w8a8_tensor_fp8_matmul(A, B, As, Bs, output_dtype)
410
+
411
+ return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
build/torch-xpu/metadata.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "xpu"
7
+ }
8
+ }
build/torch-xpu/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+
4
+
5
+ @contextmanager
6
+ def device_context(device: torch.device):
7
+ """Context manager that sets the active device for any backend (cuda, xpu, etc.)."""
8
+ backend = getattr(torch, device.type, None)
9
+ if backend is not None and hasattr(backend, "device"):
10
+ with backend.device(device):
11
+ yield
12
+ else:
13
+ yield