File size: 2,916 Bytes
3fce85f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
#Configuration for the MNiST-IMG-390k

from __future__ import annotations

from typing import Iterable, Tuple

from transformers import PretrainedConfig


class DigitDiffusionConfig(PretrainedConfig):

    model_type = "digit_diffusion"

    def __init__(
        self,
        image_size: int = 32,
        in_channels: int = 1,
        out_channels: int = 1,
        num_classes: int = 10,
        block_out_channels: Iterable[int] = (12, 16, 20),
        layers_per_block: int = 8,
        norm_num_groups: int = 4,
        cross_attention_dim: int = 8,
        class_embed_type: str | None = None,
        sample_size: int | None = None,
        **kwargs,
    ) -> None:
        image_size = int(image_size)
        sample_size = int(sample_size) if sample_size is not None else image_size

        block_out_channels = tuple(int(v) for v in block_out_channels)
        if not block_out_channels:
            raise ValueError("block_out_channels must contain at least one entry.")
        if any(v <= 0 for v in block_out_channels):
            raise ValueError("block_out_channels must contain only positive integers.")

        if image_size <= 0:
            raise ValueError("image_size must be a positive integer.")
        if sample_size <= 0:
            raise ValueError("sample_size must be a positive integer.")
        if in_channels <= 0 or out_channels <= 0:
            raise ValueError("in_channels and out_channels must be positive integers.")
        if num_classes <= 0:
            raise ValueError("num_classes must be a positive integer.")
        if layers_per_block <= 0:
            raise ValueError("layers_per_block must be a positive integer.")
        if norm_num_groups <= 0:
            raise ValueError("norm_num_groups must be a positive integer.")
        if cross_attention_dim <= 0:
            raise ValueError("cross_attention_dim must be a positive integer.")

        self.image_size = image_size
        self.sample_size = sample_size
        self.in_channels = int(in_channels)
        self.out_channels = int(out_channels)
        self.num_classes = int(num_classes)
        self.block_out_channels = block_out_channels
        self.layers_per_block = int(layers_per_block)
        self.norm_num_groups = int(norm_num_groups)
        self.cross_attention_dim = int(cross_attention_dim)
        self.class_embed_type = class_embed_type

        # Handy for HF model pages and AutoClass loading.
        kwargs.setdefault("architectures", ["DigitDiffusionModel"])

        super().__init__(**kwargs)

    @property
    def num_blocks(self) -> int:
        return len(self.block_out_channels)

    def to_dict(self):
        data = super().to_dict()
        # Keep the serialized values compact and JSON-friendly.
        data["block_out_channels"] = list(self.block_out_channels)
        return data


DigitDiffusionConfig.register_for_auto_class()