| import torch |
|
|
| from wireseghr.model import WireSegHR |
|
|
|
|
| def test_wireseghr_forward_shapes(): |
| |
| model = WireSegHR(backbone="mit_b2", in_channels=3, pretrained=False) |
|
|
| x = torch.randn(1, 3, 64, 64) |
| logits_coarse, cond = model.forward_coarse(x) |
| assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2 |
| assert cond.shape[0] == 1 and cond.shape[1] == 1 |
| |
| assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16 |
| assert cond.shape[2] == 16 and cond.shape[3] == 16 |
|
|
| logits_fine = model.forward_fine(x) |
| assert logits_fine.shape == logits_coarse.shape |
|
|
|
|
| def test_wireseghr_forward_shapes_resnet50(): |
| |
| model = WireSegHR(backbone="resnet50", in_channels=3, pretrained=False) |
|
|
| x = torch.randn(1, 3, 64, 64) |
| logits_coarse, cond = model.forward_coarse(x) |
| assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2 |
| assert cond.shape[0] == 1 and cond.shape[1] == 1 |
| |
| assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16 |
| assert cond.shape[2] == 16 and cond.shape[3] == 16 |
|
|
| logits_fine = model.forward_fine(x) |
| assert logits_fine.shape == logits_coarse.shape |
|
|