| import torch |
| from torch.autograd import gradcheck |
|
|
| import kornia |
| import kornia.testing as utils |
|
|
|
|
| class TestBatchedForward: |
| def test_runbatch(self, device): |
| patches = torch.rand(34, 1, 32, 32) |
| sift = kornia.feature.SIFTDescriptor(32) |
| desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32) |
| desc = sift(patches) |
| assert torch.allclose(desc, desc_batched) |
|
|
| def test_runone(self, device): |
| patches = torch.rand(16, 1, 32, 32) |
| sift = kornia.feature.SIFTDescriptor(32) |
| desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32) |
| desc = sift(patches) |
| assert torch.allclose(desc, desc_batched) |
|
|
| def test_gradcheck(self, device): |
| batch_size, channels, height, width = 3, 2, 5, 4 |
| img = torch.rand(batch_size, channels, height, width, device=device) |
| img = utils.tensor_to_gradcheck_var(img) |
| assert gradcheck( |
| kornia.utils.memory.batched_forward, |
| (kornia.feature.BlobHessian(), img, device, 2), |
| raise_exception=True, |
| nondet_tol=1e-4, |
| ) |
|
|