| |
| |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
| from megablocks import ops |
|
|
| TOPOLOGY_TESTS = ( |
| (1024, 1536, 2), |
| (1024, 1536, 4), |
| (1024, 1536, 8), |
| (1024, 1536, 16), |
| (1024, 1536, 32), |
| (1024, 1536, 64), |
| (1024, 1536, 128), |
| (1024, 1536, 256), |
| (1024, 1536, 512), |
| (16384, 768, 2), |
| (16384, 768, 4), |
| (16384, 768, 8), |
| (16384, 768, 16), |
| (16384, 768, 32), |
| (16384, 768, 64), |
| (16384, 768, 128), |
| (16384, 768, 256), |
| (16384, 768, 512), |
| (16384, 768, 1024), |
| (8, 14336, 8), |
| ) |
|
|
|
|
| @pytest.mark.gpu |
| @pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS) |
| def test_topology(sl: int, hs: int, ne: int): |
| |
| blocking = 128 |
| assert hs % blocking == 0 |
|
|
| |
| top_expert = torch.randint(0, ne, (sl,)).cuda().int() |
| tokens_per_expert = ops.histogram(top_expert, ne) |
| padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking) |
| padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) |
|
|
| |
| output_block_rows = int(padded_bins[-1]) // blocking |
| output_block_columns = hs // blocking |
|
|
| def topology( |
| padded_bins: torch.Tensor, |
| blocking: torch.Tensor, |
| rows: int, |
| columns: int, |
| ): |
| padded_bins = padded_bins.cpu().numpy() |
|
|
| out = np.zeros([rows * columns]) |
| start = 0 |
| for i in range(padded_bins.shape[0]): |
| end = padded_bins[i] // blocking |
| while start < end: |
| for j in range(columns): |
| out[start * columns + j] = j + i * columns |
| start += 1 |
| return torch.from_numpy(out).cuda().short() |
|
|
| out = ops.topology( |
| padded_bins, |
| blocking, |
| output_block_rows, |
| output_block_columns, |
| ) |
| expected_out = topology( |
| padded_bins, |
| blocking, |
| output_block_rows, |
| output_block_columns, |
| ) |
| assert torch.all(torch.eq(out, expected_out)) |
|
|