| import numpy as np |
| import pytest |
| import torch |
| from PIL import Image |
|
|
| from pytorch_fid import fid_score, inception |
|
|
|
|
| @pytest.fixture |
| def device(): |
| return torch.device("cpu") |
|
|
|
|
| def test_calculate_fid_given_statistics(mocker, tmp_path, device): |
| dim = 2048 |
| m1, m2 = np.zeros((dim,)), np.ones((dim,)) |
| sigma = np.eye(dim) |
|
|
| def dummy_statistics(path, model, batch_size, dims, device, num_workers): |
| if path.endswith("1"): |
| return m1, sigma |
| elif path.endswith("2"): |
| return m2, sigma |
| else: |
| raise ValueError |
|
|
| mocker.patch( |
| "pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics |
| ) |
|
|
| dir_names = ["1", "2"] |
| paths = [] |
| for name in dir_names: |
| path = tmp_path / name |
| path.mkdir() |
| paths.append(str(path)) |
|
|
| fid_value = fid_score.calculate_fid_given_paths( |
| paths, batch_size=dim, device=device, dims=dim, num_workers=0 |
| ) |
|
|
| |
| assert fid_value == np.sum((m1 - m2) ** 2) |
|
|
|
|
| def test_compute_statistics_of_path(mocker, tmp_path, device): |
| model = mocker.MagicMock(inception.InceptionV3)() |
| model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)] |
|
|
| size = (4, 4, 3) |
| arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)] |
| images = [(arr * 255).astype(np.uint8) for arr in arrays] |
|
|
| paths = [] |
| for idx, image in enumerate(images): |
| paths.append(str(tmp_path / "{}.png".format(idx))) |
| Image.fromarray(image, mode="RGB").save(paths[-1]) |
|
|
| stats = fid_score.compute_statistics_of_path( |
| str(tmp_path), |
| model, |
| batch_size=len(images), |
| dims=3, |
| device=device, |
| num_workers=0, |
| ) |
|
|
| assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3) |
| assert np.allclose(stats[1], np.ones((3, 3)) * 0.25) |
|
|
|
|
| def test_compute_statistics_of_path_from_file(mocker, tmp_path, device): |
| model = mocker.MagicMock(inception.InceptionV3)() |
|
|
| mu = np.random.randn(5) |
| sigma = np.random.randn(5, 5) |
|
|
| path = tmp_path / "stats.npz" |
| with path.open("wb") as f: |
| np.savez(f, mu=mu, sigma=sigma) |
|
|
| stats = fid_score.compute_statistics_of_path( |
| str(path), model, batch_size=1, dims=5, device=device, num_workers=0 |
| ) |
|
|
| assert np.allclose(stats[0], mu) |
| assert np.allclose(stats[1], sigma) |
|
|
|
|
| def test_image_types(tmp_path): |
| in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255 |
| in_image = Image.fromarray(in_arr, mode="RGB") |
|
|
| paths = [] |
| for ext in fid_score.IMAGE_EXTENSIONS: |
| paths.append(str(tmp_path / "img.{}".format(ext))) |
| in_image.save(paths[-1]) |
|
|
| dataset = fid_score.ImagePathDataset(paths) |
|
|
| for img in dataset: |
| assert np.allclose(np.array(img), in_arr) |
|
|