| import torch |
| import megablocks |
|
|
|
|
| def randn(bs, x, y): |
| out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) |
| return out.cuda().to(torch.bfloat16) |
|
|
|
|
| def gmm(a, b, batch_sizes, trans_b=False): |
| batch_sizes = batch_sizes.cpu().numpy() |
|
|
| out = [] |
| start = 0 |
| for i, size in enumerate(batch_sizes): |
| rhs = b[i, :, :].t() if trans_b else b[i, :, :] |
| out.append(a[start : start + size, :] @ rhs) |
| start += size |
| return torch.cat(out) |
|
|
|
|
| def test_gmm(): |
| z = 1 |
| m = 128 |
| n = 128 |
| k = 128 |
| trans_b = False |
| batch_sizes_on_device = False |
| |
| |
|
|
| torch.manual_seed(0) |
| a = randn(z, m, k).view(-1, k) |
| b = randn(z, n, k) if trans_b else randn(z, k, n) |
| batch_sizes = torch.tensor([m] * z) |
| if batch_sizes_on_device: |
| batch_sizes = batch_sizes.cuda() |
|
|
| a.requires_grad_(True) |
| b.requires_grad_(True) |
| a_ref = a.detach().clone().requires_grad_(True) |
| b_ref = b.detach().clone().requires_grad_(True) |
|
|
| |
| out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
| print("out", out) |
|
|
| expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) |
|
|
| assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}" |
|
|
| out.sum().backward() |
|
|
| expected_out.sum().backward() |
| assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}" |
| assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}" |
| print("Test passed successfully!") |