Harley-ml commited on
Commit
d09dc68
·
verified ·
1 Parent(s): d4806e0

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +11 -50
  2. configuration.py +80 -0
  3. modeling.py +177 -0
config.json CHANGED
@@ -1,61 +1,22 @@
1
  {
2
- "_class_name": "UNet2DConditionModel",
3
- "_diffusers_version": "0.37.1",
4
- "act_fn": "silu",
5
- "addition_embed_type": null,
6
- "addition_embed_type_num_heads": 64,
7
- "addition_time_embed_dim": null,
8
- "attention_head_dim": 8,
9
- "attention_type": "default",
10
- "block_out_channels": [
11
- 12,
12
- 16
13
- ],
14
- "center_input_sample": false,
15
  "class_embed_type": null,
16
- "class_embeddings_concat": false,
17
- "conv_in_kernel": 3,
18
- "conv_out_kernel": 3,
19
  "cross_attention_dim": 8,
20
- "cross_attention_norm": null,
21
- "down_block_types": [
22
- "DownBlock2D",
23
- "DownBlock2D"
24
- ],
25
- "downsample_padding": 1,
26
- "dropout": 0.0,
27
- "dual_cross_attention": false,
28
- "encoder_hid_dim": null,
29
- "encoder_hid_dim_type": null,
30
- "flip_sin_to_cos": true,
31
- "freq_shift": 0,
32
  "in_channels": 1,
33
  "layers_per_block": 8,
34
- "mid_block_only_cross_attention": null,
35
- "mid_block_scale_factor": 1,
36
  "mid_block_type": "UNetMidBlock2D",
37
- "norm_eps": 1e-05,
38
  "norm_num_groups": 4,
39
- "num_attention_heads": null,
40
  "num_class_embeds": 10,
41
- "only_cross_attention": false,
42
  "out_channels": 1,
43
- "projection_class_embeddings_input_dim": null,
44
- "resnet_out_scale_factor": 1.0,
45
- "resnet_skip_time_act": false,
46
- "resnet_time_scale_shift": "default",
47
- "reverse_transformer_layers_per_block": null,
48
  "sample_size": 32,
49
- "time_cond_proj_dim": null,
50
- "time_embedding_act_fn": null,
51
- "time_embedding_dim": null,
52
- "time_embedding_type": "positional",
53
- "timestep_post_act": null,
54
- "transformer_layers_per_block": 1,
55
- "up_block_types": [
56
- "UpBlock2D",
57
- "UpBlock2D"
58
- ],
59
- "upcast_attention": false,
60
- "use_linear_projection": false
61
  }
 
1
  {
2
+ "architectures": ["DigitDiffusionModel"],
3
+ "auto_map": {
4
+ "AutoConfig": "configuration.DigitDiffusionConfig",
5
+ "AutoModel": "modeling.DigitDiffusionModel"
6
+ },
7
+ "block_out_channels": [12, 16, 20],
 
 
 
 
 
 
 
8
  "class_embed_type": null,
 
 
 
9
  "cross_attention_dim": 8,
10
+ "down_block_types": ["DownBlock2D", "DownBlock2D", "DownBlock2D"],
11
+ "image_size": 32,
 
 
 
 
 
 
 
 
 
 
12
  "in_channels": 1,
13
  "layers_per_block": 8,
 
 
14
  "mid_block_type": "UNetMidBlock2D",
15
+ "model_type": "digit_diffusion",
16
  "norm_num_groups": 4,
 
17
  "num_class_embeds": 10,
18
+ "num_classes": 10,
19
  "out_channels": 1,
 
 
 
 
 
20
  "sample_size": 32,
21
+ "up_block_types": ["UpBlock2D", "UpBlock2D", "UpBlock2D"]
 
 
 
 
 
 
 
 
 
 
 
22
  }
configuration.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #Configuration for the MNiST-IMG-390k
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Iterable, Tuple
7
+
8
+ from transformers import PretrainedConfig
9
+
10
+
11
+ class DigitDiffusionConfig(PretrainedConfig):
12
+
13
+ model_type = "digit_diffusion"
14
+
15
+ def __init__(
16
+ self,
17
+ image_size: int = 32,
18
+ in_channels: int = 1,
19
+ out_channels: int = 1,
20
+ num_classes: int = 10,
21
+ block_out_channels: Iterable[int] = (12, 16, 20),
22
+ layers_per_block: int = 8,
23
+ norm_num_groups: int = 4,
24
+ cross_attention_dim: int = 8,
25
+ class_embed_type: str | None = None,
26
+ sample_size: int | None = None,
27
+ **kwargs,
28
+ ) -> None:
29
+ image_size = int(image_size)
30
+ sample_size = int(sample_size) if sample_size is not None else image_size
31
+
32
+ block_out_channels = tuple(int(v) for v in block_out_channels)
33
+ if not block_out_channels:
34
+ raise ValueError("block_out_channels must contain at least one entry.")
35
+ if any(v <= 0 for v in block_out_channels):
36
+ raise ValueError("block_out_channels must contain only positive integers.")
37
+
38
+ if image_size <= 0:
39
+ raise ValueError("image_size must be a positive integer.")
40
+ if sample_size <= 0:
41
+ raise ValueError("sample_size must be a positive integer.")
42
+ if in_channels <= 0 or out_channels <= 0:
43
+ raise ValueError("in_channels and out_channels must be positive integers.")
44
+ if num_classes <= 0:
45
+ raise ValueError("num_classes must be a positive integer.")
46
+ if layers_per_block <= 0:
47
+ raise ValueError("layers_per_block must be a positive integer.")
48
+ if norm_num_groups <= 0:
49
+ raise ValueError("norm_num_groups must be a positive integer.")
50
+ if cross_attention_dim <= 0:
51
+ raise ValueError("cross_attention_dim must be a positive integer.")
52
+
53
+ self.image_size = image_size
54
+ self.sample_size = sample_size
55
+ self.in_channels = int(in_channels)
56
+ self.out_channels = int(out_channels)
57
+ self.num_classes = int(num_classes)
58
+ self.block_out_channels = block_out_channels
59
+ self.layers_per_block = int(layers_per_block)
60
+ self.norm_num_groups = int(norm_num_groups)
61
+ self.cross_attention_dim = int(cross_attention_dim)
62
+ self.class_embed_type = class_embed_type
63
+
64
+ # Handy for HF model pages and AutoClass loading.
65
+ kwargs.setdefault("architectures", ["DigitDiffusionModel"])
66
+
67
+ super().__init__(**kwargs)
68
+
69
+ @property
70
+ def num_blocks(self) -> int:
71
+ return len(self.block_out_channels)
72
+
73
+ def to_dict(self):
74
+ data = super().to_dict()
75
+ # Keep the serialized values compact and JSON-friendly.
76
+ data["block_out_channels"] = list(self.block_out_channels)
77
+ return data
78
+
79
+
80
+ DigitDiffusionConfig.register_for_auto_class()
modeling.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Model for MNiST-IMG-390k
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Any, Optional
8
+
9
+ import torch
10
+ from diffusers import UNet2DConditionModel
11
+ from transformers import PreTrainedModel
12
+ from transformers.utils import ModelOutput
13
+
14
+ from configuration import DigitDiffusionConfig
15
+
16
+
17
+ @dataclass
18
+ class DigitDiffusionOutput(ModelOutput):
19
+ sample: torch.FloatTensor | None = None
20
+
21
+
22
+ class DigitDiffusionModel(PreTrainedModel):
23
+
24
+ config_class = DigitDiffusionConfig
25
+ base_model_prefix = "unet"
26
+ main_input_name = "noisy_images"
27
+
28
+ def __init__(self, config: DigitDiffusionConfig) -> None:
29
+ super().__init__(config)
30
+
31
+ block_count = len(config.block_out_channels)
32
+
33
+ self.unet = UNet2DConditionModel(
34
+ sample_size=config.sample_size,
35
+ in_channels=config.in_channels,
36
+ out_channels=config.out_channels,
37
+ layers_per_block=config.layers_per_block,
38
+ block_out_channels=tuple(config.block_out_channels),
39
+ down_block_types=("DownBlock2D",) * block_count,
40
+ up_block_types=("UpBlock2D",) * block_count,
41
+ mid_block_type="UNetMidBlock2D",
42
+ norm_num_groups=config.norm_num_groups,
43
+ num_class_embeds=config.num_classes,
44
+ cross_attention_dim=config.cross_attention_dim,
45
+ class_embed_type=config.class_embed_type,
46
+ )
47
+
48
+ def _init_weights(self, module):
49
+ # Diffusers initializes the UNet internally, so there is nothing extra
50
+ # to initialize here.
51
+ return
52
+
53
+ def _make_dummy_context(
54
+ self,
55
+ batch_size: int,
56
+ device: torch.device,
57
+ dtype: torch.dtype,
58
+ ) -> torch.Tensor:
59
+ return torch.zeros(
60
+ batch_size,
61
+ 1,
62
+ self.config.cross_attention_dim,
63
+ device=device,
64
+ dtype=dtype,
65
+ )
66
+
67
+ def _normalize_inputs(
68
+ self,
69
+ noisy_images: Optional[torch.Tensor] = None,
70
+ timesteps: Optional[torch.Tensor | int] = None,
71
+ sample: Optional[torch.Tensor] = None,
72
+ timestep: Optional[torch.Tensor | int] = None,
73
+ ) -> tuple[torch.Tensor, torch.Tensor]:
74
+ if noisy_images is None:
75
+ noisy_images = sample
76
+ if timesteps is None:
77
+ timesteps = timestep
78
+
79
+ if noisy_images is None:
80
+ raise ValueError("Either `noisy_images` or `sample` must be provided.")
81
+ if timesteps is None:
82
+ raise ValueError("Either `timesteps` or `timestep` must be provided.")
83
+
84
+ if not torch.is_tensor(timesteps):
85
+ timesteps = torch.tensor(
86
+ timesteps,
87
+ device=noisy_images.device,
88
+ dtype=torch.long,
89
+ )
90
+ if timesteps.ndim == 0:
91
+ timesteps = timesteps.expand(noisy_images.shape[0])
92
+ elif timesteps.shape[0] != noisy_images.shape[0]:
93
+ timesteps = timesteps.reshape(-1)
94
+ if timesteps.numel() == 1:
95
+ timesteps = timesteps.expand(noisy_images.shape[0])
96
+ elif timesteps.shape[0] != noisy_images.shape[0]:
97
+ raise ValueError(
98
+ "Timesteps must be a scalar, a batch-sized tensor, or a single-value tensor."
99
+ )
100
+
101
+ return noisy_images, timesteps.to(device=noisy_images.device, dtype=torch.long)
102
+
103
+ def forward(
104
+ self,
105
+ noisy_images: Optional[torch.Tensor] = None,
106
+ timesteps: Optional[torch.Tensor | int] = None,
107
+ class_labels: Optional[torch.Tensor] = None,
108
+ sample: Optional[torch.Tensor] = None,
109
+ timestep: Optional[torch.Tensor | int] = None,
110
+ encoder_hidden_states: Optional[torch.Tensor] = None,
111
+ return_dict: bool = True,
112
+ **kwargs: Any,
113
+ ):
114
+ noisy_images, timesteps = self._normalize_inputs(
115
+ noisy_images=noisy_images,
116
+ timesteps=timesteps,
117
+ sample=sample,
118
+ timestep=timestep,
119
+ )
120
+
121
+ batch_size = noisy_images.shape[0]
122
+ if class_labels is None:
123
+ class_labels = torch.zeros(
124
+ batch_size,
125
+ device=noisy_images.device,
126
+ dtype=torch.long,
127
+ )
128
+ else:
129
+ class_labels = class_labels.to(device=noisy_images.device, dtype=torch.long)
130
+
131
+ if encoder_hidden_states is None:
132
+ encoder_hidden_states = self._make_dummy_context(
133
+ batch_size=batch_size,
134
+ device=noisy_images.device,
135
+ dtype=noisy_images.dtype,
136
+ )
137
+
138
+ noise_pred = self.unet(
139
+ sample=noisy_images,
140
+ timestep=timesteps,
141
+ encoder_hidden_states=encoder_hidden_states,
142
+ class_labels=class_labels,
143
+ return_dict=True,
144
+ **kwargs,
145
+ ).sample
146
+
147
+ if return_dict:
148
+ return DigitDiffusionOutput(sample=noise_pred)
149
+ return (noise_pred,)
150
+
151
+ def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
152
+ if state_dict:
153
+ keys = list(state_dict.keys())
154
+ has_prefixed = any(k.startswith("unet.") for k in keys)
155
+ has_plain_unet = any(
156
+ k.startswith(
157
+ (
158
+ "conv_in.",
159
+ "conv_norm_out.",
160
+ "conv_out.",
161
+ "time_embedding.",
162
+ "class_embedding.",
163
+ "down_blocks.",
164
+ "up_blocks.",
165
+ "mid_block.",
166
+ )
167
+ )
168
+ for k in keys
169
+ )
170
+
171
+ if has_plain_unet and not has_prefixed:
172
+ state_dict = {f"unet.{k}": v for k, v in state_dict.items()}
173
+
174
+ return super().load_state_dict(state_dict, strict=strict, assign=assign)
175
+
176
+
177
+ DigitDiffusionModel.register_for_auto_class("AutoModel")