File size: 2,916 Bytes
3fce85f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | #!/usr/bin/env python3
#Configuration for the MNiST-IMG-390k
from __future__ import annotations
from typing import Iterable, Tuple
from transformers import PretrainedConfig
class DigitDiffusionConfig(PretrainedConfig):
model_type = "digit_diffusion"
def __init__(
self,
image_size: int = 32,
in_channels: int = 1,
out_channels: int = 1,
num_classes: int = 10,
block_out_channels: Iterable[int] = (12, 16, 20),
layers_per_block: int = 8,
norm_num_groups: int = 4,
cross_attention_dim: int = 8,
class_embed_type: str | None = None,
sample_size: int | None = None,
**kwargs,
) -> None:
image_size = int(image_size)
sample_size = int(sample_size) if sample_size is not None else image_size
block_out_channels = tuple(int(v) for v in block_out_channels)
if not block_out_channels:
raise ValueError("block_out_channels must contain at least one entry.")
if any(v <= 0 for v in block_out_channels):
raise ValueError("block_out_channels must contain only positive integers.")
if image_size <= 0:
raise ValueError("image_size must be a positive integer.")
if sample_size <= 0:
raise ValueError("sample_size must be a positive integer.")
if in_channels <= 0 or out_channels <= 0:
raise ValueError("in_channels and out_channels must be positive integers.")
if num_classes <= 0:
raise ValueError("num_classes must be a positive integer.")
if layers_per_block <= 0:
raise ValueError("layers_per_block must be a positive integer.")
if norm_num_groups <= 0:
raise ValueError("norm_num_groups must be a positive integer.")
if cross_attention_dim <= 0:
raise ValueError("cross_attention_dim must be a positive integer.")
self.image_size = image_size
self.sample_size = sample_size
self.in_channels = int(in_channels)
self.out_channels = int(out_channels)
self.num_classes = int(num_classes)
self.block_out_channels = block_out_channels
self.layers_per_block = int(layers_per_block)
self.norm_num_groups = int(norm_num_groups)
self.cross_attention_dim = int(cross_attention_dim)
self.class_embed_type = class_embed_type
# Handy for HF model pages and AutoClass loading.
kwargs.setdefault("architectures", ["DigitDiffusionModel"])
super().__init__(**kwargs)
@property
def num_blocks(self) -> int:
return len(self.block_out_channels)
def to_dict(self):
data = super().to_dict()
# Keep the serialized values compact and JSON-friendly.
data["block_out_channels"] = list(self.block_out_channels)
return data
DigitDiffusionConfig.register_for_auto_class()
|