| import torch |
| import torch.nn.functional as F |
| from unicore.modules import softmax_dropout |
|
|
|
|
| def gen_attn_mask(mask, neg_inf): |
| assert neg_inf < -1e4 |
| attn_mask = torch.zeros_like(mask) |
| attn_mask[mask == 0] = neg_inf |
| return attn_mask |
|
|
|
|
| def normal_softmax(a, mask, bias): |
| return F.softmax(a + mask + bias, dim=-1) |
|
|
|
|
| def fused_softmax(a, mask, bias): |
| return softmax_dropout(a, 0, True, mask=mask, bias=bias) |
|
|
|
|
| def wrap_forward_backward(func, a1, mask, bias1): |
| a = a1.clone() |
| bias = bias1.clone() |
| a.requires_grad = True |
| bias.requires_grad = True |
| output = func(a, mask, bias) |
| o = output.float().sum() |
| o.backward() |
| return output, a.grad, bias.grad |
|
|
|
|
| def check_diff(a, b, name, eps=1e-3): |
| assert (a - b).abs().max() < eps, "name {}, diff {}".format( |
| name, (a - b).abs().max() |
| ) |
|
|
|
|
| def test_softmax(): |
| n_batch = 4 |
| n_heads = 8 |
| n_query = 128 |
| test_dims = [64, 128, 256, 512, 1024, 1536, 2048] |
| test_dtype = [torch.float32, torch.float16, torch.bfloat16] |
| test_device = torch.device("cuda") |
| for last_dim in test_dims: |
| for dtype in test_dtype: |
| x = torch.rand( |
| n_batch, |
| n_heads, |
| n_query, |
| last_dim, |
| dtype=dtype, |
| device=test_device, |
| ) |
| mask = gen_attn_mask( |
| ( |
| torch.rand( |
| n_batch, |
| 1, |
| 1, |
| last_dim, |
| dtype=dtype, |
| device=test_device, |
| ) |
| > 0.2 |
| ).type(x.dtype), |
| -3e4, |
| ) |
| bias = torch.rand( |
| n_batch, n_heads, n_query, last_dim, dtype=dtype, device=test_device |
| ) |
| out_a1, out_b1, out_c1 = wrap_forward_backward( |
| normal_softmax, x, mask, bias |
| ) |
| out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias) |
| check_diff(out_a1, out_a2, "output") |
| check_diff(out_b1, out_b2, "grad_input") |
| check_diff(out_c1, out_c2, "grad_bias") |
|
|
|
|
| def test_tri_softmax1(): |
| n_batch = 2 |
| n_groups = 32 |
| n_heads = 8 |
| n_query = 128 |
| test_dims = [64, 128, 256, 512, 1024, 1536, 2048] |
| test_dtype = [torch.float32, torch.float16, torch.bfloat16] |
| test_device = torch.device("cuda") |
| for last_dim in test_dims: |
| for dtype in test_dtype: |
| x = torch.rand( |
| n_batch, |
| n_groups, |
| n_heads, |
| n_query, |
| last_dim, |
| dtype=dtype, |
| device=test_device, |
| ) |
| mask = gen_attn_mask( |
| ( |
| torch.rand( |
| n_batch, |
| n_groups, |
| 1, |
| 1, |
| last_dim, |
| dtype=dtype, |
| device=test_device, |
| ) |
| > 0.2 |
| ).type(x.dtype), |
| -3e4, |
| ) |
| bias = torch.rand( |
| 1, 1, n_heads, n_query, last_dim, dtype=dtype, device=test_device |
| ) |
| out_a1, out_b1, out_c1 = wrap_forward_backward( |
| normal_softmax, x, mask, bias |
| ) |
| out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias) |
| check_diff(out_a1, out_a2, "output") |
| check_diff(out_b1, out_b2, "grad_input") |
| check_diff(out_c1, out_c2, "grad_bias") |
|
|
|
|
| def test_tri_softmax2(): |
| n_batch = 2 |
| n_groups = 32 |
| n_heads = 8 |
| n_query = 128 |
| test_dims = [64, 128, 256, 512, 1024, 1536, 2048] |
| test_dtype = [torch.float32, torch.float16, torch.bfloat16] |
| test_device = torch.device("cuda") |
| for last_dim in test_dims: |
| for dtype in test_dtype: |
| x = torch.rand( |
| n_batch, |
| n_groups, |
| n_heads, |
| n_query, |
| last_dim, |
| dtype=dtype, |
| device=test_device, |
| ) |
| mask = gen_attn_mask( |
| ( |
| torch.rand( |
| n_batch, |
| n_groups, |
| n_heads, |
| 1, |
| last_dim, |
| dtype=dtype, |
| device=test_device, |
| ) |
| > 0.2 |
| ).type(x.dtype), |
| -3e4, |
| ) |
| bias = torch.rand( |
| 1, n_groups, n_heads, n_query, last_dim, dtype=dtype, device=test_device |
| ) |
| out_a1, out_b1, out_c1 = wrap_forward_backward( |
| normal_softmax, x, mask, bias |
| ) |
| out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias) |
| check_diff(out_a1, out_a2, "output") |
| check_diff(out_b1, out_b2, "grad_input") |
| check_diff(out_c1, out_c2, "grad_bias") |
|
|