Harley-ml commited on
Commit
3fce85f
·
verified ·
1 Parent(s): 47fa26c

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration.py +80 -0
  2. modeling.py +177 -0
configuration.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #Configuration for the MNiST-IMG-390k
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Iterable, Tuple
7
+
8
+ from transformers import PretrainedConfig
9
+
10
+
11
+ class DigitDiffusionConfig(PretrainedConfig):
12
+
13
+ model_type = "digit_diffusion"
14
+
15
+ def __init__(
16
+ self,
17
+ image_size: int = 32,
18
+ in_channels: int = 1,
19
+ out_channels: int = 1,
20
+ num_classes: int = 10,
21
+ block_out_channels: Iterable[int] = (12, 16, 20),
22
+ layers_per_block: int = 8,
23
+ norm_num_groups: int = 4,
24
+ cross_attention_dim: int = 8,
25
+ class_embed_type: str | None = None,
26
+ sample_size: int | None = None,
27
+ **kwargs,
28
+ ) -> None:
29
+ image_size = int(image_size)
30
+ sample_size = int(sample_size) if sample_size is not None else image_size
31
+
32
+ block_out_channels = tuple(int(v) for v in block_out_channels)
33
+ if not block_out_channels:
34
+ raise ValueError("block_out_channels must contain at least one entry.")
35
+ if any(v <= 0 for v in block_out_channels):
36
+ raise ValueError("block_out_channels must contain only positive integers.")
37
+
38
+ if image_size <= 0:
39
+ raise ValueError("image_size must be a positive integer.")
40
+ if sample_size <= 0:
41
+ raise ValueError("sample_size must be a positive integer.")
42
+ if in_channels <= 0 or out_channels <= 0:
43
+ raise ValueError("in_channels and out_channels must be positive integers.")
44
+ if num_classes <= 0:
45
+ raise ValueError("num_classes must be a positive integer.")
46
+ if layers_per_block <= 0:
47
+ raise ValueError("layers_per_block must be a positive integer.")
48
+ if norm_num_groups <= 0:
49
+ raise ValueError("norm_num_groups must be a positive integer.")
50
+ if cross_attention_dim <= 0:
51
+ raise ValueError("cross_attention_dim must be a positive integer.")
52
+
53
+ self.image_size = image_size
54
+ self.sample_size = sample_size
55
+ self.in_channels = int(in_channels)
56
+ self.out_channels = int(out_channels)
57
+ self.num_classes = int(num_classes)
58
+ self.block_out_channels = block_out_channels
59
+ self.layers_per_block = int(layers_per_block)
60
+ self.norm_num_groups = int(norm_num_groups)
61
+ self.cross_attention_dim = int(cross_attention_dim)
62
+ self.class_embed_type = class_embed_type
63
+
64
+ # Handy for HF model pages and AutoClass loading.
65
+ kwargs.setdefault("architectures", ["DigitDiffusionModel"])
66
+
67
+ super().__init__(**kwargs)
68
+
69
+ @property
70
+ def num_blocks(self) -> int:
71
+ return len(self.block_out_channels)
72
+
73
+ def to_dict(self):
74
+ data = super().to_dict()
75
+ # Keep the serialized values compact and JSON-friendly.
76
+ data["block_out_channels"] = list(self.block_out_channels)
77
+ return data
78
+
79
+
80
+ DigitDiffusionConfig.register_for_auto_class()
modeling.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Model for MNiST-IMG-390k
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Any, Optional
8
+
9
+ import torch
10
+ from diffusers import UNet2DConditionModel
11
+ from transformers import PreTrainedModel
12
+ from transformers.utils import ModelOutput
13
+
14
+ from configuration import DigitDiffusionConfig
15
+
16
+
17
+ @dataclass
18
+ class DigitDiffusionOutput(ModelOutput):
19
+ sample: torch.FloatTensor | None = None
20
+
21
+
22
+ class DigitDiffusionModel(PreTrainedModel):
23
+
24
+ config_class = DigitDiffusionConfig
25
+ base_model_prefix = "unet"
26
+ main_input_name = "noisy_images"
27
+
28
+ def __init__(self, config: DigitDiffusionConfig) -> None:
29
+ super().__init__(config)
30
+
31
+ block_count = len(config.block_out_channels)
32
+
33
+ self.unet = UNet2DConditionModel(
34
+ sample_size=config.sample_size,
35
+ in_channels=config.in_channels,
36
+ out_channels=config.out_channels,
37
+ layers_per_block=config.layers_per_block,
38
+ block_out_channels=tuple(config.block_out_channels),
39
+ down_block_types=("DownBlock2D",) * block_count,
40
+ up_block_types=("UpBlock2D",) * block_count,
41
+ mid_block_type="UNetMidBlock2D",
42
+ norm_num_groups=config.norm_num_groups,
43
+ num_class_embeds=config.num_classes,
44
+ cross_attention_dim=config.cross_attention_dim,
45
+ class_embed_type=config.class_embed_type,
46
+ )
47
+
48
+ def _init_weights(self, module):
49
+ # Diffusers initializes the UNet internally, so there is nothing extra
50
+ # to initialize here.
51
+ return
52
+
53
+ def _make_dummy_context(
54
+ self,
55
+ batch_size: int,
56
+ device: torch.device,
57
+ dtype: torch.dtype,
58
+ ) -> torch.Tensor:
59
+ return torch.zeros(
60
+ batch_size,
61
+ 1,
62
+ self.config.cross_attention_dim,
63
+ device=device,
64
+ dtype=dtype,
65
+ )
66
+
67
+ def _normalize_inputs(
68
+ self,
69
+ noisy_images: Optional[torch.Tensor] = None,
70
+ timesteps: Optional[torch.Tensor | int] = None,
71
+ sample: Optional[torch.Tensor] = None,
72
+ timestep: Optional[torch.Tensor | int] = None,
73
+ ) -> tuple[torch.Tensor, torch.Tensor]:
74
+ if noisy_images is None:
75
+ noisy_images = sample
76
+ if timesteps is None:
77
+ timesteps = timestep
78
+
79
+ if noisy_images is None:
80
+ raise ValueError("Either `noisy_images` or `sample` must be provided.")
81
+ if timesteps is None:
82
+ raise ValueError("Either `timesteps` or `timestep` must be provided.")
83
+
84
+ if not torch.is_tensor(timesteps):
85
+ timesteps = torch.tensor(
86
+ timesteps,
87
+ device=noisy_images.device,
88
+ dtype=torch.long,
89
+ )
90
+ if timesteps.ndim == 0:
91
+ timesteps = timesteps.expand(noisy_images.shape[0])
92
+ elif timesteps.shape[0] != noisy_images.shape[0]:
93
+ timesteps = timesteps.reshape(-1)
94
+ if timesteps.numel() == 1:
95
+ timesteps = timesteps.expand(noisy_images.shape[0])
96
+ elif timesteps.shape[0] != noisy_images.shape[0]:
97
+ raise ValueError(
98
+ "Timesteps must be a scalar, a batch-sized tensor, or a single-value tensor."
99
+ )
100
+
101
+ return noisy_images, timesteps.to(device=noisy_images.device, dtype=torch.long)
102
+
103
+ def forward(
104
+ self,
105
+ noisy_images: Optional[torch.Tensor] = None,
106
+ timesteps: Optional[torch.Tensor | int] = None,
107
+ class_labels: Optional[torch.Tensor] = None,
108
+ sample: Optional[torch.Tensor] = None,
109
+ timestep: Optional[torch.Tensor | int] = None,
110
+ encoder_hidden_states: Optional[torch.Tensor] = None,
111
+ return_dict: bool = True,
112
+ **kwargs: Any,
113
+ ):
114
+ noisy_images, timesteps = self._normalize_inputs(
115
+ noisy_images=noisy_images,
116
+ timesteps=timesteps,
117
+ sample=sample,
118
+ timestep=timestep,
119
+ )
120
+
121
+ batch_size = noisy_images.shape[0]
122
+ if class_labels is None:
123
+ class_labels = torch.zeros(
124
+ batch_size,
125
+ device=noisy_images.device,
126
+ dtype=torch.long,
127
+ )
128
+ else:
129
+ class_labels = class_labels.to(device=noisy_images.device, dtype=torch.long)
130
+
131
+ if encoder_hidden_states is None:
132
+ encoder_hidden_states = self._make_dummy_context(
133
+ batch_size=batch_size,
134
+ device=noisy_images.device,
135
+ dtype=noisy_images.dtype,
136
+ )
137
+
138
+ noise_pred = self.unet(
139
+ sample=noisy_images,
140
+ timestep=timesteps,
141
+ encoder_hidden_states=encoder_hidden_states,
142
+ class_labels=class_labels,
143
+ return_dict=True,
144
+ **kwargs,
145
+ ).sample
146
+
147
+ if return_dict:
148
+ return DigitDiffusionOutput(sample=noise_pred)
149
+ return (noise_pred,)
150
+
151
+ def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
152
+ if state_dict:
153
+ keys = list(state_dict.keys())
154
+ has_prefixed = any(k.startswith("unet.") for k in keys)
155
+ has_plain_unet = any(
156
+ k.startswith(
157
+ (
158
+ "conv_in.",
159
+ "conv_norm_out.",
160
+ "conv_out.",
161
+ "time_embedding.",
162
+ "class_embedding.",
163
+ "down_blocks.",
164
+ "up_blocks.",
165
+ "mid_block.",
166
+ )
167
+ )
168
+ for k in keys
169
+ )
170
+
171
+ if has_plain_unet and not has_prefixed:
172
+ state_dict = {f"unet.{k}": v for k, v in state_dict.items()}
173
+
174
+ return super().load_state_dict(state_dict, strict=strict, assign=assign)
175
+
176
+
177
+ DigitDiffusionModel.register_for_auto_class("AutoModel")