Kernels
danieldk HF Staff commited on
Commit
74d778e
·
verified ·
1 Parent(s): 056134d

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +119 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from kernels.benchmark import Benchmark
4
+
5
+
6
+ def apply_rotary_reference(
7
+ x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool
8
+ ) -> tuple[torch.Tensor, torch.Tensor]:
9
+ if not conj:
10
+ out1 = x1 * cos - x2 * sin
11
+ out2 = x1 * sin + x2 * cos
12
+ else:
13
+ out1 = x1 * cos + x2 * sin
14
+ out2 = -x1 * sin + x2 * cos
15
+ return out1, out2
16
+
17
+
18
+ class RotaryBenchmark(Benchmark):
19
+ seed: int = 42
20
+
21
+ def setup(self):
22
+ batch_size = 2
23
+ seqlen = 128
24
+ num_heads = 8
25
+ head_dim = 64
26
+ rotary_dim = 32
27
+
28
+ # Query tensor split into rotary parts
29
+ self.x1 = torch.randn(
30
+ batch_size,
31
+ seqlen,
32
+ num_heads,
33
+ rotary_dim,
34
+ device=self.device,
35
+ dtype=torch.float32,
36
+ )
37
+ self.x2 = torch.randn(
38
+ batch_size,
39
+ seqlen,
40
+ num_heads,
41
+ rotary_dim,
42
+ device=self.device,
43
+ dtype=torch.float32,
44
+ )
45
+
46
+ # Rotary position embeddings
47
+ self.cos = torch.randn(
48
+ seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
49
+ )
50
+ self.sin = torch.randn(
51
+ seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
52
+ )
53
+
54
+ # Output tensors (in-place, so clone inputs)
55
+ self.out1 = self.x1.clone()
56
+ self.out2 = self.x2.clone()
57
+
58
+ def benchmark_base(self):
59
+ # Reset outputs to input values for in-place operation
60
+ self.out1.copy_(self.x1)
61
+ self.out2.copy_(self.x2)
62
+ self.kernel.apply_rotary(
63
+ self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False
64
+ )
65
+
66
+ def verify_base(self) -> torch.Tensor:
67
+ ref_out1, ref_out2 = apply_rotary_reference(
68
+ self.x1, self.x2, self.cos, self.sin, False
69
+ )
70
+ # Concatenate for comparison (benchmark compares self.out with returned tensor)
71
+ self.out = torch.cat([self.out1, self.out2], dim=-1)
72
+ return torch.cat([ref_out1, ref_out2], dim=-1)
73
+
74
+ def setup_large(self):
75
+ batch_size = 8
76
+ seqlen = 512
77
+ num_heads = 32
78
+ rotary_dim = 64
79
+
80
+ self.x1 = torch.randn(
81
+ batch_size,
82
+ seqlen,
83
+ num_heads,
84
+ rotary_dim,
85
+ device=self.device,
86
+ dtype=torch.float32,
87
+ )
88
+ self.x2 = torch.randn(
89
+ batch_size,
90
+ seqlen,
91
+ num_heads,
92
+ rotary_dim,
93
+ device=self.device,
94
+ dtype=torch.float32,
95
+ )
96
+
97
+ self.cos = torch.randn(
98
+ seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
99
+ )
100
+ self.sin = torch.randn(
101
+ seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
102
+ )
103
+
104
+ self.out1 = self.x1.clone()
105
+ self.out2 = self.x2.clone()
106
+
107
+ def benchmark_large(self):
108
+ self.out1.copy_(self.x1)
109
+ self.out2.copy_(self.x2)
110
+ self.kernel.apply_rotary(
111
+ self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False
112
+ )
113
+
114
+ def verify_large(self) -> torch.Tensor:
115
+ ref_out1, ref_out2 = apply_rotary_reference(
116
+ self.x1, self.x2, self.cos, self.sin, False
117
+ )
118
+ self.out = torch.cat([self.out1, self.out2], dim=-1)
119
+ return torch.cat([ref_out1, ref_out2], dim=-1)