| from transformers import PretrainedConfig |
| from typing import List |
|
|
|
|
| class ViTMixConfig(PretrainedConfig): |
| model_type = "VitMix" |
|
|
| def __init__( |
| self, |
| image_size = 28, |
| patch_size = 14, |
| num_classes = 10, |
| dim = 1024, |
| depth = 6, |
| heads = 16, |
| mlp_dim = 2048, |
| num_experts = 12, |
| **kwargs |
| ): |
| if image_size % patch_size != 0: |
| print(f"image size must be half patch size! img_size: {image_size} | patch_size{patch_size}") |
|
|
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.num_classes = num_classes |
| self.dim = dim |
| self.depth = depth |
| self.heads = heads |
| self.mlp_dim = mlp_dim |
| self.num_experts = num_experts |
| super().__init__(**kwargs) |