dineshsai07 commited on
Commit
d5d32a4
·
verified ·
1 Parent(s): 0ccacae

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. versatile_diffusion/lib/model_zoo/__pycache__/ddim.cpython-310.pyc +0 -0
  2. versatile_diffusion/lib/model_zoo/__pycache__/ddim.cpython-38.pyc +0 -0
  3. versatile_diffusion/lib/model_zoo/__pycache__/ddim_dualcontext.cpython-38.pyc +0 -0
  4. versatile_diffusion/lib/model_zoo/__pycache__/ddim_vd.cpython-310.pyc +0 -0
  5. versatile_diffusion/lib/model_zoo/__pycache__/ddim_vd.cpython-38.pyc +0 -0
  6. versatile_diffusion/lib/model_zoo/__pycache__/diffusion_modules.cpython-310.pyc +0 -0
  7. versatile_diffusion/lib/model_zoo/__pycache__/diffusion_modules.cpython-38.pyc +0 -0
  8. versatile_diffusion/lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc +0 -0
  9. versatile_diffusion/lib/model_zoo/__pycache__/diffusion_utils.cpython-38.pyc +0 -0
  10. versatile_diffusion/lib/model_zoo/__pycache__/distributions.cpython-310.pyc +0 -0
  11. versatile_diffusion/lib/model_zoo/__pycache__/distributions.cpython-38.pyc +0 -0
  12. versatile_diffusion/lib/model_zoo/__pycache__/ema.cpython-310.pyc +0 -0
  13. versatile_diffusion/lib/model_zoo/__pycache__/ema.cpython-38.pyc +0 -0
  14. versatile_diffusion/lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc +0 -0
  15. versatile_diffusion/lib/model_zoo/__pycache__/openaimodel.cpython-38.pyc +0 -0
  16. versatile_diffusion/lib/model_zoo/__pycache__/optimus.cpython-310.pyc +0 -0
  17. versatile_diffusion/lib/model_zoo/__pycache__/optimus.cpython-38.pyc +0 -0
  18. versatile_diffusion/lib/model_zoo/__pycache__/sd.cpython-310.pyc +0 -0
  19. versatile_diffusion/lib/model_zoo/__pycache__/sd.cpython-38.pyc +0 -0
  20. versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/model.cpython-38.pyc +0 -0
  21. versatile_diffusion/lib/model_zoo/clip_justin/model.py +436 -0
  22. versatile_diffusion/lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc +0 -0
  23. versatile_diffusion/lib/model_zoo/common/__pycache__/get_model.cpython-38.pyc +0 -0
  24. versatile_diffusion/lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc +0 -0
  25. versatile_diffusion/lib/model_zoo/common/__pycache__/get_optimizer.cpython-38.pyc +0 -0
  26. versatile_diffusion/lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc +0 -0
  27. versatile_diffusion/lib/model_zoo/common/__pycache__/get_scheduler.cpython-38.pyc +0 -0
  28. versatile_diffusion/lib/model_zoo/common/__pycache__/utils.cpython-310.pyc +0 -0
  29. versatile_diffusion/lib/model_zoo/common/__pycache__/utils.cpython-38.pyc +0 -0
  30. versatile_diffusion/lib/model_zoo/common/utils.py +292 -0
  31. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_bert.cpython-310.pyc +0 -0
  32. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_bert.cpython-38.pyc +0 -0
  33. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_gpt2.cpython-310.pyc +0 -0
  34. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_gpt2.cpython-38.pyc +0 -0
  35. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_utils.cpython-310.pyc +0 -0
  36. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_utils.cpython-38.pyc +0 -0
  37. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/file_utils.cpython-310.pyc +0 -0
  38. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/file_utils.cpython-38.pyc +0 -0
  39. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/modeling_utils.cpython-38.pyc +0 -0
  40. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_bert.cpython-38.pyc +0 -0
  41. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_gpt2.cpython-310.pyc +0 -0
  42. versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_utils.cpython-38.pyc +0 -0
  43. versatile_diffusion/lib/model_zoo/optimus_models/configuration_bert.py +113 -0
  44. versatile_diffusion/lib/model_zoo/optimus_models/configuration_gpt2.py +143 -0
  45. versatile_diffusion/lib/model_zoo/optimus_models/configuration_utils.py +205 -0
  46. versatile_diffusion/lib/model_zoo/optimus_models/file_utils.py +294 -0
  47. versatile_diffusion/lib/model_zoo/optimus_models/modeling_utils.py +780 -0
  48. versatile_diffusion/lib/model_zoo/optimus_models/optimus_bert.py +1439 -0
  49. versatile_diffusion/lib/model_zoo/optimus_models/optimus_gpt2.py +1122 -0
  50. versatile_diffusion/lib/model_zoo/optimus_models/tokenization_bert.py +457 -0
versatile_diffusion/lib/model_zoo/__pycache__/ddim.cpython-310.pyc ADDED
Binary file (6.25 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (6.2 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/ddim_dualcontext.cpython-38.pyc ADDED
Binary file (4.4 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/ddim_vd.cpython-310.pyc ADDED
Binary file (9.66 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/ddim_vd.cpython-38.pyc ADDED
Binary file (9.54 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/diffusion_modules.cpython-310.pyc ADDED
Binary file (20.3 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/diffusion_modules.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc ADDED
Binary file (9.56 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/diffusion_utils.cpython-38.pyc ADDED
Binary file (9.52 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/distributions.cpython-310.pyc ADDED
Binary file (3.79 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/distributions.cpython-38.pyc ADDED
Binary file (3.79 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.04 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/ema.cpython-38.pyc ADDED
Binary file (2.99 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc ADDED
Binary file (44.8 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/openaimodel.cpython-38.pyc ADDED
Binary file (47.3 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/optimus.cpython-310.pyc ADDED
Binary file (18.3 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/optimus.cpython-38.pyc ADDED
Binary file (18.3 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/sd.cpython-310.pyc ADDED
Binary file (22 kB). View file
 
versatile_diffusion/lib/model_zoo/__pycache__/sd.cpython-38.pyc ADDED
Binary file (22.2 kB). View file
 
versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/model.cpython-38.pyc ADDED
Binary file (15 kB). View file
 
versatile_diffusion/lib/model_zoo/clip_justin/model.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisionTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ x = self.ln_post(x[:, 0, :])
236
+
237
+ if self.proj is not None:
238
+ x = x @ self.proj
239
+
240
+ return x
241
+
242
+
243
+ class CLIP(nn.Module):
244
+ def __init__(self,
245
+ embed_dim: int,
246
+ # vision
247
+ image_resolution: int,
248
+ vision_layers: Union[Tuple[int, int, int, int], int],
249
+ vision_width: int,
250
+ vision_patch_size: int,
251
+ # text
252
+ context_length: int,
253
+ vocab_size: int,
254
+ transformer_width: int,
255
+ transformer_heads: int,
256
+ transformer_layers: int
257
+ ):
258
+ super().__init__()
259
+
260
+ self.context_length = context_length
261
+
262
+ if isinstance(vision_layers, (tuple, list)):
263
+ vision_heads = vision_width * 32 // 64
264
+ self.visual = ModifiedResNet(
265
+ layers=vision_layers,
266
+ output_dim=embed_dim,
267
+ heads=vision_heads,
268
+ input_resolution=image_resolution,
269
+ width=vision_width
270
+ )
271
+ else:
272
+ vision_heads = vision_width // 64
273
+ self.visual = VisionTransformer(
274
+ input_resolution=image_resolution,
275
+ patch_size=vision_patch_size,
276
+ width=vision_width,
277
+ layers=vision_layers,
278
+ heads=vision_heads,
279
+ output_dim=embed_dim
280
+ )
281
+
282
+ self.transformer = Transformer(
283
+ width=transformer_width,
284
+ layers=transformer_layers,
285
+ heads=transformer_heads,
286
+ attn_mask=self.build_attention_mask()
287
+ )
288
+
289
+ self.vocab_size = vocab_size
290
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
291
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
292
+ self.ln_final = LayerNorm(transformer_width)
293
+
294
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
295
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
296
+
297
+ self.initialize_parameters()
298
+
299
+ def initialize_parameters(self):
300
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
301
+ nn.init.normal_(self.positional_embedding, std=0.01)
302
+
303
+ if isinstance(self.visual, ModifiedResNet):
304
+ if self.visual.attnpool is not None:
305
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
306
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
307
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
308
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
309
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
310
+
311
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
312
+ for name, param in resnet_block.named_parameters():
313
+ if name.endswith("bn3.weight"):
314
+ nn.init.zeros_(param)
315
+
316
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
317
+ attn_std = self.transformer.width ** -0.5
318
+ fc_std = (2 * self.transformer.width) ** -0.5
319
+ for block in self.transformer.resblocks:
320
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
321
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
322
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
323
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
324
+
325
+ if self.text_projection is not None:
326
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
327
+
328
+ def build_attention_mask(self):
329
+ # lazily create causal attention mask, with full attention between the vision tokens
330
+ # pytorch uses additive attention mask; fill with -inf
331
+ mask = torch.empty(self.context_length, self.context_length)
332
+ mask.fill_(float("-inf"))
333
+ mask.triu_(1) # zero out the lower diagonal
334
+ return mask
335
+
336
+ @property
337
+ def dtype(self):
338
+ return self.visual.conv1.weight.dtype
339
+
340
+ def encode_image(self, image):
341
+ return self.visual(image.type(self.dtype))
342
+
343
+ def encode_text(self, text):
344
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
345
+
346
+ x = x + self.positional_embedding.type(self.dtype)
347
+ x = x.permute(1, 0, 2) # NLD -> LND
348
+ x = self.transformer(x)
349
+ x = x.permute(1, 0, 2) # LND -> NLD
350
+ x = self.ln_final(x).type(self.dtype)
351
+
352
+ # x.shape = [batch_size, n_ctx, transformer.width]
353
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
354
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
355
+
356
+ return x
357
+
358
+ def forward(self, image, text):
359
+ image_features = self.encode_image(image)
360
+ text_features = self.encode_text(text)
361
+
362
+ # normalized features
363
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
364
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
365
+
366
+ # cosine similarity as logits
367
+ logit_scale = self.logit_scale.exp()
368
+ logits_per_image = logit_scale * image_features @ text_features.t()
369
+ logits_per_text = logits_per_image.t()
370
+
371
+ # shape = [global_batch_size, global_batch_size]
372
+ return logits_per_image, logits_per_text
373
+
374
+
375
+ def convert_weights(model: nn.Module):
376
+ """Convert applicable model parameters to fp16"""
377
+
378
+ def _convert_weights_to_fp16(l):
379
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
380
+ l.weight.data = l.weight.data.half()
381
+ if l.bias is not None:
382
+ l.bias.data = l.bias.data.half()
383
+
384
+ if isinstance(l, nn.MultiheadAttention):
385
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
386
+ tensor = getattr(l, attr)
387
+ if tensor is not None:
388
+ tensor.data = tensor.data.half()
389
+
390
+ for name in ["text_projection", "proj"]:
391
+ if hasattr(l, name):
392
+ attr = getattr(l, name)
393
+ if attr is not None:
394
+ attr.data = attr.data.half()
395
+
396
+ model.apply(_convert_weights_to_fp16)
397
+
398
+
399
+ def build_model(state_dict: dict):
400
+ vit = "visual.proj" in state_dict
401
+
402
+ if vit:
403
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
404
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
405
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
406
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
407
+ image_resolution = vision_patch_size * grid_size
408
+ else:
409
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
410
+ vision_layers = tuple(counts)
411
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
412
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
413
+ vision_patch_size = None
414
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
415
+ image_resolution = output_width * 32
416
+
417
+ embed_dim = state_dict["text_projection"].shape[1]
418
+ context_length = state_dict["positional_embedding"].shape[0]
419
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
420
+ transformer_width = state_dict["ln_final.weight"].shape[0]
421
+ transformer_heads = transformer_width // 64
422
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
423
+
424
+ model = CLIP(
425
+ embed_dim,
426
+ image_resolution, vision_layers, vision_width, vision_patch_size,
427
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
428
+ )
429
+
430
+ for key in ["input_resolution", "context_length", "vocab_size"]:
431
+ if key in state_dict:
432
+ del state_dict[key]
433
+
434
+ convert_weights(model)
435
+ model.load_state_dict(state_dict)
436
+ return model.eval()
versatile_diffusion/lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc ADDED
Binary file (3.31 kB). View file
 
versatile_diffusion/lib/model_zoo/common/__pycache__/get_model.cpython-38.pyc ADDED
Binary file (3.26 kB). View file
 
versatile_diffusion/lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc ADDED
Binary file (1.98 kB). View file
 
versatile_diffusion/lib/model_zoo/common/__pycache__/get_optimizer.cpython-38.pyc ADDED
Binary file (1.93 kB). View file
 
versatile_diffusion/lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc ADDED
Binary file (9.47 kB). View file
 
versatile_diffusion/lib/model_zoo/common/__pycache__/get_scheduler.cpython-38.pyc ADDED
Binary file (9.54 kB). View file
 
versatile_diffusion/lib/model_zoo/common/__pycache__/utils.cpython-310.pyc ADDED
Binary file (9.75 kB). View file
 
versatile_diffusion/lib/model_zoo/common/__pycache__/utils.cpython-38.pyc ADDED
Binary file (9.77 kB). View file
 
versatile_diffusion/lib/model_zoo/common/utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import copy
6
+ import functools
7
+ import itertools
8
+
9
+ import matplotlib.pyplot as plt
10
+
11
+ ########
12
+ # unit #
13
+ ########
14
+
15
+ def singleton(class_):
16
+ instances = {}
17
+ def getinstance(*args, **kwargs):
18
+ if class_ not in instances:
19
+ instances[class_] = class_(*args, **kwargs)
20
+ return instances[class_]
21
+ return getinstance
22
+
23
+ def str2value(v):
24
+ v = v.strip()
25
+ try:
26
+ return int(v)
27
+ except:
28
+ pass
29
+ try:
30
+ return float(v)
31
+ except:
32
+ pass
33
+ if v in ('True', 'true'):
34
+ return True
35
+ elif v in ('False', 'false'):
36
+ return False
37
+ else:
38
+ return v
39
+
40
+ @singleton
41
+ class get_unit(object):
42
+ def __init__(self):
43
+ self.unit = {}
44
+ self.register('none', None)
45
+
46
+ # general convolution
47
+ self.register('conv' , nn.Conv2d)
48
+ self.register('bn' , nn.BatchNorm2d)
49
+ self.register('relu' , nn.ReLU)
50
+ self.register('relu6' , nn.ReLU6)
51
+ self.register('lrelu' , nn.LeakyReLU)
52
+ self.register('dropout' , nn.Dropout)
53
+ self.register('dropout2d', nn.Dropout2d)
54
+ self.register('sine', Sine)
55
+ self.register('relusine', ReLUSine)
56
+
57
+ def register(self,
58
+ name,
59
+ unitf,):
60
+
61
+ self.unit[name] = unitf
62
+
63
+ def __call__(self, name):
64
+ if name is None:
65
+ return None
66
+ i = name.find('(')
67
+ i = len(name) if i==-1 else i
68
+ t = name[:i]
69
+ f = self.unit[t]
70
+ args = name[i:].strip('()')
71
+ if len(args) == 0:
72
+ args = {}
73
+ return f
74
+ else:
75
+ args = args.split('=')
76
+ args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
77
+ args = list(itertools.chain.from_iterable(args))
78
+ args = [i.strip() for i in args if len(i)>0]
79
+ kwargs = {}
80
+ for k, v in zip(args[::2], args[1::2]):
81
+ if v[0]=='(' and v[-1]==')':
82
+ kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
83
+ elif v[0]=='[' and v[-1]==']':
84
+ kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
85
+ else:
86
+ kwargs[k] = str2value(v)
87
+ return functools.partial(f, **kwargs)
88
+
89
+ def register(name):
90
+ def wrapper(class_):
91
+ get_unit().register(name, class_)
92
+ return class_
93
+ return wrapper
94
+
95
+ class Sine(object):
96
+ def __init__(self, freq, gain=1):
97
+ self.freq = freq
98
+ self.gain = gain
99
+ self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
100
+
101
+ def __call__(self, x, gain=1):
102
+ act_gain = self.gain * gain
103
+ return torch.sin(self.freq * x) * act_gain
104
+
105
+ def __repr__(self,):
106
+ return self.repr
107
+
108
+ class ReLUSine(nn.Module):
109
+ def __init(self):
110
+ super().__init__()
111
+
112
+ def forward(self, input):
113
+ a = torch.sin(30 * input)
114
+ b = nn.ReLU(inplace=False)(input)
115
+ return a+b
116
+
117
+ @register('lrelu_agc')
118
+ # class lrelu_agc(nn.Module):
119
+ class lrelu_agc(object):
120
+ """
121
+ The lrelu layer with alpha, gain and clamp
122
+ """
123
+ def __init__(self, alpha=0.1, gain=1, clamp=None):
124
+ # super().__init__()
125
+ self.alpha = alpha
126
+ if gain == 'sqrt_2':
127
+ self.gain = np.sqrt(2)
128
+ else:
129
+ self.gain = gain
130
+ self.clamp = clamp
131
+ self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
132
+ alpha, gain, clamp)
133
+
134
+ # def forward(self, x, gain=1):
135
+ def __call__(self, x, gain=1):
136
+ x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
137
+ act_gain = self.gain * gain
138
+ act_clamp = self.clamp * gain if self.clamp is not None else None
139
+ if act_gain != 1:
140
+ x = x * act_gain
141
+ if act_clamp is not None:
142
+ x = x.clamp(-act_clamp, act_clamp)
143
+ return x
144
+
145
+ def __repr__(self,):
146
+ return self.repr
147
+
148
+ ####################
149
+ # spatial encoding #
150
+ ####################
151
+
152
+ @register('se')
153
+ class SpatialEncoding(nn.Module):
154
+ def __init__(self,
155
+ in_dim,
156
+ out_dim,
157
+ sigma = 6,
158
+ cat_input=True,
159
+ require_grad=False,):
160
+
161
+ super().__init__()
162
+ assert out_dim % (2*in_dim) == 0, "dimension must be dividable"
163
+
164
+ n = out_dim // 2 // in_dim
165
+ m = 2**np.linspace(0, sigma, n)
166
+ m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1)
167
+ m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
168
+ self.emb = torch.FloatTensor(m)
169
+ if require_grad:
170
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
171
+ self.in_dim = in_dim
172
+ self.out_dim = out_dim
173
+ self.sigma = sigma
174
+ self.cat_input = cat_input
175
+ self.require_grad = require_grad
176
+
177
+ def forward(self, x, format='[n x c]'):
178
+ """
179
+ Args:
180
+ x: [n x m1],
181
+ m1 usually is 2
182
+ Outputs:
183
+ y: [n x m2]
184
+ m2 dimention number
185
+ """
186
+ if format == '[bs x c x 2D]':
187
+ xshape = x.shape
188
+ x = x.permute(0, 2, 3, 1).contiguous()
189
+ x = x.view(-1, x.size(-1))
190
+ elif format == '[n x c]':
191
+ pass
192
+ else:
193
+ raise ValueError
194
+
195
+ if not self.require_grad:
196
+ self.emb = self.emb.to(x.device)
197
+ y = torch.mm(x, self.emb.T)
198
+ if self.cat_input:
199
+ z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
200
+ else:
201
+ z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
202
+
203
+ if format == '[bs x c x 2D]':
204
+ z = z.view(xshape[0], xshape[2], xshape[3], -1)
205
+ z = z.permute(0, 3, 1, 2).contiguous()
206
+ return z
207
+
208
+ def extra_repr(self):
209
+ outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
210
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
211
+ return outstr
212
+
213
+ @register('rffe')
214
+ class RFFEncoding(SpatialEncoding):
215
+ """
216
+ Random Fourier Features
217
+ """
218
+ def __init__(self,
219
+ in_dim,
220
+ out_dim,
221
+ sigma = 6,
222
+ cat_input=True,
223
+ require_grad=False,):
224
+
225
+ super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
226
+ n = out_dim // 2
227
+ m = np.random.normal(0, sigma, size=(n, in_dim))
228
+ self.emb = torch.FloatTensor(m)
229
+ if require_grad:
230
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
231
+
232
+ def extra_repr(self):
233
+ outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
234
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
235
+ return outstr
236
+
237
+ ##########
238
+ # helper #
239
+ ##########
240
+
241
+ def freeze(net):
242
+ for m in net.modules():
243
+ if isinstance(m, (
244
+ nn.BatchNorm2d,
245
+ nn.SyncBatchNorm,)):
246
+ # inplace_abn not supported
247
+ m.eval()
248
+ for pi in net.parameters():
249
+ pi.requires_grad = False
250
+ return net
251
+
252
+ def common_init(m):
253
+ if isinstance(m, (
254
+ nn.Conv2d,
255
+ nn.ConvTranspose2d,)):
256
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
257
+ if m.bias is not None:
258
+ nn.init.constant_(m.bias, 0)
259
+ elif isinstance(m, (
260
+ nn.BatchNorm2d,
261
+ nn.SyncBatchNorm,)):
262
+ nn.init.constant_(m.weight, 1)
263
+ nn.init.constant_(m.bias, 0)
264
+ else:
265
+ pass
266
+
267
+ def init_module(module):
268
+ """
269
+ Args:
270
+ module: [nn.module] list or nn.module
271
+ a list of module to be initialized.
272
+ """
273
+ if isinstance(module, (list, tuple)):
274
+ module = list(module)
275
+ else:
276
+ module = [module]
277
+
278
+ for mi in module:
279
+ for mii in mi.modules():
280
+ common_init(mii)
281
+
282
+ def get_total_param(net):
283
+ if getattr(net, 'parameters', None) is None:
284
+ return 0
285
+ return sum(p.numel() for p in net.parameters())
286
+
287
+ def get_total_param_sum(net):
288
+ if getattr(net, 'parameters', None) is None:
289
+ return 0
290
+ with torch.no_grad():
291
+ s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
292
+ return s
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_bert.cpython-310.pyc ADDED
Binary file (5.34 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_bert.cpython-38.pyc ADDED
Binary file (5.24 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_gpt2.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_gpt2.cpython-38.pyc ADDED
Binary file (5.02 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_utils.cpython-310.pyc ADDED
Binary file (9.46 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_utils.cpython-38.pyc ADDED
Binary file (9.4 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/file_utils.cpython-310.pyc ADDED
Binary file (8.37 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/file_utils.cpython-38.pyc ADDED
Binary file (8.32 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (31.7 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_bert.cpython-38.pyc ADDED
Binary file (54.6 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_gpt2.cpython-310.pyc ADDED
Binary file (9.12 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_utils.cpython-38.pyc ADDED
Binary file (33.5 kB). View file
 
versatile_diffusion/lib/model_zoo/optimus_models/configuration_bert.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration """
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import json
21
+ import logging
22
+ import sys
23
+ from io import open
24
+
25
+ from .configuration_utils import PretrainedConfig
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
30
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
31
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
32
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
33
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
34
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
35
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
36
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
37
+ 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
38
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
39
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
40
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
41
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
42
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
43
+ }
44
+
45
+
46
+ class BertConfig(PretrainedConfig):
47
+ r"""
48
+ :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
49
+ `BertModel`.
50
+
51
+
52
+ Arguments:
53
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
54
+ hidden_size: Size of the encoder layers and the pooler layer.
55
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
56
+ num_attention_heads: Number of attention heads for each attention layer in
57
+ the Transformer encoder.
58
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
59
+ layer in the Transformer encoder.
60
+ hidden_act: The non-linear activation function (function or string) in the
61
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
62
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
63
+ layers in the embeddings, encoder, and pooler.
64
+ attention_probs_dropout_prob: The dropout ratio for the attention
65
+ probabilities.
66
+ max_position_embeddings: The maximum sequence length that this model might
67
+ ever be used with. Typically set this to something large just in case
68
+ (e.g., 512 or 1024 or 2048).
69
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
70
+ `BertModel`.
71
+ initializer_range: The sttdev of the truncated_normal_initializer for
72
+ initializing all weight matrices.
73
+ layer_norm_eps: The epsilon used by LayerNorm.
74
+ """
75
+ pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
76
+
77
+ def __init__(self,
78
+ vocab_size_or_config_json_file=30522,
79
+ hidden_size=768,
80
+ num_hidden_layers=12,
81
+ num_attention_heads=12,
82
+ intermediate_size=3072,
83
+ hidden_act="gelu",
84
+ hidden_dropout_prob=0.1,
85
+ attention_probs_dropout_prob=0.1,
86
+ max_position_embeddings=512,
87
+ type_vocab_size=2,
88
+ initializer_range=0.02,
89
+ layer_norm_eps=1e-12,
90
+ **kwargs):
91
+ super(BertConfig, self).__init__(**kwargs)
92
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
93
+ and isinstance(vocab_size_or_config_json_file, unicode)):
94
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
95
+ json_config = json.loads(reader.read())
96
+ for key, value in json_config.items():
97
+ self.__dict__[key] = value
98
+ elif isinstance(vocab_size_or_config_json_file, int):
99
+ self.vocab_size = vocab_size_or_config_json_file
100
+ self.hidden_size = hidden_size
101
+ self.num_hidden_layers = num_hidden_layers
102
+ self.num_attention_heads = num_attention_heads
103
+ self.hidden_act = hidden_act
104
+ self.intermediate_size = intermediate_size
105
+ self.hidden_dropout_prob = hidden_dropout_prob
106
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
107
+ self.max_position_embeddings = max_position_embeddings
108
+ self.type_vocab_size = type_vocab_size
109
+ self.initializer_range = initializer_range
110
+ self.layer_norm_eps = layer_norm_eps
111
+ else:
112
+ raise ValueError("First argument must be either a vocabulary size (int)"
113
+ " or the path to a pretrained model config file (str)")
versatile_diffusion/lib/model_zoo/optimus_models/configuration_gpt2.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ OpenAI GPT-2 configuration """
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import json
21
+ import logging
22
+ import sys
23
+ from io import open
24
+
25
+ from .configuration_utils import PretrainedConfig
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
30
+ "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
31
+ "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"}
32
+
33
+ class GPT2Config(PretrainedConfig):
34
+ """Configuration class to store the configuration of a `GPT2Model`.
35
+
36
+ Args:
37
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
38
+ n_positions: Number of positional embeddings.
39
+ n_ctx: Size of the causal mask (usually same as n_positions).
40
+ n_embd: Dimensionality of the embeddings and hidden states.
41
+ n_layer: Number of hidden layers in the Transformer encoder.
42
+ n_head: Number of attention heads for each attention layer in
43
+ the Transformer encoder.
44
+ layer_norm_epsilon: epsilon to use in the layer norm layers
45
+ resid_pdrop: The dropout probabilitiy for all fully connected
46
+ layers in the embeddings, encoder, and pooler.
47
+ attn_pdrop: The dropout ratio for the attention
48
+ probabilities.
49
+ embd_pdrop: The dropout ratio for the embeddings.
50
+ initializer_range: The sttdev of the truncated_normal_initializer for
51
+ initializing all weight matrices.
52
+ """
53
+ pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
54
+
55
+ def __init__(
56
+ self,
57
+ vocab_size_or_config_json_file=50257,
58
+ n_positions=1024,
59
+ n_ctx=1024,
60
+ n_embd=768,
61
+ n_layer=12,
62
+ n_head=12,
63
+ resid_pdrop=0.1,
64
+ embd_pdrop=0.1,
65
+ attn_pdrop=0.1,
66
+ layer_norm_epsilon=1e-5,
67
+ initializer_range=0.02,
68
+
69
+ num_labels=1,
70
+ summary_type='cls_index',
71
+ summary_use_proj=True,
72
+ summary_activation=None,
73
+ summary_proj_to_labels=True,
74
+ summary_first_dropout=0.1,
75
+ **kwargs
76
+ ):
77
+ """Constructs GPT2Config.
78
+
79
+ Args:
80
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
81
+ n_positions: Number of positional embeddings.
82
+ n_ctx: Size of the causal mask (usually same as n_positions).
83
+ n_embd: Dimensionality of the embeddings and hidden states.
84
+ n_layer: Number of hidden layers in the Transformer encoder.
85
+ n_head: Number of attention heads for each attention layer in
86
+ the Transformer encoder.
87
+ layer_norm_epsilon: epsilon to use in the layer norm layers
88
+ resid_pdrop: The dropout probabilitiy for all fully connected
89
+ layers in the embeddings, encoder, and pooler.
90
+ attn_pdrop: The dropout ratio for the attention
91
+ probabilities.
92
+ embd_pdrop: The dropout ratio for the embeddings.
93
+ initializer_range: The sttdev of the truncated_normal_initializer for
94
+ initializing all weight matrices.
95
+ """
96
+ super(GPT2Config, self).__init__(**kwargs)
97
+
98
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
99
+ and isinstance(vocab_size_or_config_json_file, unicode)):
100
+ with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
101
+ json_config = json.loads(reader.read())
102
+ for key, value in json_config.items():
103
+ self.__dict__[key] = value
104
+ elif isinstance(vocab_size_or_config_json_file, int):
105
+ self.vocab_size = vocab_size_or_config_json_file
106
+ self.n_ctx = n_ctx
107
+ self.n_positions = n_positions
108
+ self.n_embd = n_embd
109
+ self.n_layer = n_layer
110
+ self.n_head = n_head
111
+ self.resid_pdrop = resid_pdrop
112
+ self.embd_pdrop = embd_pdrop
113
+ self.attn_pdrop = attn_pdrop
114
+ self.layer_norm_epsilon = layer_norm_epsilon
115
+ self.initializer_range = initializer_range
116
+
117
+ self.num_labels = num_labels
118
+ self.summary_type = summary_type
119
+ self.summary_use_proj = summary_use_proj
120
+ self.summary_activation = summary_activation
121
+ self.summary_first_dropout = summary_first_dropout
122
+ self.summary_proj_to_labels = summary_proj_to_labels
123
+ else:
124
+ raise ValueError(
125
+ "First argument must be either a vocabulary size (int)"
126
+ "or the path to a pretrained model config file (str)"
127
+ )
128
+
129
+ @property
130
+ def max_position_embeddings(self):
131
+ return self.n_positions
132
+
133
+ @property
134
+ def hidden_size(self):
135
+ return self.n_embd
136
+
137
+ @property
138
+ def num_attention_heads(self):
139
+ return self.n_head
140
+
141
+ @property
142
+ def num_hidden_layers(self):
143
+ return self.n_layer
versatile_diffusion/lib/model_zoo/optimus_models/configuration_utils.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Configuration base class and utilities."""
17
+
18
+ from __future__ import (absolute_import, division, print_function,
19
+ unicode_literals)
20
+
21
+ import copy
22
+ import json
23
+ import logging
24
+ import os
25
+ from io import open
26
+
27
+ from .file_utils import cached_path, CONFIG_NAME
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ class PretrainedConfig(object):
32
+ r""" Base class for all configuration classes.
33
+ Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34
+
35
+ Note:
36
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37
+ It only affects the model's configuration.
38
+
39
+ Class attributes (overridden by derived classes):
40
+ - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
41
+
42
+ Parameters:
43
+ ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
44
+ ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens)
45
+ ``output_attentions``: boolean, default `False`. Should the model returns attentions weights.
46
+ ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
47
+ ``torchscript``: string, default `False`. Is the model used with Torchscript.
48
+ """
49
+ pretrained_config_archive_map = {}
50
+
51
+ def __init__(self, **kwargs):
52
+ self.finetuning_task = kwargs.pop('finetuning_task', None)
53
+ self.num_labels = kwargs.pop('num_labels', 2)
54
+ self.output_attentions = kwargs.pop('output_attentions', False)
55
+ self.output_hidden_states = kwargs.pop('output_hidden_states', False)
56
+ self.torchscript = kwargs.pop('torchscript', False)
57
+ self.pruned_heads = kwargs.pop('pruned_heads', {})
58
+
59
+ def save_pretrained(self, save_directory):
60
+ """ Save a configuration object to the directory `save_directory`, so that it
61
+ can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method.
62
+ """
63
+ assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
64
+
65
+ # If we save using the predefined names, we can load using `from_pretrained`
66
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
67
+
68
+ self.to_json_file(output_config_file)
69
+
70
+ @classmethod
71
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
72
+ r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
73
+
74
+ Parameters:
75
+ pretrained_model_name_or_path: either:
76
+
77
+ - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
78
+ - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
79
+ - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
80
+
81
+ cache_dir: (`optional`) string:
82
+ Path to a directory in which a downloaded pre-trained model
83
+ configuration should be cached if the standard cache should not be used.
84
+
85
+ kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading.
86
+
87
+ - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
88
+ - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
89
+
90
+ force_download: (`optional`) boolean, default False:
91
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
92
+
93
+ proxies: (`optional`) dict, default None:
94
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
95
+ The proxies are used on each request.
96
+
97
+ return_unused_kwargs: (`optional`) bool:
98
+
99
+ - If False, then this function returns just the final configuration object.
100
+ - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
101
+
102
+ Examples::
103
+
104
+ # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
105
+ # derived class: BertConfig
106
+ config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
107
+ config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
108
+ config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
109
+ config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
110
+ assert config.output_attention == True
111
+ config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
112
+ foo=False, return_unused_kwargs=True)
113
+ assert config.output_attention == True
114
+ assert unused_kwargs == {'foo': False}
115
+
116
+ """
117
+ cache_dir = kwargs.pop('cache_dir', None)
118
+ force_download = kwargs.pop('force_download', False)
119
+ proxies = kwargs.pop('proxies', None)
120
+ return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
121
+
122
+ if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
123
+ config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
124
+ elif os.path.isdir(pretrained_model_name_or_path):
125
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
126
+ else:
127
+ config_file = pretrained_model_name_or_path
128
+ # redirect to the cache, if necessary
129
+ try:
130
+ resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
131
+ except EnvironmentError as e:
132
+ if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
133
+ logger.error(
134
+ "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
135
+ config_file))
136
+ else:
137
+ logger.error(
138
+ "Model name '{}' was not found in model name list ({}). "
139
+ "We assumed '{}' was a path or url but couldn't find any file "
140
+ "associated to this path or url.".format(
141
+ pretrained_model_name_or_path,
142
+ ', '.join(cls.pretrained_config_archive_map.keys()),
143
+ config_file))
144
+ raise e
145
+ if resolved_config_file == config_file:
146
+ logger.info("loading configuration file {}".format(config_file))
147
+ else:
148
+ logger.info("loading configuration file {} from cache at {}".format(
149
+ config_file, resolved_config_file))
150
+
151
+ # Load config
152
+ config = cls.from_json_file(resolved_config_file)
153
+
154
+ if hasattr(config, 'pruned_heads'):
155
+ config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
156
+
157
+ # Update config with kwargs if needed
158
+ to_remove = []
159
+ for key, value in kwargs.items():
160
+ if hasattr(config, key):
161
+ setattr(config, key, value)
162
+ to_remove.append(key)
163
+ for key in to_remove:
164
+ kwargs.pop(key, None)
165
+
166
+ logger.info("Model config %s", config)
167
+ if return_unused_kwargs:
168
+ return config, kwargs
169
+ else:
170
+ return config
171
+
172
+ @classmethod
173
+ def from_dict(cls, json_object):
174
+ """Constructs a `Config` from a Python dictionary of parameters."""
175
+ config = cls(vocab_size_or_config_json_file=-1)
176
+ for key, value in json_object.items():
177
+ config.__dict__[key] = value
178
+ return config
179
+
180
+ @classmethod
181
+ def from_json_file(cls, json_file):
182
+ """Constructs a `BertConfig` from a json file of parameters."""
183
+ with open(json_file, "r", encoding='utf-8') as reader:
184
+ text = reader.read()
185
+ return cls.from_dict(json.loads(text))
186
+
187
+ def __eq__(self, other):
188
+ return self.__dict__ == other.__dict__
189
+
190
+ def __repr__(self):
191
+ return str(self.to_json_string())
192
+
193
+ def to_dict(self):
194
+ """Serializes this instance to a Python dictionary."""
195
+ output = copy.deepcopy(self.__dict__)
196
+ return output
197
+
198
+ def to_json_string(self):
199
+ """Serializes this instance to a JSON string."""
200
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
201
+
202
+ def to_json_file(self, json_file_path):
203
+ """ Save this instance to a json file."""
204
+ with open(json_file_path, "w", encoding='utf-8') as writer:
205
+ writer.write(self.to_json_string())
versatile_diffusion/lib/model_zoo/optimus_models/file_utils.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
7
+
8
+ import sys
9
+ import json
10
+ import logging
11
+ import os
12
+ import six
13
+ import shutil
14
+ import tempfile
15
+ import fnmatch
16
+ from functools import wraps
17
+ from hashlib import sha256
18
+ from io import open
19
+
20
+ # import boto3
21
+ # from botocore.config import Config
22
+ # from botocore.exceptions import ClientError
23
+ import requests
24
+ from tqdm import tqdm
25
+
26
+ try:
27
+ from torch.hub import _get_torch_home
28
+ torch_cache_home = _get_torch_home()
29
+ except ImportError:
30
+ torch_cache_home = os.path.expanduser(
31
+ os.getenv('TORCH_HOME', os.path.join(
32
+ os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
33
+ default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers')
34
+
35
+ try:
36
+ from urllib.parse import urlparse
37
+ except ImportError:
38
+ from urlparse import urlparse
39
+
40
+ try:
41
+ from pathlib import Path
42
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(
43
+ os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
44
+ except (AttributeError, ImportError):
45
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
46
+ os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
47
+ default_cache_path))
48
+
49
+ PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
50
+
51
+ WEIGHTS_NAME = "pytorch_model.bin"
52
+ TF_WEIGHTS_NAME = 'model.ckpt'
53
+ CONFIG_NAME = "config.json"
54
+
55
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
56
+
57
+ if not six.PY2:
58
+ def add_start_docstrings(*docstr):
59
+ def docstring_decorator(fn):
60
+ fn.__doc__ = ''.join(docstr) + fn.__doc__
61
+ return fn
62
+ return docstring_decorator
63
+
64
+ def add_end_docstrings(*docstr):
65
+ def docstring_decorator(fn):
66
+ fn.__doc__ = fn.__doc__ + ''.join(docstr)
67
+ return fn
68
+ return docstring_decorator
69
+ else:
70
+ # Not possible to update class docstrings on python2
71
+ def add_start_docstrings(*docstr):
72
+ def docstring_decorator(fn):
73
+ return fn
74
+ return docstring_decorator
75
+
76
+ def add_end_docstrings(*docstr):
77
+ def docstring_decorator(fn):
78
+ return fn
79
+ return docstring_decorator
80
+
81
+ def url_to_filename(url, etag=None):
82
+ """
83
+ Convert `url` into a hashed filename in a repeatable way.
84
+ If `etag` is specified, append its hash to the url's, delimited
85
+ by a period.
86
+ """
87
+ url_bytes = url.encode('utf-8')
88
+ url_hash = sha256(url_bytes)
89
+ filename = url_hash.hexdigest()
90
+
91
+ if etag:
92
+ etag_bytes = etag.encode('utf-8')
93
+ etag_hash = sha256(etag_bytes)
94
+ filename += '.' + etag_hash.hexdigest()
95
+
96
+ return filename
97
+
98
+
99
+ def filename_to_url(filename, cache_dir=None):
100
+ """
101
+ Return the url and etag (which may be ``None``) stored for `filename`.
102
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
103
+ """
104
+ if cache_dir is None:
105
+ cache_dir = PYTORCH_TRANSFORMERS_CACHE
106
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
107
+ cache_dir = str(cache_dir)
108
+
109
+ cache_path = os.path.join(cache_dir, filename)
110
+ if not os.path.exists(cache_path):
111
+ raise EnvironmentError("file {} not found".format(cache_path))
112
+
113
+ meta_path = cache_path + '.json'
114
+ if not os.path.exists(meta_path):
115
+ raise EnvironmentError("file {} not found".format(meta_path))
116
+
117
+ with open(meta_path, encoding="utf-8") as meta_file:
118
+ metadata = json.load(meta_file)
119
+ url = metadata['url']
120
+ etag = metadata['etag']
121
+
122
+ return url, etag
123
+
124
+
125
+ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
126
+ """
127
+ Given something that might be a URL (or might be a local path),
128
+ determine which. If it's a URL, download the file and cache it, and
129
+ return the path to the cached file. If it's already a local path,
130
+ make sure the file exists and then return the path.
131
+ Args:
132
+ cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
133
+ force_download: if True, re-dowload the file even if it's already cached in the cache dir.
134
+ """
135
+ if cache_dir is None:
136
+ cache_dir = PYTORCH_TRANSFORMERS_CACHE
137
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
138
+ url_or_filename = str(url_or_filename)
139
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
140
+ cache_dir = str(cache_dir)
141
+
142
+ parsed = urlparse(url_or_filename)
143
+
144
+ if parsed.scheme in ('http', 'https', 's3'):
145
+ # URL, so get it from the cache (downloading if necessary)
146
+ return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
147
+ elif os.path.exists(url_or_filename):
148
+ # File, and it exists.
149
+ return url_or_filename
150
+ elif parsed.scheme == '':
151
+ # File, but it doesn't exist.
152
+ raise EnvironmentError("file {} not found".format(url_or_filename))
153
+ else:
154
+ # Something unknown
155
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
156
+
157
+
158
+ def split_s3_path(url):
159
+ """Split a full s3 path into the bucket name and path."""
160
+ parsed = urlparse(url)
161
+ if not parsed.netloc or not parsed.path:
162
+ raise ValueError("bad s3 path {}".format(url))
163
+ bucket_name = parsed.netloc
164
+ s3_path = parsed.path
165
+ # Remove '/' at beginning of path.
166
+ if s3_path.startswith("/"):
167
+ s3_path = s3_path[1:]
168
+ return bucket_name, s3_path
169
+
170
+
171
+ def s3_request(func):
172
+ """
173
+ Wrapper function for s3 requests in order to create more helpful error
174
+ messages.
175
+ """
176
+
177
+ @wraps(func)
178
+ def wrapper(url, *args, **kwargs):
179
+ try:
180
+ return func(url, *args, **kwargs)
181
+ except ClientError as exc:
182
+ if int(exc.response["Error"]["Code"]) == 404:
183
+ raise EnvironmentError("file {} not found".format(url))
184
+ else:
185
+ raise
186
+
187
+ return wrapper
188
+
189
+
190
+ @s3_request
191
+ def s3_etag(url, proxies=None):
192
+ """Check ETag on S3 object."""
193
+ s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
194
+ bucket_name, s3_path = split_s3_path(url)
195
+ s3_object = s3_resource.Object(bucket_name, s3_path)
196
+ return s3_object.e_tag
197
+
198
+
199
+ @s3_request
200
+ def s3_get(url, temp_file, proxies=None):
201
+ """Pull a file directly from S3."""
202
+ s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
203
+ bucket_name, s3_path = split_s3_path(url)
204
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
205
+
206
+
207
+ def http_get(url, temp_file, proxies=None):
208
+ req = requests.get(url, stream=True, proxies=proxies)
209
+ content_length = req.headers.get('Content-Length')
210
+ total = int(content_length) if content_length is not None else None
211
+ progress = tqdm(unit="B", total=total)
212
+ for chunk in req.iter_content(chunk_size=1024):
213
+ if chunk: # filter out keep-alive new chunks
214
+ progress.update(len(chunk))
215
+ temp_file.write(chunk)
216
+ progress.close()
217
+
218
+
219
+ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None):
220
+ """
221
+ Given a URL, look for the corresponding dataset in the local cache.
222
+ If it's not there, download it. Then return the path to the cached file.
223
+ """
224
+ if cache_dir is None:
225
+ cache_dir = PYTORCH_TRANSFORMERS_CACHE
226
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
227
+ cache_dir = str(cache_dir)
228
+ if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
229
+ cache_dir = str(cache_dir)
230
+
231
+ if not os.path.exists(cache_dir):
232
+ os.makedirs(cache_dir)
233
+
234
+ # Get eTag to add to filename, if it exists.
235
+ if url.startswith("s3://"):
236
+ etag = s3_etag(url, proxies=proxies)
237
+ else:
238
+ try:
239
+ response = requests.head(url, allow_redirects=True, proxies=proxies)
240
+ if response.status_code != 200:
241
+ etag = None
242
+ else:
243
+ etag = response.headers.get("ETag")
244
+ except EnvironmentError:
245
+ etag = None
246
+
247
+ if sys.version_info[0] == 2 and etag is not None:
248
+ etag = etag.decode('utf-8')
249
+ filename = url_to_filename(url, etag)
250
+
251
+ # get cache path to put the file
252
+ cache_path = os.path.join(cache_dir, filename)
253
+
254
+ # If we don't have a connection (etag is None) and can't identify the file
255
+ # try to get the last downloaded one
256
+ if not os.path.exists(cache_path) and etag is None:
257
+ matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
258
+ matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
259
+ if matching_files:
260
+ cache_path = os.path.join(cache_dir, matching_files[-1])
261
+
262
+ if not os.path.exists(cache_path) or force_download:
263
+ # Download to temporary file, then copy to cache dir once finished.
264
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
265
+ with tempfile.NamedTemporaryFile() as temp_file:
266
+ logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
267
+
268
+ # GET file object
269
+ if url.startswith("s3://"):
270
+ s3_get(url, temp_file, proxies=proxies)
271
+ else:
272
+ http_get(url, temp_file, proxies=proxies)
273
+
274
+ # we are copying the file before closing it, so flush to avoid truncation
275
+ temp_file.flush()
276
+ # shutil.copyfileobj() starts at the current position, so go to the start
277
+ temp_file.seek(0)
278
+
279
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
280
+ with open(cache_path, 'wb') as cache_file:
281
+ shutil.copyfileobj(temp_file, cache_file)
282
+
283
+ logger.info("creating metadata file for %s", cache_path)
284
+ meta = {'url': url, 'etag': etag}
285
+ meta_path = cache_path + '.json'
286
+ with open(meta_path, 'w') as meta_file:
287
+ output_string = json.dumps(meta)
288
+ if sys.version_info[0] == 2 and isinstance(output_string, str):
289
+ output_string = unicode(output_string, 'utf-8') # The beauty of python 2
290
+ meta_file.write(output_string)
291
+
292
+ logger.info("removing temp file %s", temp_file.name)
293
+
294
+ return cache_path
versatile_diffusion/lib/model_zoo/optimus_models/modeling_utils.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ from __future__ import (absolute_import, division, print_function,
19
+ unicode_literals)
20
+
21
+
22
+ import pdb
23
+ import copy
24
+ import json
25
+ import logging
26
+ import os
27
+ from io import open
28
+
29
+ import six
30
+ import torch
31
+ from torch import nn
32
+ from torch.nn import CrossEntropyLoss
33
+ from torch.nn import functional as F
34
+
35
+ from .configuration_utils import PretrainedConfig
36
+ from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ try:
42
+ from torch.nn import Identity
43
+ except ImportError:
44
+ # Older PyTorch compatibility
45
+ class Identity(nn.Module):
46
+ r"""A placeholder identity operator that is argument-insensitive.
47
+ """
48
+ def __init__(self, *args, **kwargs):
49
+ super(Identity, self).__init__()
50
+
51
+ def forward(self, input):
52
+ return input
53
+
54
+ class PreTrainedModel(nn.Module):
55
+ r""" Base class for all models.
56
+
57
+ :class:`~pytorch_transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
58
+ as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
59
+
60
+ Class attributes (overridden by derived classes):
61
+ - ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture.
62
+ - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
63
+ - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
64
+
65
+ - ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`,
66
+ - ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`,
67
+ - ``path``: a path (string) to the TensorFlow checkpoint.
68
+
69
+ - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
70
+ """
71
+ config_class = None
72
+ pretrained_model_archive_map = {}
73
+ load_tf_weights = lambda model, config, path: None
74
+ base_model_prefix = ""
75
+
76
+ def __init__(self, config, *inputs, **kwargs):
77
+ super(PreTrainedModel, self).__init__()
78
+ if not isinstance(config, PretrainedConfig):
79
+ raise ValueError(
80
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
81
+ "To create a model from a pretrained model use "
82
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
83
+ self.__class__.__name__, self.__class__.__name__
84
+ ))
85
+ # Save config in model
86
+ self.config = config
87
+
88
+ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
89
+ """ Build a resized Embedding Module from a provided token Embedding Module.
90
+ Increasing the size will add newly initialized vectors at the end
91
+ Reducing the size will remove vectors from the end
92
+
93
+ Args:
94
+ new_num_tokens: (`optional`) int
95
+ New number of tokens in the embedding matrix.
96
+ Increasing the size will add newly initialized vectors at the end
97
+ Reducing the size will remove vectors from the end
98
+ If not provided or None: return the provided token Embedding Module.
99
+ Return: ``torch.nn.Embeddings``
100
+ Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
101
+ """
102
+ if new_num_tokens is None:
103
+ return old_embeddings
104
+
105
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
106
+ if old_num_tokens == new_num_tokens:
107
+ return old_embeddings
108
+
109
+ # Build new embeddings
110
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
111
+ new_embeddings.to(old_embeddings.weight.device)
112
+
113
+ # initialize all new embeddings (in particular added tokens)
114
+ self._init_weights(new_embeddings)
115
+
116
+ # Copy word embeddings from the previous weights
117
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
118
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
119
+
120
+ return new_embeddings
121
+
122
+ def _tie_or_clone_weights(self, first_module, second_module):
123
+ """ Tie or clone module weights depending of weither we are using TorchScript or not
124
+ """
125
+ if self.config.torchscript:
126
+ first_module.weight = nn.Parameter(second_module.weight.clone())
127
+ else:
128
+ first_module.weight = second_module.weight
129
+
130
+ if hasattr(first_module, 'bias') and first_module.bias is not None:
131
+ first_module.bias.data = torch.nn.functional.pad(
132
+ first_module.bias.data,
133
+ (0, first_module.weight.shape[0] - first_module.bias.shape[0]),
134
+ 'constant',
135
+ 0
136
+ )
137
+
138
+ def resize_token_embeddings(self, new_num_tokens=None):
139
+ """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
140
+ Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
141
+
142
+ Arguments:
143
+
144
+ new_num_tokens: (`optional`) int:
145
+ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
146
+ If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
147
+
148
+ Return: ``torch.nn.Embeddings``
149
+ Pointer to the input tokens Embeddings Module of the model
150
+ """
151
+
152
+
153
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
154
+
155
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens)
156
+ if new_num_tokens is None:
157
+ return model_embeds
158
+
159
+ # Update base model and current model config
160
+ self.config.vocab_size = new_num_tokens
161
+ base_model.vocab_size = new_num_tokens
162
+
163
+ # Tie weights again if needed
164
+ if hasattr(self, 'tie_weights'):
165
+ self.tie_weights()
166
+
167
+ return model_embeds
168
+
169
+ def init_weights(self):
170
+ """ Initialize and prunes weights if needed. """
171
+ # Initialize weights
172
+ self.apply(self._init_weights)
173
+
174
+ # Prune heads if needed
175
+ if self.config.pruned_heads:
176
+ self.prune_heads(self.config.pruned_heads)
177
+
178
+ def prune_heads(self, heads_to_prune):
179
+ """ Prunes heads of the base model.
180
+
181
+ Arguments:
182
+
183
+ heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
184
+ E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
185
+ """
186
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
187
+
188
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
189
+ for layer, heads in heads_to_prune.items():
190
+ union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
191
+ self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
192
+
193
+ base_model._prune_heads(heads_to_prune)
194
+
195
+ def save_pretrained(self, save_directory):
196
+ """ Save a model and its configuration file to a directory, so that it
197
+ can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
198
+ """
199
+ assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
200
+
201
+ # Only save the model it-self if we are using distributed training
202
+ model_to_save = self.module if hasattr(self, 'module') else self
203
+
204
+ # Save configuration file
205
+ model_to_save.config.save_pretrained(save_directory)
206
+
207
+ # If we save using the predefined names, we can load using `from_pretrained`
208
+ output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
209
+
210
+ torch.save(model_to_save.state_dict(), output_model_file)
211
+
212
+ @classmethod
213
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
214
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
215
+
216
+ The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
217
+ To train the model, you should first set it back in training mode with ``model.train()``
218
+
219
+ The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
220
+ It is up to you to train those weights with a downstream fine-tuning task.
221
+
222
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
223
+
224
+ Parameters:
225
+ pretrained_model_name_or_path: either:
226
+
227
+ - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
228
+ - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
229
+ - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
230
+
231
+ model_args: (`optional`) Sequence of positional arguments:
232
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method
233
+
234
+ config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
235
+ Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
236
+
237
+ - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
238
+ - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
239
+ - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
240
+
241
+ state_dict: (`optional`) dict:
242
+ an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
243
+ This option can be used if you want to create a model from a pretrained configuration but load your own weights.
244
+ In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.
245
+
246
+ cache_dir: (`optional`) string:
247
+ Path to a directory in which a downloaded pre-trained model
248
+ configuration should be cached if the standard cache should not be used.
249
+
250
+ force_download: (`optional`) boolean, default False:
251
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
252
+
253
+ proxies: (`optional`) dict, default None:
254
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
255
+ The proxies are used on each request.
256
+
257
+ output_loading_info: (`optional`) boolean:
258
+ Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
259
+
260
+ kwargs: (`optional`) Remaining dictionary of keyword arguments:
261
+ Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
262
+
263
+ - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
264
+ - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
265
+
266
+ Examples::
267
+
268
+ model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
269
+ model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
270
+ model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
271
+ assert model.config.output_attention == True
272
+ # Loading from a TF checkpoint file instead of a PyTorch model (slower)
273
+ config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
274
+ model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
275
+
276
+ """
277
+ config = kwargs.pop('config', None)
278
+ state_dict = kwargs.pop('state_dict', None)
279
+ cache_dir = kwargs.pop('cache_dir', None)
280
+ from_tf = kwargs.pop('from_tf', False)
281
+ force_download = kwargs.pop('force_download', False)
282
+ proxies = kwargs.pop('proxies', None)
283
+ output_loading_info = kwargs.pop('output_loading_info', False)
284
+
285
+ # Load config
286
+ if config is None:
287
+ config, model_kwargs = cls.config_class.from_pretrained(
288
+ pretrained_model_name_or_path, *model_args,
289
+ cache_dir=cache_dir, return_unused_kwargs=True,
290
+ force_download=force_download,
291
+ **kwargs
292
+ )
293
+ else:
294
+ model_kwargs = kwargs
295
+
296
+ # Load model
297
+ if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
298
+ archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
299
+ elif os.path.isdir(pretrained_model_name_or_path):
300
+ if from_tf:
301
+ # Directly load from a TensorFlow checkpoint
302
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
303
+ else:
304
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
305
+ else:
306
+ if from_tf:
307
+ # Directly load from a TensorFlow checkpoint
308
+ archive_file = pretrained_model_name_or_path + ".index"
309
+ else:
310
+ archive_file = pretrained_model_name_or_path
311
+ # redirect to the cache, if necessary
312
+ try:
313
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
314
+ except EnvironmentError as e:
315
+ if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
316
+ logger.error(
317
+ "Couldn't reach server at '{}' to download pretrained weights.".format(
318
+ archive_file))
319
+ else:
320
+ logger.error(
321
+ "Model name '{}' was not found in model name list ({}). "
322
+ "We assumed '{}' was a path or url but couldn't find any file "
323
+ "associated to this path or url.".format(
324
+ pretrained_model_name_or_path,
325
+ ', '.join(cls.pretrained_model_archive_map.keys()),
326
+ archive_file))
327
+ raise e
328
+ if resolved_archive_file == archive_file:
329
+ logger.info("loading weights file {}".format(archive_file))
330
+ else:
331
+ logger.info("loading weights file {} from cache at {}".format(
332
+ archive_file, resolved_archive_file))
333
+
334
+ # Instantiate model.
335
+ model = cls(config, *model_args, **model_kwargs)
336
+
337
+ if state_dict is None and not from_tf:
338
+ state_dict = torch.load(resolved_archive_file, map_location='cpu')
339
+ if from_tf:
340
+ # Directly load from a TensorFlow checkpoint
341
+ return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
342
+
343
+ # Convert old format to new format if needed from a PyTorch state_dict
344
+ old_keys = []
345
+ new_keys = []
346
+ for key in state_dict.keys():
347
+ new_key = None
348
+ if 'gamma' in key:
349
+ new_key = key.replace('gamma', 'weight')
350
+ if 'beta' in key:
351
+ new_key = key.replace('beta', 'bias')
352
+ if new_key:
353
+ old_keys.append(key)
354
+ new_keys.append(new_key)
355
+ for old_key, new_key in zip(old_keys, new_keys):
356
+ state_dict[new_key] = state_dict.pop(old_key)
357
+
358
+ # Load from a PyTorch state_dict
359
+ missing_keys = []
360
+ unexpected_keys = []
361
+ error_msgs = []
362
+ # copy state_dict so _load_from_state_dict can modify it
363
+ metadata = getattr(state_dict, '_metadata', None)
364
+ state_dict = state_dict.copy()
365
+ if metadata is not None:
366
+ state_dict._metadata = metadata
367
+
368
+ def load(module, prefix=''):
369
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
370
+ module._load_from_state_dict(
371
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
372
+ for name, child in module._modules.items():
373
+ if child is not None:
374
+ load(child, prefix + name + '.')
375
+
376
+ # Make sure we are able to load base models as well as derived models (with heads)
377
+ start_prefix = ''
378
+ model_to_load = model
379
+ if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
380
+ start_prefix = cls.base_model_prefix + '.'
381
+ if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
382
+ model_to_load = getattr(model, cls.base_model_prefix)
383
+
384
+ load(model_to_load, prefix=start_prefix)
385
+ if len(missing_keys) > 0:
386
+ logger.info("Weights of {} not initialized from pretrained model: {}".format(
387
+ model.__class__.__name__, missing_keys))
388
+ if len(unexpected_keys) > 0:
389
+ logger.info("Weights from pretrained model not used in {}: {}".format(
390
+ model.__class__.__name__, unexpected_keys))
391
+ if len(error_msgs) > 0:
392
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
393
+ model.__class__.__name__, "\n\t".join(error_msgs)))
394
+
395
+ if hasattr(model, 'tie_weights'):
396
+ model.tie_weights() # make sure word embedding weights are still tied
397
+
398
+ # Set model in evaluation mode to desactivate DropOut modules by default
399
+ model.eval()
400
+
401
+ if output_loading_info:
402
+ loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
403
+ return model, loading_info
404
+
405
+ return model
406
+
407
+
408
+ class Conv1D(nn.Module):
409
+ def __init__(self, nf, nx):
410
+ """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
411
+ Basically works like a Linear layer but the weights are transposed
412
+ """
413
+ super(Conv1D, self).__init__()
414
+ self.nf = nf
415
+ w = torch.empty(nx, nf)
416
+ nn.init.normal_(w, std=0.02)
417
+ self.weight = nn.Parameter(w)
418
+ self.bias = nn.Parameter(torch.zeros(nf))
419
+
420
+ def forward(self, x):
421
+ size_out = x.size()[:-1] + (self.nf,)
422
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
423
+ x = x.view(*size_out)
424
+ return x
425
+
426
+
427
+ class PoolerStartLogits(nn.Module):
428
+ """ Compute SQuAD start_logits from sequence hidden states. """
429
+ def __init__(self, config):
430
+ super(PoolerStartLogits, self).__init__()
431
+ self.dense = nn.Linear(config.hidden_size, 1)
432
+
433
+ def forward(self, hidden_states, p_mask=None):
434
+ """ Args:
435
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
436
+ invalid position mask such as query and special symbols (PAD, SEP, CLS)
437
+ 1.0 means token should be masked.
438
+ """
439
+ x = self.dense(hidden_states).squeeze(-1)
440
+
441
+ if p_mask is not None:
442
+ if next(self.parameters()).dtype == torch.float16:
443
+ x = x * (1 - p_mask) - 65500 * p_mask
444
+ else:
445
+ x = x * (1 - p_mask) - 1e30 * p_mask
446
+
447
+ return x
448
+
449
+
450
+ class PoolerEndLogits(nn.Module):
451
+ """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
452
+ """
453
+ def __init__(self, config):
454
+ super(PoolerEndLogits, self).__init__()
455
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
456
+ self.activation = nn.Tanh()
457
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458
+ self.dense_1 = nn.Linear(config.hidden_size, 1)
459
+
460
+ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
461
+ """ Args:
462
+ One of ``start_states``, ``start_positions`` should be not None.
463
+ If both are set, ``start_positions`` overrides ``start_states``.
464
+
465
+ **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
466
+ hidden states of the first tokens for the labeled span.
467
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
468
+ position of the first token for the labeled span:
469
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
470
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
471
+ 1.0 means token should be masked.
472
+ """
473
+ assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
474
+ if start_positions is not None:
475
+ slen, hsz = hidden_states.shape[-2:]
476
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
477
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
478
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
479
+
480
+ x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
481
+ x = self.activation(x)
482
+ x = self.LayerNorm(x)
483
+ x = self.dense_1(x).squeeze(-1)
484
+
485
+ if p_mask is not None:
486
+ x = x * (1 - p_mask) - 1e30 * p_mask
487
+
488
+ return x
489
+
490
+
491
+ class PoolerAnswerClass(nn.Module):
492
+ """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
493
+ def __init__(self, config):
494
+ super(PoolerAnswerClass, self).__init__()
495
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
496
+ self.activation = nn.Tanh()
497
+ self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
498
+
499
+ def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
500
+ """
501
+ Args:
502
+ One of ``start_states``, ``start_positions`` should be not None.
503
+ If both are set, ``start_positions`` overrides ``start_states``.
504
+
505
+ **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
506
+ hidden states of the first tokens for the labeled span.
507
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
508
+ position of the first token for the labeled span.
509
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
510
+ position of the CLS token. If None, take the last token.
511
+
512
+ note(Original repo):
513
+ no dependency on end_feature so that we can obtain one single `cls_logits`
514
+ for each sample
515
+ """
516
+ hsz = hidden_states.shape[-1]
517
+ assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
518
+ if start_positions is not None:
519
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
520
+ start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
521
+
522
+ if cls_index is not None:
523
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
524
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
525
+ else:
526
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
527
+
528
+ x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
529
+ x = self.activation(x)
530
+ x = self.dense_1(x).squeeze(-1)
531
+
532
+ return x
533
+
534
+
535
+ class SQuADHead(nn.Module):
536
+ r""" A SQuAD head inspired by XLNet.
537
+
538
+ Parameters:
539
+ config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
540
+
541
+ Inputs:
542
+ **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
543
+ hidden states of sequence tokens
544
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
545
+ position of the first token for the labeled span.
546
+ **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
547
+ position of the last token for the labeled span.
548
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
549
+ position of the CLS token. If None, take the last token.
550
+ **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
551
+ Whether the question has a possible answer in the paragraph or not.
552
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
553
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
554
+ 1.0 means token should be masked.
555
+
556
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
557
+ **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
558
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
559
+ **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
560
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
561
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
562
+ **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
563
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
564
+ Indices for the top config.start_n_top start token possibilities (beam-search).
565
+ **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
566
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
567
+ Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
568
+ **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
569
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
570
+ Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
571
+ **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
572
+ ``torch.FloatTensor`` of shape ``(batch_size,)``
573
+ Log probabilities for the ``is_impossible`` label of the answers.
574
+ """
575
+ def __init__(self, config):
576
+ super(SQuADHead, self).__init__()
577
+ self.start_n_top = config.start_n_top
578
+ self.end_n_top = config.end_n_top
579
+
580
+ self.start_logits = PoolerStartLogits(config)
581
+ self.end_logits = PoolerEndLogits(config)
582
+ self.answer_class = PoolerAnswerClass(config)
583
+
584
+ def forward(self, hidden_states, start_positions=None, end_positions=None,
585
+ cls_index=None, is_impossible=None, p_mask=None):
586
+ outputs = ()
587
+
588
+ start_logits = self.start_logits(hidden_states, p_mask=p_mask)
589
+
590
+ if start_positions is not None and end_positions is not None:
591
+ # If we are on multi-GPU, let's remove the dimension added by batch splitting
592
+ for x in (start_positions, end_positions, cls_index, is_impossible):
593
+ if x is not None and x.dim() > 1:
594
+ x.squeeze_(-1)
595
+
596
+ # during training, compute the end logits based on the ground truth of the start position
597
+ end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
598
+
599
+ loss_fct = CrossEntropyLoss()
600
+ start_loss = loss_fct(start_logits, start_positions)
601
+ end_loss = loss_fct(end_logits, end_positions)
602
+ total_loss = (start_loss + end_loss) / 2
603
+
604
+ if cls_index is not None and is_impossible is not None:
605
+ # Predict answerability from the representation of CLS and START
606
+ cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
607
+ loss_fct_cls = nn.BCEWithLogitsLoss()
608
+ cls_loss = loss_fct_cls(cls_logits, is_impossible)
609
+
610
+ # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
611
+ total_loss += cls_loss * 0.5
612
+
613
+ outputs = (total_loss,) + outputs
614
+
615
+ else:
616
+ # during inference, compute the end logits based on beam search
617
+ bsz, slen, hsz = hidden_states.size()
618
+ start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
619
+
620
+ start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
621
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
622
+ start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
623
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
624
+
625
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
626
+ p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
627
+ end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
628
+ end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
629
+
630
+ end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
631
+ end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
632
+ end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
633
+
634
+ start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
635
+ cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
636
+
637
+ outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
638
+
639
+ # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
640
+ # or (if labels are provided) (total_loss,)
641
+ return outputs
642
+
643
+
644
+ class SequenceSummary(nn.Module):
645
+ r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
646
+ Args of the config class:
647
+ summary_type:
648
+ - 'last' => [default] take the last token hidden state (like XLNet)
649
+ - 'first' => take the first token hidden state (like Bert)
650
+ - 'mean' => take the mean of all tokens hidden states
651
+ - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
652
+ - 'attn' => Not implemented now, use multi-head attention
653
+ summary_use_proj: Add a projection after the vector extraction
654
+ summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
655
+ summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
656
+ summary_first_dropout: Add a dropout before the projection and activation
657
+ summary_last_dropout: Add a dropout after the projection and activation
658
+ """
659
+ def __init__(self, config):
660
+ super(SequenceSummary, self).__init__()
661
+
662
+ self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
663
+ if self.summary_type == 'attn':
664
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
665
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
666
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
667
+ raise NotImplementedError
668
+
669
+ self.summary = Identity()
670
+ if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
671
+ if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
672
+ num_classes = config.num_labels
673
+ else:
674
+ num_classes = config.hidden_size
675
+ self.summary = nn.Linear(config.hidden_size, num_classes)
676
+
677
+ self.activation = Identity()
678
+ if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
679
+ self.activation = nn.Tanh()
680
+
681
+ self.first_dropout = Identity()
682
+ if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
683
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
684
+
685
+ self.last_dropout = Identity()
686
+ if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
687
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
688
+
689
+ def forward(self, hidden_states, cls_index=None):
690
+ """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
691
+ cls_index: [optional] position of the classification token if summary_type == 'cls_index',
692
+ shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
693
+ if summary_type == 'cls_index' and cls_index is None:
694
+ we take the last token of the sequence as classification token
695
+ """
696
+ if self.summary_type == 'last':
697
+ output = hidden_states[:, -1]
698
+ elif self.summary_type == 'first':
699
+ output = hidden_states[:, 0]
700
+ elif self.summary_type == 'mean':
701
+ output = hidden_states.mean(dim=1)
702
+ elif self.summary_type == 'cls_index':
703
+ if cls_index is None:
704
+ cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
705
+ else:
706
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
707
+ cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
708
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
709
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
710
+ elif self.summary_type == 'attn':
711
+ raise NotImplementedError
712
+
713
+ output = self.first_dropout(output)
714
+ output = self.summary(output)
715
+ output = self.activation(output)
716
+ output = self.last_dropout(output)
717
+
718
+ return output
719
+
720
+
721
+ def prune_linear_layer(layer, index, dim=0):
722
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
723
+ Return the pruned layer as a new layer with requires_grad=True.
724
+ Used to remove heads.
725
+ """
726
+ index = index.to(layer.weight.device)
727
+ W = layer.weight.index_select(dim, index).clone().detach()
728
+ if layer.bias is not None:
729
+ if dim == 1:
730
+ b = layer.bias.clone().detach()
731
+ else:
732
+ b = layer.bias[index].clone().detach()
733
+ new_size = list(layer.weight.size())
734
+ new_size[dim] = len(index)
735
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
736
+ new_layer.weight.requires_grad = False
737
+ new_layer.weight.copy_(W.contiguous())
738
+ new_layer.weight.requires_grad = True
739
+ if layer.bias is not None:
740
+ new_layer.bias.requires_grad = False
741
+ new_layer.bias.copy_(b.contiguous())
742
+ new_layer.bias.requires_grad = True
743
+ return new_layer
744
+
745
+
746
+ def prune_conv1d_layer(layer, index, dim=1):
747
+ """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
748
+ A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
749
+ Return the pruned layer as a new layer with requires_grad=True.
750
+ Used to remove heads.
751
+ """
752
+ index = index.to(layer.weight.device)
753
+ W = layer.weight.index_select(dim, index).clone().detach()
754
+ if dim == 0:
755
+ b = layer.bias.clone().detach()
756
+ else:
757
+ b = layer.bias[index].clone().detach()
758
+ new_size = list(layer.weight.size())
759
+ new_size[dim] = len(index)
760
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
761
+ new_layer.weight.requires_grad = False
762
+ new_layer.weight.copy_(W.contiguous())
763
+ new_layer.weight.requires_grad = True
764
+ new_layer.bias.requires_grad = False
765
+ new_layer.bias.copy_(b.contiguous())
766
+ new_layer.bias.requires_grad = True
767
+ return new_layer
768
+
769
+
770
+ def prune_layer(layer, index, dim=None):
771
+ """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
772
+ Return the pruned layer as a new layer with requires_grad=True.
773
+ Used to remove heads.
774
+ """
775
+ if isinstance(layer, nn.Linear):
776
+ return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
777
+ elif isinstance(layer, Conv1D):
778
+ return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
779
+ else:
780
+ raise ValueError("Can't prune layer of class {}".format(layer.__class__))
versatile_diffusion/lib/model_zoo/optimus_models/optimus_bert.py ADDED
@@ -0,0 +1,1439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import json
21
+ import logging
22
+ import math
23
+ import os
24
+ import sys
25
+ from io import open
26
+
27
+ import pdb
28
+
29
+ import torch
30
+ from torch import nn
31
+ from torch.nn import CrossEntropyLoss, MSELoss
32
+
33
+ from .modeling_utils import PreTrainedModel, prune_linear_layer
34
+ from .configuration_bert import BertConfig
35
+ from .file_utils import add_start_docstrings
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
40
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
41
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
42
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
43
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
44
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
45
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
46
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
47
+ 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
48
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
49
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
50
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
51
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
52
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
53
+ }
54
+
55
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
56
+ """ Load tf checkpoints in a pytorch model.
57
+ """
58
+ try:
59
+ import re
60
+ import numpy as np
61
+ import tensorflow as tf
62
+ except ImportError:
63
+ logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
64
+ "https://www.tensorflow.org/install/ for installation instructions.")
65
+ raise
66
+ tf_path = os.path.abspath(tf_checkpoint_path)
67
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
68
+ # Load weights from TF model
69
+ init_vars = tf.train.list_variables(tf_path)
70
+ names = []
71
+ arrays = []
72
+ for name, shape in init_vars:
73
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
74
+ array = tf.train.load_variable(tf_path, name)
75
+ names.append(name)
76
+ arrays.append(array)
77
+
78
+ for name, array in zip(names, arrays):
79
+ name = name.split('/')
80
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
81
+ # which are not required for using pretrained model
82
+ if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
83
+ logger.info("Skipping {}".format("/".join(name)))
84
+ continue
85
+ pointer = model
86
+ for m_name in name:
87
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
88
+ l = re.split(r'_(\d+)', m_name)
89
+ else:
90
+ l = [m_name]
91
+ if l[0] == 'kernel' or l[0] == 'gamma':
92
+ pointer = getattr(pointer, 'weight')
93
+ elif l[0] == 'output_bias' or l[0] == 'beta':
94
+ pointer = getattr(pointer, 'bias')
95
+ elif l[0] == 'output_weights':
96
+ pointer = getattr(pointer, 'weight')
97
+ elif l[0] == 'squad':
98
+ pointer = getattr(pointer, 'classifier')
99
+ else:
100
+ try:
101
+ pointer = getattr(pointer, l[0])
102
+ except AttributeError:
103
+ logger.info("Skipping {}".format("/".join(name)))
104
+ continue
105
+ if len(l) >= 2:
106
+ num = int(l[1])
107
+ pointer = pointer[num]
108
+ if m_name[-11:] == '_embeddings':
109
+ pointer = getattr(pointer, 'weight')
110
+ elif m_name == 'kernel':
111
+ array = np.transpose(array)
112
+ try:
113
+ assert pointer.shape == array.shape
114
+ except AssertionError as e:
115
+ e.args += (pointer.shape, array.shape)
116
+ raise
117
+ logger.info("Initialize PyTorch weight {}".format(name))
118
+ pointer.data = torch.from_numpy(array)
119
+ return model
120
+
121
+
122
+ def gelu(x):
123
+ """Implementation of the gelu activation function.
124
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
125
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
126
+ Also see https://arxiv.org/abs/1606.08415
127
+ """
128
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
129
+
130
+
131
+ def swish(x):
132
+ return x * torch.sigmoid(x)
133
+
134
+
135
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
136
+
137
+
138
+ try:
139
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
140
+ except (ImportError, AttributeError) as e:
141
+ logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
142
+ BertLayerNorm = torch.nn.LayerNorm
143
+
144
+ class BertEmbeddings(nn.Module):
145
+ """Construct the embeddings from word, position and token_type embeddings.
146
+ """
147
+ def __init__(self, config):
148
+ super(BertEmbeddings, self).__init__()
149
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
150
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
151
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
152
+
153
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
154
+ # any TensorFlow checkpoint file
155
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
156
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
157
+
158
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
159
+ seq_length = input_ids.size(1)
160
+ if position_ids is None:
161
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
162
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
163
+ if token_type_ids is None:
164
+ token_type_ids = torch.zeros_like(input_ids)
165
+
166
+ words_embeddings = self.word_embeddings(input_ids)
167
+ position_embeddings = self.position_embeddings(position_ids)
168
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
169
+
170
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
171
+ embeddings = self.LayerNorm(embeddings)
172
+ embeddings = self.dropout(embeddings)
173
+ return embeddings
174
+
175
+
176
+ class BertSelfAttention(nn.Module):
177
+ def __init__(self, config):
178
+ super(BertSelfAttention, self).__init__()
179
+ if config.hidden_size % config.num_attention_heads != 0:
180
+ raise ValueError(
181
+ "The hidden size (%d) is not a multiple of the number of attention "
182
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
183
+ self.output_attentions = config.output_attentions
184
+
185
+ self.num_attention_heads = config.num_attention_heads
186
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
187
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
188
+
189
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
190
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
191
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
192
+
193
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
194
+
195
+ def transpose_for_scores(self, x):
196
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
197
+ x = x.view(*new_x_shape)
198
+ return x.permute(0, 2, 1, 3)
199
+
200
+ def forward(self, hidden_states, attention_mask, head_mask=None):
201
+ mixed_query_layer = self.query(hidden_states)
202
+ mixed_key_layer = self.key(hidden_states)
203
+ mixed_value_layer = self.value(hidden_states)
204
+
205
+ query_layer = self.transpose_for_scores(mixed_query_layer)
206
+ key_layer = self.transpose_for_scores(mixed_key_layer)
207
+ value_layer = self.transpose_for_scores(mixed_value_layer)
208
+
209
+ # Take the dot product between "query" and "key" to get the raw attention scores.
210
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
211
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
212
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
213
+ attention_scores = attention_scores + attention_mask
214
+
215
+ # Normalize the attention scores to probabilities.
216
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
217
+
218
+ # This is actually dropping out entire tokens to attend to, which might
219
+ # seem a bit unusual, but is taken from the original Transformer paper.
220
+ attention_probs = self.dropout(attention_probs)
221
+
222
+ # Mask heads if we want to
223
+ if head_mask is not None:
224
+ attention_probs = attention_probs * head_mask
225
+
226
+ context_layer = torch.matmul(attention_probs, value_layer)
227
+
228
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
229
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
230
+ context_layer = context_layer.view(*new_context_layer_shape)
231
+
232
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
233
+ return outputs
234
+
235
+
236
+ class BertSelfOutput(nn.Module):
237
+ def __init__(self, config):
238
+ super(BertSelfOutput, self).__init__()
239
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
240
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
241
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
242
+
243
+ def forward(self, hidden_states, input_tensor):
244
+ hidden_states = self.dense(hidden_states)
245
+ hidden_states = self.dropout(hidden_states)
246
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
247
+ return hidden_states
248
+
249
+
250
+ class BertAttention(nn.Module):
251
+ def __init__(self, config):
252
+ super(BertAttention, self).__init__()
253
+ self.self = BertSelfAttention(config)
254
+ self.output = BertSelfOutput(config)
255
+ self.pruned_heads = set()
256
+
257
+ def prune_heads(self, heads):
258
+ if len(heads) == 0:
259
+ return
260
+ mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
261
+ heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
262
+ for head in heads:
263
+ # Compute how many pruned heads are before the head and move the index accordingly
264
+ head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
265
+ mask[head] = 0
266
+ mask = mask.view(-1).contiguous().eq(1)
267
+ index = torch.arange(len(mask))[mask].long()
268
+
269
+ # Prune linear layers
270
+ self.self.query = prune_linear_layer(self.self.query, index)
271
+ self.self.key = prune_linear_layer(self.self.key, index)
272
+ self.self.value = prune_linear_layer(self.self.value, index)
273
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
274
+
275
+ # Update hyper params and store pruned heads
276
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
277
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
278
+ self.pruned_heads = self.pruned_heads.union(heads)
279
+
280
+ def forward(self, input_tensor, attention_mask, head_mask=None):
281
+ self_outputs = self.self(input_tensor, attention_mask, head_mask)
282
+ attention_output = self.output(self_outputs[0], input_tensor)
283
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
284
+ return outputs
285
+
286
+
287
+ class BertIntermediate(nn.Module):
288
+ def __init__(self, config):
289
+ super(BertIntermediate, self).__init__()
290
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
291
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
292
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
293
+ else:
294
+ self.intermediate_act_fn = config.hidden_act
295
+
296
+ def forward(self, hidden_states):
297
+ hidden_states = self.dense(hidden_states)
298
+ hidden_states = self.intermediate_act_fn(hidden_states)
299
+ return hidden_states
300
+
301
+
302
+ class BertOutput(nn.Module):
303
+ def __init__(self, config):
304
+ super(BertOutput, self).__init__()
305
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
306
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
307
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
308
+
309
+ def forward(self, hidden_states, input_tensor):
310
+ hidden_states = self.dense(hidden_states)
311
+ hidden_states = self.dropout(hidden_states)
312
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
313
+ return hidden_states
314
+
315
+
316
+ class BertLayer(nn.Module):
317
+ def __init__(self, config):
318
+ super(BertLayer, self).__init__()
319
+ self.attention = BertAttention(config)
320
+ self.intermediate = BertIntermediate(config)
321
+ self.output = BertOutput(config)
322
+
323
+ def forward(self, hidden_states, attention_mask, head_mask=None):
324
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
325
+ attention_output = attention_outputs[0]
326
+ intermediate_output = self.intermediate(attention_output)
327
+ layer_output = self.output(intermediate_output, attention_output)
328
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
329
+ return outputs
330
+
331
+
332
+ class BertEncoder(nn.Module):
333
+ def __init__(self, config):
334
+ super(BertEncoder, self).__init__()
335
+ self.output_attentions = config.output_attentions
336
+ self.output_hidden_states = config.output_hidden_states
337
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
338
+
339
+ def forward(self, hidden_states, attention_mask, head_mask=None):
340
+ all_hidden_states = ()
341
+ all_attentions = ()
342
+ for i, layer_module in enumerate(self.layer):
343
+ if self.output_hidden_states:
344
+ all_hidden_states = all_hidden_states + (hidden_states,)
345
+
346
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
347
+ hidden_states = layer_outputs[0]
348
+
349
+ if self.output_attentions:
350
+ all_attentions = all_attentions + (layer_outputs[1],)
351
+
352
+ # Add last layer
353
+ if self.output_hidden_states:
354
+ all_hidden_states = all_hidden_states + (hidden_states,)
355
+
356
+ outputs = (hidden_states,)
357
+ if self.output_hidden_states:
358
+ outputs = outputs + (all_hidden_states,)
359
+ if self.output_attentions:
360
+ outputs = outputs + (all_attentions,)
361
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
362
+
363
+
364
+ class BertPooler(nn.Module):
365
+ def __init__(self, config):
366
+ super(BertPooler, self).__init__()
367
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
368
+ self.activation = nn.Tanh()
369
+
370
+ def forward(self, hidden_states):
371
+ # We "pool" the model by simply taking the hidden state corresponding
372
+ # to the first token.
373
+ first_token_tensor = hidden_states[:, 0]
374
+ pooled_output = self.dense(first_token_tensor)
375
+ pooled_output = self.activation(pooled_output)
376
+ return pooled_output
377
+
378
+
379
+ class BertPredictionHeadTransform(nn.Module):
380
+ def __init__(self, config):
381
+ super(BertPredictionHeadTransform, self).__init__()
382
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
383
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
384
+ self.transform_act_fn = ACT2FN[config.hidden_act]
385
+ else:
386
+ self.transform_act_fn = config.hidden_act
387
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
388
+
389
+ def forward(self, hidden_states):
390
+ hidden_states = self.dense(hidden_states)
391
+ hidden_states = self.transform_act_fn(hidden_states)
392
+ hidden_states = self.LayerNorm(hidden_states)
393
+ return hidden_states
394
+
395
+
396
+ class BertLMPredictionHead(nn.Module):
397
+ def __init__(self, config):
398
+ super(BertLMPredictionHead, self).__init__()
399
+ self.transform = BertPredictionHeadTransform(config)
400
+
401
+ # The output weights are the same as the input embeddings, but there is
402
+ # an output-only bias for each token.
403
+ self.decoder = nn.Linear(config.hidden_size,
404
+ config.vocab_size,
405
+ bias=False)
406
+
407
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
408
+
409
+ def forward(self, hidden_states):
410
+ hidden_states = self.transform(hidden_states)
411
+ hidden_states = self.decoder(hidden_states) + self.bias
412
+ return hidden_states
413
+
414
+
415
+ class BertOnlyMLMHead(nn.Module):
416
+ def __init__(self, config):
417
+ super(BertOnlyMLMHead, self).__init__()
418
+ self.predictions = BertLMPredictionHead(config)
419
+
420
+ def forward(self, sequence_output):
421
+ prediction_scores = self.predictions(sequence_output)
422
+ return prediction_scores
423
+
424
+
425
+ class BertOnlyNSPHead(nn.Module):
426
+ def __init__(self, config):
427
+ super(BertOnlyNSPHead, self).__init__()
428
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
429
+
430
+ def forward(self, pooled_output):
431
+ seq_relationship_score = self.seq_relationship(pooled_output)
432
+ return seq_relationship_score
433
+
434
+
435
+ class BertPreTrainingHeads(nn.Module):
436
+ def __init__(self, config):
437
+ super(BertPreTrainingHeads, self).__init__()
438
+ self.predictions = BertLMPredictionHead(config)
439
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
440
+
441
+ def forward(self, sequence_output, pooled_output):
442
+ prediction_scores = self.predictions(sequence_output)
443
+ seq_relationship_score = self.seq_relationship(pooled_output)
444
+ return prediction_scores, seq_relationship_score
445
+
446
+
447
+ class BertPreTrainedModel(PreTrainedModel):
448
+ """ An abstract class to handle weights initialization and
449
+ a simple interface for dowloading and loading pretrained models.
450
+ """
451
+ config_class = BertConfig
452
+ pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
453
+ load_tf_weights = load_tf_weights_in_bert
454
+ base_model_prefix = "bert"
455
+
456
+ def _init_weights(self, module):
457
+ """ Initialize the weights """
458
+ if isinstance(module, (nn.Linear, nn.Embedding)):
459
+ # Slightly different from the TF version which uses truncated_normal for initialization
460
+ # cf https://github.com/pytorch/pytorch/pull/5617
461
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
462
+ elif isinstance(module, BertLayerNorm):
463
+ module.bias.data.zero_()
464
+ module.weight.data.fill_(1.0)
465
+ if isinstance(module, nn.Linear) and module.bias is not None:
466
+ module.bias.data.zero_()
467
+
468
+
469
+ BERT_START_DOCSTRING = r""" The BERT model was proposed in
470
+ `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
471
+ by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
472
+ pre-trained using a combination of masked language modeling objective and next sentence prediction
473
+ on a large corpus comprising the Toronto Book Corpus and Wikipedia.
474
+
475
+ This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
476
+ refer to the PyTorch documentation for all matter related to general usage and behavior.
477
+
478
+ .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
479
+ https://arxiv.org/abs/1810.04805
480
+
481
+ .. _`torch.nn.Module`:
482
+ https://pytorch.org/docs/stable/nn.html#module
483
+
484
+ Parameters:
485
+ config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
486
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
487
+ Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
488
+ """
489
+
490
+ BERT_INPUTS_DOCSTRING = r"""
491
+ Inputs:
492
+ **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
493
+ Indices of input sequence tokens in the vocabulary.
494
+ To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
495
+
496
+ (a) For sequence pairs:
497
+
498
+ ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
499
+
500
+ ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
501
+
502
+ (b) For single sequences:
503
+
504
+ ``tokens: [CLS] the dog is hairy . [SEP]``
505
+
506
+ ``token_type_ids: 0 0 0 0 0 0 0``
507
+
508
+ Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
509
+ the right rather than the left.
510
+
511
+ Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
512
+ See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
513
+ :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
514
+ **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
515
+ Mask to avoid performing attention on padding token indices.
516
+ Mask values selected in ``[0, 1]``:
517
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
518
+ **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
519
+ Segment token indices to indicate first and second portions of the inputs.
520
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
521
+ corresponds to a `sentence B` token
522
+ (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
523
+ **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
524
+ Indices of positions of each input sequence tokens in the position embeddings.
525
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
526
+ **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
527
+ Mask to nullify selected heads of the self-attention modules.
528
+ Mask values selected in ``[0, 1]``:
529
+ ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
530
+ """
531
+
532
+ @add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
533
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
534
+ class BertModel(BertPreTrainedModel):
535
+ r"""
536
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
537
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
538
+ Sequence of hidden-states at the output of the last layer of the model.
539
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
540
+ Last layer hidden-state of the first token of the sequence (classification token)
541
+ further processed by a Linear layer and a Tanh activation function. The Linear
542
+ layer weights are trained from the next sentence prediction (classification)
543
+ objective during Bert pretraining. This output is usually *not* a good summary
544
+ of the semantic content of the input, you're often better with averaging or pooling
545
+ the sequence of hidden-states for the whole input sequence.
546
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
547
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
548
+ of shape ``(batch_size, sequence_length, hidden_size)``:
549
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
550
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
551
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
552
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
553
+
554
+ Examples::
555
+
556
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
557
+ model = BertModel.from_pretrained('bert-base-uncased')
558
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
559
+ outputs = model(input_ids)
560
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
561
+
562
+ """
563
+ def __init__(self, config):
564
+ super(BertModel, self).__init__(config)
565
+
566
+ self.embeddings = BertEmbeddings(config)
567
+ self.encoder = BertEncoder(config)
568
+ self.pooler = BertPooler(config)
569
+
570
+ self.init_weights()
571
+
572
+ def _resize_token_embeddings(self, new_num_tokens):
573
+ old_embeddings = self.embeddings.word_embeddings
574
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
575
+ self.embeddings.word_embeddings = new_embeddings
576
+ return self.embeddings.word_embeddings
577
+
578
+ def _prune_heads(self, heads_to_prune):
579
+ """ Prunes heads of the model.
580
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
581
+ See base class PreTrainedModel
582
+ """
583
+ for layer, heads in heads_to_prune.items():
584
+ self.encoder.layer[layer].attention.prune_heads(heads)
585
+
586
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
587
+ if attention_mask is None:
588
+ attention_mask = torch.ones_like(input_ids)
589
+ if token_type_ids is None:
590
+ token_type_ids = torch.zeros_like(input_ids)
591
+
592
+ # We create a 3D attention mask from a 2D tensor mask.
593
+ # Sizes are [batch_size, 1, 1, to_seq_length]
594
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
595
+ # this attention mask is more simple than the triangular masking of causal attention
596
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
597
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
598
+
599
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
600
+ # masked positions, this operation will create a tensor which is 0.0 for
601
+ # positions we want to attend and -10000.0 for masked positions.
602
+ # Since we are adding it to the raw scores before the softmax, this is
603
+ # effectively the same as removing these entirely.
604
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
605
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
606
+
607
+ # Prepare head mask if needed
608
+ # 1.0 in head_mask indicate we keep the head
609
+ # attention_probs has shape bsz x n_heads x N x N
610
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
611
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
612
+ if head_mask is not None:
613
+ if head_mask.dim() == 1:
614
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
615
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
616
+ elif head_mask.dim() == 2:
617
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
618
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
619
+ else:
620
+ head_mask = [None] * self.config.num_hidden_layers
621
+
622
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
623
+ encoder_outputs = self.encoder(embedding_output,
624
+ extended_attention_mask,
625
+ head_mask=head_mask)
626
+ sequence_output = encoder_outputs[0]
627
+ pooled_output = self.pooler(sequence_output)
628
+
629
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
630
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
631
+
632
+
633
+
634
+
635
+
636
+
637
+ @add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
638
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
639
+ class BertForLatentConnector(BertPreTrainedModel):
640
+ r"""
641
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
642
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
643
+ Sequence of hidden-states at the output of the last layer of the model.
644
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
645
+ Last layer hidden-state of the first token of the sequence (classification token)
646
+ further processed by a Linear layer and a Tanh activation function. The Linear
647
+ layer weights are trained from the next sentence prediction (classification)
648
+ objective during Bert pretraining. This output is usually *not* a good summary
649
+ of the semantic content of the input, you're often better with averaging or pooling
650
+ the sequence of hidden-states for the whole input sequence.
651
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
652
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
653
+ of shape ``(batch_size, sequence_length, hidden_size)``:
654
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
655
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
656
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
657
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
658
+
659
+ Examples::
660
+
661
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
662
+ model = BertModel.from_pretrained('bert-base-uncased')
663
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
664
+ outputs = model(input_ids)
665
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
666
+
667
+ """
668
+ def __init__(self, config, latent_size):
669
+ super(BertForLatentConnector, self).__init__(config)
670
+
671
+ self.embeddings = BertEmbeddings(config)
672
+ self.encoder = BertEncoder(config)
673
+ self.pooler = BertPooler(config)
674
+
675
+ self.linear = nn.Linear(config.hidden_size, 2 * latent_size, bias=False)
676
+
677
+ self.init_weights()
678
+
679
+ def _resize_token_embeddings(self, new_num_tokens):
680
+ old_embeddings = self.embeddings.word_embeddings
681
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
682
+ self.embeddings.word_embeddings = new_embeddings
683
+ return self.embeddings.word_embeddings
684
+
685
+ def _prune_heads(self, heads_to_prune):
686
+ """ Prunes heads of the model.
687
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
688
+ See base class PreTrainedModel
689
+ """
690
+ for layer, heads in heads_to_prune.items():
691
+ self.encoder.layer[layer].attention.prune_heads(heads)
692
+
693
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
694
+ if attention_mask is None:
695
+ attention_mask = torch.ones_like(input_ids)
696
+ if token_type_ids is None:
697
+ token_type_ids = torch.zeros_like(input_ids)
698
+
699
+ # We create a 3D attention mask from a 2D tensor mask.
700
+ # Sizes are [batch_size, 1, 1, to_seq_length]
701
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
702
+ # this attention mask is more simple than the triangular masking of causal attention
703
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
704
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
705
+
706
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
707
+ # masked positions, this operation will create a tensor which is 0.0 for
708
+ # positions we want to attend and -10000.0 for masked positions.
709
+ # Since we are adding it to the raw scores before the softmax, this is
710
+ # effectively the same as removing these entirely.
711
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
712
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
713
+
714
+ # Prepare head mask if needed
715
+ # 1.0 in head_mask indicate we keep the head
716
+ # attention_probs has shape bsz x n_heads x N x N
717
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
718
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
719
+ if head_mask is not None:
720
+ if head_mask.dim() == 1:
721
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
722
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
723
+ elif head_mask.dim() == 2:
724
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
725
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
726
+ else:
727
+ head_mask = [None] * self.config.num_hidden_layers
728
+
729
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
730
+ encoder_outputs = self.encoder(embedding_output,
731
+ extended_attention_mask,
732
+ head_mask=head_mask)
733
+ sequence_output = encoder_outputs[0]
734
+ pooled_output = self.pooler(sequence_output)
735
+
736
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
737
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
738
+
739
+
740
+
741
+ @add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
742
+ a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
743
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
744
+ class BertForPreTraining(BertPreTrainedModel):
745
+ r"""
746
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
747
+ Labels for computing the masked language modeling loss.
748
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
749
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
750
+ in ``[0, ..., config.vocab_size]``
751
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
752
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
753
+ Indices should be in ``[0, 1]``.
754
+ ``0`` indicates sequence B is a continuation of sequence A,
755
+ ``1`` indicates sequence B is a random sequence.
756
+
757
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
758
+ **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
759
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
760
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
761
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
762
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
763
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
764
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
765
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
766
+ of shape ``(batch_size, sequence_length, hidden_size)``:
767
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
768
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
769
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
770
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
771
+
772
+ Examples::
773
+
774
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
775
+ model = BertForPreTraining.from_pretrained('bert-base-uncased')
776
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
777
+ outputs = model(input_ids)
778
+ prediction_scores, seq_relationship_scores = outputs[:2]
779
+
780
+ """
781
+ def __init__(self, config):
782
+ super(BertForPreTraining, self).__init__(config)
783
+
784
+ self.bert = BertModel(config)
785
+ self.cls = BertPreTrainingHeads(config)
786
+
787
+ self.init_weights()
788
+ self.tie_weights()
789
+
790
+ def tie_weights(self):
791
+ """ Make sure we are sharing the input and output embeddings.
792
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
793
+ """
794
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
795
+ self.bert.embeddings.word_embeddings)
796
+
797
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
798
+ masked_lm_labels=None, next_sentence_label=None):
799
+
800
+ outputs = self.bert(input_ids,
801
+ attention_mask=attention_mask,
802
+ token_type_ids=token_type_ids,
803
+ position_ids=position_ids,
804
+ head_mask=head_mask)
805
+
806
+ sequence_output, pooled_output = outputs[:2]
807
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
808
+
809
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
810
+
811
+ if masked_lm_labels is not None and next_sentence_label is not None:
812
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
813
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
814
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
815
+ total_loss = masked_lm_loss + next_sentence_loss
816
+ outputs = (total_loss,) + outputs
817
+
818
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
819
+
820
+
821
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
822
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
823
+ class BertForMaskedLM(BertPreTrainedModel):
824
+ r"""
825
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
826
+ Labels for computing the masked language modeling loss.
827
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
828
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
829
+ in ``[0, ..., config.vocab_size]``
830
+
831
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
832
+ **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
833
+ Masked language modeling loss.
834
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
835
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
836
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
837
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
838
+ of shape ``(batch_size, sequence_length, hidden_size)``:
839
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
840
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
841
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
842
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
843
+
844
+ Examples::
845
+
846
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
847
+ model = BertForMaskedLM.from_pretrained('bert-base-uncased')
848
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
849
+ outputs = model(input_ids, masked_lm_labels=input_ids)
850
+ loss, prediction_scores = outputs[:2]
851
+
852
+ """
853
+ def __init__(self, config):
854
+ super(BertForMaskedLM, self).__init__(config)
855
+
856
+ self.bert = BertModel(config)
857
+ self.cls = BertOnlyMLMHead(config)
858
+
859
+ self.init_weights()
860
+ self.tie_weights()
861
+
862
+ def tie_weights(self):
863
+ """ Make sure we are sharing the input and output embeddings.
864
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
865
+ """
866
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
867
+ self.bert.embeddings.word_embeddings)
868
+
869
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
870
+ masked_lm_labels=None):
871
+
872
+ outputs = self.bert(input_ids,
873
+ attention_mask=attention_mask,
874
+ token_type_ids=token_type_ids,
875
+ position_ids=position_ids,
876
+ head_mask=head_mask)
877
+
878
+ sequence_output = outputs[0]
879
+ prediction_scores = self.cls(sequence_output)
880
+
881
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
882
+ if masked_lm_labels is not None:
883
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
884
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
885
+ outputs = (masked_lm_loss,) + outputs
886
+
887
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
888
+
889
+
890
+ @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
891
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
892
+ class BertForNextSentencePrediction(BertPreTrainedModel):
893
+ r"""
894
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
895
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
896
+ Indices should be in ``[0, 1]``.
897
+ ``0`` indicates sequence B is a continuation of sequence A,
898
+ ``1`` indicates sequence B is a random sequence.
899
+
900
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
901
+ **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
902
+ Next sequence prediction (classification) loss.
903
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
904
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
905
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
906
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
907
+ of shape ``(batch_size, sequence_length, hidden_size)``:
908
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
909
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
910
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
911
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
912
+
913
+ Examples::
914
+
915
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
916
+ model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
917
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
918
+ outputs = model(input_ids)
919
+ seq_relationship_scores = outputs[0]
920
+
921
+ """
922
+ def __init__(self, config):
923
+ super(BertForNextSentencePrediction, self).__init__(config)
924
+
925
+ self.bert = BertModel(config)
926
+ self.cls = BertOnlyNSPHead(config)
927
+
928
+ self.init_weights()
929
+
930
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
931
+ next_sentence_label=None):
932
+
933
+ outputs = self.bert(input_ids,
934
+ attention_mask=attention_mask,
935
+ token_type_ids=token_type_ids,
936
+ position_ids=position_ids,
937
+ head_mask=head_mask)
938
+
939
+ pooled_output = outputs[1]
940
+
941
+ seq_relationship_score = self.cls(pooled_output)
942
+
943
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
944
+ if next_sentence_label is not None:
945
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
946
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
947
+ outputs = (next_sentence_loss,) + outputs
948
+
949
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
950
+
951
+
952
+ @add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
953
+ the pooled output) e.g. for GLUE tasks. """,
954
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
955
+ class BertForSequenceClassification(BertPreTrainedModel):
956
+ r"""
957
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
958
+ Labels for computing the sequence classification/regression loss.
959
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
960
+ If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
961
+ If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
962
+
963
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
964
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
965
+ Classification (or regression if config.num_labels==1) loss.
966
+ **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
967
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
968
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
969
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
970
+ of shape ``(batch_size, sequence_length, hidden_size)``:
971
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
972
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
973
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
974
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
975
+
976
+ Examples::
977
+
978
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
979
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
980
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
981
+ labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
982
+ outputs = model(input_ids, labels=labels)
983
+ loss, logits = outputs[:2]
984
+
985
+ """
986
+ def __init__(self, config):
987
+ super(BertForSequenceClassification, self).__init__(config)
988
+ self.num_labels = config.num_labels
989
+
990
+ self.bert = BertModel(config)
991
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
992
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
993
+ self.use_freeze = False
994
+
995
+ self.init_weights()
996
+
997
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
998
+ position_ids=None, head_mask=None, labels=None):
999
+
1000
+ outputs = self.bert(input_ids,
1001
+ attention_mask=attention_mask,
1002
+ token_type_ids=token_type_ids,
1003
+ position_ids=position_ids,
1004
+ head_mask=head_mask)
1005
+
1006
+ pooled_output = outputs[1]
1007
+
1008
+ if self.use_freeze:
1009
+ pooled_output = pooled_output.detach()
1010
+
1011
+ pooled_output = self.dropout(pooled_output)
1012
+ logits = self.classifier(pooled_output)
1013
+
1014
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1015
+
1016
+ if labels is not None:
1017
+ if self.num_labels == 1:
1018
+ # We are doing regression
1019
+ loss_fct = MSELoss()
1020
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1021
+ else:
1022
+ loss_fct = CrossEntropyLoss()
1023
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1024
+ outputs = (loss,) + outputs
1025
+
1026
+ # pdb.set_trace()
1027
+ return outputs, pooled_output # (loss), logits, (hidden_states), (attentions)
1028
+
1029
+
1030
+ @add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
1031
+ the pooled output) e.g. for GLUE tasks. """,
1032
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
1033
+ class BertForSequenceClassificationLatentConnector(BertPreTrainedModel):
1034
+ r"""
1035
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
1036
+ Labels for computing the sequence classification/regression loss.
1037
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
1038
+ If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1039
+ If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
1040
+
1041
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1042
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1043
+ Classification (or regression if config.num_labels==1) loss.
1044
+ **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
1045
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
1046
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
1047
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
1048
+ of shape ``(batch_size, sequence_length, hidden_size)``:
1049
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1050
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
1051
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1052
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1053
+
1054
+ Examples::
1055
+
1056
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1057
+ model = BertForSequenceClassificationLatentConnector.from_pretrained('bert-base-uncased')
1058
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
1059
+ labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
1060
+ outputs = model(input_ids, labels=labels)
1061
+ loss, logits = outputs[:2]
1062
+
1063
+ """
1064
+ def __init__(self, config, latent_size):
1065
+ super(BertForSequenceClassificationLatentConnector, self).__init__(config)
1066
+ self.num_labels = config.num_labels
1067
+
1068
+ self.bert = BertModel(config)
1069
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1070
+
1071
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
1072
+ self.linear = nn.Linear(config.hidden_size, 2 * latent_size, bias=False)
1073
+ self.use_freeze = False
1074
+
1075
+ self.init_weights()
1076
+
1077
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
1078
+ position_ids=None, head_mask=None, labels=None):
1079
+
1080
+ outputs = self.bert(input_ids,
1081
+ attention_mask=attention_mask,
1082
+ token_type_ids=token_type_ids,
1083
+ position_ids=position_ids,
1084
+ head_mask=head_mask)
1085
+
1086
+
1087
+ pooled_output = outputs[1]
1088
+ # mean, logvar = self.linear(pooled_output).chunk(2, -1)
1089
+
1090
+ if self.use_freeze:
1091
+ pooled_output = pooled_output.detach()
1092
+
1093
+ pooled_output = self.dropout(pooled_output)
1094
+ logits = self.classifier(pooled_output)
1095
+
1096
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1097
+
1098
+ if labels is not None:
1099
+ if self.num_labels == 1:
1100
+ # We are doing regression
1101
+ loss_fct = MSELoss()
1102
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1103
+ else:
1104
+ loss_fct = CrossEntropyLoss()
1105
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1106
+ outputs = (loss,) + outputs
1107
+
1108
+ return outputs, pooled_output # (loss), logits, (hidden_states), (attentions)
1109
+
1110
+
1111
+ @add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
1112
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1113
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
1114
+ class BertForMultipleChoice(BertPreTrainedModel):
1115
+ r"""
1116
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
1117
+ Labels for computing the multiple choice classification loss.
1118
+ Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
1119
+ of the input tensors. (see `input_ids` above)
1120
+
1121
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1122
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1123
+ Classification loss.
1124
+ **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
1125
+ of the input tensors. (see `input_ids` above).
1126
+ Classification scores (before SoftMax).
1127
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
1128
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
1129
+ of shape ``(batch_size, sequence_length, hidden_size)``:
1130
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1131
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
1132
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1133
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1134
+
1135
+ Examples::
1136
+
1137
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1138
+ model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
1139
+ choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
1140
+ input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
1141
+ labels = torch.tensor(1).unsqueeze(0) # Batch size 1
1142
+ outputs = model(input_ids, labels=labels)
1143
+ loss, classification_scores = outputs[:2]
1144
+
1145
+ """
1146
+ def __init__(self, config):
1147
+ super(BertForMultipleChoice, self).__init__(config)
1148
+
1149
+ self.bert = BertModel(config)
1150
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1151
+ self.classifier = nn.Linear(config.hidden_size, 1)
1152
+
1153
+ self.init_weights()
1154
+
1155
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
1156
+ position_ids=None, head_mask=None, labels=None):
1157
+ num_choices = input_ids.shape[1]
1158
+
1159
+ input_ids = input_ids.view(-1, input_ids.size(-1))
1160
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1161
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1162
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1163
+
1164
+ outputs = self.bert(input_ids,
1165
+ attention_mask=attention_mask,
1166
+ token_type_ids=token_type_ids,
1167
+ position_ids=position_ids,
1168
+ head_mask=head_mask)
1169
+
1170
+ pooled_output = outputs[1]
1171
+
1172
+ pooled_output = self.dropout(pooled_output)
1173
+ logits = self.classifier(pooled_output)
1174
+ reshaped_logits = logits.view(-1, num_choices)
1175
+
1176
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
1177
+
1178
+ if labels is not None:
1179
+ loss_fct = CrossEntropyLoss()
1180
+ loss = loss_fct(reshaped_logits, labels)
1181
+ outputs = (loss,) + outputs
1182
+
1183
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
1184
+
1185
+
1186
+ @add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
1187
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1188
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
1189
+ class BertForTokenClassification(BertPreTrainedModel):
1190
+ r"""
1191
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
1192
+ Labels for computing the token classification loss.
1193
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
1194
+
1195
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1196
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1197
+ Classification loss.
1198
+ **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
1199
+ Classification scores (before SoftMax).
1200
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
1201
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
1202
+ of shape ``(batch_size, sequence_length, hidden_size)``:
1203
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1204
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
1205
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1206
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1207
+
1208
+ Examples::
1209
+
1210
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1211
+ model = BertForTokenClassification.from_pretrained('bert-base-uncased')
1212
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
1213
+ labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
1214
+ outputs = model(input_ids, labels=labels)
1215
+ loss, scores = outputs[:2]
1216
+
1217
+ """
1218
+ def __init__(self, config):
1219
+ super(BertForTokenClassification, self).__init__(config)
1220
+ self.num_labels = config.num_labels
1221
+
1222
+ self.bert = BertModel(config)
1223
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1224
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1225
+
1226
+ self.init_weights()
1227
+
1228
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
1229
+ position_ids=None, head_mask=None, labels=None):
1230
+
1231
+ outputs = self.bert(input_ids,
1232
+ attention_mask=attention_mask,
1233
+ token_type_ids=token_type_ids,
1234
+ position_ids=position_ids,
1235
+ head_mask=head_mask)
1236
+
1237
+ sequence_output = outputs[0]
1238
+
1239
+ sequence_output = self.dropout(sequence_output)
1240
+ logits = self.classifier(sequence_output)
1241
+
1242
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1243
+ if labels is not None:
1244
+ loss_fct = CrossEntropyLoss()
1245
+ # Only keep active parts of the loss
1246
+ if attention_mask is not None:
1247
+ active_loss = attention_mask.view(-1) == 1
1248
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
1249
+ active_labels = labels.view(-1)[active_loss]
1250
+ loss = loss_fct(active_logits, active_labels)
1251
+ else:
1252
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1253
+ outputs = (loss,) + outputs
1254
+
1255
+ return outputs # (loss), scores, (hidden_states), (attentions)
1256
+
1257
+
1258
+ @add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1259
+ the hidden-states output to compute `span start logits` and `span end logits`). """,
1260
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
1261
+ class BertForQuestionAnswering(BertPreTrainedModel):
1262
+ r"""
1263
+ **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
1264
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1265
+ Positions are clamped to the length of the sequence (`sequence_length`).
1266
+ Position outside of the sequence are not taken into account for computing the loss.
1267
+ **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
1268
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1269
+ Positions are clamped to the length of the sequence (`sequence_length`).
1270
+ Position outside of the sequence are not taken into account for computing the loss.
1271
+
1272
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1273
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1274
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1275
+ **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
1276
+ Span-start scores (before SoftMax).
1277
+ **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
1278
+ Span-end scores (before SoftMax).
1279
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
1280
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
1281
+ of shape ``(batch_size, sequence_length, hidden_size)``:
1282
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1283
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
1284
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1285
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1286
+
1287
+ Examples::
1288
+
1289
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1290
+ model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
1291
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
1292
+ start_positions = torch.tensor([1])
1293
+ end_positions = torch.tensor([3])
1294
+ outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
1295
+ loss, start_scores, end_scores = outputs[:2]
1296
+
1297
+ """
1298
+ def __init__(self, config):
1299
+ super(BertForQuestionAnswering, self).__init__(config)
1300
+ self.num_labels = config.num_labels
1301
+
1302
+ self.bert = BertModel(config)
1303
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1304
+
1305
+ self.init_weights()
1306
+
1307
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
1308
+ start_positions=None, end_positions=None):
1309
+
1310
+ outputs = self.bert(input_ids,
1311
+ attention_mask=attention_mask,
1312
+ token_type_ids=token_type_ids,
1313
+ position_ids=position_ids,
1314
+ head_mask=head_mask)
1315
+
1316
+ sequence_output = outputs[0]
1317
+
1318
+ logits = self.qa_outputs(sequence_output)
1319
+ start_logits, end_logits = logits.split(1, dim=-1)
1320
+ start_logits = start_logits.squeeze(-1)
1321
+ end_logits = end_logits.squeeze(-1)
1322
+
1323
+ outputs = (start_logits, end_logits,) + outputs[2:]
1324
+ if start_positions is not None and end_positions is not None:
1325
+ # If we are on multi-GPU, split add a dimension
1326
+ if len(start_positions.size()) > 1:
1327
+ start_positions = start_positions.squeeze(-1)
1328
+ if len(end_positions.size()) > 1:
1329
+ end_positions = end_positions.squeeze(-1)
1330
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1331
+ ignored_index = start_logits.size(1)
1332
+ start_positions.clamp_(0, ignored_index)
1333
+ end_positions.clamp_(0, ignored_index)
1334
+
1335
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1336
+ start_loss = loss_fct(start_logits, start_positions)
1337
+ end_loss = loss_fct(end_logits, end_positions)
1338
+ total_loss = (start_loss + end_loss) / 2
1339
+ outputs = (total_loss,) + outputs
1340
+
1341
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
1342
+
1343
+
1344
+ ############
1345
+ # XX Added #
1346
+ ############
1347
+
1348
+ class BertForLatentConnector_XX(nn.Module):
1349
+ def __init__(self, config, latent_size):
1350
+ super().__init__()
1351
+ self.config = config
1352
+ self.embeddings = BertEmbeddings(config)
1353
+ self.encoder = BertEncoder(config)
1354
+ self.pooler = BertPooler(config)
1355
+ self.linear = nn.Linear(config.hidden_size, 2 * latent_size, bias=False)
1356
+ self.init_weights()
1357
+
1358
+ def init_weights(self):
1359
+ """ Initialize and prunes weights if needed. """
1360
+ # Initialize weights
1361
+ self.apply(self._init_weights)
1362
+
1363
+ # Prune heads if needed
1364
+ if self.config.pruned_heads:
1365
+ self.prune_heads(self.config.pruned_heads)
1366
+
1367
+ def _init_weights(self, module):
1368
+ """ Initialize the weights """
1369
+ if isinstance(module, (nn.Linear, nn.Embedding)):
1370
+ # Slightly different from the TF version which uses truncated_normal for initialization
1371
+ # cf https://github.com/pytorch/pytorch/pull/5617
1372
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1373
+ elif isinstance(module, BertLayerNorm):
1374
+ module.bias.data.zero_()
1375
+ module.weight.data.fill_(1.0)
1376
+ if isinstance(module, nn.Linear) and module.bias is not None:
1377
+ module.bias.data.zero_()
1378
+
1379
+ def _resize_token_embeddings(self, new_num_tokens):
1380
+ old_embeddings = self.embeddings.word_embeddings
1381
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
1382
+ self.embeddings.word_embeddings = new_embeddings
1383
+ return self.embeddings.word_embeddings
1384
+
1385
+ def _prune_heads(self, heads_to_prune):
1386
+ """ Prunes heads of the model.
1387
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
1388
+ See base class PreTrainedModel
1389
+ """
1390
+ for layer, heads in heads_to_prune.items():
1391
+ self.encoder.layer[layer].attention.prune_heads(heads)
1392
+
1393
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
1394
+ if attention_mask is None:
1395
+ attention_mask = torch.ones_like(input_ids)
1396
+ if token_type_ids is None:
1397
+ token_type_ids = torch.zeros_like(input_ids)
1398
+
1399
+ # We create a 3D attention mask from a 2D tensor mask.
1400
+ # Sizes are [batch_size, 1, 1, to_seq_length]
1401
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
1402
+ # this attention mask is more simple than the triangular masking of causal attention
1403
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
1404
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
1405
+
1406
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1407
+ # masked positions, this operation will create a tensor which is 0.0 for
1408
+ # positions we want to attend and -10000.0 for masked positions.
1409
+ # Since we are adding it to the raw scores before the softmax, this is
1410
+ # effectively the same as removing these entirely.
1411
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
1412
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1413
+
1414
+ # Prepare head mask if needed
1415
+ # 1.0 in head_mask indicate we keep the head
1416
+ # attention_probs has shape bsz x n_heads x N x N
1417
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1418
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1419
+ if head_mask is not None:
1420
+ if head_mask.dim() == 1:
1421
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
1422
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
1423
+ elif head_mask.dim() == 2:
1424
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
1425
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
1426
+ else:
1427
+ head_mask = [None] * self.config.num_hidden_layers
1428
+
1429
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
1430
+ encoder_outputs = self.encoder(embedding_output,
1431
+ extended_attention_mask,
1432
+ head_mask=head_mask)
1433
+ sequence_output = encoder_outputs[0]
1434
+ pooled_output = self.pooler(sequence_output)
1435
+
1436
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
1437
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
1438
+
1439
+
versatile_diffusion/lib/model_zoo/optimus_models/optimus_gpt2.py ADDED
@@ -0,0 +1,1122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch OpenAI GPT-2 model."""
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import pdb
21
+
22
+ import collections
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ from io import open
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ from torch.nn import CrossEntropyLoss
33
+ from torch.nn.parameter import Parameter
34
+
35
+ from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
36
+ from .configuration_gpt2 import GPT2Config
37
+ from .file_utils import add_start_docstrings
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
42
+ "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
43
+ "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"}
44
+
45
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
46
+ """ Load tf checkpoints in a pytorch model
47
+ """
48
+ try:
49
+ import re
50
+ import numpy as np
51
+ import tensorflow as tf
52
+ except ImportError:
53
+ logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
54
+ "https://www.tensorflow.org/install/ for installation instructions.")
55
+ raise
56
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
57
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
58
+ # Load weights from TF model
59
+ init_vars = tf.train.list_variables(tf_path)
60
+ names = []
61
+ arrays = []
62
+ for name, shape in init_vars:
63
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
64
+ array = tf.train.load_variable(tf_path, name)
65
+ names.append(name)
66
+ arrays.append(array.squeeze())
67
+
68
+ for name, array in zip(names, arrays):
69
+ name = name[6:] # skip "model/"
70
+ name = name.split('/')
71
+ pointer = model
72
+ for m_name in name:
73
+ if re.fullmatch(r'[A-Za-z]+\d+', m_name):
74
+ l = re.split(r'(\d+)', m_name)
75
+ else:
76
+ l = [m_name]
77
+ if l[0] == 'w' or l[0] == 'g':
78
+ pointer = getattr(pointer, 'weight')
79
+ elif l[0] == 'b':
80
+ pointer = getattr(pointer, 'bias')
81
+ elif l[0] == 'wpe' or l[0] == 'wte':
82
+ pointer = getattr(pointer, l[0])
83
+ pointer = getattr(pointer, 'weight')
84
+ else:
85
+ pointer = getattr(pointer, l[0])
86
+ if len(l) >= 2:
87
+ num = int(l[1])
88
+ pointer = pointer[num]
89
+ try:
90
+ assert pointer.shape == array.shape
91
+ except AssertionError as e:
92
+ e.args += (pointer.shape, array.shape)
93
+ raise
94
+ logger.info("Initialize PyTorch weight {}".format(name))
95
+ pointer.data = torch.from_numpy(array)
96
+ return model
97
+
98
+
99
+ def gelu(x):
100
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
101
+
102
+
103
+ class Attention(nn.Module):
104
+ def __init__(self, nx, n_ctx, config, scale=False):
105
+ super(Attention, self).__init__()
106
+ self.output_attentions = config.output_attentions
107
+
108
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
109
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
110
+ assert n_state % config.n_head == 0
111
+ self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
112
+ self.n_head = config.n_head
113
+ self.split_size = n_state
114
+ self.scale = scale
115
+
116
+ self.c_attn = Conv1D(n_state * 3, nx)
117
+ self.c_proj = Conv1D(n_state, nx)
118
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
119
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
120
+ self.pruned_heads = set()
121
+
122
+ def prune_heads(self, heads):
123
+ if len(heads) == 0:
124
+ return
125
+ mask = torch.ones(self.n_head, self.split_size // self.n_head)
126
+ heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
127
+ for head in heads:
128
+ # Compute how many pruned heads are before the head and move the index accordingly
129
+ head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
130
+ mask[head] = 0
131
+ mask = mask.view(-1).contiguous().eq(1)
132
+ index = torch.arange(len(mask))[mask].long()
133
+ index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
134
+
135
+ # Prune conv1d layers
136
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
137
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
138
+
139
+ # Update hyper params
140
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
141
+ self.n_head = self.n_head - len(heads)
142
+ self.pruned_heads = self.pruned_heads.union(heads)
143
+
144
+ def _attn(self, q, k, v, attention_mask=None, head_mask=None):
145
+ w = torch.matmul(q, k)
146
+ if self.scale:
147
+ w = w / math.sqrt(v.size(-1))
148
+ nd, ns = w.size(-2), w.size(-1)
149
+ b = self.bias[:, :, ns-nd:ns, :ns]
150
+ w = w * b - 1e4 * (1 - b)
151
+
152
+ if attention_mask is not None:
153
+ # Apply the attention mask
154
+ w = w + attention_mask
155
+
156
+ w = nn.Softmax(dim=-1)(w)
157
+ w = self.attn_dropout(w)
158
+
159
+ # Mask heads if we want to
160
+ if head_mask is not None:
161
+ w = w * head_mask
162
+
163
+ outputs = [torch.matmul(w, v)]
164
+ if self.output_attentions:
165
+ outputs.append(w)
166
+ return outputs
167
+
168
+ def merge_heads(self, x):
169
+ x = x.permute(0, 2, 1, 3).contiguous()
170
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
171
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
172
+
173
+ def split_heads(self, x, k=False):
174
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
175
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
176
+ if k:
177
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
178
+ else:
179
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
180
+
181
+ def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
182
+ x = self.c_attn(x)
183
+ query, key, value = x.split(self.split_size, dim=2)
184
+ query = self.split_heads(query)
185
+ key = self.split_heads(key, k=True)
186
+ value = self.split_heads(value)
187
+
188
+
189
+ if layer_past is not None:
190
+ past_key, past_value = layer_past[0], layer_past[1] # transpose back cf below
191
+
192
+ past_key = self.split_heads(past_key, k=True)
193
+ past_value = self.split_heads(past_value)
194
+ # pdb.set_trace()
195
+ key = torch.cat((past_key, key), dim=-1)
196
+ value = torch.cat((past_value, value), dim=-2)
197
+ present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
198
+
199
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
200
+ a = attn_outputs[0]
201
+
202
+ a = self.merge_heads(a)
203
+ a = self.c_proj(a)
204
+ a = self.resid_dropout(a)
205
+
206
+ outputs = [a, present] + attn_outputs[1:]
207
+ return outputs # a, present, (attentions)
208
+
209
+
210
+ class MLP(nn.Module):
211
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
212
+ super(MLP, self).__init__()
213
+ nx = config.n_embd
214
+ self.c_fc = Conv1D(n_state, nx)
215
+ self.c_proj = Conv1D(nx, n_state)
216
+ self.act = gelu
217
+ self.dropout = nn.Dropout(config.resid_pdrop)
218
+
219
+ def forward(self, x):
220
+ h = self.act(self.c_fc(x))
221
+ h2 = self.c_proj(h)
222
+ return self.dropout(h2)
223
+
224
+
225
+ class Block(nn.Module):
226
+ def __init__(self, n_ctx, config, scale=False):
227
+ super(Block, self).__init__()
228
+ nx = config.n_embd
229
+ self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
230
+ self.attn = Attention(nx, n_ctx, config, scale)
231
+ self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
232
+ self.mlp = MLP(4 * nx, config)
233
+
234
+ def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
235
+ output_attn = self.attn(self.ln_1(x),
236
+ layer_past=layer_past,
237
+ attention_mask=attention_mask,
238
+ head_mask=head_mask)
239
+ a = output_attn[0] # output_attn: a, present, (attentions)
240
+
241
+ x = x + a
242
+ m = self.mlp(self.ln_2(x))
243
+ x = x + m
244
+
245
+ outputs = [x] + output_attn[1:]
246
+ return outputs # x, present, (attentions)
247
+
248
+
249
+ class GPT2PreTrainedModel(PreTrainedModel):
250
+ """ An abstract class to handle weights initialization and
251
+ a simple interface for dowloading and loading pretrained models.
252
+ """
253
+ config_class = GPT2Config
254
+ pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
255
+ load_tf_weights = load_tf_weights_in_gpt2
256
+ base_model_prefix = "transformer"
257
+
258
+ def __init__(self, *inputs, **kwargs):
259
+ super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
260
+
261
+ def _init_weights(self, module):
262
+ """ Initialize the weights.
263
+ """
264
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
265
+ # Slightly different from the TF version which uses truncated_normal for initialization
266
+ # cf https://github.com/pytorch/pytorch/pull/5617
267
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
268
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
269
+ module.bias.data.zero_()
270
+ elif isinstance(module, nn.LayerNorm):
271
+ module.bias.data.zero_()
272
+ module.weight.data.fill_(1.0)
273
+
274
+
275
+ GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
276
+ `Language Models are Unsupervised Multitask Learners`_
277
+ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
278
+ It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
279
+ corpus of ~40 GB of text data.
280
+
281
+ This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
282
+ refer to the PyTorch documentation for all matter related to general usage and behavior.
283
+
284
+ .. _`Language Models are Unsupervised Multitask Learners`:
285
+ https://openai.com/blog/better-language-models/
286
+
287
+ .. _`torch.nn.Module`:
288
+ https://pytorch.org/docs/stable/nn.html#module
289
+
290
+ Parameters:
291
+ config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
292
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
293
+ Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
294
+ """
295
+
296
+ GPT2_INPUTS_DOCSTRING = r""" Inputs:
297
+ **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
298
+ Indices of input sequence tokens in the vocabulary.
299
+ GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
300
+ the right rather than the left.
301
+ Indices can be obtained using :class:`pytorch_transformers.GPT2Tokenizer`.
302
+ See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
303
+ :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
304
+ **past**:
305
+ list of ``torch.FloatTensor`` (one for each layer):
306
+ that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
307
+ (see `past` output below). Can be used to speed up sequential decoding.
308
+ **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
309
+ Mask to avoid performing attention on padding token indices.
310
+ Mask values selected in ``[0, 1]``:
311
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
312
+ **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
313
+ A parallel sequence of tokens (can be used to indicate various portions of the inputs).
314
+ The embeddings from these tokens will be summed with the respective token embeddings.
315
+ Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
316
+ **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
317
+ Indices of positions of each input sequence tokens in the position embeddings.
318
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
319
+ **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
320
+ Mask to nullify selected heads of the self-attention modules.
321
+ Mask values selected in ``[0, 1]``:
322
+ ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
323
+ """
324
+
325
+ @add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
326
+ GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
327
+ class GPT2Model(GPT2PreTrainedModel):
328
+ r"""
329
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
330
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
331
+ Sequence of hidden-states at the last layer of the model.
332
+ **past**:
333
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
334
+ that contains pre-computed hidden-states (key and values in the attention blocks).
335
+ Can be used (see `past` input) to speed up sequential decoding.
336
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
337
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
338
+ of shape ``(batch_size, sequence_length, hidden_size)``:
339
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
340
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
341
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
342
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
343
+
344
+ Examples::
345
+
346
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
347
+ model = GPT2Model.from_pretrained('gpt2')
348
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
349
+ outputs = model(input_ids)
350
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
351
+
352
+ """
353
+ def __init__(self, config):
354
+ super(GPT2Model, self).__init__(config)
355
+ self.output_hidden_states = config.output_hidden_states
356
+ self.output_attentions = config.output_attentions
357
+
358
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
359
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
360
+ self.drop = nn.Dropout(config.embd_pdrop)
361
+ self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
362
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
363
+
364
+ try:
365
+ self.latent_size = config.latent_size
366
+ except:
367
+ self.latent_size = 32 # default size is 32
368
+
369
+ self.linear = nn.Linear(self.latent_size, config.hidden_size * config.n_layer, bias=False) # different latent vector for each layer
370
+ self.linear_emb = nn.Linear(self.latent_size, config.hidden_size, bias=False) # share the same latent vector as the embeddings
371
+
372
+ self.config = config
373
+ self.init_weights()
374
+
375
+ def _resize_token_embeddings(self, new_num_tokens):
376
+ self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
377
+ return self.wte
378
+
379
+ def _prune_heads(self, heads_to_prune):
380
+ """ Prunes heads of the model.
381
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
382
+ """
383
+ for layer, heads in heads_to_prune.items():
384
+ self.h[layer].attn.prune_heads(heads)
385
+
386
+ def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, latent_as_gpt_emb=False, latent_as_gpt_memory=True):
387
+
388
+ if past is None:
389
+ past_length = 0
390
+ past = [None] * len(self.h)
391
+ else:
392
+
393
+
394
+ if latent_as_gpt_emb:
395
+ past_emb = self.linear_emb(past) # used as embeddings to add on other three embeddings
396
+
397
+ if latent_as_gpt_memory:
398
+ past = self.linear(past)
399
+ share_latent = False
400
+ if share_latent:
401
+ # the same latent vector shared by all layers
402
+ past = [past.unsqueeze(-2), past.unsqueeze(-2)] # query, key
403
+ past = [past] * len(self.h)
404
+ past_length = past[0][0].size(-2)
405
+ else:
406
+ # different latent vectors for each layer
407
+ past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2)
408
+ past = list(zip(past_split,past_split))
409
+
410
+ # past = past.view(batch_size,len(self.h),-1)
411
+ # past = [[past[:,i,:].unsqueeze(-2), past[:,i,:].unsqueeze(-2) ] for i in range(len(self.h))]
412
+ past_length = 1 # past[0][0].size(-2)
413
+ else:
414
+ past_length = 0
415
+ past = [None] * len(self.h)
416
+
417
+
418
+ if position_ids is None:
419
+ position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
420
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
421
+
422
+
423
+ # Attention mask.
424
+ if attention_mask is not None:
425
+ # We create a 3D attention mask from a 2D tensor mask.
426
+ # Sizes are [batch_size, 1, 1, to_seq_length]
427
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
428
+ # this attention mask is more simple than the triangular masking of causal attention
429
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
430
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
431
+
432
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
433
+ # masked positions, this operation will create a tensor which is 0.0 for
434
+ # positions we want to attend and -10000.0 for masked positions.
435
+ # Since we are adding it to the raw scores before the softmax, this is
436
+ # effectively the same as removing these entirely.
437
+ attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
438
+ attention_mask = (1.0 - attention_mask) * -10000.0
439
+
440
+ # Prepare head mask if needed
441
+ # 1.0 in head_mask indicate we keep the head
442
+ # attention_probs has shape bsz x n_heads x N x N
443
+ # head_mask has shape n_layer x batch x n_heads x N x N
444
+ if head_mask is not None:
445
+ if head_mask.dim() == 1:
446
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
447
+ head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
448
+ elif head_mask.dim() == 2:
449
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
450
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
451
+ else:
452
+ head_mask = [None] * self.config.n_layer
453
+
454
+
455
+ input_shape = input_ids.size()
456
+ input_ids = input_ids.view(-1, input_ids.size(-1))
457
+ position_ids = position_ids.view(-1, position_ids.size(-1))
458
+
459
+
460
+ inputs_embeds = self.wte(input_ids)
461
+ position_embeds = self.wpe(position_ids)
462
+ if token_type_ids is not None:
463
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
464
+ token_type_embeds = self.wte(token_type_ids)
465
+ else:
466
+ token_type_embeds = 0
467
+
468
+
469
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
470
+ if latent_as_gpt_emb:
471
+ # pdb.set_trace()
472
+ hidden_states = hidden_states + past_emb.unsqueeze(1)
473
+
474
+ hidden_states = self.drop(hidden_states)
475
+
476
+ output_shape = input_shape + (hidden_states.size(-1),)
477
+
478
+ presents = ()
479
+ all_attentions = []
480
+ all_hidden_states = ()
481
+ for i, (block, layer_past) in enumerate(zip(self.h, past)):
482
+ if self.output_hidden_states:
483
+ all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
484
+
485
+
486
+ outputs = block(hidden_states,
487
+ layer_past=layer_past,
488
+ attention_mask=attention_mask,
489
+ head_mask=head_mask[i])
490
+
491
+
492
+ hidden_states, present = outputs[:2]
493
+ presents = presents + (present,)
494
+
495
+ if self.output_attentions:
496
+ all_attentions.append(outputs[2])
497
+
498
+ hidden_states = self.ln_f(hidden_states)
499
+
500
+ hidden_states = hidden_states.view(*output_shape)
501
+ # Add last hidden state
502
+ if self.output_hidden_states:
503
+ all_hidden_states = all_hidden_states + (hidden_states,)
504
+
505
+ outputs = (hidden_states, presents)
506
+ if self.output_hidden_states:
507
+ outputs = outputs + (all_hidden_states,)
508
+ if self.output_attentions:
509
+ # let the number of heads free (-1) so we can extract attention even after head pruning
510
+ attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
511
+ all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
512
+ outputs = outputs + (all_attentions,)
513
+ return outputs # last hidden state, presents, (all hidden_states), (attentions)
514
+
515
+
516
+ @add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
517
+ (linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
518
+ class GPT2LMHeadModel(GPT2PreTrainedModel):
519
+ r"""
520
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
521
+ Labels for language modeling.
522
+ Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
523
+ Indices are selected in ``[-1, 0, ..., config.vocab_size]``
524
+ All labels set to ``-1`` are ignored (masked), the loss is only
525
+ computed for labels in ``[0, ..., config.vocab_size]``
526
+
527
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
528
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
529
+ Language modeling loss.
530
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
531
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
532
+ **past**:
533
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
534
+ that contains pre-computed hidden-states (key and values in the attention blocks).
535
+ Can be used (see `past` input) to speed up sequential decoding.
536
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
537
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
538
+ of shape ``(batch_size, sequence_length, hidden_size)``:
539
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
540
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
541
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
542
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
543
+
544
+ Examples::
545
+
546
+ import torch
547
+ from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel
548
+
549
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
550
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
551
+
552
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
553
+ outputs = model(input_ids, labels=input_ids)
554
+ loss, logits = outputs[:2]
555
+
556
+ """
557
+ def __init__(self, config):
558
+ super(GPT2LMHeadModel, self).__init__(config)
559
+ self.transformer = GPT2Model(config)
560
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
561
+
562
+ self.init_weights()
563
+ self.tie_weights()
564
+
565
+
566
+ def tie_weights(self):
567
+ """ Make sure we are sharing the input and output embeddings.
568
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
569
+ """
570
+ self._tie_or_clone_weights(self.lm_head,
571
+ self.transformer.wte)
572
+
573
+ def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
574
+ labels=None, label_ignore=None):
575
+ transformer_outputs = self.transformer(input_ids,
576
+ past=past,
577
+ attention_mask=attention_mask,
578
+ token_type_ids=token_type_ids,
579
+ position_ids=position_ids,
580
+ head_mask=head_mask)
581
+ hidden_states = transformer_outputs[0]
582
+
583
+ lm_logits = self.lm_head(hidden_states)
584
+
585
+ outputs = (lm_logits,) + transformer_outputs[1:]
586
+ if labels is not None:
587
+ # Shift so that tokens < n predict n
588
+ shift_logits = lm_logits[..., :-1, :].contiguous()
589
+ shift_labels = labels[..., 1:].contiguous()
590
+ # Flatten the tokens
591
+ loss_fct = CrossEntropyLoss(ignore_index=label_ignore, reduce=False) # 50258 is the padding id, otherwise -1 is used for masked LM.
592
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
593
+ shift_labels.view(-1))
594
+ loss = torch.sum(loss.view(-1, shift_labels.shape[-1]), -1)
595
+ outputs = (loss,) + outputs
596
+
597
+
598
+ return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
599
+
600
+
601
+
602
+ @add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
603
+ (linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
604
+ class GPT2ForLatentConnector(GPT2PreTrainedModel):
605
+ r"""
606
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
607
+ Labels for language modeling.
608
+ Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
609
+ Indices are selected in ``[-1, 0, ..., config.vocab_size]``
610
+ All labels set to ``-1`` are ignored (masked), the loss is only
611
+ computed for labels in ``[0, ..., config.vocab_size]``
612
+
613
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
614
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
615
+ Language modeling loss.
616
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
617
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
618
+ **past**:
619
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
620
+ that contains pre-computed hidden-states (key and values in the attention blocks).
621
+ Can be used (see `past` input) to speed up sequential decoding.
622
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
623
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
624
+ of shape ``(batch_size, sequence_length, hidden_size)``:
625
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
626
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
627
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
628
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
629
+
630
+ Examples::
631
+
632
+ import torch
633
+ from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel
634
+
635
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
636
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
637
+
638
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
639
+ outputs = model(input_ids, labels=input_ids)
640
+ loss, logits = outputs[:2]
641
+
642
+ """
643
+ def __init__(self, config, latent_size=32, latent_as_gpt_emb=True, latent_as_gpt_memory=True):
644
+
645
+ super(GPT2ForLatentConnector, self).__init__(config)
646
+
647
+
648
+ self.transformer = GPT2Model(config)
649
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
650
+
651
+ self.init_weights()
652
+ self.tie_weights()
653
+
654
+ self.latent_as_gpt_emb = latent_as_gpt_emb
655
+ self.latent_as_gpt_memory = latent_as_gpt_memory
656
+
657
+
658
+
659
+ def tie_weights(self):
660
+ """ Make sure we are sharing the input and output embeddings.
661
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
662
+ """
663
+ self._tie_or_clone_weights(self.lm_head,
664
+ self.transformer.wte)
665
+
666
+ def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
667
+ labels=None, label_ignore=None):
668
+
669
+
670
+ transformer_outputs = self.transformer(input_ids,
671
+ past=past,
672
+ attention_mask=attention_mask,
673
+ token_type_ids=token_type_ids,
674
+ position_ids=position_ids,
675
+ head_mask=head_mask,
676
+ latent_as_gpt_emb=self.latent_as_gpt_emb,
677
+ latent_as_gpt_memory=self.latent_as_gpt_memory)
678
+ hidden_states = transformer_outputs[0]
679
+
680
+ lm_logits = self.lm_head(hidden_states)
681
+
682
+ outputs = (lm_logits,) + transformer_outputs[1:]
683
+ if labels is not None:
684
+ # Shift so that tokens < n predict n
685
+ shift_logits = lm_logits[..., :-1, :].contiguous()
686
+ shift_labels = labels[..., 1:].contiguous()
687
+ # Flatten the tokens
688
+ loss_fct = CrossEntropyLoss(ignore_index=label_ignore, reduce=False) # 50258 is the padding id, otherwise -1 is used for masked LM.
689
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
690
+ shift_labels.view(-1))
691
+ loss = torch.sum(loss.view(-1, shift_labels.shape[-1]), -1)
692
+ outputs = (loss,) + outputs
693
+
694
+
695
+ return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
696
+
697
+ @add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
698
+ head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
699
+ The language modeling head has its weights tied to the input embeddings,
700
+ the classification head takes as input the input of a specified classification token index in the input sequence).
701
+ """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
702
+ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
703
+ r"""
704
+ **mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
705
+ Index of the classification token in each input sequence.
706
+ Selected in the range ``[0, input_ids.size(-1) - 1[``.
707
+ **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
708
+ Labels for language modeling.
709
+ Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
710
+ Indices are selected in ``[-1, 0, ..., config.vocab_size]``
711
+ All labels set to ``-1`` are ignored (masked), the loss is only
712
+ computed for labels in ``[0, ..., config.vocab_size]``
713
+ **mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
714
+ Labels for computing the multiple choice classification loss.
715
+ Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
716
+ of the input tensors. (see `input_ids` above)
717
+
718
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
719
+ **lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
720
+ Language modeling loss.
721
+ **mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
722
+ Multiple choice classification loss.
723
+ **lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
724
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
725
+ **mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
726
+ Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
727
+ **past**:
728
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
729
+ that contains pre-computed hidden-states (key and values in the attention blocks).
730
+ Can be used (see `past` input) to speed up sequential decoding.
731
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
732
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
733
+ of shape ``(batch_size, sequence_length, hidden_size)``:
734
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
735
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
736
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
737
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
738
+
739
+ Examples::
740
+
741
+ import torch
742
+ from pytorch_transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
743
+
744
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
745
+ model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
746
+
747
+ # Add a [CLS] to the vocabulary (we should train it also!)
748
+ tokenizer.add_special_tokens({'cls_token': '[CLS]'})
749
+ model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
750
+ print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
751
+
752
+ choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
753
+ encoded_choices = [tokenizer.encode(s) for s in choices]
754
+ cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
755
+
756
+ input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
757
+ mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
758
+
759
+ outputs = model(input_ids, mc_token_ids=mc_token_ids)
760
+ lm_prediction_scores, mc_prediction_scores = outputs[:2]
761
+
762
+ """
763
+ def __init__(self, config):
764
+ super(GPT2DoubleHeadsModel, self).__init__(config)
765
+ self.transformer = GPT2Model(config)
766
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
767
+ self.multiple_choice_head = SequenceSummary(config)
768
+
769
+ self.init_weights()
770
+ self.tie_weights()
771
+
772
+ def tie_weights(self):
773
+ """ Make sure we are sharing the input and output embeddings.
774
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
775
+ """
776
+ self._tie_or_clone_weights(self.lm_head,
777
+ self.transformer.wte)
778
+
779
+ def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
780
+ mc_token_ids=None, lm_labels=None, mc_labels=None):
781
+ transformer_outputs = self.transformer(input_ids,
782
+ past=past,
783
+ attention_mask=attention_mask,
784
+ token_type_ids=token_type_ids,
785
+ position_ids=position_ids,
786
+ head_mask=head_mask)
787
+
788
+ hidden_states = transformer_outputs[0]
789
+
790
+ lm_logits = self.lm_head(hidden_states)
791
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
792
+
793
+ outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
794
+ if mc_labels is not None:
795
+ loss_fct = CrossEntropyLoss()
796
+ loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
797
+ mc_labels.view(-1))
798
+ outputs = (loss,) + outputs
799
+ if lm_labels is not None:
800
+ shift_logits = lm_logits[..., :-1, :].contiguous()
801
+ shift_labels = lm_labels[..., 1:].contiguous()
802
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
803
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
804
+ shift_labels.view(-1))
805
+ outputs = (loss,) + outputs
806
+
807
+ return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
808
+
809
+ ############
810
+ # XX Added #
811
+ ############
812
+
813
+ class GPT2Model_XX(nn.Module):
814
+ def __init__(self, config):
815
+ super().__init__()
816
+ self.config = config
817
+ self.output_hidden_states = config.output_hidden_states
818
+ self.output_attentions = config.output_attentions
819
+
820
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
821
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
822
+ self.drop = nn.Dropout(config.embd_pdrop)
823
+ self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
824
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
825
+
826
+ try:
827
+ self.latent_size = config.latent_size
828
+ except:
829
+ self.latent_size = 32 # default size is 32
830
+
831
+ self.linear = nn.Linear(self.latent_size, config.hidden_size * config.n_layer, bias=False) # different latent vector for each layer
832
+ self.linear_emb = nn.Linear(self.latent_size, config.hidden_size, bias=False) # share the same latent vector as the embeddings
833
+
834
+ self.config = config
835
+ self.init_weights()
836
+
837
+ def init_weights(self):
838
+ """ Initialize and prunes weights if needed. """
839
+ # Initialize weights
840
+ self.apply(self._init_weights)
841
+
842
+ # Prune heads if needed
843
+ if self.config.pruned_heads:
844
+ self.prune_heads(self.config.pruned_heads)
845
+
846
+ def _init_weights(self, module):
847
+ """ Initialize the weights.
848
+ """
849
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
850
+ # Slightly different from the TF version which uses truncated_normal for initialization
851
+ # cf https://github.com/pytorch/pytorch/pull/5617
852
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
853
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
854
+ module.bias.data.zero_()
855
+ elif isinstance(module, nn.LayerNorm):
856
+ module.bias.data.zero_()
857
+ module.weight.data.fill_(1.0)
858
+
859
+ def _resize_token_embeddings(self, new_num_tokens):
860
+ self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
861
+ return self.wte
862
+
863
+ def _prune_heads(self, heads_to_prune):
864
+ """ Prunes heads of the model.
865
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
866
+ """
867
+ for layer, heads in heads_to_prune.items():
868
+ self.h[layer].attn.prune_heads(heads)
869
+
870
+ def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, latent_as_gpt_emb=False, latent_as_gpt_memory=True):
871
+ if past is None:
872
+ past_length = 0
873
+ past = [None] * len(self.h)
874
+ else:
875
+ if latent_as_gpt_emb:
876
+ past_emb = self.linear_emb(past) # used as embeddings to add on other three embeddings
877
+
878
+ if latent_as_gpt_memory:
879
+ past = self.linear(past)
880
+ share_latent = False
881
+ if share_latent:
882
+ # the same latent vector shared by all layers
883
+ past = [past.unsqueeze(-2), past.unsqueeze(-2)] # query, key
884
+ past = [past] * len(self.h)
885
+ past_length = past[0][0].size(-2)
886
+ else:
887
+ # different latent vectors for each layer
888
+ past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2)
889
+ past = list(zip(past_split,past_split))
890
+
891
+ # past = past.view(batch_size,len(self.h),-1)
892
+ # past = [[past[:,i,:].unsqueeze(-2), past[:,i,:].unsqueeze(-2) ] for i in range(len(self.h))]
893
+ past_length = 1 # past[0][0].size(-2)
894
+ else:
895
+ past_length = 0
896
+ past = [None] * len(self.h)
897
+
898
+
899
+ if position_ids is None:
900
+ position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
901
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
902
+
903
+
904
+ # Attention mask.
905
+ if attention_mask is not None:
906
+ # We create a 3D attention mask from a 2D tensor mask.
907
+ # Sizes are [batch_size, 1, 1, to_seq_length]
908
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
909
+ # this attention mask is more simple than the triangular masking of causal attention
910
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
911
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
912
+
913
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
914
+ # masked positions, this operation will create a tensor which is 0.0 for
915
+ # positions we want to attend and -10000.0 for masked positions.
916
+ # Since we are adding it to the raw scores before the softmax, this is
917
+ # effectively the same as removing these entirely.
918
+ attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
919
+ attention_mask = (1.0 - attention_mask) * -10000.0
920
+
921
+ # Prepare head mask if needed
922
+ # 1.0 in head_mask indicate we keep the head
923
+ # attention_probs has shape bsz x n_heads x N x N
924
+ # head_mask has shape n_layer x batch x n_heads x N x N
925
+ if head_mask is not None:
926
+ if head_mask.dim() == 1:
927
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
928
+ head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
929
+ elif head_mask.dim() == 2:
930
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
931
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
932
+ else:
933
+ head_mask = [None] * self.config.n_layer
934
+
935
+
936
+ input_shape = input_ids.size()
937
+ input_ids = input_ids.view(-1, input_ids.size(-1))
938
+ position_ids = position_ids.view(-1, position_ids.size(-1))
939
+
940
+
941
+ inputs_embeds = self.wte(input_ids)
942
+ position_embeds = self.wpe(position_ids)
943
+ if token_type_ids is not None:
944
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
945
+ token_type_embeds = self.wte(token_type_ids)
946
+ else:
947
+ token_type_embeds = 0
948
+
949
+
950
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
951
+ if latent_as_gpt_emb:
952
+ # pdb.set_trace()
953
+ hidden_states = hidden_states + past_emb.unsqueeze(1)
954
+
955
+ hidden_states = self.drop(hidden_states)
956
+
957
+ output_shape = input_shape + (hidden_states.size(-1),)
958
+
959
+ presents = ()
960
+ all_attentions = []
961
+ all_hidden_states = ()
962
+ for i, (block, layer_past) in enumerate(zip(self.h, past)):
963
+ if self.output_hidden_states:
964
+ all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
965
+
966
+
967
+ outputs = block(hidden_states,
968
+ layer_past=layer_past,
969
+ attention_mask=attention_mask,
970
+ head_mask=head_mask[i])
971
+
972
+
973
+ hidden_states, present = outputs[:2]
974
+ presents = presents + (present,)
975
+
976
+ if self.output_attentions:
977
+ all_attentions.append(outputs[2])
978
+
979
+ hidden_states = self.ln_f(hidden_states)
980
+
981
+ hidden_states = hidden_states.view(*output_shape)
982
+ # Add last hidden state
983
+ if self.output_hidden_states:
984
+ all_hidden_states = all_hidden_states + (hidden_states,)
985
+
986
+ outputs = (hidden_states, presents)
987
+ if self.output_hidden_states:
988
+ outputs = outputs + (all_hidden_states,)
989
+ if self.output_attentions:
990
+ # let the number of heads free (-1) so we can extract attention even after head pruning
991
+ attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
992
+ all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
993
+ outputs = outputs + (all_attentions,)
994
+ return outputs # last hidden state, presents, (all hidden_states), (attentions)
995
+
996
+ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
997
+ """ Build a resized Embedding Module from a provided token Embedding Module.
998
+ Increasing the size will add newly initialized vectors at the end
999
+ Reducing the size will remove vectors from the end
1000
+
1001
+ Args:
1002
+ new_num_tokens: (`optional`) int
1003
+ New number of tokens in the embedding matrix.
1004
+ Increasing the size will add newly initialized vectors at the end
1005
+ Reducing the size will remove vectors from the end
1006
+ If not provided or None: return the provided token Embedding Module.
1007
+ Return: ``torch.nn.Embeddings``
1008
+ Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
1009
+ """
1010
+ if new_num_tokens is None:
1011
+ return old_embeddings
1012
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1013
+ if old_num_tokens == new_num_tokens:
1014
+ return old_embeddings
1015
+ # Build new embeddings
1016
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
1017
+ new_embeddings.to(old_embeddings.weight.device)
1018
+ # initialize all new embeddings (in particular added tokens)
1019
+ self._init_weights(new_embeddings)
1020
+ # Copy word embeddings from the previous weights
1021
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
1022
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
1023
+ return new_embeddings
1024
+
1025
+ class GPT2ForLatentConnector_XX(nn.Module):
1026
+ def __init__(self,
1027
+ config,
1028
+ latent_size=32,
1029
+ latent_as_gpt_emb=True,
1030
+ latent_as_gpt_memory=True):
1031
+
1032
+ super().__init__()
1033
+ self.config = config
1034
+ self.transformer = GPT2Model_XX(config)
1035
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1036
+ self.init_weights()
1037
+ self.tie_weights()
1038
+ self.latent_as_gpt_emb = latent_as_gpt_emb
1039
+ self.latent_as_gpt_memory = latent_as_gpt_memory
1040
+
1041
+ def init_weights(self):
1042
+ """ Initialize and prunes weights if needed. """
1043
+ # Initialize weights
1044
+ self.apply(self._init_weights)
1045
+
1046
+ # Prune heads if needed
1047
+ if self.config.pruned_heads:
1048
+ self.prune_heads(self.config.pruned_heads)
1049
+
1050
+ def _init_weights(self, module):
1051
+ """ Initialize the weights.
1052
+ """
1053
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
1054
+ # Slightly different from the TF version which uses truncated_normal for initialization
1055
+ # cf https://github.com/pytorch/pytorch/pull/5617
1056
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1057
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
1058
+ module.bias.data.zero_()
1059
+ elif isinstance(module, nn.LayerNorm):
1060
+ module.bias.data.zero_()
1061
+ module.weight.data.fill_(1.0)
1062
+
1063
+ def _tie_or_clone_weights(self, first_module, second_module):
1064
+ """ Tie or clone module weights depending of weither we are using TorchScript or not
1065
+ """
1066
+ if self.config.torchscript:
1067
+ first_module.weight = nn.Parameter(second_module.weight.clone())
1068
+ else:
1069
+ first_module.weight = second_module.weight
1070
+
1071
+ if hasattr(first_module, 'bias') and first_module.bias is not None:
1072
+ first_module.bias.data = torch.nn.functional.pad(
1073
+ first_module.bias.data,
1074
+ (0, first_module.weight.shape[0] - first_module.bias.shape[0]),
1075
+ 'constant', 0,)
1076
+
1077
+ def tie_weights(self):
1078
+ """ Make sure we are sharing the input and output embeddings.
1079
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
1080
+ """
1081
+ self._tie_or_clone_weights(self.lm_head,
1082
+ self.transformer.wte)
1083
+
1084
+ def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
1085
+ labels=None, label_ignore=None):
1086
+
1087
+
1088
+ transformer_outputs = self.transformer(input_ids,
1089
+ past=past,
1090
+ attention_mask=attention_mask,
1091
+ token_type_ids=token_type_ids,
1092
+ position_ids=position_ids,
1093
+ head_mask=head_mask,
1094
+ latent_as_gpt_emb=self.latent_as_gpt_emb,
1095
+ latent_as_gpt_memory=self.latent_as_gpt_memory)
1096
+ hidden_states = transformer_outputs[0]
1097
+
1098
+ lm_logits = self.lm_head(hidden_states)
1099
+
1100
+ outputs = (lm_logits,) + transformer_outputs[1:]
1101
+ if labels is not None:
1102
+ # Shift so that tokens < n predict n
1103
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1104
+ shift_labels = labels[..., 1:].contiguous()
1105
+ # Flatten the tokens
1106
+ loss_fct = CrossEntropyLoss(ignore_index=label_ignore, reduce=False) # 50258 is the padding id, otherwise -1 is used for masked LM.
1107
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
1108
+ shift_labels.view(-1))
1109
+ loss = torch.sum(loss.view(-1, shift_labels.shape[-1]), -1)
1110
+ outputs = (loss,) + outputs
1111
+
1112
+ return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
1113
+
1114
+ def resize_token_embeddings(self, new_num_tokens=None):
1115
+ model_embeds = self.transformer._resize_token_embeddings(new_num_tokens)
1116
+ if new_num_tokens is None:
1117
+ return model_embeds
1118
+ self.config.vocab_size = new_num_tokens
1119
+ self.transformer.vocab_size = new_num_tokens
1120
+ if hasattr(self, 'tie_weights'):
1121
+ self.tie_weights()
1122
+ return model_embeds
versatile_diffusion/lib/model_zoo/optimus_models/tokenization_bert.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ from __future__ import absolute_import, division, print_function, unicode_literals
18
+
19
+ import collections
20
+ import logging
21
+ import os
22
+ import unicodedata
23
+ from io import open
24
+
25
+ from .tokenization_utils import PreTrainedTokenizer
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ 'vocab_file':
33
+ {
34
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
35
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
36
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
37
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
38
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
39
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
40
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
41
+ 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
42
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
43
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
44
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
45
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
46
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
47
+ }
48
+ }
49
+
50
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
51
+ 'bert-base-uncased': 512,
52
+ 'bert-large-uncased': 512,
53
+ 'bert-base-cased': 512,
54
+ 'bert-large-cased': 512,
55
+ 'bert-base-multilingual-uncased': 512,
56
+ 'bert-base-multilingual-cased': 512,
57
+ 'bert-base-chinese': 512,
58
+ 'bert-base-german-cased': 512,
59
+ 'bert-large-uncased-whole-word-masking': 512,
60
+ 'bert-large-cased-whole-word-masking': 512,
61
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
62
+ 'bert-large-cased-whole-word-masking-finetuned-squad': 512,
63
+ 'bert-base-cased-finetuned-mrpc': 512,
64
+ }
65
+
66
+ PRETRAINED_INIT_CONFIGURATION = {
67
+ 'bert-base-uncased': {'do_lower_case': True},
68
+ 'bert-large-uncased': {'do_lower_case': True},
69
+ 'bert-base-cased': {'do_lower_case': False},
70
+ 'bert-large-cased': {'do_lower_case': False},
71
+ 'bert-base-multilingual-uncased': {'do_lower_case': True},
72
+ 'bert-base-multilingual-cased': {'do_lower_case': False},
73
+ 'bert-base-chinese': {'do_lower_case': False},
74
+ 'bert-base-german-cased': {'do_lower_case': False},
75
+ 'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
76
+ 'bert-large-cased-whole-word-masking': {'do_lower_case': False},
77
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
78
+ 'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
79
+ 'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
80
+ }
81
+
82
+
83
+ def load_vocab(vocab_file):
84
+ """Loads a vocabulary file into a dictionary."""
85
+ vocab = collections.OrderedDict()
86
+ with open(vocab_file, "r", encoding="utf-8") as reader:
87
+ tokens = reader.readlines()
88
+ for index, token in enumerate(tokens):
89
+ token = token.rstrip('\n')
90
+ vocab[token] = index
91
+ return vocab
92
+
93
+
94
+ def whitespace_tokenize(text):
95
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
96
+ text = text.strip()
97
+ if not text:
98
+ return []
99
+ tokens = text.split()
100
+ return tokens
101
+
102
+
103
+ class BertTokenizer(PreTrainedTokenizer):
104
+ r"""
105
+ Constructs a BertTokenizer.
106
+ :class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
107
+
108
+ Args:
109
+ vocab_file: Path to a one-wordpiece-per-line vocabulary file
110
+ do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
111
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
112
+ max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
113
+ minimum of this value (if specified) and the underlying BERT model's sequence length.
114
+ never_split: List of tokens which will never be split during tokenization. Only has an effect when
115
+ do_wordpiece_only=False
116
+ """
117
+
118
+ vocab_files_names = VOCAB_FILES_NAMES
119
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
120
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
121
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
122
+
123
+ def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
124
+ unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
125
+ mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
126
+ """Constructs a BertTokenizer.
127
+
128
+ Args:
129
+ **vocab_file**: Path to a one-wordpiece-per-line vocabulary file
130
+ **do_lower_case**: (`optional`) boolean (default True)
131
+ Whether to lower case the input
132
+ Only has an effect when do_basic_tokenize=True
133
+ **do_basic_tokenize**: (`optional`) boolean (default True)
134
+ Whether to do basic tokenization before wordpiece.
135
+ **never_split**: (`optional`) list of string
136
+ List of tokens which will never be split during tokenization.
137
+ Only has an effect when do_basic_tokenize=True
138
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
139
+ Whether to tokenize Chinese characters.
140
+ This should likely be deactivated for Japanese:
141
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
142
+ """
143
+ super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
144
+ pad_token=pad_token, cls_token=cls_token,
145
+ mask_token=mask_token, **kwargs)
146
+ self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
147
+ self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
148
+
149
+ if not os.path.isfile(vocab_file):
150
+ raise ValueError(
151
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
152
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
153
+ self.vocab = load_vocab(vocab_file)
154
+ self.ids_to_tokens = collections.OrderedDict(
155
+ [(ids, tok) for tok, ids in self.vocab.items()])
156
+ self.do_basic_tokenize = do_basic_tokenize
157
+ if do_basic_tokenize:
158
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
159
+ never_split=never_split,
160
+ tokenize_chinese_chars=tokenize_chinese_chars)
161
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
162
+
163
+ @property
164
+ def vocab_size(self):
165
+ return len(self.vocab)
166
+
167
+ def _tokenize(self, text):
168
+ split_tokens = []
169
+ if self.do_basic_tokenize:
170
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
171
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
172
+ split_tokens.append(sub_token)
173
+ else:
174
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
175
+ return split_tokens
176
+
177
+ def _convert_token_to_id(self, token):
178
+ """ Converts a token (str/unicode) in an id using the vocab. """
179
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
180
+
181
+ def _convert_id_to_token(self, index):
182
+ """Converts an index (integer) in a token (string/unicode) using the vocab."""
183
+ return self.ids_to_tokens.get(index, self.unk_token)
184
+
185
+ def convert_tokens_to_string(self, tokens):
186
+ """ Converts a sequence of tokens (string) in a single string. """
187
+ out_string = ' '.join(tokens).replace(' ##', '').strip()
188
+ return out_string
189
+
190
+ def add_special_tokens_single_sentence(self, token_ids):
191
+ """
192
+ Adds special tokens to the a sequence for sequence classification tasks.
193
+ A BERT sequence has the following format: [CLS] X [SEP]
194
+ """
195
+ return [self.cls_token_id] + token_ids + [self.sep_token_id]
196
+
197
+ def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
198
+ """
199
+ Adds special tokens to a sequence pair for sequence classification tasks.
200
+ A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
201
+ """
202
+ sep = [self.sep_token_id]
203
+ cls = [self.cls_token_id]
204
+ return cls + token_ids_0 + sep + token_ids_1 + sep
205
+
206
+ def save_vocabulary(self, vocab_path):
207
+ """Save the tokenizer vocabulary to a directory or file."""
208
+ index = 0
209
+ if os.path.isdir(vocab_path):
210
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
211
+ else:
212
+ vocab_file = vocab_path
213
+ with open(vocab_file, "w", encoding="utf-8") as writer:
214
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
215
+ if index != token_index:
216
+ logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
217
+ " Please check that the vocabulary is not corrupted!".format(vocab_file))
218
+ index = token_index
219
+ writer.write(token + u'\n')
220
+ index += 1
221
+ return (vocab_file,)
222
+
223
+
224
+ class BasicTokenizer(object):
225
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
226
+
227
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
228
+ """ Constructs a BasicTokenizer.
229
+
230
+ Args:
231
+ **do_lower_case**: Whether to lower case the input.
232
+ **never_split**: (`optional`) list of str
233
+ Kept for backward compatibility purposes.
234
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
235
+ List of token not to split.
236
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
237
+ Whether to tokenize Chinese characters.
238
+ This should likely be deactivated for Japanese:
239
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
240
+ """
241
+ if never_split is None:
242
+ never_split = []
243
+ self.do_lower_case = do_lower_case
244
+ self.never_split = never_split
245
+ self.tokenize_chinese_chars = tokenize_chinese_chars
246
+
247
+ def tokenize(self, text, never_split=None):
248
+ """ Basic Tokenization of a piece of text.
249
+ Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
250
+
251
+ Args:
252
+ **never_split**: (`optional`) list of str
253
+ Kept for backward compatibility purposes.
254
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
255
+ List of token not to split.
256
+ """
257
+ never_split = self.never_split + (never_split if never_split is not None else [])
258
+ text = self._clean_text(text)
259
+ # This was added on November 1st, 2018 for the multilingual and Chinese
260
+ # models. This is also applied to the English models now, but it doesn't
261
+ # matter since the English models were not trained on any Chinese data
262
+ # and generally don't have any Chinese data in them (there are Chinese
263
+ # characters in the vocabulary because Wikipedia does have some Chinese
264
+ # words in the English Wikipedia.).
265
+ if self.tokenize_chinese_chars:
266
+ text = self._tokenize_chinese_chars(text)
267
+ orig_tokens = whitespace_tokenize(text)
268
+ split_tokens = []
269
+ for token in orig_tokens:
270
+ if self.do_lower_case and token not in never_split:
271
+ token = token.lower()
272
+ token = self._run_strip_accents(token)
273
+ split_tokens.extend(self._run_split_on_punc(token))
274
+
275
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
276
+ return output_tokens
277
+
278
+ def _run_strip_accents(self, text):
279
+ """Strips accents from a piece of text."""
280
+ text = unicodedata.normalize("NFD", text)
281
+ output = []
282
+ for char in text:
283
+ cat = unicodedata.category(char)
284
+ if cat == "Mn":
285
+ continue
286
+ output.append(char)
287
+ return "".join(output)
288
+
289
+ def _run_split_on_punc(self, text, never_split=None):
290
+ """Splits punctuation on a piece of text."""
291
+ if never_split is not None and text in never_split:
292
+ return [text]
293
+ chars = list(text)
294
+ i = 0
295
+ start_new_word = True
296
+ output = []
297
+ while i < len(chars):
298
+ char = chars[i]
299
+ if _is_punctuation(char):
300
+ output.append([char])
301
+ start_new_word = True
302
+ else:
303
+ if start_new_word:
304
+ output.append([])
305
+ start_new_word = False
306
+ output[-1].append(char)
307
+ i += 1
308
+
309
+ return ["".join(x) for x in output]
310
+
311
+ def _tokenize_chinese_chars(self, text):
312
+ """Adds whitespace around any CJK character."""
313
+ output = []
314
+ for char in text:
315
+ cp = ord(char)
316
+ if self._is_chinese_char(cp):
317
+ output.append(" ")
318
+ output.append(char)
319
+ output.append(" ")
320
+ else:
321
+ output.append(char)
322
+ return "".join(output)
323
+
324
+ def _is_chinese_char(self, cp):
325
+ """Checks whether CP is the codepoint of a CJK character."""
326
+ # This defines a "chinese character" as anything in the CJK Unicode block:
327
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
328
+ #
329
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
330
+ # despite its name. The modern Korean Hangul alphabet is a different block,
331
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
332
+ # space-separated words, so they are not treated specially and handled
333
+ # like the all of the other languages.
334
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
335
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
336
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
337
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
338
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
339
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
340
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
341
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
342
+ return True
343
+
344
+ return False
345
+
346
+ def _clean_text(self, text):
347
+ """Performs invalid character removal and whitespace cleanup on text."""
348
+ output = []
349
+ for char in text:
350
+ cp = ord(char)
351
+ if cp == 0 or cp == 0xfffd or _is_control(char):
352
+ continue
353
+ if _is_whitespace(char):
354
+ output.append(" ")
355
+ else:
356
+ output.append(char)
357
+ return "".join(output)
358
+
359
+
360
+ class WordpieceTokenizer(object):
361
+ """Runs WordPiece tokenization."""
362
+
363
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
364
+ self.vocab = vocab
365
+ self.unk_token = unk_token
366
+ self.max_input_chars_per_word = max_input_chars_per_word
367
+
368
+ def tokenize(self, text):
369
+ """Tokenizes a piece of text into its word pieces.
370
+
371
+ This uses a greedy longest-match-first algorithm to perform tokenization
372
+ using the given vocabulary.
373
+
374
+ For example:
375
+ input = "unaffable"
376
+ output = ["un", "##aff", "##able"]
377
+
378
+ Args:
379
+ text: A single token or whitespace separated tokens. This should have
380
+ already been passed through `BasicTokenizer`.
381
+
382
+ Returns:
383
+ A list of wordpiece tokens.
384
+ """
385
+
386
+ output_tokens = []
387
+ for token in whitespace_tokenize(text):
388
+ chars = list(token)
389
+ if len(chars) > self.max_input_chars_per_word:
390
+ output_tokens.append(self.unk_token)
391
+ continue
392
+
393
+ is_bad = False
394
+ start = 0
395
+ sub_tokens = []
396
+ while start < len(chars):
397
+ end = len(chars)
398
+ cur_substr = None
399
+ while start < end:
400
+ substr = "".join(chars[start:end])
401
+ if start > 0:
402
+ substr = "##" + substr
403
+ if substr in self.vocab:
404
+ cur_substr = substr
405
+ break
406
+ end -= 1
407
+ if cur_substr is None:
408
+ is_bad = True
409
+ break
410
+ sub_tokens.append(cur_substr)
411
+ start = end
412
+
413
+ if is_bad:
414
+ output_tokens.append(self.unk_token)
415
+ else:
416
+ output_tokens.extend(sub_tokens)
417
+ return output_tokens
418
+
419
+
420
+ def _is_whitespace(char):
421
+ """Checks whether `chars` is a whitespace character."""
422
+ # \t, \n, and \r are technically contorl characters but we treat them
423
+ # as whitespace since they are generally considered as such.
424
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
425
+ return True
426
+ cat = unicodedata.category(char)
427
+ if cat == "Zs":
428
+ return True
429
+ return False
430
+
431
+
432
+ def _is_control(char):
433
+ """Checks whether `chars` is a control character."""
434
+ # These are technically control characters but we count them as whitespace
435
+ # characters.
436
+ if char == "\t" or char == "\n" or char == "\r":
437
+ return False
438
+ cat = unicodedata.category(char)
439
+ if cat.startswith("C"):
440
+ return True
441
+ return False
442
+
443
+
444
+ def _is_punctuation(char):
445
+ """Checks whether `chars` is a punctuation character."""
446
+ cp = ord(char)
447
+ # We treat all non-letter/number ASCII as punctuation.
448
+ # Characters such as "^", "$", and "`" are not in the Unicode
449
+ # Punctuation class but we treat them as punctuation anyways, for
450
+ # consistency.
451
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
452
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
453
+ return True
454
+ cat = unicodedata.category(char)
455
+ if cat.startswith("P"):
456
+ return True
457
+ return False