drbh commited on
Commit
1988537
·
unverified ·
0 Parent(s):

Migrated from kernels-community/paged-attention

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. README.md +23 -0
  3. benchmarks/benchmark.py +263 -0
  4. build.toml +110 -0
  5. build/torch210-cxx11-cu126-aarch64-linux/__init__.py +21 -0
  6. build/torch210-cxx11-cu126-aarch64-linux/_custom_ops.py +173 -0
  7. build/torch210-cxx11-cu126-aarch64-linux/_ops.py +9 -0
  8. build/torch210-cxx11-cu126-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so +3 -0
  9. build/torch210-cxx11-cu126-aarch64-linux/metadata.json +18 -0
  10. build/torch210-cxx11-cu126-aarch64-linux/paged_attention/__init__.py +26 -0
  11. build/torch210-cxx11-cu126-aarch64-linux/platforms.py +92 -0
  12. build/torch210-cxx11-cu126-x86_64-linux/__init__.py +21 -0
  13. build/torch210-cxx11-cu126-x86_64-linux/_custom_ops.py +173 -0
  14. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +9 -0
  15. build/torch210-cxx11-cu126-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so +3 -0
  16. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +18 -0
  17. build/torch210-cxx11-cu126-x86_64-linux/paged_attention/__init__.py +26 -0
  18. build/torch210-cxx11-cu126-x86_64-linux/platforms.py +92 -0
  19. build/torch210-cxx11-cu128-aarch64-linux/__init__.py +21 -0
  20. build/torch210-cxx11-cu128-aarch64-linux/_custom_ops.py +173 -0
  21. build/torch210-cxx11-cu128-aarch64-linux/_ops.py +9 -0
  22. build/torch210-cxx11-cu128-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so +3 -0
  23. build/torch210-cxx11-cu128-aarch64-linux/metadata.json +21 -0
  24. build/torch210-cxx11-cu128-aarch64-linux/paged_attention/__init__.py +26 -0
  25. build/torch210-cxx11-cu128-aarch64-linux/platforms.py +92 -0
  26. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +21 -0
  27. build/torch210-cxx11-cu128-x86_64-linux/_custom_ops.py +173 -0
  28. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +9 -0
  29. build/torch210-cxx11-cu128-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so +3 -0
  30. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +21 -0
  31. build/torch210-cxx11-cu128-x86_64-linux/paged_attention/__init__.py +26 -0
  32. build/torch210-cxx11-cu128-x86_64-linux/platforms.py +92 -0
  33. build/torch210-cxx11-cu130-aarch64-linux/__init__.py +21 -0
  34. build/torch210-cxx11-cu130-aarch64-linux/_custom_ops.py +173 -0
  35. build/torch210-cxx11-cu130-aarch64-linux/_ops.py +9 -0
  36. build/torch210-cxx11-cu130-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so +3 -0
  37. build/torch210-cxx11-cu130-aarch64-linux/metadata.json +19 -0
  38. build/torch210-cxx11-cu130-aarch64-linux/paged_attention/__init__.py +26 -0
  39. build/torch210-cxx11-cu130-aarch64-linux/platforms.py +92 -0
  40. build/torch210-cxx11-cu130-x86_64-linux/__init__.py +21 -0
  41. build/torch210-cxx11-cu130-x86_64-linux/_custom_ops.py +173 -0
  42. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +9 -0
  43. build/torch210-cxx11-cu130-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so +3 -0
  44. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +19 -0
  45. build/torch210-cxx11-cu130-x86_64-linux/paged_attention/__init__.py +26 -0
  46. build/torch210-cxx11-cu130-x86_64-linux/platforms.py +92 -0
  47. build/torch210-cxx11-rocm70-x86_64-linux/__init__.py +21 -0
  48. build/torch210-cxx11-rocm70-x86_64-linux/_custom_ops.py +173 -0
  49. build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +9 -0
  50. build/torch210-cxx11-rocm70-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so +3 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
37
+ *.metallib filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - kernels
5
+ ---
6
+
7
+ ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/paged-attention)
8
+
9
+ ## attention
10
+
11
+ Paged attention kernels from [vLLM](https://github.com/vllm-project/) and [mistral.rs](https://github.com/EricLBuehler/mistral.rs).
12
+
13
+
14
+ ### Performance
15
+
16
+ <img class="dark:hidden border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_light_animation.svg" />
17
+ <img class="hidden dark:block border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_dark_animation.svg" />
18
+
19
+ <img class="dark:hidden border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_light_latency.svg" />
20
+ <img class="hidden dark:block border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_dark_latency.svg" />
21
+
22
+ <img class="dark:hidden border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_light_throughput.svg" />
23
+ <img class="hidden dark:block border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_dark_throughput.svg" />
benchmarks/benchmark.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from kernels.benchmark import Benchmark
4
+
5
+
6
+ def ref_masked_attention(
7
+ query: torch.Tensor,
8
+ key: torch.Tensor,
9
+ value: torch.Tensor,
10
+ scale: float,
11
+ ) -> torch.Tensor:
12
+ # query: (q, h, d), key: (k, h, d), value: (k, h, d)
13
+ # Transpose to (h, q, d) and (h, k, d) for batched matmul
14
+ q = query.transpose(0, 1) # (h, q, d)
15
+ k = key.transpose(0, 1) # (h, k, d)
16
+ v = value.transpose(0, 1) # (h, k, d)
17
+
18
+ # Compute attention scores: (h, q, d) @ (h, d, k) -> (h, q, k)
19
+ attn_weights = (scale * torch.matmul(q, k.transpose(-1, -2))).float()
20
+ attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
21
+
22
+ # Compute output: (h, q, k) @ (h, k, d) -> (h, q, d)
23
+ out = torch.matmul(attn_weights, v)
24
+
25
+ # Transpose back to (q, h, d)
26
+ return out.transpose(0, 1)
27
+
28
+
29
+ def ref_paged_attention(
30
+ query: torch.Tensor,
31
+ key_cache: torch.Tensor,
32
+ value_cache: torch.Tensor,
33
+ block_tables: torch.Tensor,
34
+ seq_lens: torch.Tensor,
35
+ scale: float,
36
+ ) -> torch.Tensor:
37
+ num_seqs = query.shape[0]
38
+ num_heads = query.shape[1]
39
+ head_size = query.shape[2]
40
+ block_size = value_cache.shape[3]
41
+ max_seq_len = int(seq_lens.max().item())
42
+
43
+ # Create position indices for all sequences up to max_seq_len
44
+ positions = torch.arange(max_seq_len, device=query.device)
45
+ block_indices = positions // block_size # (max_seq_len,)
46
+ block_offsets = positions % block_size # (max_seq_len,)
47
+
48
+ # Gather block numbers for all sequences: (num_seqs, max_seq_len)
49
+ block_numbers = block_tables[:, block_indices.long()]
50
+
51
+ # Flatten for gathering: (num_seqs * max_seq_len,)
52
+ flat_block_numbers = block_numbers.reshape(-1)
53
+ flat_offsets = block_offsets.repeat(num_seqs)
54
+
55
+ # Gather keys: key_cache is (num_blocks, num_heads, head_size // x, block_size, x)
56
+ # Index into [block_number, :, :, offset, :] and reshape
57
+ keys = key_cache[flat_block_numbers, :, :, flat_offsets, :]
58
+ keys = keys.reshape(num_seqs, max_seq_len, num_heads, head_size)
59
+ keys = keys.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size)
60
+
61
+ # Gather values: value_cache is (num_blocks, num_heads, head_size, block_size)
62
+ values = value_cache[flat_block_numbers, :, :, flat_offsets]
63
+ values = values.reshape(num_seqs, max_seq_len, num_heads, head_size)
64
+ values = values.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size)
65
+
66
+ # Query: (num_seqs, num_heads, head_size) -> (num_seqs, num_heads, 1, head_size)
67
+ q = query.unsqueeze(2)
68
+
69
+ # Compute attention scores: (num_seqs, num_heads, 1, head_size) @ (num_seqs, num_heads, head_size, max_seq_len)
70
+ attn_weights = (scale * torch.matmul(q, keys.transpose(-1, -2))).float()
71
+
72
+ # Create causal mask for variable sequence lengths
73
+ # Mask out positions beyond seq_len for each sequence
74
+ seq_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze(
75
+ 1
76
+ ) # (num_seqs, max_seq_len)
77
+ seq_mask = seq_mask.unsqueeze(1).unsqueeze(2) # (num_seqs, 1, 1, max_seq_len)
78
+ attn_weights = attn_weights.masked_fill(seq_mask, float("-inf"))
79
+
80
+ attn_weights = torch.softmax(attn_weights, dim=-1).to(values.dtype)
81
+
82
+ # Compute output: (num_seqs, num_heads, 1, max_seq_len) @ (num_seqs, num_heads, max_seq_len, head_size)
83
+ out = torch.matmul(attn_weights, values)
84
+
85
+ return out.squeeze(2) # (num_seqs, num_heads, head_size)
86
+
87
+
88
+ class PagedAttentionBenchmark(Benchmark):
89
+ seed: int = 42
90
+
91
+ def setup(self):
92
+ num_seqs = 4
93
+ num_heads = 8
94
+ head_size = 64
95
+ block_size = 16
96
+ max_seq_len = 128
97
+ num_blocks = 64
98
+ dtype = torch.float16
99
+
100
+ self.num_heads = num_heads
101
+ self.block_size = block_size
102
+ self.max_seq_len = max_seq_len
103
+ self.scale = 1.0 / (head_size**0.5)
104
+
105
+ # Query tensor (current token)
106
+ self.query = torch.randn(
107
+ num_seqs, num_heads, head_size, device=self.device, dtype=dtype
108
+ )
109
+
110
+ # KV cache with proper layout for the kernel
111
+ # x = 16 // element_size, for float16 x = 8
112
+ x = 16 // torch.tensor([], dtype=dtype).element_size()
113
+ self.key_cache = torch.randn(
114
+ num_blocks,
115
+ num_heads,
116
+ head_size // x,
117
+ block_size,
118
+ x,
119
+ device=self.device,
120
+ dtype=dtype,
121
+ )
122
+ self.value_cache = torch.randn(
123
+ num_blocks,
124
+ num_heads,
125
+ head_size,
126
+ block_size,
127
+ device=self.device,
128
+ dtype=dtype,
129
+ )
130
+
131
+ # Block tables: mapping from sequences to memory blocks
132
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
133
+ self.block_tables = torch.randint(
134
+ 0,
135
+ num_blocks,
136
+ (num_seqs, max_num_blocks_per_seq),
137
+ device=self.device,
138
+ dtype=torch.int32,
139
+ )
140
+
141
+ # Sequence lengths
142
+ self.seq_lens = torch.tensor(
143
+ [64, 96, 48, 128], device=self.device, dtype=torch.int32
144
+ )
145
+
146
+ # KV scales
147
+ self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
148
+ self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
149
+
150
+ # Output tensor
151
+ self.out = torch.empty_like(self.query)
152
+
153
+ def benchmark_base(self):
154
+ self.kernel.paged_attention_v1(
155
+ self.out,
156
+ self.query,
157
+ self.key_cache,
158
+ self.value_cache,
159
+ num_kv_heads=self.num_heads,
160
+ scale=self.scale,
161
+ block_tables=self.block_tables,
162
+ seq_lens=self.seq_lens,
163
+ block_size=self.block_size,
164
+ max_seq_len=self.max_seq_len,
165
+ alibi_slopes=None,
166
+ kv_cache_dtype="auto",
167
+ k_scale=self.k_scale,
168
+ v_scale=self.v_scale,
169
+ )
170
+
171
+ def verify_base(self) -> torch.Tensor:
172
+ return ref_paged_attention(
173
+ self.query,
174
+ self.key_cache,
175
+ self.value_cache,
176
+ self.block_tables,
177
+ self.seq_lens,
178
+ self.scale,
179
+ )
180
+
181
+ def setup_large(self):
182
+ num_seqs = 16
183
+ num_heads = 32
184
+ head_size = 128
185
+ block_size = 16
186
+ max_seq_len = 512
187
+ num_blocks = 256
188
+ dtype = torch.float16
189
+
190
+ self.num_heads = num_heads
191
+ self.block_size = block_size
192
+ self.max_seq_len = max_seq_len
193
+ self.scale = 1.0 / (head_size**0.5)
194
+
195
+ self.query = torch.randn(
196
+ num_seqs, num_heads, head_size, device=self.device, dtype=dtype
197
+ )
198
+
199
+ x = 16 // torch.tensor([], dtype=dtype).element_size()
200
+ self.key_cache = torch.randn(
201
+ num_blocks,
202
+ num_heads,
203
+ head_size // x,
204
+ block_size,
205
+ x,
206
+ device=self.device,
207
+ dtype=dtype,
208
+ )
209
+ self.value_cache = torch.randn(
210
+ num_blocks,
211
+ num_heads,
212
+ head_size,
213
+ block_size,
214
+ device=self.device,
215
+ dtype=dtype,
216
+ )
217
+
218
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
219
+ self.block_tables = torch.randint(
220
+ 0,
221
+ num_blocks,
222
+ (num_seqs, max_num_blocks_per_seq),
223
+ device=self.device,
224
+ dtype=torch.int32,
225
+ )
226
+
227
+ # Variable sequence lengths
228
+ self.seq_lens = torch.randint(
229
+ 64, max_seq_len + 1, (num_seqs,), device=self.device, dtype=torch.int32
230
+ )
231
+
232
+ self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
233
+ self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
234
+
235
+ self.out = torch.empty_like(self.query)
236
+
237
+ def benchmark_large(self):
238
+ self.kernel.paged_attention_v1(
239
+ self.out,
240
+ self.query,
241
+ self.key_cache,
242
+ self.value_cache,
243
+ num_kv_heads=self.num_heads,
244
+ scale=self.scale,
245
+ block_tables=self.block_tables,
246
+ seq_lens=self.seq_lens,
247
+ block_size=self.block_size,
248
+ max_seq_len=self.max_seq_len,
249
+ alibi_slopes=None,
250
+ kv_cache_dtype="auto",
251
+ k_scale=self.k_scale,
252
+ v_scale=self.v_scale,
253
+ )
254
+
255
+ def verify_large(self) -> torch.Tensor:
256
+ return ref_paged_attention(
257
+ self.query,
258
+ self.key_cache,
259
+ self.value_cache,
260
+ self.block_tables,
261
+ self.seq_lens,
262
+ self.scale,
263
+ )
build.toml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "paged_attention"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h"
9
+ ]
10
+
11
+ [kernel.cuda_utils]
12
+ backend = "cuda"
13
+ src = [
14
+ "cuda-utils/cuda_utils.h",
15
+ "cuda-utils/cuda_utils_kernels.cu",
16
+ ]
17
+ depends = []
18
+
19
+ [kernel.cuda_utils_rocm]
20
+ backend = "rocm"
21
+ rocm-archs = [
22
+ "gfx906",
23
+ "gfx908",
24
+ "gfx90a",
25
+ "gfx940",
26
+ "gfx941",
27
+ "gfx942",
28
+ "gfx1030",
29
+ "gfx1100",
30
+ "gfx1101",
31
+ ]
32
+ src = [
33
+ "cuda-utils/cuda_utils.h",
34
+ "cuda-utils/cuda_utils_kernels.cu",
35
+ ]
36
+ depends = ["torch"]
37
+
38
+ [kernel.paged_attention]
39
+ backend = "cuda"
40
+ src = [
41
+ "cuda-utils/cuda_utils.h",
42
+ "paged-attention/attention/attention_dtypes.h",
43
+ "paged-attention/attention/attention_generic.cuh",
44
+ "paged-attention/attention/attention_kernels.cuh",
45
+ "paged-attention/attention/attention_utils.cuh",
46
+ "paged-attention/attention/dtype_bfloat16.cuh",
47
+ "paged-attention/attention/dtype_float16.cuh",
48
+ "paged-attention/attention/dtype_float32.cuh",
49
+ "paged-attention/attention/dtype_fp8.cuh",
50
+ "paged-attention/attention/paged_attention_v1.cu",
51
+ "paged-attention/attention/paged_attention_v2.cu",
52
+ "paged-attention/cache_kernels.cu",
53
+ "paged-attention/cuda_compat.h",
54
+ "paged-attention/dispatch_utils.h",
55
+ "paged-attention/quantization/fp8/amd/quant_utils.cuh",
56
+ "paged-attention/quantization/fp8/nvidia/quant_utils.cuh",
57
+ ]
58
+ include = [ "cuda-utils", "paged-attention" ]
59
+ depends = [ "torch" ]
60
+
61
+ [kernel.paged_attention_rocm]
62
+ backend = "rocm"
63
+ rocm-archs = [
64
+ "gfx906",
65
+ "gfx908",
66
+ "gfx90a",
67
+ "gfx940",
68
+ "gfx941",
69
+ "gfx942",
70
+ "gfx1030",
71
+ "gfx1100",
72
+ "gfx1101",
73
+ ]
74
+ src = [
75
+ "cuda-utils/cuda_utils.h",
76
+ "paged-attention/attention/attention_dtypes.h",
77
+ "paged-attention/attention/attention_generic.cuh",
78
+ "paged-attention/attention/attention_kernels.cuh",
79
+ "paged-attention/attention/attention_utils.cuh",
80
+ "paged-attention/attention/dtype_bfloat16.cuh",
81
+ "paged-attention/attention/dtype_float16.cuh",
82
+ "paged-attention/attention/dtype_float32.cuh",
83
+ "paged-attention/attention/dtype_fp8.cuh",
84
+ "paged-attention/attention/paged_attention_v1.cu",
85
+ "paged-attention/attention/paged_attention_v2.cu",
86
+ "paged-attention/cache_kernels.cu",
87
+ "paged-attention/cuda_compat.h",
88
+ "paged-attention/dispatch_utils.h",
89
+ "paged-attention/quantization/fp8/amd/quant_utils.cuh",
90
+ "paged-attention/quantization/fp8/nvidia/quant_utils.cuh",
91
+ ]
92
+ include = [ "cuda-utils", "paged-attention" ]
93
+ depends = [ "torch" ]
94
+
95
+ [kernel.paged_attention_metal]
96
+ backend = "metal"
97
+ src = [
98
+ "paged-attention-metal/attention/paged_attention.metal",
99
+ "paged-attention-metal/cache/copy_blocks.metal",
100
+ "paged-attention-metal/cache/reshape_and_cache.metal",
101
+ "paged-attention-metal/convert_fp8.metal",
102
+ "paged-attention-metal/float8.metal",
103
+ "paged-attention-metal/utils.metal",
104
+ "paged-attention-metal/paged_attention.mm",
105
+ "paged-attention-metal/cache.mm",
106
+ "paged-attention-metal/convert_fp8.mm",
107
+ "paged-attention-metal/device.mm",
108
+ ]
109
+ include = [ "." ]
110
+ depends = [ "torch" ]
build/torch210-cxx11-cu126-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch210-cxx11-cu126-aarch64-linux/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch210-cxx11-cu126-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _paged_attention_cuda_83cf4a3
3
+ ops = torch.ops._paged_attention_cuda_83cf4a3
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_paged_attention_cuda_83cf4a3::{op_name}"
build/torch210-cxx11-cu126-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33d5f8b98a2a171fee0e0106dfd9174438e40cbea4d13f0f53105a0c0d49695b
3
+ size 140013424
build/torch210-cxx11-cu126-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "7.0",
9
+ "7.2",
10
+ "7.5",
11
+ "8.0",
12
+ "8.6",
13
+ "8.7",
14
+ "8.9",
15
+ "9.0+PTX"
16
+ ]
17
+ }
18
+ }
build/torch210-cxx11-cu126-aarch64-linux/paged_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu126-aarch64-linux/platforms.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+ IS_MPS = torch.backends.mps.is_available()
12
+
13
+
14
+ class Platform(ABC):
15
+ @classmethod
16
+ def seed_everything(cls, seed: int) -> None:
17
+ """
18
+ Set the seed of each random module.
19
+ `torch.manual_seed` will set seed on all devices.
20
+
21
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ @abstractmethod
28
+ def get_device_name(self, device_id: int = 0) -> str: ...
29
+
30
+ @abstractmethod
31
+ def is_cuda(self) -> bool: ...
32
+
33
+ @abstractmethod
34
+ def is_rocm(self) -> bool: ...
35
+
36
+ @abstractmethod
37
+ def is_mps(self) -> bool: ...
38
+
39
+
40
+ class CudaPlatform(Platform):
41
+ @classmethod
42
+ @lru_cache(maxsize=8)
43
+ def get_device_name(cls, device_id: int = 0) -> str:
44
+ return torch.cuda.get_device_name(0)
45
+
46
+ def is_cuda(self) -> bool:
47
+ return True
48
+
49
+ def is_rocm(self) -> bool:
50
+ return False
51
+
52
+ def is_mps(self) -> bool:
53
+ return False
54
+
55
+
56
+ class RocmPlatform(Platform):
57
+ @classmethod
58
+ @lru_cache(maxsize=8)
59
+ def get_device_name(cls, device_id: int = 0) -> str:
60
+ return torch.cuda.get_device_name(device_id)
61
+
62
+ def is_cuda(self) -> bool:
63
+ return False
64
+
65
+ def is_rocm(self) -> bool:
66
+ return True
67
+
68
+ def is_mps(self) -> bool:
69
+ return False
70
+
71
+
72
+ class MpsPlatform(Platform):
73
+ @classmethod
74
+ @lru_cache(maxsize=8)
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ return torch.cuda.get_device_name(device_id)
77
+
78
+ def is_cuda(self) -> bool:
79
+ return False
80
+
81
+ def is_rocm(self) -> bool:
82
+ return False
83
+
84
+ def is_mps(self) -> bool:
85
+ return True
86
+
87
+ current_platform = (
88
+ RocmPlatform() if IS_ROCM else
89
+ MpsPlatform() if IS_MPS else
90
+ CudaPlatform() if torch.cuda.is_available() else
91
+ None
92
+ )
build/torch210-cxx11-cu126-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch210-cxx11-cu126-x86_64-linux/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch210-cxx11-cu126-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _paged_attention_cuda_83cf4a3
3
+ ops = torch.ops._paged_attention_cuda_83cf4a3
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_paged_attention_cuda_83cf4a3::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f84331fb1023844b101c03c8f12818bb3b09c273a9442b631cc2efe87b1eee2f
3
+ size 140162704
build/torch210-cxx11-cu126-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "7.0",
9
+ "7.2",
10
+ "7.5",
11
+ "8.0",
12
+ "8.6",
13
+ "8.7",
14
+ "8.9",
15
+ "9.0+PTX"
16
+ ]
17
+ }
18
+ }
build/torch210-cxx11-cu126-x86_64-linux/paged_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu126-x86_64-linux/platforms.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+ IS_MPS = torch.backends.mps.is_available()
12
+
13
+
14
+ class Platform(ABC):
15
+ @classmethod
16
+ def seed_everything(cls, seed: int) -> None:
17
+ """
18
+ Set the seed of each random module.
19
+ `torch.manual_seed` will set seed on all devices.
20
+
21
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ @abstractmethod
28
+ def get_device_name(self, device_id: int = 0) -> str: ...
29
+
30
+ @abstractmethod
31
+ def is_cuda(self) -> bool: ...
32
+
33
+ @abstractmethod
34
+ def is_rocm(self) -> bool: ...
35
+
36
+ @abstractmethod
37
+ def is_mps(self) -> bool: ...
38
+
39
+
40
+ class CudaPlatform(Platform):
41
+ @classmethod
42
+ @lru_cache(maxsize=8)
43
+ def get_device_name(cls, device_id: int = 0) -> str:
44
+ return torch.cuda.get_device_name(0)
45
+
46
+ def is_cuda(self) -> bool:
47
+ return True
48
+
49
+ def is_rocm(self) -> bool:
50
+ return False
51
+
52
+ def is_mps(self) -> bool:
53
+ return False
54
+
55
+
56
+ class RocmPlatform(Platform):
57
+ @classmethod
58
+ @lru_cache(maxsize=8)
59
+ def get_device_name(cls, device_id: int = 0) -> str:
60
+ return torch.cuda.get_device_name(device_id)
61
+
62
+ def is_cuda(self) -> bool:
63
+ return False
64
+
65
+ def is_rocm(self) -> bool:
66
+ return True
67
+
68
+ def is_mps(self) -> bool:
69
+ return False
70
+
71
+
72
+ class MpsPlatform(Platform):
73
+ @classmethod
74
+ @lru_cache(maxsize=8)
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ return torch.cuda.get_device_name(device_id)
77
+
78
+ def is_cuda(self) -> bool:
79
+ return False
80
+
81
+ def is_rocm(self) -> bool:
82
+ return False
83
+
84
+ def is_mps(self) -> bool:
85
+ return True
86
+
87
+ current_platform = (
88
+ RocmPlatform() if IS_ROCM else
89
+ MpsPlatform() if IS_MPS else
90
+ CudaPlatform() if torch.cuda.is_available() else
91
+ None
92
+ )
build/torch210-cxx11-cu128-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch210-cxx11-cu128-aarch64-linux/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch210-cxx11-cu128-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _paged_attention_cuda_83cf4a3
3
+ ops = torch.ops._paged_attention_cuda_83cf4a3
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_paged_attention_cuda_83cf4a3::{op_name}"
build/torch210-cxx11-cu128-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dd2622118e8d4a9e7d952da74ffdb90627c4bb7a76a3be349847427b43db1dd
3
+ size 167603936
build/torch210-cxx11-cu128-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "10.0",
9
+ "10.1",
10
+ "12.0+PTX",
11
+ "7.0",
12
+ "7.2",
13
+ "7.5",
14
+ "8.0",
15
+ "8.6",
16
+ "8.7",
17
+ "8.9",
18
+ "9.0"
19
+ ]
20
+ }
21
+ }
build/torch210-cxx11-cu128-aarch64-linux/paged_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu128-aarch64-linux/platforms.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+ IS_MPS = torch.backends.mps.is_available()
12
+
13
+
14
+ class Platform(ABC):
15
+ @classmethod
16
+ def seed_everything(cls, seed: int) -> None:
17
+ """
18
+ Set the seed of each random module.
19
+ `torch.manual_seed` will set seed on all devices.
20
+
21
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ @abstractmethod
28
+ def get_device_name(self, device_id: int = 0) -> str: ...
29
+
30
+ @abstractmethod
31
+ def is_cuda(self) -> bool: ...
32
+
33
+ @abstractmethod
34
+ def is_rocm(self) -> bool: ...
35
+
36
+ @abstractmethod
37
+ def is_mps(self) -> bool: ...
38
+
39
+
40
+ class CudaPlatform(Platform):
41
+ @classmethod
42
+ @lru_cache(maxsize=8)
43
+ def get_device_name(cls, device_id: int = 0) -> str:
44
+ return torch.cuda.get_device_name(0)
45
+
46
+ def is_cuda(self) -> bool:
47
+ return True
48
+
49
+ def is_rocm(self) -> bool:
50
+ return False
51
+
52
+ def is_mps(self) -> bool:
53
+ return False
54
+
55
+
56
+ class RocmPlatform(Platform):
57
+ @classmethod
58
+ @lru_cache(maxsize=8)
59
+ def get_device_name(cls, device_id: int = 0) -> str:
60
+ return torch.cuda.get_device_name(device_id)
61
+
62
+ def is_cuda(self) -> bool:
63
+ return False
64
+
65
+ def is_rocm(self) -> bool:
66
+ return True
67
+
68
+ def is_mps(self) -> bool:
69
+ return False
70
+
71
+
72
+ class MpsPlatform(Platform):
73
+ @classmethod
74
+ @lru_cache(maxsize=8)
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ return torch.cuda.get_device_name(device_id)
77
+
78
+ def is_cuda(self) -> bool:
79
+ return False
80
+
81
+ def is_rocm(self) -> bool:
82
+ return False
83
+
84
+ def is_mps(self) -> bool:
85
+ return True
86
+
87
+ current_platform = (
88
+ RocmPlatform() if IS_ROCM else
89
+ MpsPlatform() if IS_MPS else
90
+ CudaPlatform() if torch.cuda.is_available() else
91
+ None
92
+ )
build/torch210-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch210-cxx11-cu128-x86_64-linux/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch210-cxx11-cu128-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _paged_attention_cuda_83cf4a3
3
+ ops = torch.ops._paged_attention_cuda_83cf4a3
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_paged_attention_cuda_83cf4a3::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37c0783a4a3628ffc43d64b65090cb4fa8b2f5cc2fe913a51901378f518d11af
3
+ size 167726096
build/torch210-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "10.0",
9
+ "10.1",
10
+ "12.0+PTX",
11
+ "7.0",
12
+ "7.2",
13
+ "7.5",
14
+ "8.0",
15
+ "8.6",
16
+ "8.7",
17
+ "8.9",
18
+ "9.0"
19
+ ]
20
+ }
21
+ }
build/torch210-cxx11-cu128-x86_64-linux/paged_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu128-x86_64-linux/platforms.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+ IS_MPS = torch.backends.mps.is_available()
12
+
13
+
14
+ class Platform(ABC):
15
+ @classmethod
16
+ def seed_everything(cls, seed: int) -> None:
17
+ """
18
+ Set the seed of each random module.
19
+ `torch.manual_seed` will set seed on all devices.
20
+
21
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ @abstractmethod
28
+ def get_device_name(self, device_id: int = 0) -> str: ...
29
+
30
+ @abstractmethod
31
+ def is_cuda(self) -> bool: ...
32
+
33
+ @abstractmethod
34
+ def is_rocm(self) -> bool: ...
35
+
36
+ @abstractmethod
37
+ def is_mps(self) -> bool: ...
38
+
39
+
40
+ class CudaPlatform(Platform):
41
+ @classmethod
42
+ @lru_cache(maxsize=8)
43
+ def get_device_name(cls, device_id: int = 0) -> str:
44
+ return torch.cuda.get_device_name(0)
45
+
46
+ def is_cuda(self) -> bool:
47
+ return True
48
+
49
+ def is_rocm(self) -> bool:
50
+ return False
51
+
52
+ def is_mps(self) -> bool:
53
+ return False
54
+
55
+
56
+ class RocmPlatform(Platform):
57
+ @classmethod
58
+ @lru_cache(maxsize=8)
59
+ def get_device_name(cls, device_id: int = 0) -> str:
60
+ return torch.cuda.get_device_name(device_id)
61
+
62
+ def is_cuda(self) -> bool:
63
+ return False
64
+
65
+ def is_rocm(self) -> bool:
66
+ return True
67
+
68
+ def is_mps(self) -> bool:
69
+ return False
70
+
71
+
72
+ class MpsPlatform(Platform):
73
+ @classmethod
74
+ @lru_cache(maxsize=8)
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ return torch.cuda.get_device_name(device_id)
77
+
78
+ def is_cuda(self) -> bool:
79
+ return False
80
+
81
+ def is_rocm(self) -> bool:
82
+ return False
83
+
84
+ def is_mps(self) -> bool:
85
+ return True
86
+
87
+ current_platform = (
88
+ RocmPlatform() if IS_ROCM else
89
+ MpsPlatform() if IS_MPS else
90
+ CudaPlatform() if torch.cuda.is_available() else
91
+ None
92
+ )
build/torch210-cxx11-cu130-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch210-cxx11-cu130-aarch64-linux/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch210-cxx11-cu130-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _paged_attention_cuda_83cf4a3
3
+ ops = torch.ops._paged_attention_cuda_83cf4a3
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_paged_attention_cuda_83cf4a3::{op_name}"
build/torch210-cxx11-cu130-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31b7d92afaaffa6d335dad007ca97f76c66a5470e6a380e03a93fca6ff2232dc
3
+ size 86068816
build/torch210-cxx11-cu130-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "10.0",
9
+ "11.0",
10
+ "12.0+PTX",
11
+ "7.5",
12
+ "8.0",
13
+ "8.6",
14
+ "8.7",
15
+ "8.9",
16
+ "9.0"
17
+ ]
18
+ }
19
+ }
build/torch210-cxx11-cu130-aarch64-linux/paged_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu130-aarch64-linux/platforms.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+ IS_MPS = torch.backends.mps.is_available()
12
+
13
+
14
+ class Platform(ABC):
15
+ @classmethod
16
+ def seed_everything(cls, seed: int) -> None:
17
+ """
18
+ Set the seed of each random module.
19
+ `torch.manual_seed` will set seed on all devices.
20
+
21
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ @abstractmethod
28
+ def get_device_name(self, device_id: int = 0) -> str: ...
29
+
30
+ @abstractmethod
31
+ def is_cuda(self) -> bool: ...
32
+
33
+ @abstractmethod
34
+ def is_rocm(self) -> bool: ...
35
+
36
+ @abstractmethod
37
+ def is_mps(self) -> bool: ...
38
+
39
+
40
+ class CudaPlatform(Platform):
41
+ @classmethod
42
+ @lru_cache(maxsize=8)
43
+ def get_device_name(cls, device_id: int = 0) -> str:
44
+ return torch.cuda.get_device_name(0)
45
+
46
+ def is_cuda(self) -> bool:
47
+ return True
48
+
49
+ def is_rocm(self) -> bool:
50
+ return False
51
+
52
+ def is_mps(self) -> bool:
53
+ return False
54
+
55
+
56
+ class RocmPlatform(Platform):
57
+ @classmethod
58
+ @lru_cache(maxsize=8)
59
+ def get_device_name(cls, device_id: int = 0) -> str:
60
+ return torch.cuda.get_device_name(device_id)
61
+
62
+ def is_cuda(self) -> bool:
63
+ return False
64
+
65
+ def is_rocm(self) -> bool:
66
+ return True
67
+
68
+ def is_mps(self) -> bool:
69
+ return False
70
+
71
+
72
+ class MpsPlatform(Platform):
73
+ @classmethod
74
+ @lru_cache(maxsize=8)
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ return torch.cuda.get_device_name(device_id)
77
+
78
+ def is_cuda(self) -> bool:
79
+ return False
80
+
81
+ def is_rocm(self) -> bool:
82
+ return False
83
+
84
+ def is_mps(self) -> bool:
85
+ return True
86
+
87
+ current_platform = (
88
+ RocmPlatform() if IS_ROCM else
89
+ MpsPlatform() if IS_MPS else
90
+ CudaPlatform() if torch.cuda.is_available() else
91
+ None
92
+ )
build/torch210-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch210-cxx11-cu130-x86_64-linux/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch210-cxx11-cu130-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _paged_attention_cuda_83cf4a3
3
+ ops = torch.ops._paged_attention_cuda_83cf4a3
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_paged_attention_cuda_83cf4a3::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fc05b440e24ece432bd009e23dbf721d191d03cfa3d020c2d52d3eaface9992
3
+ size 86563792
build/torch210-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "10.0",
9
+ "11.0",
10
+ "12.0+PTX",
11
+ "7.5",
12
+ "8.0",
13
+ "8.6",
14
+ "8.7",
15
+ "8.9",
16
+ "9.0"
17
+ ]
18
+ }
19
+ }
build/torch210-cxx11-cu130-x86_64-linux/paged_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu130-x86_64-linux/platforms.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+ IS_MPS = torch.backends.mps.is_available()
12
+
13
+
14
+ class Platform(ABC):
15
+ @classmethod
16
+ def seed_everything(cls, seed: int) -> None:
17
+ """
18
+ Set the seed of each random module.
19
+ `torch.manual_seed` will set seed on all devices.
20
+
21
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ @abstractmethod
28
+ def get_device_name(self, device_id: int = 0) -> str: ...
29
+
30
+ @abstractmethod
31
+ def is_cuda(self) -> bool: ...
32
+
33
+ @abstractmethod
34
+ def is_rocm(self) -> bool: ...
35
+
36
+ @abstractmethod
37
+ def is_mps(self) -> bool: ...
38
+
39
+
40
+ class CudaPlatform(Platform):
41
+ @classmethod
42
+ @lru_cache(maxsize=8)
43
+ def get_device_name(cls, device_id: int = 0) -> str:
44
+ return torch.cuda.get_device_name(0)
45
+
46
+ def is_cuda(self) -> bool:
47
+ return True
48
+
49
+ def is_rocm(self) -> bool:
50
+ return False
51
+
52
+ def is_mps(self) -> bool:
53
+ return False
54
+
55
+
56
+ class RocmPlatform(Platform):
57
+ @classmethod
58
+ @lru_cache(maxsize=8)
59
+ def get_device_name(cls, device_id: int = 0) -> str:
60
+ return torch.cuda.get_device_name(device_id)
61
+
62
+ def is_cuda(self) -> bool:
63
+ return False
64
+
65
+ def is_rocm(self) -> bool:
66
+ return True
67
+
68
+ def is_mps(self) -> bool:
69
+ return False
70
+
71
+
72
+ class MpsPlatform(Platform):
73
+ @classmethod
74
+ @lru_cache(maxsize=8)
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ return torch.cuda.get_device_name(device_id)
77
+
78
+ def is_cuda(self) -> bool:
79
+ return False
80
+
81
+ def is_rocm(self) -> bool:
82
+ return False
83
+
84
+ def is_mps(self) -> bool:
85
+ return True
86
+
87
+ current_platform = (
88
+ RocmPlatform() if IS_ROCM else
89
+ MpsPlatform() if IS_MPS else
90
+ CudaPlatform() if torch.cuda.is_available() else
91
+ None
92
+ )
build/torch210-cxx11-rocm70-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch210-cxx11-rocm70-x86_64-linux/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch210-cxx11-rocm70-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _paged_attention_rocm_83cf4a3
3
+ ops = torch.ops._paged_attention_rocm_83cf4a3
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_paged_attention_rocm_83cf4a3::{op_name}"
build/torch210-cxx11-rocm70-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c715078de15626c6dc53b2bb321828478a33952ed5bac5e6f5730a984445b321
3
+ size 58992416