|
|
| #include <torch/extension.h> |
|
|
| #include <vector> |
|
|
| using torch::Tensor; |
| using torch::IntArrayRef; |
|
|
| std::tuple<Tensor, Tensor> custom_ctc_loss_gpu( |
| const Tensor& log_probs, |
| const Tensor& targets, |
| const Tensor& realval, |
| const Tensor& targets_realval, |
| IntArrayRef input_lengths, |
| IntArrayRef target_lengths, |
| double const sigma, |
| int64_t BLANK, |
| int64_t BLANK_1 |
| ); |
| std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu( |
| const Tensor& grad, |
| const Tensor& log_probs, |
| const Tensor& targets, |
| const Tensor& realval, |
| const Tensor& targets_realval, |
| IntArrayRef input_lengths, |
| IntArrayRef target_lengths, |
| const Tensor& neg_log_likelihood, |
| const Tensor& log_alpha, |
| double const sigma, |
| int64_t BLANK, |
| int64_t BLANK_1, |
| bool zero_infinity |
| ); |
|
|
|
|
| std::tuple<Tensor, Tensor> custom_ctc_loss_gpu_driver( |
| const Tensor& log_probs, |
| const Tensor& targets, |
| const Tensor& realval, |
| const Tensor& targets_realval, |
| const Tensor& input_lengths, |
| const Tensor& target_lengths, |
| double const sigma, |
| int64_t BLANK, |
| int64_t BLANK_1, |
| bool zero_infinity |
| ) { |
| (void)zero_infinity; |
| Tensor ilc = input_lengths.contiguous(); |
| Tensor tlc = target_lengths.contiguous(); |
| IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel()); |
| IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel()); |
| return custom_ctc_loss_gpu(log_probs, targets, realval, targets_realval, il, tl, sigma, BLANK, BLANK_1); |
| } |
|
|
| std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu_driver( |
| const Tensor& grad, |
| const Tensor& log_probs, |
| const Tensor& targets, |
| const Tensor& realval, |
| const Tensor& targets_realval, |
| const Tensor& input_lengths, |
| const Tensor& target_lengths, |
| const Tensor& neg_log_likelihood, |
| const Tensor& log_alpha, |
| double const sigma, |
| int64_t BLANK, |
| int64_t BLANK_1, |
| bool zero_infinity |
| ) { |
| Tensor ilc = input_lengths.contiguous(); |
| Tensor tlc = target_lengths.contiguous(); |
| IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel()); |
| IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel()); |
| return custom_ctc_loss_backward_gpu(grad, log_probs, targets, realval, targets_realval, il, tl, neg_log_likelihood, log_alpha, sigma, BLANK, BLANK_1, zero_infinity); |
| } |
|
|
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| m.def("forward", &custom_ctc_loss_gpu_driver, "custom CTC forward (CUDA)"); |
| m.def("backward", &custom_ctc_loss_backward_gpu_driver, "custom CTC backward (CUDA)"); |
| } |
|
|