| import unittest |
|
|
| from diffusers.pipelines.pipeline_utils import is_safetensors_compatible |
|
|
|
|
| class IsSafetensorsCompatibleTests(unittest.TestCase): |
| def test_all_is_compatible(self): |
| filenames = [ |
| "safety_checker/pytorch_model.bin", |
| "safety_checker/model.safetensors", |
| "vae/diffusion_pytorch_model.bin", |
| "vae/diffusion_pytorch_model.safetensors", |
| "text_encoder/pytorch_model.bin", |
| "text_encoder/model.safetensors", |
| "unet/diffusion_pytorch_model.bin", |
| "unet/diffusion_pytorch_model.safetensors", |
| ] |
| self.assertTrue(is_safetensors_compatible(filenames)) |
|
|
| def test_diffusers_model_is_compatible(self): |
| filenames = [ |
| "unet/diffusion_pytorch_model.bin", |
| "unet/diffusion_pytorch_model.safetensors", |
| ] |
| self.assertTrue(is_safetensors_compatible(filenames)) |
|
|
| def test_diffusers_model_is_not_compatible(self): |
| filenames = [ |
| "safety_checker/pytorch_model.bin", |
| "safety_checker/model.safetensors", |
| "vae/diffusion_pytorch_model.bin", |
| "vae/diffusion_pytorch_model.safetensors", |
| "text_encoder/pytorch_model.bin", |
| "text_encoder/model.safetensors", |
| "unet/diffusion_pytorch_model.bin", |
| |
| ] |
| self.assertFalse(is_safetensors_compatible(filenames)) |
|
|
| def test_transformer_model_is_compatible(self): |
| filenames = [ |
| "text_encoder/pytorch_model.bin", |
| "text_encoder/model.safetensors", |
| ] |
| self.assertTrue(is_safetensors_compatible(filenames)) |
|
|
| def test_transformer_model_is_not_compatible(self): |
| filenames = [ |
| "safety_checker/pytorch_model.bin", |
| "safety_checker/model.safetensors", |
| "vae/diffusion_pytorch_model.bin", |
| "vae/diffusion_pytorch_model.safetensors", |
| "text_encoder/pytorch_model.bin", |
| |
| "unet/diffusion_pytorch_model.bin", |
| "unet/diffusion_pytorch_model.safetensors", |
| ] |
| self.assertFalse(is_safetensors_compatible(filenames)) |
|
|
| def test_all_is_compatible_variant(self): |
| filenames = [ |
| "safety_checker/pytorch_model.fp16.bin", |
| "safety_checker/model.fp16.safetensors", |
| "vae/diffusion_pytorch_model.fp16.bin", |
| "vae/diffusion_pytorch_model.fp16.safetensors", |
| "text_encoder/pytorch_model.fp16.bin", |
| "text_encoder/model.fp16.safetensors", |
| "unet/diffusion_pytorch_model.fp16.bin", |
| "unet/diffusion_pytorch_model.fp16.safetensors", |
| ] |
| variant = "fp16" |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
|
|
| def test_diffusers_model_is_compatible_variant(self): |
| filenames = [ |
| "unet/diffusion_pytorch_model.fp16.bin", |
| "unet/diffusion_pytorch_model.fp16.safetensors", |
| ] |
| variant = "fp16" |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
|
|
| def test_diffusers_model_is_compatible_variant_partial(self): |
| |
| filenames = [ |
| "unet/diffusion_pytorch_model.bin", |
| "unet/diffusion_pytorch_model.safetensors", |
| ] |
| variant = "fp16" |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
|
|
| def test_diffusers_model_is_not_compatible_variant(self): |
| filenames = [ |
| "safety_checker/pytorch_model.fp16.bin", |
| "safety_checker/model.fp16.safetensors", |
| "vae/diffusion_pytorch_model.fp16.bin", |
| "vae/diffusion_pytorch_model.fp16.safetensors", |
| "text_encoder/pytorch_model.fp16.bin", |
| "text_encoder/model.fp16.safetensors", |
| "unet/diffusion_pytorch_model.fp16.bin", |
| |
| ] |
| variant = "fp16" |
| self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) |
|
|
| def test_transformer_model_is_compatible_variant(self): |
| filenames = [ |
| "text_encoder/pytorch_model.fp16.bin", |
| "text_encoder/model.fp16.safetensors", |
| ] |
| variant = "fp16" |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
|
|
| def test_transformer_model_is_compatible_variant_partial(self): |
| |
| filenames = [ |
| "text_encoder/pytorch_model.bin", |
| "text_encoder/model.safetensors", |
| ] |
| variant = "fp16" |
| self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
|
|
| def test_transformer_model_is_not_compatible_variant(self): |
| filenames = [ |
| "safety_checker/pytorch_model.fp16.bin", |
| "safety_checker/model.fp16.safetensors", |
| "vae/diffusion_pytorch_model.fp16.bin", |
| "vae/diffusion_pytorch_model.fp16.safetensors", |
| "text_encoder/pytorch_model.fp16.bin", |
| |
| "unet/diffusion_pytorch_model.fp16.bin", |
| "unet/diffusion_pytorch_model.fp16.safetensors", |
| ] |
| variant = "fp16" |
| self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) |
|
|