|
|
| import torch |
| import custom_ctc |
| import custom_ctc_gpu |
|
|
| import numpy as np |
|
|
| import torch.nn.functional as F |
| from torch.autograd import gradcheck |
|
|
| custom_ctc_f = custom_ctc.CustomCTCLossFunction.apply |
| custom_ctc_f_gpu = custom_ctc_gpu.CustomCTCLossFunction.apply |
|
|
| def test_ctc_loss_custom(device): |
| batch_size = 64 |
| num_labels = 101 |
| target_length = 15 |
| gradcheck_input_size = 10 |
|
|
| ZERO_NONE = 0 |
| ZERO_SOME = 1 |
| ZERO_ALL = 2 |
|
|
| |
| tests = [(150, False, ZERO_NONE), |
| (150, True, ZERO_NONE), |
| (50, True, ZERO_SOME), |
| (50, True, ZERO_ALL)] |
|
|
| tests += [(50, False, ZERO_NONE), |
| (50, True, ZERO_NONE), |
| (150, True, ZERO_SOME), |
| (150, True, ZERO_ALL)] |
|
|
| for input_length, vary_lengths, zero_mode in tests: |
| targets = torch.randint(1, num_labels, (batch_size, target_length), |
| device=device, dtype=torch.long) |
| x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True) |
| tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1, |
| device=device) |
| input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item() |
| if vary_lengths or i == 0 else input_length) for i in range(batch_size)] |
| if zero_mode == ZERO_ALL: |
| target_lengths = [0 for _ in range(batch_size)] |
| else: |
| target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item() |
| if vary_lengths else target_length) for _ in range(batch_size)] |
| if zero_mode == ZERO_SOME: |
| idxes = torch.randint(0, batch_size, (10,)) |
| for i in idxes: |
| target_lengths[i] = 0 |
|
|
| num_realval = np.random.randint(1, 16) |
| rv_x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True) |
| tile_factors_rv = torch.randn(batch_size * input_length * num_realval // gradcheck_input_size + 1, |
| device=device) |
| |
| targets_realvals = torch.randn(batch_size, input_length, num_realval, dtype=torch.double) |
|
|
| blank1 = np.random.randint(1, num_labels - 1) |
|
|
| def ctc_after_softmax(x, rv): |
| x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels] |
| .view(batch_size, input_length, num_labels)) |
| rv_full = ((rv[:, None] * tile_factors_rv[None, :]).view(-1)[:input_length * batch_size * num_realval] |
| .view(batch_size, input_length, num_realval)) |
| log_probs = torch.log_softmax(x_full, 2) |
| return custom_ctc_f(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 2.2, 0, blank1, 'mean', True) |
|
|
| gradcheck(ctc_after_softmax, [x, rv_x]) |
|
|
|
|
| def test_ctc_loss_custom_gpu(device, fp = torch.float32): |
| print('testing GPU gradient for %s' % str(fp)) |
| batch_size = 64 |
| num_labels = 101 |
| target_length = 15 |
| gradcheck_input_size = 10 |
|
|
| ZERO_NONE = 0 |
| ZERO_SOME = 1 |
| ZERO_ALL = 2 |
|
|
| |
| tests = [(150, False, ZERO_NONE), |
| (150, True, ZERO_NONE), |
| (50, True, ZERO_SOME), |
| (50, True, ZERO_ALL)] |
|
|
| tests += [(50, False, ZERO_NONE), |
| (50, True, ZERO_NONE), |
| (150, True, ZERO_SOME), |
| (150, True, ZERO_ALL)] |
|
|
| for input_length, vary_lengths, zero_mode in tests: |
| targets = torch.randint(1, num_labels, (batch_size, target_length), |
| device=device, dtype=torch.long) |
| x = torch.randn(gradcheck_input_size, dtype=fp, device=device) |
| tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1, |
| device=device) |
| input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item() |
| if vary_lengths or i == 0 else input_length) for i in range(batch_size)] |
| if zero_mode == ZERO_ALL: |
| target_lengths = [0 for _ in range(batch_size)] |
| else: |
| target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item() |
| if vary_lengths else target_length) for _ in range(batch_size)] |
| if zero_mode == ZERO_SOME: |
| idxes = torch.randint(0, batch_size, (10,)) |
| for i in idxes: |
| target_lengths[i] = 0 |
|
|
| num_realval = np.random.randint(1, 16) |
| rv_x = torch.randn(gradcheck_input_size, dtype=fp, device=device) |
| tile_factors_rv = torch.randn(batch_size * input_length * num_realval // gradcheck_input_size + 1, |
| device=device) |
| |
| targets_realvals = torch.randn(batch_size, input_length, num_realval, dtype=fp) |
|
|
| blank1 = np.random.randint(1, num_labels - 1) |
|
|
| x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels] |
| .view(batch_size, input_length, num_labels)) |
| rv_full = ((rv_x[:, None] * tile_factors_rv[None, :]).view(-1)[:input_length * batch_size * num_realval] |
| .view(batch_size, input_length, num_realval)) |
| log_probs = torch.log_softmax(x_full, 2) |
| log_probs.requires_grad_() |
| rv_full.requires_grad_() |
| grad_out = torch.randn(batch_size, device='cpu', dtype=fp) |
| loss_native = custom_ctc_f(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 1, 0, blank1, 'none', True) |
| grad_native = torch.autograd.grad(loss_native, [log_probs, rv_full], grad_out) |
| if torch.any(loss_native < 0) : |
| breakpoint() |
| |
| log_probs.requires_grad_(False) |
| rv_full.requires_grad_(False) |
| log_probs = log_probs.cuda() |
| rv_full = rv_full.cuda() |
| log_probs.requires_grad_() |
| rv_full.requires_grad_() |
| targets = targets.cuda() |
| targets_realvals = targets_realvals.cuda() |
|
|
| loss_gpu = custom_ctc_f_gpu(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 1, 0, blank1, 'none', True) |
| grad_gpu = torch.autograd.grad(loss_gpu, [log_probs, rv_full], grad_out.cuda()) |
| |
| assert torch.allclose(loss_native, loss_gpu.cpu(), rtol=1e-4, atol=1e-4) |
| print((grad_native[0] - grad_gpu[0].cpu()).abs().sum()) |
| if not torch.allclose(grad_native[0], grad_gpu[0].cpu(), rtol=1e-2, atol=1e-2) : |
| breakpoint() |
| print((grad_native[1] - grad_gpu[1].cpu()).abs().sum()) |
| assert torch.allclose(grad_native[1], grad_gpu[1].cpu(), rtol=1e-2, atol=1e-2) |
|
|
| if __name__ == '__main__' : |
| test_ctc_loss_custom('cpu:0') |
| for _ in range(100) : |
| test_ctc_loss_custom_gpu('cpu:0') |
| test_ctc_loss_custom_gpu('cpu:0', torch.double) |
| |
| print('test passed') |
|
|