| import numpy as np |
| import pytest |
| import torch |
|
|
| from mmseg.models.losses import Accuracy, reduce_loss, weight_reduce_loss |
|
|
|
|
| def test_utils(): |
| loss = torch.rand(1, 3, 4, 4) |
| weight = torch.zeros(1, 3, 4, 4) |
| weight[:, :, :2, :2] = 1 |
|
|
| |
| reduced = reduce_loss(loss, 'none') |
| assert reduced is loss |
|
|
| reduced = reduce_loss(loss, 'mean') |
| np.testing.assert_almost_equal(reduced.numpy(), loss.mean()) |
|
|
| reduced = reduce_loss(loss, 'sum') |
| np.testing.assert_almost_equal(reduced.numpy(), loss.sum()) |
|
|
| |
| reduced = weight_reduce_loss(loss, weight=None, reduction='none') |
| assert reduced is loss |
|
|
| reduced = weight_reduce_loss(loss, weight=weight, reduction='mean') |
| target = (loss * weight).mean() |
| np.testing.assert_almost_equal(reduced.numpy(), target) |
|
|
| reduced = weight_reduce_loss(loss, weight=weight, reduction='sum') |
| np.testing.assert_almost_equal(reduced.numpy(), (loss * weight).sum()) |
|
|
| with pytest.raises(AssertionError): |
| weight_wrong = weight[0, 0, ...] |
| weight_reduce_loss(loss, weight=weight_wrong, reduction='mean') |
|
|
| with pytest.raises(AssertionError): |
| weight_wrong = weight[:, 0:2, ...] |
| weight_reduce_loss(loss, weight=weight_wrong, reduction='mean') |
|
|
|
|
| def test_ce_loss(): |
| from mmseg.models import build_loss |
|
|
| |
| with pytest.raises(AssertionError): |
| loss_cfg = dict( |
| type='CrossEntropyLoss', |
| use_mask=True, |
| use_sigmoid=True, |
| loss_weight=1.0) |
| build_loss(loss_cfg) |
|
|
| |
| loss_cls_cfg = dict( |
| type='CrossEntropyLoss', |
| use_sigmoid=False, |
| class_weight=[0.8, 0.2], |
| loss_weight=1.0) |
| loss_cls = build_loss(loss_cls_cfg) |
| fake_pred = torch.Tensor([[100, -100]]) |
| fake_label = torch.Tensor([1]).long() |
| assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) |
|
|
| loss_cls_cfg = dict( |
| type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) |
| loss_cls = build_loss(loss_cls_cfg) |
| assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) |
|
|
| loss_cls_cfg = dict( |
| type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0) |
| loss_cls = build_loss(loss_cls_cfg) |
| assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.)) |
|
|
| fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) |
| fake_label = torch.ones(2, 8, 8).long() |
| assert torch.allclose( |
| loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4) |
| fake_label[:, 0, 0] = 255 |
| assert torch.allclose( |
| loss_cls(fake_pred, fake_label, ignore_index=255), |
| torch.tensor(0.9354), |
| atol=1e-4) |
|
|
| |
|
|
|
|
| def test_accuracy(): |
| |
| pred = torch.empty(0, 4) |
| label = torch.empty(0) |
| accuracy = Accuracy(topk=1) |
| acc = accuracy(pred, label) |
| assert acc.item() == 0 |
|
|
| pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6], |
| [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1], |
| [0.0, 0.0, 0.99, 0]]) |
| |
| true_label = torch.Tensor([2, 3, 0, 1, 2]).long() |
| accuracy = Accuracy(topk=1) |
| acc = accuracy(pred, true_label) |
| assert acc.item() == 100 |
|
|
| |
| true_label = torch.Tensor([2, 3, 0, 1, 2]).long() |
| accuracy = Accuracy(topk=1, thresh=0.8) |
| acc = accuracy(pred, true_label) |
| assert acc.item() == 40 |
|
|
| |
| accuracy = Accuracy(topk=2) |
| label = torch.Tensor([3, 2, 0, 0, 2]).long() |
| acc = accuracy(pred, label) |
| assert acc.item() == 100 |
|
|
| |
| accuracy = Accuracy(topk=(1, 2)) |
| true_label = torch.Tensor([2, 3, 0, 1, 2]).long() |
| acc = accuracy(pred, true_label) |
| for a in acc: |
| assert a.item() == 100 |
|
|
| |
| with pytest.raises(AssertionError): |
| accuracy = Accuracy(topk=5) |
| accuracy(pred, true_label) |
|
|
| |
| with pytest.raises(AssertionError): |
| accuracy = Accuracy(topk='wrong type') |
| accuracy(pred, true_label) |
|
|
| |
| with pytest.raises(AssertionError): |
| label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() |
| accuracy = Accuracy() |
| accuracy(pred, label) |
|
|
| |
| with pytest.raises(AssertionError): |
| accuracy = Accuracy() |
| accuracy(pred[:, :, None], true_label) |
|
|
|
|
| def test_lovasz_loss(): |
| from mmseg.models import build_loss |
|
|
| |
| with pytest.raises(AssertionError): |
| loss_cfg = dict( |
| type='LovaszLoss', |
| loss_type='Binary', |
| reduction='none', |
| loss_weight=1.0) |
| build_loss(loss_cfg) |
|
|
| |
| with pytest.raises(AssertionError): |
| loss_cfg = dict(type='LovaszLoss', loss_type='multi_class') |
| build_loss(loss_cfg) |
|
|
| |
| loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0) |
| lovasz_loss = build_loss(loss_cfg) |
| logits = torch.rand(1, 3, 4, 4) |
| labels = (torch.rand(1, 4, 4) * 2).long() |
| lovasz_loss(logits, labels) |
|
|
| |
| loss_cfg = dict( |
| type='LovaszLoss', |
| per_image=True, |
| reduction='mean', |
| class_weight=[1.0, 2.0, 3.0], |
| loss_weight=1.0) |
| lovasz_loss = build_loss(loss_cfg) |
| logits = torch.rand(1, 3, 4, 4) |
| labels = (torch.rand(1, 4, 4) * 2).long() |
| lovasz_loss(logits, labels, ignore_index=None) |
|
|
| |
| loss_cfg = dict( |
| type='LovaszLoss', |
| loss_type='binary', |
| reduction='none', |
| loss_weight=1.0) |
| lovasz_loss = build_loss(loss_cfg) |
| logits = torch.rand(2, 4, 4) |
| labels = (torch.rand(2, 4, 4)).long() |
| lovasz_loss(logits, labels) |
|
|
| |
| loss_cfg = dict( |
| type='LovaszLoss', |
| loss_type='binary', |
| per_image=True, |
| reduction='mean', |
| loss_weight=1.0) |
| lovasz_loss = build_loss(loss_cfg) |
| logits = torch.rand(2, 4, 4) |
| labels = (torch.rand(2, 4, 4)).long() |
| lovasz_loss(logits, labels, ignore_index=None) |
|
|