Benchmarks uploaded using `kernels`.
Browse files- benchmarks/benchmark.py +80 -0
benchmarks/benchmark.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from kernels.benchmark import Benchmark
|
| 4 |
+
|
| 5 |
+
# Monkey patch torch.allclose to use higher tolerance for FP8 comparisons
|
| 6 |
+
_original_allclose = torch.allclose
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _fp8_tolerant_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
| 10 |
+
"""Custom allclose that uses higher tolerance for FP8-related comparisons."""
|
| 11 |
+
# Use higher tolerance since FP8 has low precision (~3 bits mantissa)
|
| 12 |
+
# FP8 e4m3 has relative precision of ~12.5%, so use atol based on max value
|
| 13 |
+
max_val = max(input.abs().max().item(), other.abs().max().item(), 1.0)
|
| 14 |
+
fp8_atol = max(atol, max_val * 0.15) # 15% relative tolerance
|
| 15 |
+
return _original_allclose(
|
| 16 |
+
input, other, rtol=rtol, atol=fp8_atol, equal_nan=equal_nan
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Apply the monkey patch
|
| 21 |
+
torch.allclose = _fp8_tolerant_allclose
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def quantize_fp8_per_row_reference(
|
| 25 |
+
a: torch.Tensor,
|
| 26 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 27 |
+
"""Reference implementation of FP8 per-row quantization."""
|
| 28 |
+
pt_fp8_dtype = torch.float8_e4m3fn
|
| 29 |
+
max_fp8 = torch.finfo(pt_fp8_dtype).max
|
| 30 |
+
eps = 1e-12
|
| 31 |
+
|
| 32 |
+
original_shape = a.shape
|
| 33 |
+
a_2d = a.view(-1, a.shape[-1])
|
| 34 |
+
|
| 35 |
+
# Compute max absolute value per row
|
| 36 |
+
row_max = a_2d.abs().max(dim=-1).values
|
| 37 |
+
row_max = torch.clamp(row_max, min=eps)
|
| 38 |
+
|
| 39 |
+
# Compute scale: MAX_FP8 / max_abs
|
| 40 |
+
scale = max_fp8 / row_max
|
| 41 |
+
|
| 42 |
+
# Quantize
|
| 43 |
+
a_scaled = a_2d * scale.unsqueeze(-1)
|
| 44 |
+
a_scaled = torch.clamp(a_scaled, -max_fp8, max_fp8)
|
| 45 |
+
a_fp8 = a_scaled.to(pt_fp8_dtype)
|
| 46 |
+
|
| 47 |
+
# Return reciprocal scale
|
| 48 |
+
a_scale = 1.0 / scale
|
| 49 |
+
|
| 50 |
+
return a_fp8.view(original_shape), a_scale.view(original_shape[:-1])
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class QuantizeFp8PerRowBenchmark(Benchmark):
|
| 54 |
+
seed: int = 42
|
| 55 |
+
|
| 56 |
+
def setup(self):
|
| 57 |
+
M, K = 512, 1024
|
| 58 |
+
self.a = torch.randn(M, K, device=self.device, dtype=torch.float32)
|
| 59 |
+
self.out = torch.empty(M, K, device=self.device, dtype=torch.float32)
|
| 60 |
+
|
| 61 |
+
def benchmark_base(self):
|
| 62 |
+
a_fp8, a_scale = self.kernel.quantize_fp8_per_row(self.a)
|
| 63 |
+
self.out = a_fp8.to(torch.float32)
|
| 64 |
+
|
| 65 |
+
def verify_base(self) -> torch.Tensor:
|
| 66 |
+
a_fp8, _ = quantize_fp8_per_row_reference(self.a)
|
| 67 |
+
return a_fp8.to(torch.float32)
|
| 68 |
+
|
| 69 |
+
def setup_large(self):
|
| 70 |
+
M, K = 2048, 4096
|
| 71 |
+
self.a = torch.randn(M, K, device=self.device, dtype=torch.float32)
|
| 72 |
+
self.out = torch.empty(M, K, device=self.device, dtype=torch.float32)
|
| 73 |
+
|
| 74 |
+
def benchmark_large(self):
|
| 75 |
+
a_fp8, a_scale = self.kernel.quantize_fp8_per_row(self.a)
|
| 76 |
+
self.out = a_fp8.to(torch.float32)
|
| 77 |
+
|
| 78 |
+
def verify_large(self) -> torch.Tensor:
|
| 79 |
+
a_fp8, _ = quantize_fp8_per_row_reference(self.a)
|
| 80 |
+
return a_fp8.to(torch.float32)
|