| from math import pi |
|
|
| import pytest |
| import torch |
| from torch.autograd import gradcheck |
|
|
| import kornia.testing as utils |
| from kornia.feature.mkd import ( |
| COEFFS, |
| EmbedGradients, |
| ExplicitSpacialEncoding, |
| get_grid_dict, |
| get_kron_order, |
| MKDDescriptor, |
| MKDGradients, |
| SimpleKD, |
| spatial_kernel_embedding, |
| VonMisesKernel, |
| Whitening, |
| ) |
| from kornia.testing import assert_close |
|
|
|
|
| @pytest.mark.parametrize("ps", [5, 13, 25]) |
| def test_get_grid_dict(ps): |
| grid_dict = get_grid_dict(ps) |
| param_keys = ['x', 'y', 'phi', 'rho'] |
| assert set(grid_dict.keys()) == set(param_keys) |
| for k in param_keys: |
| assert grid_dict[k].shape == (ps, ps) |
|
|
|
|
| @pytest.mark.parametrize("d1,d2", [(1, 1), (1, 2), (2, 1), (5, 6)]) |
| def test_get_kron_order(d1, d2): |
| out = get_kron_order(d1, d2) |
| assert out.shape == (d1 * d2, 2) |
|
|
|
|
| class TestMKDGradients: |
| @pytest.mark.parametrize("ps", [5, 13, 25]) |
| def test_shape(self, ps, device): |
| inp = torch.ones(1, 1, ps, ps).to(device) |
| gradients = MKDGradients().to(device) |
| out = gradients(inp) |
| assert out.shape == (1, 2, ps, ps) |
|
|
| @pytest.mark.parametrize("bs", [1, 5, 13]) |
| def test_batch_shape(self, bs, device): |
| inp = torch.ones(bs, 1, 15, 15).to(device) |
| gradients = MKDGradients().to(device) |
| out = gradients(inp) |
| assert out.shape == (bs, 2, 15, 15) |
|
|
| def test_print(self, device): |
| gradients = MKDGradients().to(device) |
| gradients.__repr__() |
|
|
| def test_toy(self, device): |
| patch = torch.ones(1, 1, 6, 6).to(device).float() |
| patch[0, 0, :, 3:] = 0 |
| gradients = MKDGradients().to(device) |
| out = gradients(patch) |
| expected_mags_1 = torch.Tensor([0, 0, 1.0, 1.0, 0, 0]).to(device) |
| expected_mags = expected_mags_1.unsqueeze(0).repeat(6, 1) |
| expected_oris_1 = torch.Tensor([-pi, -pi, 0, 0, -pi, -pi]).to(device) |
| expected_oris = expected_oris_1.unsqueeze(0).repeat(6, 1) |
| assert_close(out[0, 0, :, :], expected_mags, atol=1e-3, rtol=1e-3) |
| assert_close(out[0, 1, :, :], expected_oris, atol=1e-3, rtol=1e-3) |
|
|
| def test_gradcheck(self, device): |
| batch_size, channels, height, width = 1, 1, 13, 13 |
| patches = torch.rand(batch_size, channels, height, width).to(device) |
| patches = utils.tensor_to_gradcheck_var(patches) |
|
|
| def grad_describe(patches): |
| mkd_grads = MKDGradients() |
| mkd_grads.to(device) |
| return mkd_grads(patches) |
|
|
| assert gradcheck(grad_describe, (patches), raise_exception=True, nondet_tol=1e-4) |
|
|
| @pytest.mark.jit |
| def test_jit(self, device, dtype): |
| B, C, H, W = 2, 1, 13, 13 |
| patches = torch.rand(B, C, H, W, device=device, dtype=dtype) |
| model = MKDGradients().to(patches.device, patches.dtype).eval() |
| model_jit = torch.jit.script(MKDGradients().to(patches.device, patches.dtype).eval()) |
| assert_close(model(patches), model_jit(patches)) |
|
|
|
|
| class TestVonMisesKernel: |
| @pytest.mark.parametrize("ps", [5, 13, 25]) |
| def test_shape(self, ps, device): |
| inp = torch.ones(1, 1, ps, ps).to(device) |
| vm = VonMisesKernel(patch_size=ps, coeffs=[0.38214156, 0.48090413]).to(device) |
| out = vm(inp) |
| assert out.shape == (1, 3, ps, ps) |
|
|
| @pytest.mark.parametrize("bs", [1, 5, 13]) |
| def test_batch_shape(self, bs, device): |
| inp = torch.ones(bs, 1, 15, 15).to(device) |
| vm = VonMisesKernel(patch_size=15, coeffs=[0.38214156, 0.48090413]).to(device) |
| out = vm(inp) |
| assert out.shape == (bs, 3, 15, 15) |
|
|
| @pytest.mark.parametrize("coeffs", COEFFS.values()) |
| def test_coeffs(self, coeffs, device): |
| inp = torch.ones(1, 1, 15, 15).to(device) |
| vm = VonMisesKernel(patch_size=15, coeffs=coeffs).to(device) |
| out = vm(inp) |
| assert out.shape == (1, 2 * len(coeffs) - 1, 15, 15) |
|
|
| def test_print(self, device): |
| vm = VonMisesKernel(patch_size=32, coeffs=[0.38214156, 0.48090413]).to(device) |
| vm.__repr__() |
|
|
| def test_toy(self, device): |
| patch = torch.ones(1, 1, 6, 6).float().to(device) |
| patch[0, 0, :, 3:] = 0 |
| vm = VonMisesKernel(patch_size=6, coeffs=[0.38214156, 0.48090413]).to(device) |
| out = vm(patch) |
| expected = torch.ones_like(out[0, 0, :, :]).to(device) |
| assert_close(out[0, 0, :, :], expected * 0.6182, atol=1e-3, rtol=1e-3) |
|
|
| expected = torch.Tensor([0.3747, 0.3747, 0.3747, 0.6935, 0.6935, 0.6935]).to(device) |
| expected = expected.unsqueeze(0).repeat(6, 1) |
| assert_close(out[0, 1, :, :], expected, atol=1e-3, rtol=1e-3) |
|
|
| expected = torch.Tensor([0.5835, 0.5835, 0.5835, 0.0000, 0.0000, 0.0000]).to(device) |
| expected = expected.unsqueeze(0).repeat(6, 1) |
| assert_close(out[0, 2, :, :], expected, atol=1e-3, rtol=1e-3) |
|
|
| def test_gradcheck(self, device): |
| batch_size, channels, ps = 1, 1, 13 |
| patches = torch.rand(batch_size, channels, ps, ps).to(device) |
| patches = utils.tensor_to_gradcheck_var(patches) |
|
|
| def vm_describe(patches, ps=13): |
| vmkernel = VonMisesKernel(patch_size=ps, coeffs=[0.38214156, 0.48090413]).double() |
| vmkernel.to(device) |
| return vmkernel(patches.double()) |
|
|
| assert gradcheck(vm_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4) |
|
|
| @pytest.mark.jit |
| def test_jit(self, device, dtype): |
| B, C, H, W = 2, 1, 13, 13 |
| patches = torch.rand(B, C, H, W, device=device, dtype=dtype) |
| model = VonMisesKernel(patch_size=13, coeffs=[0.38214156, 0.48090413]).to(patches.device, patches.dtype).eval() |
| model_jit = torch.jit.script( |
| VonMisesKernel(patch_size=13, coeffs=[0.38214156, 0.48090413]).to(patches.device, patches.dtype).eval() |
| ) |
| assert_close(model(patches), model_jit(patches)) |
|
|
|
|
| class TestEmbedGradients: |
| @pytest.mark.parametrize("ps,relative", [(5, True), (13, True), (25, True), (5, False), (13, False), (25, False)]) |
| def test_shape(self, ps, relative, device): |
| inp = torch.ones(1, 2, ps, ps).to(device) |
| emb_grads = EmbedGradients(patch_size=ps, relative=relative).to(device) |
| out = emb_grads(inp) |
| assert out.shape == (1, 7, ps, ps) |
|
|
| @pytest.mark.parametrize("bs", [1, 5, 13]) |
| def test_batch_shape(self, bs, device): |
| inp = torch.ones(bs, 2, 15, 15).to(device) |
| emb_grads = EmbedGradients(patch_size=15, relative=True).to(device) |
| out = emb_grads(inp) |
| assert out.shape == (bs, 7, 15, 15) |
|
|
| def test_print(self, device): |
| emb_grads = EmbedGradients(patch_size=15, relative=True).to(device) |
| emb_grads.__repr__() |
|
|
| def test_toy(self, device): |
| grads = torch.ones(1, 2, 6, 6).float().to(device) |
| grads[0, 0, :, 3:] = 0 |
| emb_grads = EmbedGradients(patch_size=6, relative=True).to(device) |
| out = emb_grads(grads) |
| expected = torch.ones_like(out[0, 0, :, :3]).to(device) |
| assert_close(out[0, 0, :, :3], expected * 0.3787, atol=1e-3, rtol=1e-3) |
| assert_close(out[0, 0, :, 3:], expected * 0, atol=1e-3, rtol=1e-3) |
|
|
| |
| @pytest.mark.xfail(reason="RuntimeError: Jacobian mismatch for output 0 with respect to input 0,") |
| def test_gradcheck(self, device): |
| batch_size, channels, ps = 1, 2, 13 |
| patches = torch.rand(batch_size, channels, ps, ps).to(device) |
| patches = utils.tensor_to_gradcheck_var(patches) |
|
|
| def emb_grads_describe(patches, ps=13): |
| emb_grads = EmbedGradients(patch_size=ps, relative=True).double() |
| emb_grads.to(device) |
| return emb_grads(patches.double()) |
|
|
| assert gradcheck(emb_grads_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4) |
|
|
| @pytest.mark.jit |
| def test_jit(self, device, dtype): |
| B, C, H, W = 2, 2, 13, 13 |
| patches = torch.rand(B, C, H, W, device=device, dtype=dtype) |
| model = EmbedGradients(patch_size=W, relative=True).to(patches.device, patches.dtype).eval() |
| model_jit = torch.jit.script( |
| EmbedGradients(patch_size=W, relative=True).to(patches.device, patches.dtype).eval() |
| ) |
| assert_close(model(patches), model_jit(patches)) |
|
|
|
|
| @pytest.mark.parametrize("kernel_type,d,ps", [('cart', 9, 9), ('polar', 25, 9), ('cart', 9, 16), ('polar', 25, 16)]) |
| def test_spatial_kernel_embedding(kernel_type, ps, d): |
| grids = get_grid_dict(ps) |
| spatial_kernel = spatial_kernel_embedding(kernel_type, grids) |
| assert spatial_kernel.shape == (d, ps, ps) |
|
|
|
|
| class TestExplicitSpacialEncoding: |
| @pytest.mark.parametrize( |
| "kernel_type,ps,in_dims", [('cart', 9, 3), ('polar', 9, 3), ('cart', 13, 7), ('polar', 13, 7)] |
| ) |
| def test_shape(self, kernel_type, ps, in_dims, device): |
| inp = torch.ones(1, in_dims, ps, ps).to(device) |
| ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=ps, in_dims=in_dims).to(device) |
| out = ese(inp) |
| d_ = 9 if kernel_type == 'cart' else 25 |
| assert out.shape == (1, d_ * in_dims) |
|
|
| @pytest.mark.parametrize( |
| "kernel_type,bs", [('cart', 1), ('cart', 5), ('cart', 13), ('polar', 1), ('polar', 5), ('polar', 13)] |
| ) |
| def test_batch_shape(self, kernel_type, bs, device): |
| inp = torch.ones(bs, 7, 15, 15).to(device) |
| ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=15, in_dims=7).to(device) |
| out = ese(inp) |
| d_ = 9 if kernel_type == 'cart' else 25 |
| assert out.shape == (bs, d_ * 7) |
|
|
| @pytest.mark.parametrize("kernel_type", ['cart', 'polar']) |
| def test_print(self, kernel_type, device): |
| ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=15, in_dims=7).to(device) |
| ese.__repr__() |
|
|
| def test_toy(self, device): |
| inp = torch.ones(1, 2, 6, 6).to(device).float() |
| inp[0, 0, :, :] = 0 |
| cart_ese = ExplicitSpacialEncoding(kernel_type='cart', fmap_size=6, in_dims=2).to(device) |
| out = cart_ese(inp) |
| out_part = out[:, :9] |
| expected = torch.zeros_like(out_part).to(device) |
| assert_close(out_part, expected, atol=1e-3, rtol=1e-3) |
|
|
| polar_ese = ExplicitSpacialEncoding(kernel_type='polar', fmap_size=6, in_dims=2).to(device) |
| out = polar_ese(inp) |
| out_part = out[:, :25] |
| expected = torch.zeros_like(out_part).to(device) |
| assert_close(out_part, expected, atol=1e-3, rtol=1e-3) |
|
|
| @pytest.mark.parametrize("kernel_type", ['cart', 'polar']) |
| def test_gradcheck(self, kernel_type, device): |
| batch_size, channels, ps = 1, 2, 13 |
| patches = torch.rand(batch_size, channels, ps, ps).to(device) |
| patches = utils.tensor_to_gradcheck_var(patches) |
|
|
| def explicit_spatial_describe(patches, ps=13): |
| ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=ps, in_dims=2) |
| ese.to(device) |
| return ese(patches) |
|
|
| assert gradcheck(explicit_spatial_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4) |
|
|
| @pytest.mark.jit |
| def test_jit(self, device, dtype): |
| B, C, H, W = 2, 2, 13, 13 |
| patches = torch.rand(B, C, H, W, device=device, dtype=dtype) |
| model = ( |
| ExplicitSpacialEncoding(kernel_type='cart', fmap_size=W, in_dims=2).to(patches.device, patches.dtype).eval() |
| ) |
| model_jit = torch.jit.script( |
| ExplicitSpacialEncoding(kernel_type='cart', fmap_size=W, in_dims=2).to(patches.device, patches.dtype).eval() |
| ) |
| assert_close(model(patches), model_jit(patches)) |
|
|
|
|
| class TestWhitening: |
| @pytest.mark.parametrize( |
| "kernel_type,xform,output_dims", |
| [ |
| ('cart', None, 3), |
| ('polar', None, 3), |
| ('cart', 'lw', 7), |
| ('polar', 'lw', 7), |
| ('cart', 'pca', 9), |
| ('polar', 'pca', 9), |
| ], |
| ) |
| def test_shape(self, kernel_type, xform, output_dims, device): |
| in_dims = 63 if kernel_type == 'cart' else 175 |
| wh = Whitening(xform=xform, whitening_model=None, in_dims=in_dims, output_dims=output_dims).to(device) |
| inp = torch.ones(1, in_dims).to(device) |
| out = wh(inp) |
| assert out.shape == (1, output_dims) |
|
|
| @pytest.mark.parametrize("bs", [1, 3, 7]) |
| def test_batch_shape(self, bs, device): |
| wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=128).to(device) |
| inp = torch.ones(bs, 175).to(device) |
| out = wh(inp) |
| assert out.shape == (bs, 128) |
|
|
| def test_print(self, device): |
| wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=128).to(device) |
| wh.__repr__() |
|
|
| def test_toy(self, device): |
| wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=175).to(device) |
| inp = torch.ones(1, 175).to(device).float() |
| out = wh(inp) |
| expected = torch.ones_like(inp).to(device) * 0.0756 |
| assert_close(out, expected, atol=1e-3, rtol=1e-3) |
|
|
| def test_gradcheck(self, device): |
| batch_size, in_dims = 1, 175 |
| patches = torch.rand(batch_size, in_dims).to(device) |
| patches = utils.tensor_to_gradcheck_var(patches) |
|
|
| def whitening_describe(patches, in_dims=175): |
| wh = Whitening(xform='lw', whitening_model=None, in_dims=in_dims).double() |
| wh.to(device) |
| return wh(patches.double()) |
|
|
| assert gradcheck(whitening_describe, (patches, in_dims), raise_exception=True, nondet_tol=1e-4) |
|
|
| @pytest.mark.jit |
| def test_jit(self, device, dtype): |
| batch_size, in_dims = 1, 175 |
| patches = torch.rand(batch_size, in_dims).to(device) |
| model = Whitening(xform='lw', whitening_model=None, in_dims=in_dims).to(patches.device, patches.dtype).eval() |
| model_jit = torch.jit.script( |
| Whitening(xform='lw', whitening_model=None, in_dims=in_dims).to(patches.device, patches.dtype).eval() |
| ) |
| assert_close(model(patches), model_jit(patches)) |
|
|
|
|
| class TestMKDDescriptor: |
| dims = {'cart': 63, 'polar': 175, 'concat': 238} |
|
|
| @pytest.mark.parametrize( |
| "ps,kernel_type", [(9, 'concat'), (9, 'cart'), (9, 'polar'), (32, 'concat'), (32, 'cart'), (32, 'polar')] |
| ) |
| def test_shape(self, ps, kernel_type, device): |
| mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=None).to(device) |
| inp = torch.ones(1, 1, ps, ps).to(device) |
| out = mkd(inp) |
| assert out.shape == (1, self.dims[kernel_type]) |
|
|
| @pytest.mark.parametrize( |
| "ps,kernel_type,whitening", |
| [ |
| (9, 'concat', 'lw'), |
| (9, 'cart', 'lw'), |
| (9, 'polar', 'lw'), |
| (9, 'concat', 'pcawt'), |
| (9, 'cart', 'pcawt'), |
| (9, 'polar', 'pcawt'), |
| ], |
| ) |
| def test_whitened_shape(self, ps, kernel_type, whitening, device): |
| mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=whitening).to(device) |
| inp = torch.ones(1, 1, ps, ps).to(device) |
| out = mkd(inp) |
| output_dims = min(self.dims[kernel_type], 128) |
| assert out.shape == (1, output_dims) |
|
|
| @pytest.mark.parametrize("bs", [1, 3, 7]) |
| def test_batch_shape(self, bs, device): |
| mkd = MKDDescriptor(patch_size=19, kernel_type='concat', whitening=None).to(device) |
| inp = torch.ones(bs, 1, 19, 19).to(device) |
| out = mkd(inp) |
| assert out.shape == (bs, 238) |
|
|
| def test_print(self, device): |
| mkd = MKDDescriptor(patch_size=32, whitening='lw', training_set='liberty', output_dims=128).to(device) |
| mkd.__repr__() |
|
|
| def test_toy(self, device): |
| inp = torch.ones(1, 1, 6, 6).to(device).float() |
| inp[0, 0, :, :] = 0 |
| mkd = MKDDescriptor(patch_size=6, kernel_type='concat', whitening=None).to(device) |
| out = mkd(inp) |
| out_part = out[0, -28:] |
| expected = torch.zeros_like(out_part).to(device) |
| assert_close(out_part, expected, atol=1e-3, rtol=1e-3) |
|
|
| @pytest.mark.skip("Just because") |
| @pytest.mark.parametrize("whitening", [None, 'lw', 'pca']) |
| def test_gradcheck(self, whitening, device): |
| batch_size, channels, ps = 1, 1, 19 |
| patches = torch.rand(batch_size, channels, ps, ps).to(device) |
| patches = utils.tensor_to_gradcheck_var(patches) |
|
|
| def mkd_describe(patches, patch_size=19): |
| mkd = MKDDescriptor(patch_size=patch_size, kernel_type='concat', whitening=whitening).double() |
| mkd.to(device) |
| return mkd(patches.double()) |
|
|
| assert gradcheck(mkd_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4) |
|
|
| @pytest.mark.skip("neither dict, nor nn.ModuleDict works") |
| @pytest.mark.jit |
| def test_jit(self, device, dtype): |
| batch_size, channels, ps = 1, 1, 19 |
| patches = torch.rand(batch_size, channels, ps, ps).to(device) |
| kt = 'concat' |
| wt = 'lw' |
| model = MKDDescriptor(patch_size=ps, kernel_type=kt, whitening=wt).to(patches.device, patches.dtype).eval() |
| model_jit = torch.jit.script( |
| MKDDescriptor(patch_size=ps, kernel_type=kt, whitening=wt).to(patches.device, patches.dtype).eval() |
| ) |
| assert_close(model(patches), model_jit(patches)) |
|
|
|
|
| class TestSimpleKD: |
| dims = {'cart': 63, 'polar': 175} |
|
|
| @pytest.mark.parametrize("ps,kernel_type", [(9, 'cart'), (9, 'polar'), (32, 'cart'), (32, 'polar')]) |
| def test_shape(self, ps, kernel_type, device): |
| skd = SimpleKD(patch_size=ps, kernel_type=kernel_type).to(device) |
| inp = torch.ones(1, 1, ps, ps).to(device) |
| out = skd(inp) |
| assert out.shape == (1, min(128, self.dims[kernel_type])) |
|
|
| @pytest.mark.parametrize("bs", [1, 3, 7]) |
| def test_batch_shape(self, bs, device): |
| skd = SimpleKD(patch_size=19, kernel_type='polar').to(device) |
| inp = torch.ones(bs, 1, 19, 19).to(device) |
| out = skd(inp) |
| assert out.shape == (bs, 128) |
|
|
| def test_print(self, device): |
| skd = SimpleKD(patch_size=19, kernel_type='polar').to(device) |
| skd.__repr__() |
|
|
| def test_gradcheck(self, device): |
| batch_size, channels, ps = 1, 1, 19 |
| patches = torch.rand(batch_size, channels, ps, ps).to(device) |
| patches = utils.tensor_to_gradcheck_var(patches) |
|
|
| def skd_describe(patches, patch_size=19): |
| skd = SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').double() |
| skd.to(device) |
| return skd(patches.double()) |
|
|
| assert gradcheck(skd_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4) |
|
|
| @pytest.mark.jit |
| def test_jit(self, device, dtype): |
| batch_size, channels, ps = 1, 1, 19 |
| patches = torch.rand(batch_size, channels, ps, ps).to(device) |
| model = SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').to(patches.device, patches.dtype).eval() |
| model_jit = torch.jit.script( |
| SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').to(patches.device, patches.dtype).eval() |
| ) |
| assert_close(model(patches), model_jit(patches)) |
|
|