| import logging |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
|
|
| import kornia |
| from kornia.testing import assert_close |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TestIntegrationSoftArgmax2d: |
| |
| lr = 1e-3 |
| num_iterations = 500 |
|
|
| |
| height = 240 |
| width = 320 |
|
|
| def generate_sample(self, base_target, std_val=1.0): |
| """Generate a random sample around the given point. |
| |
| The standard deviation is in pixel. |
| """ |
| noise = std_val * torch.rand_like(base_target) |
| return base_target + noise |
|
|
| def test_regression_2d(self, device): |
| |
| params = nn.Parameter(torch.rand(1, 1, self.height, self.width).to(device)) |
|
|
| |
| target = torch.zeros(1, 1, 2).to(device) |
| target[..., 0] = self.width / 2 |
| target[..., 1] = self.height / 2 |
|
|
| |
| optimizer = optim.Adam([params], lr=self.lr) |
|
|
| |
| criterion = nn.MSELoss() |
|
|
| |
| soft_argmax2d = kornia.geometry.SpatialSoftArgmax2d(normalized_coordinates=False) |
|
|
| |
| temperature = (self.height * self.width) ** (0.5) |
|
|
| for _ in range(self.num_iterations): |
| x = params |
| sample = self.generate_sample(target).to(device) |
| pred = soft_argmax2d(temperature * x) |
| loss = criterion(pred, sample) |
| logger.debug(f"Loss: {loss.item():.3f} Pred: {pred}") |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| assert_close(pred[..., 0], target[..., 0], rtol=1e-2, atol=1e-2) |
| assert_close(pred[..., 1], target[..., 1], rtol=1e-2, atol=1e-2) |
|
|