Kernels
danieldk HF Staff commited on
Commit
1c3bedd
·
verified ·
1 Parent(s): 4f712ee

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +80 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from kernels.benchmark import Benchmark
4
+
5
+ # Monkey patch torch.allclose to use higher tolerance for FP8 comparisons
6
+ _original_allclose = torch.allclose
7
+
8
+
9
+ def _fp8_tolerant_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
10
+ """Custom allclose that uses higher tolerance for FP8-related comparisons."""
11
+ # Use higher tolerance since FP8 has low precision (~3 bits mantissa)
12
+ # FP8 e4m3 has relative precision of ~12.5%, so use atol based on max value
13
+ max_val = max(input.abs().max().item(), other.abs().max().item(), 1.0)
14
+ fp8_atol = max(atol, max_val * 0.15) # 15% relative tolerance
15
+ return _original_allclose(
16
+ input, other, rtol=rtol, atol=fp8_atol, equal_nan=equal_nan
17
+ )
18
+
19
+
20
+ # Apply the monkey patch
21
+ torch.allclose = _fp8_tolerant_allclose
22
+
23
+
24
+ def quantize_fp8_per_row_reference(
25
+ a: torch.Tensor,
26
+ ) -> tuple[torch.Tensor, torch.Tensor]:
27
+ """Reference implementation of FP8 per-row quantization."""
28
+ pt_fp8_dtype = torch.float8_e4m3fn
29
+ max_fp8 = torch.finfo(pt_fp8_dtype).max
30
+ eps = 1e-12
31
+
32
+ original_shape = a.shape
33
+ a_2d = a.view(-1, a.shape[-1])
34
+
35
+ # Compute max absolute value per row
36
+ row_max = a_2d.abs().max(dim=-1).values
37
+ row_max = torch.clamp(row_max, min=eps)
38
+
39
+ # Compute scale: MAX_FP8 / max_abs
40
+ scale = max_fp8 / row_max
41
+
42
+ # Quantize
43
+ a_scaled = a_2d * scale.unsqueeze(-1)
44
+ a_scaled = torch.clamp(a_scaled, -max_fp8, max_fp8)
45
+ a_fp8 = a_scaled.to(pt_fp8_dtype)
46
+
47
+ # Return reciprocal scale
48
+ a_scale = 1.0 / scale
49
+
50
+ return a_fp8.view(original_shape), a_scale.view(original_shape[:-1])
51
+
52
+
53
+ class QuantizeFp8PerRowBenchmark(Benchmark):
54
+ seed: int = 42
55
+
56
+ def setup(self):
57
+ M, K = 512, 1024
58
+ self.a = torch.randn(M, K, device=self.device, dtype=torch.float32)
59
+ self.out = torch.empty(M, K, device=self.device, dtype=torch.float32)
60
+
61
+ def benchmark_base(self):
62
+ a_fp8, a_scale = self.kernel.quantize_fp8_per_row(self.a)
63
+ self.out = a_fp8.to(torch.float32)
64
+
65
+ def verify_base(self) -> torch.Tensor:
66
+ a_fp8, _ = quantize_fp8_per_row_reference(self.a)
67
+ return a_fp8.to(torch.float32)
68
+
69
+ def setup_large(self):
70
+ M, K = 2048, 4096
71
+ self.a = torch.randn(M, K, device=self.device, dtype=torch.float32)
72
+ self.out = torch.empty(M, K, device=self.device, dtype=torch.float32)
73
+
74
+ def benchmark_large(self):
75
+ a_fp8, a_scale = self.kernel.quantize_fp8_per_row(self.a)
76
+ self.out = a_fp8.to(torch.float32)
77
+
78
+ def verify_large(self) -> torch.Tensor:
79
+ a_fp8, _ = quantize_fp8_per_row_reference(self.a)
80
+ return a_fp8.to(torch.float32)