Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- versatile_diffusion/lib/model_zoo/__pycache__/ddim.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/ddim.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/ddim_dualcontext.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/ddim_vd.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/ddim_vd.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/diffusion_modules.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/diffusion_modules.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/diffusion_utils.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/distributions.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/distributions.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/ema.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/ema.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/openaimodel.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/optimus.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/optimus.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/sd.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/__pycache__/sd.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/model.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/clip_justin/model.py +436 -0
- versatile_diffusion/lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/common/__pycache__/get_model.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/common/__pycache__/get_optimizer.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/common/__pycache__/get_scheduler.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/common/__pycache__/utils.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/common/__pycache__/utils.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/common/utils.py +292 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_bert.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_bert.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_gpt2.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_gpt2.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_utils.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_utils.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/file_utils.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/file_utils.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/modeling_utils.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_bert.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_gpt2.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_utils.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/configuration_bert.py +113 -0
- versatile_diffusion/lib/model_zoo/optimus_models/configuration_gpt2.py +143 -0
- versatile_diffusion/lib/model_zoo/optimus_models/configuration_utils.py +205 -0
- versatile_diffusion/lib/model_zoo/optimus_models/file_utils.py +294 -0
- versatile_diffusion/lib/model_zoo/optimus_models/modeling_utils.py +780 -0
- versatile_diffusion/lib/model_zoo/optimus_models/optimus_bert.py +1439 -0
- versatile_diffusion/lib/model_zoo/optimus_models/optimus_gpt2.py +1122 -0
- versatile_diffusion/lib/model_zoo/optimus_models/tokenization_bert.py +457 -0
versatile_diffusion/lib/model_zoo/__pycache__/ddim.cpython-310.pyc
ADDED
|
Binary file (6.25 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/ddim.cpython-38.pyc
ADDED
|
Binary file (6.2 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/ddim_dualcontext.cpython-38.pyc
ADDED
|
Binary file (4.4 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/ddim_vd.cpython-310.pyc
ADDED
|
Binary file (9.66 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/ddim_vd.cpython-38.pyc
ADDED
|
Binary file (9.54 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/diffusion_modules.cpython-310.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/diffusion_modules.cpython-38.pyc
ADDED
|
Binary file (20.7 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc
ADDED
|
Binary file (9.56 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/diffusion_utils.cpython-38.pyc
ADDED
|
Binary file (9.52 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/distributions.cpython-310.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/distributions.cpython-38.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/ema.cpython-310.pyc
ADDED
|
Binary file (3.04 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/ema.cpython-38.pyc
ADDED
|
Binary file (2.99 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc
ADDED
|
Binary file (44.8 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/openaimodel.cpython-38.pyc
ADDED
|
Binary file (47.3 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/optimus.cpython-310.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/optimus.cpython-38.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/sd.cpython-310.pyc
ADDED
|
Binary file (22 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/__pycache__/sd.cpython-38.pyc
ADDED
|
Binary file (22.2 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/clip_justin/__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/clip_justin/model.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
expansion = 4
|
| 12 |
+
|
| 13 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 20 |
+
|
| 21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 24 |
+
|
| 25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 26 |
+
|
| 27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 30 |
+
|
| 31 |
+
self.downsample = None
|
| 32 |
+
self.stride = stride
|
| 33 |
+
|
| 34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 37 |
+
("-1", nn.AvgPool2d(stride)),
|
| 38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 40 |
+
]))
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
identity = x
|
| 44 |
+
|
| 45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
| 47 |
+
out = self.avgpool(out)
|
| 48 |
+
out = self.bn3(self.conv3(out))
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
identity = self.downsample(x)
|
| 52 |
+
|
| 53 |
+
out += identity
|
| 54 |
+
out = self.relu3(out)
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AttentionPool2d(nn.Module):
|
| 59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 72 |
+
x, _ = F.multi_head_attention_forward(
|
| 73 |
+
query=x[:1], key=x, value=x,
|
| 74 |
+
embed_dim_to_check=x.shape[-1],
|
| 75 |
+
num_heads=self.num_heads,
|
| 76 |
+
q_proj_weight=self.q_proj.weight,
|
| 77 |
+
k_proj_weight=self.k_proj.weight,
|
| 78 |
+
v_proj_weight=self.v_proj.weight,
|
| 79 |
+
in_proj_weight=None,
|
| 80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 81 |
+
bias_k=None,
|
| 82 |
+
bias_v=None,
|
| 83 |
+
add_zero_attn=False,
|
| 84 |
+
dropout_p=0,
|
| 85 |
+
out_proj_weight=self.c_proj.weight,
|
| 86 |
+
out_proj_bias=self.c_proj.bias,
|
| 87 |
+
use_separate_proj_weight=True,
|
| 88 |
+
training=self.training,
|
| 89 |
+
need_weights=False
|
| 90 |
+
)
|
| 91 |
+
return x.squeeze(0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ModifiedResNet(nn.Module):
|
| 95 |
+
"""
|
| 96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.output_dim = output_dim
|
| 105 |
+
self.input_resolution = input_resolution
|
| 106 |
+
|
| 107 |
+
# the 3-layer stem
|
| 108 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 109 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 113 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 116 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 117 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 118 |
+
|
| 119 |
+
# residual layers
|
| 120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 125 |
+
|
| 126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 128 |
+
|
| 129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 130 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 131 |
+
|
| 132 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 133 |
+
for _ in range(1, blocks):
|
| 134 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 135 |
+
|
| 136 |
+
return nn.Sequential(*layers)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
def stem(x):
|
| 140 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 141 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 142 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 143 |
+
x = self.avgpool(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
x = x.type(self.conv1.weight.dtype)
|
| 147 |
+
x = stem(x)
|
| 148 |
+
x = self.layer1(x)
|
| 149 |
+
x = self.layer2(x)
|
| 150 |
+
x = self.layer3(x)
|
| 151 |
+
x = self.layer4(x)
|
| 152 |
+
x = self.attnpool(x)
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class LayerNorm(nn.LayerNorm):
|
| 158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 159 |
+
|
| 160 |
+
def forward(self, x: torch.Tensor):
|
| 161 |
+
orig_type = x.dtype
|
| 162 |
+
ret = super().forward(x.type(torch.float32))
|
| 163 |
+
return ret.type(orig_type)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class QuickGELU(nn.Module):
|
| 167 |
+
def forward(self, x: torch.Tensor):
|
| 168 |
+
return x * torch.sigmoid(1.702 * x)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ResidualAttentionBlock(nn.Module):
|
| 172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 173 |
+
super().__init__()
|
| 174 |
+
|
| 175 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 176 |
+
self.ln_1 = LayerNorm(d_model)
|
| 177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 179 |
+
("gelu", QuickGELU()),
|
| 180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 181 |
+
]))
|
| 182 |
+
self.ln_2 = LayerNorm(d_model)
|
| 183 |
+
self.attn_mask = attn_mask
|
| 184 |
+
|
| 185 |
+
def attention(self, x: torch.Tensor):
|
| 186 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 187 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 188 |
+
|
| 189 |
+
def forward(self, x: torch.Tensor):
|
| 190 |
+
x = x + self.attention(self.ln_1(x))
|
| 191 |
+
x = x + self.mlp(self.ln_2(x))
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Transformer(nn.Module):
|
| 196 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.width = width
|
| 199 |
+
self.layers = layers
|
| 200 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.Tensor):
|
| 203 |
+
return self.resblocks(x)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class VisionTransformer(nn.Module):
|
| 207 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.input_resolution = input_resolution
|
| 210 |
+
self.output_dim = output_dim
|
| 211 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 212 |
+
|
| 213 |
+
scale = width ** -0.5
|
| 214 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 215 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 216 |
+
self.ln_pre = LayerNorm(width)
|
| 217 |
+
|
| 218 |
+
self.transformer = Transformer(width, layers, heads)
|
| 219 |
+
|
| 220 |
+
self.ln_post = LayerNorm(width)
|
| 221 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 222 |
+
|
| 223 |
+
def forward(self, x: torch.Tensor):
|
| 224 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 225 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 226 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 227 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 228 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 229 |
+
x = self.ln_pre(x)
|
| 230 |
+
|
| 231 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 232 |
+
x = self.transformer(x)
|
| 233 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 234 |
+
|
| 235 |
+
x = self.ln_post(x[:, 0, :])
|
| 236 |
+
|
| 237 |
+
if self.proj is not None:
|
| 238 |
+
x = x @ self.proj
|
| 239 |
+
|
| 240 |
+
return x
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class CLIP(nn.Module):
|
| 244 |
+
def __init__(self,
|
| 245 |
+
embed_dim: int,
|
| 246 |
+
# vision
|
| 247 |
+
image_resolution: int,
|
| 248 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 249 |
+
vision_width: int,
|
| 250 |
+
vision_patch_size: int,
|
| 251 |
+
# text
|
| 252 |
+
context_length: int,
|
| 253 |
+
vocab_size: int,
|
| 254 |
+
transformer_width: int,
|
| 255 |
+
transformer_heads: int,
|
| 256 |
+
transformer_layers: int
|
| 257 |
+
):
|
| 258 |
+
super().__init__()
|
| 259 |
+
|
| 260 |
+
self.context_length = context_length
|
| 261 |
+
|
| 262 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 263 |
+
vision_heads = vision_width * 32 // 64
|
| 264 |
+
self.visual = ModifiedResNet(
|
| 265 |
+
layers=vision_layers,
|
| 266 |
+
output_dim=embed_dim,
|
| 267 |
+
heads=vision_heads,
|
| 268 |
+
input_resolution=image_resolution,
|
| 269 |
+
width=vision_width
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
vision_heads = vision_width // 64
|
| 273 |
+
self.visual = VisionTransformer(
|
| 274 |
+
input_resolution=image_resolution,
|
| 275 |
+
patch_size=vision_patch_size,
|
| 276 |
+
width=vision_width,
|
| 277 |
+
layers=vision_layers,
|
| 278 |
+
heads=vision_heads,
|
| 279 |
+
output_dim=embed_dim
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
self.transformer = Transformer(
|
| 283 |
+
width=transformer_width,
|
| 284 |
+
layers=transformer_layers,
|
| 285 |
+
heads=transformer_heads,
|
| 286 |
+
attn_mask=self.build_attention_mask()
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
self.vocab_size = vocab_size
|
| 290 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 291 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 292 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 293 |
+
|
| 294 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 295 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 296 |
+
|
| 297 |
+
self.initialize_parameters()
|
| 298 |
+
|
| 299 |
+
def initialize_parameters(self):
|
| 300 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 301 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 302 |
+
|
| 303 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 304 |
+
if self.visual.attnpool is not None:
|
| 305 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 306 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 307 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 308 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 309 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 310 |
+
|
| 311 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 312 |
+
for name, param in resnet_block.named_parameters():
|
| 313 |
+
if name.endswith("bn3.weight"):
|
| 314 |
+
nn.init.zeros_(param)
|
| 315 |
+
|
| 316 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 317 |
+
attn_std = self.transformer.width ** -0.5
|
| 318 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 319 |
+
for block in self.transformer.resblocks:
|
| 320 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 321 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 322 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 323 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 324 |
+
|
| 325 |
+
if self.text_projection is not None:
|
| 326 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 327 |
+
|
| 328 |
+
def build_attention_mask(self):
|
| 329 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 330 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 331 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 332 |
+
mask.fill_(float("-inf"))
|
| 333 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 334 |
+
return mask
|
| 335 |
+
|
| 336 |
+
@property
|
| 337 |
+
def dtype(self):
|
| 338 |
+
return self.visual.conv1.weight.dtype
|
| 339 |
+
|
| 340 |
+
def encode_image(self, image):
|
| 341 |
+
return self.visual(image.type(self.dtype))
|
| 342 |
+
|
| 343 |
+
def encode_text(self, text):
|
| 344 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 345 |
+
|
| 346 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 347 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 348 |
+
x = self.transformer(x)
|
| 349 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 350 |
+
x = self.ln_final(x).type(self.dtype)
|
| 351 |
+
|
| 352 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 353 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 354 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 355 |
+
|
| 356 |
+
return x
|
| 357 |
+
|
| 358 |
+
def forward(self, image, text):
|
| 359 |
+
image_features = self.encode_image(image)
|
| 360 |
+
text_features = self.encode_text(text)
|
| 361 |
+
|
| 362 |
+
# normalized features
|
| 363 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 364 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 365 |
+
|
| 366 |
+
# cosine similarity as logits
|
| 367 |
+
logit_scale = self.logit_scale.exp()
|
| 368 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 369 |
+
logits_per_text = logits_per_image.t()
|
| 370 |
+
|
| 371 |
+
# shape = [global_batch_size, global_batch_size]
|
| 372 |
+
return logits_per_image, logits_per_text
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def convert_weights(model: nn.Module):
|
| 376 |
+
"""Convert applicable model parameters to fp16"""
|
| 377 |
+
|
| 378 |
+
def _convert_weights_to_fp16(l):
|
| 379 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 380 |
+
l.weight.data = l.weight.data.half()
|
| 381 |
+
if l.bias is not None:
|
| 382 |
+
l.bias.data = l.bias.data.half()
|
| 383 |
+
|
| 384 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 385 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 386 |
+
tensor = getattr(l, attr)
|
| 387 |
+
if tensor is not None:
|
| 388 |
+
tensor.data = tensor.data.half()
|
| 389 |
+
|
| 390 |
+
for name in ["text_projection", "proj"]:
|
| 391 |
+
if hasattr(l, name):
|
| 392 |
+
attr = getattr(l, name)
|
| 393 |
+
if attr is not None:
|
| 394 |
+
attr.data = attr.data.half()
|
| 395 |
+
|
| 396 |
+
model.apply(_convert_weights_to_fp16)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def build_model(state_dict: dict):
|
| 400 |
+
vit = "visual.proj" in state_dict
|
| 401 |
+
|
| 402 |
+
if vit:
|
| 403 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 404 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 405 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 406 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 407 |
+
image_resolution = vision_patch_size * grid_size
|
| 408 |
+
else:
|
| 409 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 410 |
+
vision_layers = tuple(counts)
|
| 411 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 412 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 413 |
+
vision_patch_size = None
|
| 414 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 415 |
+
image_resolution = output_width * 32
|
| 416 |
+
|
| 417 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 418 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 419 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 420 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 421 |
+
transformer_heads = transformer_width // 64
|
| 422 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 423 |
+
|
| 424 |
+
model = CLIP(
|
| 425 |
+
embed_dim,
|
| 426 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 427 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 431 |
+
if key in state_dict:
|
| 432 |
+
del state_dict[key]
|
| 433 |
+
|
| 434 |
+
convert_weights(model)
|
| 435 |
+
model.load_state_dict(state_dict)
|
| 436 |
+
return model.eval()
|
versatile_diffusion/lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc
ADDED
|
Binary file (3.31 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/common/__pycache__/get_model.cpython-38.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc
ADDED
|
Binary file (1.98 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/common/__pycache__/get_optimizer.cpython-38.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc
ADDED
|
Binary file (9.47 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/common/__pycache__/get_scheduler.cpython-38.pyc
ADDED
|
Binary file (9.54 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/common/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (9.75 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/common/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (9.77 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/common/utils.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import copy
|
| 6 |
+
import functools
|
| 7 |
+
import itertools
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
########
|
| 12 |
+
# unit #
|
| 13 |
+
########
|
| 14 |
+
|
| 15 |
+
def singleton(class_):
|
| 16 |
+
instances = {}
|
| 17 |
+
def getinstance(*args, **kwargs):
|
| 18 |
+
if class_ not in instances:
|
| 19 |
+
instances[class_] = class_(*args, **kwargs)
|
| 20 |
+
return instances[class_]
|
| 21 |
+
return getinstance
|
| 22 |
+
|
| 23 |
+
def str2value(v):
|
| 24 |
+
v = v.strip()
|
| 25 |
+
try:
|
| 26 |
+
return int(v)
|
| 27 |
+
except:
|
| 28 |
+
pass
|
| 29 |
+
try:
|
| 30 |
+
return float(v)
|
| 31 |
+
except:
|
| 32 |
+
pass
|
| 33 |
+
if v in ('True', 'true'):
|
| 34 |
+
return True
|
| 35 |
+
elif v in ('False', 'false'):
|
| 36 |
+
return False
|
| 37 |
+
else:
|
| 38 |
+
return v
|
| 39 |
+
|
| 40 |
+
@singleton
|
| 41 |
+
class get_unit(object):
|
| 42 |
+
def __init__(self):
|
| 43 |
+
self.unit = {}
|
| 44 |
+
self.register('none', None)
|
| 45 |
+
|
| 46 |
+
# general convolution
|
| 47 |
+
self.register('conv' , nn.Conv2d)
|
| 48 |
+
self.register('bn' , nn.BatchNorm2d)
|
| 49 |
+
self.register('relu' , nn.ReLU)
|
| 50 |
+
self.register('relu6' , nn.ReLU6)
|
| 51 |
+
self.register('lrelu' , nn.LeakyReLU)
|
| 52 |
+
self.register('dropout' , nn.Dropout)
|
| 53 |
+
self.register('dropout2d', nn.Dropout2d)
|
| 54 |
+
self.register('sine', Sine)
|
| 55 |
+
self.register('relusine', ReLUSine)
|
| 56 |
+
|
| 57 |
+
def register(self,
|
| 58 |
+
name,
|
| 59 |
+
unitf,):
|
| 60 |
+
|
| 61 |
+
self.unit[name] = unitf
|
| 62 |
+
|
| 63 |
+
def __call__(self, name):
|
| 64 |
+
if name is None:
|
| 65 |
+
return None
|
| 66 |
+
i = name.find('(')
|
| 67 |
+
i = len(name) if i==-1 else i
|
| 68 |
+
t = name[:i]
|
| 69 |
+
f = self.unit[t]
|
| 70 |
+
args = name[i:].strip('()')
|
| 71 |
+
if len(args) == 0:
|
| 72 |
+
args = {}
|
| 73 |
+
return f
|
| 74 |
+
else:
|
| 75 |
+
args = args.split('=')
|
| 76 |
+
args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
|
| 77 |
+
args = list(itertools.chain.from_iterable(args))
|
| 78 |
+
args = [i.strip() for i in args if len(i)>0]
|
| 79 |
+
kwargs = {}
|
| 80 |
+
for k, v in zip(args[::2], args[1::2]):
|
| 81 |
+
if v[0]=='(' and v[-1]==')':
|
| 82 |
+
kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
|
| 83 |
+
elif v[0]=='[' and v[-1]==']':
|
| 84 |
+
kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
|
| 85 |
+
else:
|
| 86 |
+
kwargs[k] = str2value(v)
|
| 87 |
+
return functools.partial(f, **kwargs)
|
| 88 |
+
|
| 89 |
+
def register(name):
|
| 90 |
+
def wrapper(class_):
|
| 91 |
+
get_unit().register(name, class_)
|
| 92 |
+
return class_
|
| 93 |
+
return wrapper
|
| 94 |
+
|
| 95 |
+
class Sine(object):
|
| 96 |
+
def __init__(self, freq, gain=1):
|
| 97 |
+
self.freq = freq
|
| 98 |
+
self.gain = gain
|
| 99 |
+
self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
|
| 100 |
+
|
| 101 |
+
def __call__(self, x, gain=1):
|
| 102 |
+
act_gain = self.gain * gain
|
| 103 |
+
return torch.sin(self.freq * x) * act_gain
|
| 104 |
+
|
| 105 |
+
def __repr__(self,):
|
| 106 |
+
return self.repr
|
| 107 |
+
|
| 108 |
+
class ReLUSine(nn.Module):
|
| 109 |
+
def __init(self):
|
| 110 |
+
super().__init__()
|
| 111 |
+
|
| 112 |
+
def forward(self, input):
|
| 113 |
+
a = torch.sin(30 * input)
|
| 114 |
+
b = nn.ReLU(inplace=False)(input)
|
| 115 |
+
return a+b
|
| 116 |
+
|
| 117 |
+
@register('lrelu_agc')
|
| 118 |
+
# class lrelu_agc(nn.Module):
|
| 119 |
+
class lrelu_agc(object):
|
| 120 |
+
"""
|
| 121 |
+
The lrelu layer with alpha, gain and clamp
|
| 122 |
+
"""
|
| 123 |
+
def __init__(self, alpha=0.1, gain=1, clamp=None):
|
| 124 |
+
# super().__init__()
|
| 125 |
+
self.alpha = alpha
|
| 126 |
+
if gain == 'sqrt_2':
|
| 127 |
+
self.gain = np.sqrt(2)
|
| 128 |
+
else:
|
| 129 |
+
self.gain = gain
|
| 130 |
+
self.clamp = clamp
|
| 131 |
+
self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
|
| 132 |
+
alpha, gain, clamp)
|
| 133 |
+
|
| 134 |
+
# def forward(self, x, gain=1):
|
| 135 |
+
def __call__(self, x, gain=1):
|
| 136 |
+
x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
|
| 137 |
+
act_gain = self.gain * gain
|
| 138 |
+
act_clamp = self.clamp * gain if self.clamp is not None else None
|
| 139 |
+
if act_gain != 1:
|
| 140 |
+
x = x * act_gain
|
| 141 |
+
if act_clamp is not None:
|
| 142 |
+
x = x.clamp(-act_clamp, act_clamp)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def __repr__(self,):
|
| 146 |
+
return self.repr
|
| 147 |
+
|
| 148 |
+
####################
|
| 149 |
+
# spatial encoding #
|
| 150 |
+
####################
|
| 151 |
+
|
| 152 |
+
@register('se')
|
| 153 |
+
class SpatialEncoding(nn.Module):
|
| 154 |
+
def __init__(self,
|
| 155 |
+
in_dim,
|
| 156 |
+
out_dim,
|
| 157 |
+
sigma = 6,
|
| 158 |
+
cat_input=True,
|
| 159 |
+
require_grad=False,):
|
| 160 |
+
|
| 161 |
+
super().__init__()
|
| 162 |
+
assert out_dim % (2*in_dim) == 0, "dimension must be dividable"
|
| 163 |
+
|
| 164 |
+
n = out_dim // 2 // in_dim
|
| 165 |
+
m = 2**np.linspace(0, sigma, n)
|
| 166 |
+
m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1)
|
| 167 |
+
m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
|
| 168 |
+
self.emb = torch.FloatTensor(m)
|
| 169 |
+
if require_grad:
|
| 170 |
+
self.emb = nn.Parameter(self.emb, requires_grad=True)
|
| 171 |
+
self.in_dim = in_dim
|
| 172 |
+
self.out_dim = out_dim
|
| 173 |
+
self.sigma = sigma
|
| 174 |
+
self.cat_input = cat_input
|
| 175 |
+
self.require_grad = require_grad
|
| 176 |
+
|
| 177 |
+
def forward(self, x, format='[n x c]'):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
x: [n x m1],
|
| 181 |
+
m1 usually is 2
|
| 182 |
+
Outputs:
|
| 183 |
+
y: [n x m2]
|
| 184 |
+
m2 dimention number
|
| 185 |
+
"""
|
| 186 |
+
if format == '[bs x c x 2D]':
|
| 187 |
+
xshape = x.shape
|
| 188 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
| 189 |
+
x = x.view(-1, x.size(-1))
|
| 190 |
+
elif format == '[n x c]':
|
| 191 |
+
pass
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError
|
| 194 |
+
|
| 195 |
+
if not self.require_grad:
|
| 196 |
+
self.emb = self.emb.to(x.device)
|
| 197 |
+
y = torch.mm(x, self.emb.T)
|
| 198 |
+
if self.cat_input:
|
| 199 |
+
z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
|
| 200 |
+
else:
|
| 201 |
+
z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
|
| 202 |
+
|
| 203 |
+
if format == '[bs x c x 2D]':
|
| 204 |
+
z = z.view(xshape[0], xshape[2], xshape[3], -1)
|
| 205 |
+
z = z.permute(0, 3, 1, 2).contiguous()
|
| 206 |
+
return z
|
| 207 |
+
|
| 208 |
+
def extra_repr(self):
|
| 209 |
+
outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
|
| 210 |
+
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
|
| 211 |
+
return outstr
|
| 212 |
+
|
| 213 |
+
@register('rffe')
|
| 214 |
+
class RFFEncoding(SpatialEncoding):
|
| 215 |
+
"""
|
| 216 |
+
Random Fourier Features
|
| 217 |
+
"""
|
| 218 |
+
def __init__(self,
|
| 219 |
+
in_dim,
|
| 220 |
+
out_dim,
|
| 221 |
+
sigma = 6,
|
| 222 |
+
cat_input=True,
|
| 223 |
+
require_grad=False,):
|
| 224 |
+
|
| 225 |
+
super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
|
| 226 |
+
n = out_dim // 2
|
| 227 |
+
m = np.random.normal(0, sigma, size=(n, in_dim))
|
| 228 |
+
self.emb = torch.FloatTensor(m)
|
| 229 |
+
if require_grad:
|
| 230 |
+
self.emb = nn.Parameter(self.emb, requires_grad=True)
|
| 231 |
+
|
| 232 |
+
def extra_repr(self):
|
| 233 |
+
outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
|
| 234 |
+
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
|
| 235 |
+
return outstr
|
| 236 |
+
|
| 237 |
+
##########
|
| 238 |
+
# helper #
|
| 239 |
+
##########
|
| 240 |
+
|
| 241 |
+
def freeze(net):
|
| 242 |
+
for m in net.modules():
|
| 243 |
+
if isinstance(m, (
|
| 244 |
+
nn.BatchNorm2d,
|
| 245 |
+
nn.SyncBatchNorm,)):
|
| 246 |
+
# inplace_abn not supported
|
| 247 |
+
m.eval()
|
| 248 |
+
for pi in net.parameters():
|
| 249 |
+
pi.requires_grad = False
|
| 250 |
+
return net
|
| 251 |
+
|
| 252 |
+
def common_init(m):
|
| 253 |
+
if isinstance(m, (
|
| 254 |
+
nn.Conv2d,
|
| 255 |
+
nn.ConvTranspose2d,)):
|
| 256 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 257 |
+
if m.bias is not None:
|
| 258 |
+
nn.init.constant_(m.bias, 0)
|
| 259 |
+
elif isinstance(m, (
|
| 260 |
+
nn.BatchNorm2d,
|
| 261 |
+
nn.SyncBatchNorm,)):
|
| 262 |
+
nn.init.constant_(m.weight, 1)
|
| 263 |
+
nn.init.constant_(m.bias, 0)
|
| 264 |
+
else:
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
def init_module(module):
|
| 268 |
+
"""
|
| 269 |
+
Args:
|
| 270 |
+
module: [nn.module] list or nn.module
|
| 271 |
+
a list of module to be initialized.
|
| 272 |
+
"""
|
| 273 |
+
if isinstance(module, (list, tuple)):
|
| 274 |
+
module = list(module)
|
| 275 |
+
else:
|
| 276 |
+
module = [module]
|
| 277 |
+
|
| 278 |
+
for mi in module:
|
| 279 |
+
for mii in mi.modules():
|
| 280 |
+
common_init(mii)
|
| 281 |
+
|
| 282 |
+
def get_total_param(net):
|
| 283 |
+
if getattr(net, 'parameters', None) is None:
|
| 284 |
+
return 0
|
| 285 |
+
return sum(p.numel() for p in net.parameters())
|
| 286 |
+
|
| 287 |
+
def get_total_param_sum(net):
|
| 288 |
+
if getattr(net, 'parameters', None) is None:
|
| 289 |
+
return 0
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
|
| 292 |
+
return s
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_bert.cpython-310.pyc
ADDED
|
Binary file (5.34 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_bert.cpython-38.pyc
ADDED
|
Binary file (5.24 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_gpt2.cpython-310.pyc
ADDED
|
Binary file (5.11 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_gpt2.cpython-38.pyc
ADDED
|
Binary file (5.02 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_utils.cpython-310.pyc
ADDED
|
Binary file (9.46 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/configuration_utils.cpython-38.pyc
ADDED
|
Binary file (9.4 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/file_utils.cpython-310.pyc
ADDED
|
Binary file (8.37 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/file_utils.cpython-38.pyc
ADDED
|
Binary file (8.32 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/modeling_utils.cpython-38.pyc
ADDED
|
Binary file (31.7 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_bert.cpython-38.pyc
ADDED
|
Binary file (54.6 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_gpt2.cpython-310.pyc
ADDED
|
Binary file (9.12 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_utils.cpython-38.pyc
ADDED
|
Binary file (33.5 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/configuration_bert.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" BERT model configuration """
|
| 17 |
+
|
| 18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import sys
|
| 23 |
+
from io import open
|
| 24 |
+
|
| 25 |
+
from .configuration_utils import PretrainedConfig
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
| 30 |
+
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
| 31 |
+
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
| 32 |
+
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
| 33 |
+
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
| 34 |
+
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
| 35 |
+
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
| 36 |
+
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
| 37 |
+
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
| 38 |
+
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
| 39 |
+
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
| 40 |
+
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
|
| 41 |
+
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
|
| 42 |
+
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BertConfig(PretrainedConfig):
|
| 47 |
+
r"""
|
| 48 |
+
:class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
|
| 49 |
+
`BertModel`.
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
Arguments:
|
| 53 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
| 54 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
| 55 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
| 56 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
| 57 |
+
the Transformer encoder.
|
| 58 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
| 59 |
+
layer in the Transformer encoder.
|
| 60 |
+
hidden_act: The non-linear activation function (function or string) in the
|
| 61 |
+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
| 62 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
| 63 |
+
layers in the embeddings, encoder, and pooler.
|
| 64 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
| 65 |
+
probabilities.
|
| 66 |
+
max_position_embeddings: The maximum sequence length that this model might
|
| 67 |
+
ever be used with. Typically set this to something large just in case
|
| 68 |
+
(e.g., 512 or 1024 or 2048).
|
| 69 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
| 70 |
+
`BertModel`.
|
| 71 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
| 72 |
+
initializing all weight matrices.
|
| 73 |
+
layer_norm_eps: The epsilon used by LayerNorm.
|
| 74 |
+
"""
|
| 75 |
+
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
| 76 |
+
|
| 77 |
+
def __init__(self,
|
| 78 |
+
vocab_size_or_config_json_file=30522,
|
| 79 |
+
hidden_size=768,
|
| 80 |
+
num_hidden_layers=12,
|
| 81 |
+
num_attention_heads=12,
|
| 82 |
+
intermediate_size=3072,
|
| 83 |
+
hidden_act="gelu",
|
| 84 |
+
hidden_dropout_prob=0.1,
|
| 85 |
+
attention_probs_dropout_prob=0.1,
|
| 86 |
+
max_position_embeddings=512,
|
| 87 |
+
type_vocab_size=2,
|
| 88 |
+
initializer_range=0.02,
|
| 89 |
+
layer_norm_eps=1e-12,
|
| 90 |
+
**kwargs):
|
| 91 |
+
super(BertConfig, self).__init__(**kwargs)
|
| 92 |
+
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
| 93 |
+
and isinstance(vocab_size_or_config_json_file, unicode)):
|
| 94 |
+
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
| 95 |
+
json_config = json.loads(reader.read())
|
| 96 |
+
for key, value in json_config.items():
|
| 97 |
+
self.__dict__[key] = value
|
| 98 |
+
elif isinstance(vocab_size_or_config_json_file, int):
|
| 99 |
+
self.vocab_size = vocab_size_or_config_json_file
|
| 100 |
+
self.hidden_size = hidden_size
|
| 101 |
+
self.num_hidden_layers = num_hidden_layers
|
| 102 |
+
self.num_attention_heads = num_attention_heads
|
| 103 |
+
self.hidden_act = hidden_act
|
| 104 |
+
self.intermediate_size = intermediate_size
|
| 105 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 106 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 107 |
+
self.max_position_embeddings = max_position_embeddings
|
| 108 |
+
self.type_vocab_size = type_vocab_size
|
| 109 |
+
self.initializer_range = initializer_range
|
| 110 |
+
self.layer_norm_eps = layer_norm_eps
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError("First argument must be either a vocabulary size (int)"
|
| 113 |
+
" or the path to a pretrained model config file (str)")
|
versatile_diffusion/lib/model_zoo/optimus_models/configuration_gpt2.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" OpenAI GPT-2 configuration """
|
| 17 |
+
|
| 18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import sys
|
| 23 |
+
from io import open
|
| 24 |
+
|
| 25 |
+
from .configuration_utils import PretrainedConfig
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
| 30 |
+
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
| 31 |
+
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"}
|
| 32 |
+
|
| 33 |
+
class GPT2Config(PretrainedConfig):
|
| 34 |
+
"""Configuration class to store the configuration of a `GPT2Model`.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
|
| 38 |
+
n_positions: Number of positional embeddings.
|
| 39 |
+
n_ctx: Size of the causal mask (usually same as n_positions).
|
| 40 |
+
n_embd: Dimensionality of the embeddings and hidden states.
|
| 41 |
+
n_layer: Number of hidden layers in the Transformer encoder.
|
| 42 |
+
n_head: Number of attention heads for each attention layer in
|
| 43 |
+
the Transformer encoder.
|
| 44 |
+
layer_norm_epsilon: epsilon to use in the layer norm layers
|
| 45 |
+
resid_pdrop: The dropout probabilitiy for all fully connected
|
| 46 |
+
layers in the embeddings, encoder, and pooler.
|
| 47 |
+
attn_pdrop: The dropout ratio for the attention
|
| 48 |
+
probabilities.
|
| 49 |
+
embd_pdrop: The dropout ratio for the embeddings.
|
| 50 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
| 51 |
+
initializing all weight matrices.
|
| 52 |
+
"""
|
| 53 |
+
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
vocab_size_or_config_json_file=50257,
|
| 58 |
+
n_positions=1024,
|
| 59 |
+
n_ctx=1024,
|
| 60 |
+
n_embd=768,
|
| 61 |
+
n_layer=12,
|
| 62 |
+
n_head=12,
|
| 63 |
+
resid_pdrop=0.1,
|
| 64 |
+
embd_pdrop=0.1,
|
| 65 |
+
attn_pdrop=0.1,
|
| 66 |
+
layer_norm_epsilon=1e-5,
|
| 67 |
+
initializer_range=0.02,
|
| 68 |
+
|
| 69 |
+
num_labels=1,
|
| 70 |
+
summary_type='cls_index',
|
| 71 |
+
summary_use_proj=True,
|
| 72 |
+
summary_activation=None,
|
| 73 |
+
summary_proj_to_labels=True,
|
| 74 |
+
summary_first_dropout=0.1,
|
| 75 |
+
**kwargs
|
| 76 |
+
):
|
| 77 |
+
"""Constructs GPT2Config.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
|
| 81 |
+
n_positions: Number of positional embeddings.
|
| 82 |
+
n_ctx: Size of the causal mask (usually same as n_positions).
|
| 83 |
+
n_embd: Dimensionality of the embeddings and hidden states.
|
| 84 |
+
n_layer: Number of hidden layers in the Transformer encoder.
|
| 85 |
+
n_head: Number of attention heads for each attention layer in
|
| 86 |
+
the Transformer encoder.
|
| 87 |
+
layer_norm_epsilon: epsilon to use in the layer norm layers
|
| 88 |
+
resid_pdrop: The dropout probabilitiy for all fully connected
|
| 89 |
+
layers in the embeddings, encoder, and pooler.
|
| 90 |
+
attn_pdrop: The dropout ratio for the attention
|
| 91 |
+
probabilities.
|
| 92 |
+
embd_pdrop: The dropout ratio for the embeddings.
|
| 93 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
| 94 |
+
initializing all weight matrices.
|
| 95 |
+
"""
|
| 96 |
+
super(GPT2Config, self).__init__(**kwargs)
|
| 97 |
+
|
| 98 |
+
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
| 99 |
+
and isinstance(vocab_size_or_config_json_file, unicode)):
|
| 100 |
+
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
| 101 |
+
json_config = json.loads(reader.read())
|
| 102 |
+
for key, value in json_config.items():
|
| 103 |
+
self.__dict__[key] = value
|
| 104 |
+
elif isinstance(vocab_size_or_config_json_file, int):
|
| 105 |
+
self.vocab_size = vocab_size_or_config_json_file
|
| 106 |
+
self.n_ctx = n_ctx
|
| 107 |
+
self.n_positions = n_positions
|
| 108 |
+
self.n_embd = n_embd
|
| 109 |
+
self.n_layer = n_layer
|
| 110 |
+
self.n_head = n_head
|
| 111 |
+
self.resid_pdrop = resid_pdrop
|
| 112 |
+
self.embd_pdrop = embd_pdrop
|
| 113 |
+
self.attn_pdrop = attn_pdrop
|
| 114 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 115 |
+
self.initializer_range = initializer_range
|
| 116 |
+
|
| 117 |
+
self.num_labels = num_labels
|
| 118 |
+
self.summary_type = summary_type
|
| 119 |
+
self.summary_use_proj = summary_use_proj
|
| 120 |
+
self.summary_activation = summary_activation
|
| 121 |
+
self.summary_first_dropout = summary_first_dropout
|
| 122 |
+
self.summary_proj_to_labels = summary_proj_to_labels
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"First argument must be either a vocabulary size (int)"
|
| 126 |
+
"or the path to a pretrained model config file (str)"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def max_position_embeddings(self):
|
| 131 |
+
return self.n_positions
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def hidden_size(self):
|
| 135 |
+
return self.n_embd
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def num_attention_heads(self):
|
| 139 |
+
return self.n_head
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def num_hidden_layers(self):
|
| 143 |
+
return self.n_layer
|
versatile_diffusion/lib/model_zoo/optimus_models/configuration_utils.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" Configuration base class and utilities."""
|
| 17 |
+
|
| 18 |
+
from __future__ import (absolute_import, division, print_function,
|
| 19 |
+
unicode_literals)
|
| 20 |
+
|
| 21 |
+
import copy
|
| 22 |
+
import json
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
from io import open
|
| 26 |
+
|
| 27 |
+
from .file_utils import cached_path, CONFIG_NAME
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
class PretrainedConfig(object):
|
| 32 |
+
r""" Base class for all configuration classes.
|
| 33 |
+
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
| 34 |
+
|
| 35 |
+
Note:
|
| 36 |
+
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
|
| 37 |
+
It only affects the model's configuration.
|
| 38 |
+
|
| 39 |
+
Class attributes (overridden by derived classes):
|
| 40 |
+
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
|
| 44 |
+
``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens)
|
| 45 |
+
``output_attentions``: boolean, default `False`. Should the model returns attentions weights.
|
| 46 |
+
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
|
| 47 |
+
``torchscript``: string, default `False`. Is the model used with Torchscript.
|
| 48 |
+
"""
|
| 49 |
+
pretrained_config_archive_map = {}
|
| 50 |
+
|
| 51 |
+
def __init__(self, **kwargs):
|
| 52 |
+
self.finetuning_task = kwargs.pop('finetuning_task', None)
|
| 53 |
+
self.num_labels = kwargs.pop('num_labels', 2)
|
| 54 |
+
self.output_attentions = kwargs.pop('output_attentions', False)
|
| 55 |
+
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
| 56 |
+
self.torchscript = kwargs.pop('torchscript', False)
|
| 57 |
+
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
| 58 |
+
|
| 59 |
+
def save_pretrained(self, save_directory):
|
| 60 |
+
""" Save a configuration object to the directory `save_directory`, so that it
|
| 61 |
+
can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method.
|
| 62 |
+
"""
|
| 63 |
+
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
| 64 |
+
|
| 65 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 66 |
+
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
| 67 |
+
|
| 68 |
+
self.to_json_file(output_config_file)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 72 |
+
r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
|
| 73 |
+
|
| 74 |
+
Parameters:
|
| 75 |
+
pretrained_model_name_or_path: either:
|
| 76 |
+
|
| 77 |
+
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
|
| 78 |
+
- a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
|
| 79 |
+
- a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
|
| 80 |
+
|
| 81 |
+
cache_dir: (`optional`) string:
|
| 82 |
+
Path to a directory in which a downloaded pre-trained model
|
| 83 |
+
configuration should be cached if the standard cache should not be used.
|
| 84 |
+
|
| 85 |
+
kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading.
|
| 86 |
+
|
| 87 |
+
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
|
| 88 |
+
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
| 89 |
+
|
| 90 |
+
force_download: (`optional`) boolean, default False:
|
| 91 |
+
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
| 92 |
+
|
| 93 |
+
proxies: (`optional`) dict, default None:
|
| 94 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
| 95 |
+
The proxies are used on each request.
|
| 96 |
+
|
| 97 |
+
return_unused_kwargs: (`optional`) bool:
|
| 98 |
+
|
| 99 |
+
- If False, then this function returns just the final configuration object.
|
| 100 |
+
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
|
| 101 |
+
|
| 102 |
+
Examples::
|
| 103 |
+
|
| 104 |
+
# We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
|
| 105 |
+
# derived class: BertConfig
|
| 106 |
+
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
| 107 |
+
config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
|
| 108 |
+
config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
|
| 109 |
+
config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
|
| 110 |
+
assert config.output_attention == True
|
| 111 |
+
config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
|
| 112 |
+
foo=False, return_unused_kwargs=True)
|
| 113 |
+
assert config.output_attention == True
|
| 114 |
+
assert unused_kwargs == {'foo': False}
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
cache_dir = kwargs.pop('cache_dir', None)
|
| 118 |
+
force_download = kwargs.pop('force_download', False)
|
| 119 |
+
proxies = kwargs.pop('proxies', None)
|
| 120 |
+
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
| 121 |
+
|
| 122 |
+
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
| 123 |
+
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
| 124 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
| 125 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
| 126 |
+
else:
|
| 127 |
+
config_file = pretrained_model_name_or_path
|
| 128 |
+
# redirect to the cache, if necessary
|
| 129 |
+
try:
|
| 130 |
+
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
| 131 |
+
except EnvironmentError as e:
|
| 132 |
+
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
| 133 |
+
logger.error(
|
| 134 |
+
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
| 135 |
+
config_file))
|
| 136 |
+
else:
|
| 137 |
+
logger.error(
|
| 138 |
+
"Model name '{}' was not found in model name list ({}). "
|
| 139 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
| 140 |
+
"associated to this path or url.".format(
|
| 141 |
+
pretrained_model_name_or_path,
|
| 142 |
+
', '.join(cls.pretrained_config_archive_map.keys()),
|
| 143 |
+
config_file))
|
| 144 |
+
raise e
|
| 145 |
+
if resolved_config_file == config_file:
|
| 146 |
+
logger.info("loading configuration file {}".format(config_file))
|
| 147 |
+
else:
|
| 148 |
+
logger.info("loading configuration file {} from cache at {}".format(
|
| 149 |
+
config_file, resolved_config_file))
|
| 150 |
+
|
| 151 |
+
# Load config
|
| 152 |
+
config = cls.from_json_file(resolved_config_file)
|
| 153 |
+
|
| 154 |
+
if hasattr(config, 'pruned_heads'):
|
| 155 |
+
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
|
| 156 |
+
|
| 157 |
+
# Update config with kwargs if needed
|
| 158 |
+
to_remove = []
|
| 159 |
+
for key, value in kwargs.items():
|
| 160 |
+
if hasattr(config, key):
|
| 161 |
+
setattr(config, key, value)
|
| 162 |
+
to_remove.append(key)
|
| 163 |
+
for key in to_remove:
|
| 164 |
+
kwargs.pop(key, None)
|
| 165 |
+
|
| 166 |
+
logger.info("Model config %s", config)
|
| 167 |
+
if return_unused_kwargs:
|
| 168 |
+
return config, kwargs
|
| 169 |
+
else:
|
| 170 |
+
return config
|
| 171 |
+
|
| 172 |
+
@classmethod
|
| 173 |
+
def from_dict(cls, json_object):
|
| 174 |
+
"""Constructs a `Config` from a Python dictionary of parameters."""
|
| 175 |
+
config = cls(vocab_size_or_config_json_file=-1)
|
| 176 |
+
for key, value in json_object.items():
|
| 177 |
+
config.__dict__[key] = value
|
| 178 |
+
return config
|
| 179 |
+
|
| 180 |
+
@classmethod
|
| 181 |
+
def from_json_file(cls, json_file):
|
| 182 |
+
"""Constructs a `BertConfig` from a json file of parameters."""
|
| 183 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
| 184 |
+
text = reader.read()
|
| 185 |
+
return cls.from_dict(json.loads(text))
|
| 186 |
+
|
| 187 |
+
def __eq__(self, other):
|
| 188 |
+
return self.__dict__ == other.__dict__
|
| 189 |
+
|
| 190 |
+
def __repr__(self):
|
| 191 |
+
return str(self.to_json_string())
|
| 192 |
+
|
| 193 |
+
def to_dict(self):
|
| 194 |
+
"""Serializes this instance to a Python dictionary."""
|
| 195 |
+
output = copy.deepcopy(self.__dict__)
|
| 196 |
+
return output
|
| 197 |
+
|
| 198 |
+
def to_json_string(self):
|
| 199 |
+
"""Serializes this instance to a JSON string."""
|
| 200 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
| 201 |
+
|
| 202 |
+
def to_json_file(self, json_file_path):
|
| 203 |
+
""" Save this instance to a json file."""
|
| 204 |
+
with open(json_file_path, "w", encoding='utf-8') as writer:
|
| 205 |
+
writer.write(self.to_json_string())
|
versatile_diffusion/lib/model_zoo/optimus_models/file_utils.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for working with the local dataset cache.
|
| 3 |
+
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
| 4 |
+
Copyright by the AllenNLP authors.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import six
|
| 13 |
+
import shutil
|
| 14 |
+
import tempfile
|
| 15 |
+
import fnmatch
|
| 16 |
+
from functools import wraps
|
| 17 |
+
from hashlib import sha256
|
| 18 |
+
from io import open
|
| 19 |
+
|
| 20 |
+
# import boto3
|
| 21 |
+
# from botocore.config import Config
|
| 22 |
+
# from botocore.exceptions import ClientError
|
| 23 |
+
import requests
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from torch.hub import _get_torch_home
|
| 28 |
+
torch_cache_home = _get_torch_home()
|
| 29 |
+
except ImportError:
|
| 30 |
+
torch_cache_home = os.path.expanduser(
|
| 31 |
+
os.getenv('TORCH_HOME', os.path.join(
|
| 32 |
+
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
|
| 33 |
+
default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers')
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from urllib.parse import urlparse
|
| 37 |
+
except ImportError:
|
| 38 |
+
from urlparse import urlparse
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
from pathlib import Path
|
| 42 |
+
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
| 43 |
+
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
|
| 44 |
+
except (AttributeError, ImportError):
|
| 45 |
+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
|
| 46 |
+
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
| 47 |
+
default_cache_path))
|
| 48 |
+
|
| 49 |
+
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
| 50 |
+
|
| 51 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 52 |
+
TF_WEIGHTS_NAME = 'model.ckpt'
|
| 53 |
+
CONFIG_NAME = "config.json"
|
| 54 |
+
|
| 55 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
| 56 |
+
|
| 57 |
+
if not six.PY2:
|
| 58 |
+
def add_start_docstrings(*docstr):
|
| 59 |
+
def docstring_decorator(fn):
|
| 60 |
+
fn.__doc__ = ''.join(docstr) + fn.__doc__
|
| 61 |
+
return fn
|
| 62 |
+
return docstring_decorator
|
| 63 |
+
|
| 64 |
+
def add_end_docstrings(*docstr):
|
| 65 |
+
def docstring_decorator(fn):
|
| 66 |
+
fn.__doc__ = fn.__doc__ + ''.join(docstr)
|
| 67 |
+
return fn
|
| 68 |
+
return docstring_decorator
|
| 69 |
+
else:
|
| 70 |
+
# Not possible to update class docstrings on python2
|
| 71 |
+
def add_start_docstrings(*docstr):
|
| 72 |
+
def docstring_decorator(fn):
|
| 73 |
+
return fn
|
| 74 |
+
return docstring_decorator
|
| 75 |
+
|
| 76 |
+
def add_end_docstrings(*docstr):
|
| 77 |
+
def docstring_decorator(fn):
|
| 78 |
+
return fn
|
| 79 |
+
return docstring_decorator
|
| 80 |
+
|
| 81 |
+
def url_to_filename(url, etag=None):
|
| 82 |
+
"""
|
| 83 |
+
Convert `url` into a hashed filename in a repeatable way.
|
| 84 |
+
If `etag` is specified, append its hash to the url's, delimited
|
| 85 |
+
by a period.
|
| 86 |
+
"""
|
| 87 |
+
url_bytes = url.encode('utf-8')
|
| 88 |
+
url_hash = sha256(url_bytes)
|
| 89 |
+
filename = url_hash.hexdigest()
|
| 90 |
+
|
| 91 |
+
if etag:
|
| 92 |
+
etag_bytes = etag.encode('utf-8')
|
| 93 |
+
etag_hash = sha256(etag_bytes)
|
| 94 |
+
filename += '.' + etag_hash.hexdigest()
|
| 95 |
+
|
| 96 |
+
return filename
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def filename_to_url(filename, cache_dir=None):
|
| 100 |
+
"""
|
| 101 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
| 102 |
+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
| 103 |
+
"""
|
| 104 |
+
if cache_dir is None:
|
| 105 |
+
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
| 106 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
| 107 |
+
cache_dir = str(cache_dir)
|
| 108 |
+
|
| 109 |
+
cache_path = os.path.join(cache_dir, filename)
|
| 110 |
+
if not os.path.exists(cache_path):
|
| 111 |
+
raise EnvironmentError("file {} not found".format(cache_path))
|
| 112 |
+
|
| 113 |
+
meta_path = cache_path + '.json'
|
| 114 |
+
if not os.path.exists(meta_path):
|
| 115 |
+
raise EnvironmentError("file {} not found".format(meta_path))
|
| 116 |
+
|
| 117 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
| 118 |
+
metadata = json.load(meta_file)
|
| 119 |
+
url = metadata['url']
|
| 120 |
+
etag = metadata['etag']
|
| 121 |
+
|
| 122 |
+
return url, etag
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
|
| 126 |
+
"""
|
| 127 |
+
Given something that might be a URL (or might be a local path),
|
| 128 |
+
determine which. If it's a URL, download the file and cache it, and
|
| 129 |
+
return the path to the cached file. If it's already a local path,
|
| 130 |
+
make sure the file exists and then return the path.
|
| 131 |
+
Args:
|
| 132 |
+
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
| 133 |
+
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
| 134 |
+
"""
|
| 135 |
+
if cache_dir is None:
|
| 136 |
+
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
| 137 |
+
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
| 138 |
+
url_or_filename = str(url_or_filename)
|
| 139 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
| 140 |
+
cache_dir = str(cache_dir)
|
| 141 |
+
|
| 142 |
+
parsed = urlparse(url_or_filename)
|
| 143 |
+
|
| 144 |
+
if parsed.scheme in ('http', 'https', 's3'):
|
| 145 |
+
# URL, so get it from the cache (downloading if necessary)
|
| 146 |
+
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
| 147 |
+
elif os.path.exists(url_or_filename):
|
| 148 |
+
# File, and it exists.
|
| 149 |
+
return url_or_filename
|
| 150 |
+
elif parsed.scheme == '':
|
| 151 |
+
# File, but it doesn't exist.
|
| 152 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
| 153 |
+
else:
|
| 154 |
+
# Something unknown
|
| 155 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def split_s3_path(url):
|
| 159 |
+
"""Split a full s3 path into the bucket name and path."""
|
| 160 |
+
parsed = urlparse(url)
|
| 161 |
+
if not parsed.netloc or not parsed.path:
|
| 162 |
+
raise ValueError("bad s3 path {}".format(url))
|
| 163 |
+
bucket_name = parsed.netloc
|
| 164 |
+
s3_path = parsed.path
|
| 165 |
+
# Remove '/' at beginning of path.
|
| 166 |
+
if s3_path.startswith("/"):
|
| 167 |
+
s3_path = s3_path[1:]
|
| 168 |
+
return bucket_name, s3_path
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def s3_request(func):
|
| 172 |
+
"""
|
| 173 |
+
Wrapper function for s3 requests in order to create more helpful error
|
| 174 |
+
messages.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
@wraps(func)
|
| 178 |
+
def wrapper(url, *args, **kwargs):
|
| 179 |
+
try:
|
| 180 |
+
return func(url, *args, **kwargs)
|
| 181 |
+
except ClientError as exc:
|
| 182 |
+
if int(exc.response["Error"]["Code"]) == 404:
|
| 183 |
+
raise EnvironmentError("file {} not found".format(url))
|
| 184 |
+
else:
|
| 185 |
+
raise
|
| 186 |
+
|
| 187 |
+
return wrapper
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@s3_request
|
| 191 |
+
def s3_etag(url, proxies=None):
|
| 192 |
+
"""Check ETag on S3 object."""
|
| 193 |
+
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
| 194 |
+
bucket_name, s3_path = split_s3_path(url)
|
| 195 |
+
s3_object = s3_resource.Object(bucket_name, s3_path)
|
| 196 |
+
return s3_object.e_tag
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@s3_request
|
| 200 |
+
def s3_get(url, temp_file, proxies=None):
|
| 201 |
+
"""Pull a file directly from S3."""
|
| 202 |
+
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
| 203 |
+
bucket_name, s3_path = split_s3_path(url)
|
| 204 |
+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def http_get(url, temp_file, proxies=None):
|
| 208 |
+
req = requests.get(url, stream=True, proxies=proxies)
|
| 209 |
+
content_length = req.headers.get('Content-Length')
|
| 210 |
+
total = int(content_length) if content_length is not None else None
|
| 211 |
+
progress = tqdm(unit="B", total=total)
|
| 212 |
+
for chunk in req.iter_content(chunk_size=1024):
|
| 213 |
+
if chunk: # filter out keep-alive new chunks
|
| 214 |
+
progress.update(len(chunk))
|
| 215 |
+
temp_file.write(chunk)
|
| 216 |
+
progress.close()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None):
|
| 220 |
+
"""
|
| 221 |
+
Given a URL, look for the corresponding dataset in the local cache.
|
| 222 |
+
If it's not there, download it. Then return the path to the cached file.
|
| 223 |
+
"""
|
| 224 |
+
if cache_dir is None:
|
| 225 |
+
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
| 226 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
| 227 |
+
cache_dir = str(cache_dir)
|
| 228 |
+
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
|
| 229 |
+
cache_dir = str(cache_dir)
|
| 230 |
+
|
| 231 |
+
if not os.path.exists(cache_dir):
|
| 232 |
+
os.makedirs(cache_dir)
|
| 233 |
+
|
| 234 |
+
# Get eTag to add to filename, if it exists.
|
| 235 |
+
if url.startswith("s3://"):
|
| 236 |
+
etag = s3_etag(url, proxies=proxies)
|
| 237 |
+
else:
|
| 238 |
+
try:
|
| 239 |
+
response = requests.head(url, allow_redirects=True, proxies=proxies)
|
| 240 |
+
if response.status_code != 200:
|
| 241 |
+
etag = None
|
| 242 |
+
else:
|
| 243 |
+
etag = response.headers.get("ETag")
|
| 244 |
+
except EnvironmentError:
|
| 245 |
+
etag = None
|
| 246 |
+
|
| 247 |
+
if sys.version_info[0] == 2 and etag is not None:
|
| 248 |
+
etag = etag.decode('utf-8')
|
| 249 |
+
filename = url_to_filename(url, etag)
|
| 250 |
+
|
| 251 |
+
# get cache path to put the file
|
| 252 |
+
cache_path = os.path.join(cache_dir, filename)
|
| 253 |
+
|
| 254 |
+
# If we don't have a connection (etag is None) and can't identify the file
|
| 255 |
+
# try to get the last downloaded one
|
| 256 |
+
if not os.path.exists(cache_path) and etag is None:
|
| 257 |
+
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
| 258 |
+
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
|
| 259 |
+
if matching_files:
|
| 260 |
+
cache_path = os.path.join(cache_dir, matching_files[-1])
|
| 261 |
+
|
| 262 |
+
if not os.path.exists(cache_path) or force_download:
|
| 263 |
+
# Download to temporary file, then copy to cache dir once finished.
|
| 264 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
| 265 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
| 266 |
+
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
| 267 |
+
|
| 268 |
+
# GET file object
|
| 269 |
+
if url.startswith("s3://"):
|
| 270 |
+
s3_get(url, temp_file, proxies=proxies)
|
| 271 |
+
else:
|
| 272 |
+
http_get(url, temp_file, proxies=proxies)
|
| 273 |
+
|
| 274 |
+
# we are copying the file before closing it, so flush to avoid truncation
|
| 275 |
+
temp_file.flush()
|
| 276 |
+
# shutil.copyfileobj() starts at the current position, so go to the start
|
| 277 |
+
temp_file.seek(0)
|
| 278 |
+
|
| 279 |
+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
| 280 |
+
with open(cache_path, 'wb') as cache_file:
|
| 281 |
+
shutil.copyfileobj(temp_file, cache_file)
|
| 282 |
+
|
| 283 |
+
logger.info("creating metadata file for %s", cache_path)
|
| 284 |
+
meta = {'url': url, 'etag': etag}
|
| 285 |
+
meta_path = cache_path + '.json'
|
| 286 |
+
with open(meta_path, 'w') as meta_file:
|
| 287 |
+
output_string = json.dumps(meta)
|
| 288 |
+
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
| 289 |
+
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
|
| 290 |
+
meta_file.write(output_string)
|
| 291 |
+
|
| 292 |
+
logger.info("removing temp file %s", temp_file.name)
|
| 293 |
+
|
| 294 |
+
return cache_path
|
versatile_diffusion/lib/model_zoo/optimus_models/modeling_utils.py
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch BERT model."""
|
| 17 |
+
|
| 18 |
+
from __future__ import (absolute_import, division, print_function,
|
| 19 |
+
unicode_literals)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
import pdb
|
| 23 |
+
import copy
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import os
|
| 27 |
+
from io import open
|
| 28 |
+
|
| 29 |
+
import six
|
| 30 |
+
import torch
|
| 31 |
+
from torch import nn
|
| 32 |
+
from torch.nn import CrossEntropyLoss
|
| 33 |
+
from torch.nn import functional as F
|
| 34 |
+
|
| 35 |
+
from .configuration_utils import PretrainedConfig
|
| 36 |
+
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
from torch.nn import Identity
|
| 43 |
+
except ImportError:
|
| 44 |
+
# Older PyTorch compatibility
|
| 45 |
+
class Identity(nn.Module):
|
| 46 |
+
r"""A placeholder identity operator that is argument-insensitive.
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, *args, **kwargs):
|
| 49 |
+
super(Identity, self).__init__()
|
| 50 |
+
|
| 51 |
+
def forward(self, input):
|
| 52 |
+
return input
|
| 53 |
+
|
| 54 |
+
class PreTrainedModel(nn.Module):
|
| 55 |
+
r""" Base class for all models.
|
| 56 |
+
|
| 57 |
+
:class:`~pytorch_transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
| 58 |
+
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
| 59 |
+
|
| 60 |
+
Class attributes (overridden by derived classes):
|
| 61 |
+
- ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
| 62 |
+
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
|
| 63 |
+
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
|
| 64 |
+
|
| 65 |
+
- ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`,
|
| 66 |
+
- ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`,
|
| 67 |
+
- ``path``: a path (string) to the TensorFlow checkpoint.
|
| 68 |
+
|
| 69 |
+
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
|
| 70 |
+
"""
|
| 71 |
+
config_class = None
|
| 72 |
+
pretrained_model_archive_map = {}
|
| 73 |
+
load_tf_weights = lambda model, config, path: None
|
| 74 |
+
base_model_prefix = ""
|
| 75 |
+
|
| 76 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 77 |
+
super(PreTrainedModel, self).__init__()
|
| 78 |
+
if not isinstance(config, PretrainedConfig):
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
|
| 81 |
+
"To create a model from a pretrained model use "
|
| 82 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
| 83 |
+
self.__class__.__name__, self.__class__.__name__
|
| 84 |
+
))
|
| 85 |
+
# Save config in model
|
| 86 |
+
self.config = config
|
| 87 |
+
|
| 88 |
+
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
| 89 |
+
""" Build a resized Embedding Module from a provided token Embedding Module.
|
| 90 |
+
Increasing the size will add newly initialized vectors at the end
|
| 91 |
+
Reducing the size will remove vectors from the end
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
new_num_tokens: (`optional`) int
|
| 95 |
+
New number of tokens in the embedding matrix.
|
| 96 |
+
Increasing the size will add newly initialized vectors at the end
|
| 97 |
+
Reducing the size will remove vectors from the end
|
| 98 |
+
If not provided or None: return the provided token Embedding Module.
|
| 99 |
+
Return: ``torch.nn.Embeddings``
|
| 100 |
+
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
| 101 |
+
"""
|
| 102 |
+
if new_num_tokens is None:
|
| 103 |
+
return old_embeddings
|
| 104 |
+
|
| 105 |
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
| 106 |
+
if old_num_tokens == new_num_tokens:
|
| 107 |
+
return old_embeddings
|
| 108 |
+
|
| 109 |
+
# Build new embeddings
|
| 110 |
+
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
| 111 |
+
new_embeddings.to(old_embeddings.weight.device)
|
| 112 |
+
|
| 113 |
+
# initialize all new embeddings (in particular added tokens)
|
| 114 |
+
self._init_weights(new_embeddings)
|
| 115 |
+
|
| 116 |
+
# Copy word embeddings from the previous weights
|
| 117 |
+
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
| 118 |
+
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
| 119 |
+
|
| 120 |
+
return new_embeddings
|
| 121 |
+
|
| 122 |
+
def _tie_or_clone_weights(self, first_module, second_module):
|
| 123 |
+
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
| 124 |
+
"""
|
| 125 |
+
if self.config.torchscript:
|
| 126 |
+
first_module.weight = nn.Parameter(second_module.weight.clone())
|
| 127 |
+
else:
|
| 128 |
+
first_module.weight = second_module.weight
|
| 129 |
+
|
| 130 |
+
if hasattr(first_module, 'bias') and first_module.bias is not None:
|
| 131 |
+
first_module.bias.data = torch.nn.functional.pad(
|
| 132 |
+
first_module.bias.data,
|
| 133 |
+
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
|
| 134 |
+
'constant',
|
| 135 |
+
0
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def resize_token_embeddings(self, new_num_tokens=None):
|
| 139 |
+
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
| 140 |
+
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
| 141 |
+
|
| 142 |
+
Arguments:
|
| 143 |
+
|
| 144 |
+
new_num_tokens: (`optional`) int:
|
| 145 |
+
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
| 146 |
+
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
| 147 |
+
|
| 148 |
+
Return: ``torch.nn.Embeddings``
|
| 149 |
+
Pointer to the input tokens Embeddings Module of the model
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
| 154 |
+
|
| 155 |
+
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
| 156 |
+
if new_num_tokens is None:
|
| 157 |
+
return model_embeds
|
| 158 |
+
|
| 159 |
+
# Update base model and current model config
|
| 160 |
+
self.config.vocab_size = new_num_tokens
|
| 161 |
+
base_model.vocab_size = new_num_tokens
|
| 162 |
+
|
| 163 |
+
# Tie weights again if needed
|
| 164 |
+
if hasattr(self, 'tie_weights'):
|
| 165 |
+
self.tie_weights()
|
| 166 |
+
|
| 167 |
+
return model_embeds
|
| 168 |
+
|
| 169 |
+
def init_weights(self):
|
| 170 |
+
""" Initialize and prunes weights if needed. """
|
| 171 |
+
# Initialize weights
|
| 172 |
+
self.apply(self._init_weights)
|
| 173 |
+
|
| 174 |
+
# Prune heads if needed
|
| 175 |
+
if self.config.pruned_heads:
|
| 176 |
+
self.prune_heads(self.config.pruned_heads)
|
| 177 |
+
|
| 178 |
+
def prune_heads(self, heads_to_prune):
|
| 179 |
+
""" Prunes heads of the base model.
|
| 180 |
+
|
| 181 |
+
Arguments:
|
| 182 |
+
|
| 183 |
+
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
| 184 |
+
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
| 185 |
+
"""
|
| 186 |
+
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
| 187 |
+
|
| 188 |
+
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
| 189 |
+
for layer, heads in heads_to_prune.items():
|
| 190 |
+
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
|
| 191 |
+
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
|
| 192 |
+
|
| 193 |
+
base_model._prune_heads(heads_to_prune)
|
| 194 |
+
|
| 195 |
+
def save_pretrained(self, save_directory):
|
| 196 |
+
""" Save a model and its configuration file to a directory, so that it
|
| 197 |
+
can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
|
| 198 |
+
"""
|
| 199 |
+
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
| 200 |
+
|
| 201 |
+
# Only save the model it-self if we are using distributed training
|
| 202 |
+
model_to_save = self.module if hasattr(self, 'module') else self
|
| 203 |
+
|
| 204 |
+
# Save configuration file
|
| 205 |
+
model_to_save.config.save_pretrained(save_directory)
|
| 206 |
+
|
| 207 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 208 |
+
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
| 209 |
+
|
| 210 |
+
torch.save(model_to_save.state_dict(), output_model_file)
|
| 211 |
+
|
| 212 |
+
@classmethod
|
| 213 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 214 |
+
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
| 215 |
+
|
| 216 |
+
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
|
| 217 |
+
To train the model, you should first set it back in training mode with ``model.train()``
|
| 218 |
+
|
| 219 |
+
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
|
| 220 |
+
It is up to you to train those weights with a downstream fine-tuning task.
|
| 221 |
+
|
| 222 |
+
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
|
| 223 |
+
|
| 224 |
+
Parameters:
|
| 225 |
+
pretrained_model_name_or_path: either:
|
| 226 |
+
|
| 227 |
+
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
| 228 |
+
- a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
| 229 |
+
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
| 230 |
+
|
| 231 |
+
model_args: (`optional`) Sequence of positional arguments:
|
| 232 |
+
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
| 233 |
+
|
| 234 |
+
config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
|
| 235 |
+
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
| 236 |
+
|
| 237 |
+
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
| 238 |
+
- the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
| 239 |
+
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
| 240 |
+
|
| 241 |
+
state_dict: (`optional`) dict:
|
| 242 |
+
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
| 243 |
+
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
| 244 |
+
In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
| 245 |
+
|
| 246 |
+
cache_dir: (`optional`) string:
|
| 247 |
+
Path to a directory in which a downloaded pre-trained model
|
| 248 |
+
configuration should be cached if the standard cache should not be used.
|
| 249 |
+
|
| 250 |
+
force_download: (`optional`) boolean, default False:
|
| 251 |
+
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
| 252 |
+
|
| 253 |
+
proxies: (`optional`) dict, default None:
|
| 254 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
| 255 |
+
The proxies are used on each request.
|
| 256 |
+
|
| 257 |
+
output_loading_info: (`optional`) boolean:
|
| 258 |
+
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
| 259 |
+
|
| 260 |
+
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
| 261 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
| 262 |
+
|
| 263 |
+
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
| 264 |
+
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
| 265 |
+
|
| 266 |
+
Examples::
|
| 267 |
+
|
| 268 |
+
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
| 269 |
+
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
| 270 |
+
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
| 271 |
+
assert model.config.output_attention == True
|
| 272 |
+
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
| 273 |
+
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
| 274 |
+
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
| 275 |
+
|
| 276 |
+
"""
|
| 277 |
+
config = kwargs.pop('config', None)
|
| 278 |
+
state_dict = kwargs.pop('state_dict', None)
|
| 279 |
+
cache_dir = kwargs.pop('cache_dir', None)
|
| 280 |
+
from_tf = kwargs.pop('from_tf', False)
|
| 281 |
+
force_download = kwargs.pop('force_download', False)
|
| 282 |
+
proxies = kwargs.pop('proxies', None)
|
| 283 |
+
output_loading_info = kwargs.pop('output_loading_info', False)
|
| 284 |
+
|
| 285 |
+
# Load config
|
| 286 |
+
if config is None:
|
| 287 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
| 288 |
+
pretrained_model_name_or_path, *model_args,
|
| 289 |
+
cache_dir=cache_dir, return_unused_kwargs=True,
|
| 290 |
+
force_download=force_download,
|
| 291 |
+
**kwargs
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
model_kwargs = kwargs
|
| 295 |
+
|
| 296 |
+
# Load model
|
| 297 |
+
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
| 298 |
+
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
| 299 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
| 300 |
+
if from_tf:
|
| 301 |
+
# Directly load from a TensorFlow checkpoint
|
| 302 |
+
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
| 303 |
+
else:
|
| 304 |
+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
| 305 |
+
else:
|
| 306 |
+
if from_tf:
|
| 307 |
+
# Directly load from a TensorFlow checkpoint
|
| 308 |
+
archive_file = pretrained_model_name_or_path + ".index"
|
| 309 |
+
else:
|
| 310 |
+
archive_file = pretrained_model_name_or_path
|
| 311 |
+
# redirect to the cache, if necessary
|
| 312 |
+
try:
|
| 313 |
+
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
| 314 |
+
except EnvironmentError as e:
|
| 315 |
+
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
| 316 |
+
logger.error(
|
| 317 |
+
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
| 318 |
+
archive_file))
|
| 319 |
+
else:
|
| 320 |
+
logger.error(
|
| 321 |
+
"Model name '{}' was not found in model name list ({}). "
|
| 322 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
| 323 |
+
"associated to this path or url.".format(
|
| 324 |
+
pretrained_model_name_or_path,
|
| 325 |
+
', '.join(cls.pretrained_model_archive_map.keys()),
|
| 326 |
+
archive_file))
|
| 327 |
+
raise e
|
| 328 |
+
if resolved_archive_file == archive_file:
|
| 329 |
+
logger.info("loading weights file {}".format(archive_file))
|
| 330 |
+
else:
|
| 331 |
+
logger.info("loading weights file {} from cache at {}".format(
|
| 332 |
+
archive_file, resolved_archive_file))
|
| 333 |
+
|
| 334 |
+
# Instantiate model.
|
| 335 |
+
model = cls(config, *model_args, **model_kwargs)
|
| 336 |
+
|
| 337 |
+
if state_dict is None and not from_tf:
|
| 338 |
+
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
| 339 |
+
if from_tf:
|
| 340 |
+
# Directly load from a TensorFlow checkpoint
|
| 341 |
+
return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
|
| 342 |
+
|
| 343 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
| 344 |
+
old_keys = []
|
| 345 |
+
new_keys = []
|
| 346 |
+
for key in state_dict.keys():
|
| 347 |
+
new_key = None
|
| 348 |
+
if 'gamma' in key:
|
| 349 |
+
new_key = key.replace('gamma', 'weight')
|
| 350 |
+
if 'beta' in key:
|
| 351 |
+
new_key = key.replace('beta', 'bias')
|
| 352 |
+
if new_key:
|
| 353 |
+
old_keys.append(key)
|
| 354 |
+
new_keys.append(new_key)
|
| 355 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
| 356 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 357 |
+
|
| 358 |
+
# Load from a PyTorch state_dict
|
| 359 |
+
missing_keys = []
|
| 360 |
+
unexpected_keys = []
|
| 361 |
+
error_msgs = []
|
| 362 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 363 |
+
metadata = getattr(state_dict, '_metadata', None)
|
| 364 |
+
state_dict = state_dict.copy()
|
| 365 |
+
if metadata is not None:
|
| 366 |
+
state_dict._metadata = metadata
|
| 367 |
+
|
| 368 |
+
def load(module, prefix=''):
|
| 369 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 370 |
+
module._load_from_state_dict(
|
| 371 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
| 372 |
+
for name, child in module._modules.items():
|
| 373 |
+
if child is not None:
|
| 374 |
+
load(child, prefix + name + '.')
|
| 375 |
+
|
| 376 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
| 377 |
+
start_prefix = ''
|
| 378 |
+
model_to_load = model
|
| 379 |
+
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
|
| 380 |
+
start_prefix = cls.base_model_prefix + '.'
|
| 381 |
+
if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
|
| 382 |
+
model_to_load = getattr(model, cls.base_model_prefix)
|
| 383 |
+
|
| 384 |
+
load(model_to_load, prefix=start_prefix)
|
| 385 |
+
if len(missing_keys) > 0:
|
| 386 |
+
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
| 387 |
+
model.__class__.__name__, missing_keys))
|
| 388 |
+
if len(unexpected_keys) > 0:
|
| 389 |
+
logger.info("Weights from pretrained model not used in {}: {}".format(
|
| 390 |
+
model.__class__.__name__, unexpected_keys))
|
| 391 |
+
if len(error_msgs) > 0:
|
| 392 |
+
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
| 393 |
+
model.__class__.__name__, "\n\t".join(error_msgs)))
|
| 394 |
+
|
| 395 |
+
if hasattr(model, 'tie_weights'):
|
| 396 |
+
model.tie_weights() # make sure word embedding weights are still tied
|
| 397 |
+
|
| 398 |
+
# Set model in evaluation mode to desactivate DropOut modules by default
|
| 399 |
+
model.eval()
|
| 400 |
+
|
| 401 |
+
if output_loading_info:
|
| 402 |
+
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
|
| 403 |
+
return model, loading_info
|
| 404 |
+
|
| 405 |
+
return model
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class Conv1D(nn.Module):
|
| 409 |
+
def __init__(self, nf, nx):
|
| 410 |
+
""" Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
|
| 411 |
+
Basically works like a Linear layer but the weights are transposed
|
| 412 |
+
"""
|
| 413 |
+
super(Conv1D, self).__init__()
|
| 414 |
+
self.nf = nf
|
| 415 |
+
w = torch.empty(nx, nf)
|
| 416 |
+
nn.init.normal_(w, std=0.02)
|
| 417 |
+
self.weight = nn.Parameter(w)
|
| 418 |
+
self.bias = nn.Parameter(torch.zeros(nf))
|
| 419 |
+
|
| 420 |
+
def forward(self, x):
|
| 421 |
+
size_out = x.size()[:-1] + (self.nf,)
|
| 422 |
+
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
| 423 |
+
x = x.view(*size_out)
|
| 424 |
+
return x
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class PoolerStartLogits(nn.Module):
|
| 428 |
+
""" Compute SQuAD start_logits from sequence hidden states. """
|
| 429 |
+
def __init__(self, config):
|
| 430 |
+
super(PoolerStartLogits, self).__init__()
|
| 431 |
+
self.dense = nn.Linear(config.hidden_size, 1)
|
| 432 |
+
|
| 433 |
+
def forward(self, hidden_states, p_mask=None):
|
| 434 |
+
""" Args:
|
| 435 |
+
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
|
| 436 |
+
invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
| 437 |
+
1.0 means token should be masked.
|
| 438 |
+
"""
|
| 439 |
+
x = self.dense(hidden_states).squeeze(-1)
|
| 440 |
+
|
| 441 |
+
if p_mask is not None:
|
| 442 |
+
if next(self.parameters()).dtype == torch.float16:
|
| 443 |
+
x = x * (1 - p_mask) - 65500 * p_mask
|
| 444 |
+
else:
|
| 445 |
+
x = x * (1 - p_mask) - 1e30 * p_mask
|
| 446 |
+
|
| 447 |
+
return x
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class PoolerEndLogits(nn.Module):
|
| 451 |
+
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
|
| 452 |
+
"""
|
| 453 |
+
def __init__(self, config):
|
| 454 |
+
super(PoolerEndLogits, self).__init__()
|
| 455 |
+
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 456 |
+
self.activation = nn.Tanh()
|
| 457 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 458 |
+
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
| 459 |
+
|
| 460 |
+
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
|
| 461 |
+
""" Args:
|
| 462 |
+
One of ``start_states``, ``start_positions`` should be not None.
|
| 463 |
+
If both are set, ``start_positions`` overrides ``start_states``.
|
| 464 |
+
|
| 465 |
+
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
| 466 |
+
hidden states of the first tokens for the labeled span.
|
| 467 |
+
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 468 |
+
position of the first token for the labeled span:
|
| 469 |
+
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
| 470 |
+
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
| 471 |
+
1.0 means token should be masked.
|
| 472 |
+
"""
|
| 473 |
+
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
| 474 |
+
if start_positions is not None:
|
| 475 |
+
slen, hsz = hidden_states.shape[-2:]
|
| 476 |
+
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 477 |
+
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
| 478 |
+
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
| 479 |
+
|
| 480 |
+
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
|
| 481 |
+
x = self.activation(x)
|
| 482 |
+
x = self.LayerNorm(x)
|
| 483 |
+
x = self.dense_1(x).squeeze(-1)
|
| 484 |
+
|
| 485 |
+
if p_mask is not None:
|
| 486 |
+
x = x * (1 - p_mask) - 1e30 * p_mask
|
| 487 |
+
|
| 488 |
+
return x
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class PoolerAnswerClass(nn.Module):
|
| 492 |
+
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
|
| 493 |
+
def __init__(self, config):
|
| 494 |
+
super(PoolerAnswerClass, self).__init__()
|
| 495 |
+
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 496 |
+
self.activation = nn.Tanh()
|
| 497 |
+
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
| 498 |
+
|
| 499 |
+
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
|
| 500 |
+
"""
|
| 501 |
+
Args:
|
| 502 |
+
One of ``start_states``, ``start_positions`` should be not None.
|
| 503 |
+
If both are set, ``start_positions`` overrides ``start_states``.
|
| 504 |
+
|
| 505 |
+
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
|
| 506 |
+
hidden states of the first tokens for the labeled span.
|
| 507 |
+
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 508 |
+
position of the first token for the labeled span.
|
| 509 |
+
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
| 510 |
+
position of the CLS token. If None, take the last token.
|
| 511 |
+
|
| 512 |
+
note(Original repo):
|
| 513 |
+
no dependency on end_feature so that we can obtain one single `cls_logits`
|
| 514 |
+
for each sample
|
| 515 |
+
"""
|
| 516 |
+
hsz = hidden_states.shape[-1]
|
| 517 |
+
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
| 518 |
+
if start_positions is not None:
|
| 519 |
+
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 520 |
+
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
| 521 |
+
|
| 522 |
+
if cls_index is not None:
|
| 523 |
+
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 524 |
+
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
|
| 525 |
+
else:
|
| 526 |
+
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
|
| 527 |
+
|
| 528 |
+
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
|
| 529 |
+
x = self.activation(x)
|
| 530 |
+
x = self.dense_1(x).squeeze(-1)
|
| 531 |
+
|
| 532 |
+
return x
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class SQuADHead(nn.Module):
|
| 536 |
+
r""" A SQuAD head inspired by XLNet.
|
| 537 |
+
|
| 538 |
+
Parameters:
|
| 539 |
+
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
| 540 |
+
|
| 541 |
+
Inputs:
|
| 542 |
+
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
|
| 543 |
+
hidden states of sequence tokens
|
| 544 |
+
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 545 |
+
position of the first token for the labeled span.
|
| 546 |
+
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 547 |
+
position of the last token for the labeled span.
|
| 548 |
+
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
| 549 |
+
position of the CLS token. If None, take the last token.
|
| 550 |
+
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 551 |
+
Whether the question has a possible answer in the paragraph or not.
|
| 552 |
+
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
| 553 |
+
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
| 554 |
+
1.0 means token should be masked.
|
| 555 |
+
|
| 556 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 557 |
+
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 558 |
+
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
| 559 |
+
**start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 560 |
+
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
| 561 |
+
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
| 562 |
+
**start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 563 |
+
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
| 564 |
+
Indices for the top config.start_n_top start token possibilities (beam-search).
|
| 565 |
+
**end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 566 |
+
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
| 567 |
+
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
| 568 |
+
**end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 569 |
+
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
| 570 |
+
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
| 571 |
+
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 572 |
+
``torch.FloatTensor`` of shape ``(batch_size,)``
|
| 573 |
+
Log probabilities for the ``is_impossible`` label of the answers.
|
| 574 |
+
"""
|
| 575 |
+
def __init__(self, config):
|
| 576 |
+
super(SQuADHead, self).__init__()
|
| 577 |
+
self.start_n_top = config.start_n_top
|
| 578 |
+
self.end_n_top = config.end_n_top
|
| 579 |
+
|
| 580 |
+
self.start_logits = PoolerStartLogits(config)
|
| 581 |
+
self.end_logits = PoolerEndLogits(config)
|
| 582 |
+
self.answer_class = PoolerAnswerClass(config)
|
| 583 |
+
|
| 584 |
+
def forward(self, hidden_states, start_positions=None, end_positions=None,
|
| 585 |
+
cls_index=None, is_impossible=None, p_mask=None):
|
| 586 |
+
outputs = ()
|
| 587 |
+
|
| 588 |
+
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
| 589 |
+
|
| 590 |
+
if start_positions is not None and end_positions is not None:
|
| 591 |
+
# If we are on multi-GPU, let's remove the dimension added by batch splitting
|
| 592 |
+
for x in (start_positions, end_positions, cls_index, is_impossible):
|
| 593 |
+
if x is not None and x.dim() > 1:
|
| 594 |
+
x.squeeze_(-1)
|
| 595 |
+
|
| 596 |
+
# during training, compute the end logits based on the ground truth of the start position
|
| 597 |
+
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
|
| 598 |
+
|
| 599 |
+
loss_fct = CrossEntropyLoss()
|
| 600 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 601 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 602 |
+
total_loss = (start_loss + end_loss) / 2
|
| 603 |
+
|
| 604 |
+
if cls_index is not None and is_impossible is not None:
|
| 605 |
+
# Predict answerability from the representation of CLS and START
|
| 606 |
+
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
|
| 607 |
+
loss_fct_cls = nn.BCEWithLogitsLoss()
|
| 608 |
+
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
| 609 |
+
|
| 610 |
+
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
| 611 |
+
total_loss += cls_loss * 0.5
|
| 612 |
+
|
| 613 |
+
outputs = (total_loss,) + outputs
|
| 614 |
+
|
| 615 |
+
else:
|
| 616 |
+
# during inference, compute the end logits based on beam search
|
| 617 |
+
bsz, slen, hsz = hidden_states.size()
|
| 618 |
+
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
| 619 |
+
|
| 620 |
+
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
| 621 |
+
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
| 622 |
+
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
| 623 |
+
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
| 624 |
+
|
| 625 |
+
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
| 626 |
+
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
|
| 627 |
+
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
|
| 628 |
+
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
|
| 629 |
+
|
| 630 |
+
end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
|
| 631 |
+
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
| 632 |
+
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
| 633 |
+
|
| 634 |
+
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
| 635 |
+
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
| 636 |
+
|
| 637 |
+
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
| 638 |
+
|
| 639 |
+
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
| 640 |
+
# or (if labels are provided) (total_loss,)
|
| 641 |
+
return outputs
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
class SequenceSummary(nn.Module):
|
| 645 |
+
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
|
| 646 |
+
Args of the config class:
|
| 647 |
+
summary_type:
|
| 648 |
+
- 'last' => [default] take the last token hidden state (like XLNet)
|
| 649 |
+
- 'first' => take the first token hidden state (like Bert)
|
| 650 |
+
- 'mean' => take the mean of all tokens hidden states
|
| 651 |
+
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
|
| 652 |
+
- 'attn' => Not implemented now, use multi-head attention
|
| 653 |
+
summary_use_proj: Add a projection after the vector extraction
|
| 654 |
+
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
| 655 |
+
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
|
| 656 |
+
summary_first_dropout: Add a dropout before the projection and activation
|
| 657 |
+
summary_last_dropout: Add a dropout after the projection and activation
|
| 658 |
+
"""
|
| 659 |
+
def __init__(self, config):
|
| 660 |
+
super(SequenceSummary, self).__init__()
|
| 661 |
+
|
| 662 |
+
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
|
| 663 |
+
if self.summary_type == 'attn':
|
| 664 |
+
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
| 665 |
+
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
| 666 |
+
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
| 667 |
+
raise NotImplementedError
|
| 668 |
+
|
| 669 |
+
self.summary = Identity()
|
| 670 |
+
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
|
| 671 |
+
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
|
| 672 |
+
num_classes = config.num_labels
|
| 673 |
+
else:
|
| 674 |
+
num_classes = config.hidden_size
|
| 675 |
+
self.summary = nn.Linear(config.hidden_size, num_classes)
|
| 676 |
+
|
| 677 |
+
self.activation = Identity()
|
| 678 |
+
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
|
| 679 |
+
self.activation = nn.Tanh()
|
| 680 |
+
|
| 681 |
+
self.first_dropout = Identity()
|
| 682 |
+
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
|
| 683 |
+
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
| 684 |
+
|
| 685 |
+
self.last_dropout = Identity()
|
| 686 |
+
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
|
| 687 |
+
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
| 688 |
+
|
| 689 |
+
def forward(self, hidden_states, cls_index=None):
|
| 690 |
+
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
|
| 691 |
+
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
|
| 692 |
+
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
|
| 693 |
+
if summary_type == 'cls_index' and cls_index is None:
|
| 694 |
+
we take the last token of the sequence as classification token
|
| 695 |
+
"""
|
| 696 |
+
if self.summary_type == 'last':
|
| 697 |
+
output = hidden_states[:, -1]
|
| 698 |
+
elif self.summary_type == 'first':
|
| 699 |
+
output = hidden_states[:, 0]
|
| 700 |
+
elif self.summary_type == 'mean':
|
| 701 |
+
output = hidden_states.mean(dim=1)
|
| 702 |
+
elif self.summary_type == 'cls_index':
|
| 703 |
+
if cls_index is None:
|
| 704 |
+
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
|
| 705 |
+
else:
|
| 706 |
+
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
| 707 |
+
cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
|
| 708 |
+
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
| 709 |
+
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
| 710 |
+
elif self.summary_type == 'attn':
|
| 711 |
+
raise NotImplementedError
|
| 712 |
+
|
| 713 |
+
output = self.first_dropout(output)
|
| 714 |
+
output = self.summary(output)
|
| 715 |
+
output = self.activation(output)
|
| 716 |
+
output = self.last_dropout(output)
|
| 717 |
+
|
| 718 |
+
return output
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def prune_linear_layer(layer, index, dim=0):
|
| 722 |
+
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
| 723 |
+
Return the pruned layer as a new layer with requires_grad=True.
|
| 724 |
+
Used to remove heads.
|
| 725 |
+
"""
|
| 726 |
+
index = index.to(layer.weight.device)
|
| 727 |
+
W = layer.weight.index_select(dim, index).clone().detach()
|
| 728 |
+
if layer.bias is not None:
|
| 729 |
+
if dim == 1:
|
| 730 |
+
b = layer.bias.clone().detach()
|
| 731 |
+
else:
|
| 732 |
+
b = layer.bias[index].clone().detach()
|
| 733 |
+
new_size = list(layer.weight.size())
|
| 734 |
+
new_size[dim] = len(index)
|
| 735 |
+
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
| 736 |
+
new_layer.weight.requires_grad = False
|
| 737 |
+
new_layer.weight.copy_(W.contiguous())
|
| 738 |
+
new_layer.weight.requires_grad = True
|
| 739 |
+
if layer.bias is not None:
|
| 740 |
+
new_layer.bias.requires_grad = False
|
| 741 |
+
new_layer.bias.copy_(b.contiguous())
|
| 742 |
+
new_layer.bias.requires_grad = True
|
| 743 |
+
return new_layer
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def prune_conv1d_layer(layer, index, dim=1):
|
| 747 |
+
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
|
| 748 |
+
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
|
| 749 |
+
Return the pruned layer as a new layer with requires_grad=True.
|
| 750 |
+
Used to remove heads.
|
| 751 |
+
"""
|
| 752 |
+
index = index.to(layer.weight.device)
|
| 753 |
+
W = layer.weight.index_select(dim, index).clone().detach()
|
| 754 |
+
if dim == 0:
|
| 755 |
+
b = layer.bias.clone().detach()
|
| 756 |
+
else:
|
| 757 |
+
b = layer.bias[index].clone().detach()
|
| 758 |
+
new_size = list(layer.weight.size())
|
| 759 |
+
new_size[dim] = len(index)
|
| 760 |
+
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
| 761 |
+
new_layer.weight.requires_grad = False
|
| 762 |
+
new_layer.weight.copy_(W.contiguous())
|
| 763 |
+
new_layer.weight.requires_grad = True
|
| 764 |
+
new_layer.bias.requires_grad = False
|
| 765 |
+
new_layer.bias.copy_(b.contiguous())
|
| 766 |
+
new_layer.bias.requires_grad = True
|
| 767 |
+
return new_layer
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def prune_layer(layer, index, dim=None):
|
| 771 |
+
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
|
| 772 |
+
Return the pruned layer as a new layer with requires_grad=True.
|
| 773 |
+
Used to remove heads.
|
| 774 |
+
"""
|
| 775 |
+
if isinstance(layer, nn.Linear):
|
| 776 |
+
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
|
| 777 |
+
elif isinstance(layer, Conv1D):
|
| 778 |
+
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
| 779 |
+
else:
|
| 780 |
+
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
|
versatile_diffusion/lib/model_zoo/optimus_models/optimus_bert.py
ADDED
|
@@ -0,0 +1,1439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch BERT model. """
|
| 17 |
+
|
| 18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import math
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
from io import open
|
| 26 |
+
|
| 27 |
+
import pdb
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from torch import nn
|
| 31 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 32 |
+
|
| 33 |
+
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
| 34 |
+
from .configuration_bert import BertConfig
|
| 35 |
+
from .file_utils import add_start_docstrings
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
| 40 |
+
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
| 41 |
+
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
| 42 |
+
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
| 43 |
+
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
|
| 44 |
+
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
|
| 45 |
+
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
|
| 46 |
+
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
|
| 47 |
+
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
| 48 |
+
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
| 49 |
+
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
| 50 |
+
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
| 51 |
+
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
| 52 |
+
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
| 56 |
+
""" Load tf checkpoints in a pytorch model.
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
import re
|
| 60 |
+
import numpy as np
|
| 61 |
+
import tensorflow as tf
|
| 62 |
+
except ImportError:
|
| 63 |
+
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 64 |
+
"https://www.tensorflow.org/install/ for installation instructions.")
|
| 65 |
+
raise
|
| 66 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 67 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
| 68 |
+
# Load weights from TF model
|
| 69 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 70 |
+
names = []
|
| 71 |
+
arrays = []
|
| 72 |
+
for name, shape in init_vars:
|
| 73 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
| 74 |
+
array = tf.train.load_variable(tf_path, name)
|
| 75 |
+
names.append(name)
|
| 76 |
+
arrays.append(array)
|
| 77 |
+
|
| 78 |
+
for name, array in zip(names, arrays):
|
| 79 |
+
name = name.split('/')
|
| 80 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 81 |
+
# which are not required for using pretrained model
|
| 82 |
+
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
| 83 |
+
logger.info("Skipping {}".format("/".join(name)))
|
| 84 |
+
continue
|
| 85 |
+
pointer = model
|
| 86 |
+
for m_name in name:
|
| 87 |
+
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
| 88 |
+
l = re.split(r'_(\d+)', m_name)
|
| 89 |
+
else:
|
| 90 |
+
l = [m_name]
|
| 91 |
+
if l[0] == 'kernel' or l[0] == 'gamma':
|
| 92 |
+
pointer = getattr(pointer, 'weight')
|
| 93 |
+
elif l[0] == 'output_bias' or l[0] == 'beta':
|
| 94 |
+
pointer = getattr(pointer, 'bias')
|
| 95 |
+
elif l[0] == 'output_weights':
|
| 96 |
+
pointer = getattr(pointer, 'weight')
|
| 97 |
+
elif l[0] == 'squad':
|
| 98 |
+
pointer = getattr(pointer, 'classifier')
|
| 99 |
+
else:
|
| 100 |
+
try:
|
| 101 |
+
pointer = getattr(pointer, l[0])
|
| 102 |
+
except AttributeError:
|
| 103 |
+
logger.info("Skipping {}".format("/".join(name)))
|
| 104 |
+
continue
|
| 105 |
+
if len(l) >= 2:
|
| 106 |
+
num = int(l[1])
|
| 107 |
+
pointer = pointer[num]
|
| 108 |
+
if m_name[-11:] == '_embeddings':
|
| 109 |
+
pointer = getattr(pointer, 'weight')
|
| 110 |
+
elif m_name == 'kernel':
|
| 111 |
+
array = np.transpose(array)
|
| 112 |
+
try:
|
| 113 |
+
assert pointer.shape == array.shape
|
| 114 |
+
except AssertionError as e:
|
| 115 |
+
e.args += (pointer.shape, array.shape)
|
| 116 |
+
raise
|
| 117 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
| 118 |
+
pointer.data = torch.from_numpy(array)
|
| 119 |
+
return model
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def gelu(x):
|
| 123 |
+
"""Implementation of the gelu activation function.
|
| 124 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
| 125 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
| 126 |
+
Also see https://arxiv.org/abs/1606.08415
|
| 127 |
+
"""
|
| 128 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def swish(x):
|
| 132 |
+
return x * torch.sigmoid(x)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
| 140 |
+
except (ImportError, AttributeError) as e:
|
| 141 |
+
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
| 142 |
+
BertLayerNorm = torch.nn.LayerNorm
|
| 143 |
+
|
| 144 |
+
class BertEmbeddings(nn.Module):
|
| 145 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
| 146 |
+
"""
|
| 147 |
+
def __init__(self, config):
|
| 148 |
+
super(BertEmbeddings, self).__init__()
|
| 149 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
| 150 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 151 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 152 |
+
|
| 153 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 154 |
+
# any TensorFlow checkpoint file
|
| 155 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 156 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 157 |
+
|
| 158 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
| 159 |
+
seq_length = input_ids.size(1)
|
| 160 |
+
if position_ids is None:
|
| 161 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
| 162 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
| 163 |
+
if token_type_ids is None:
|
| 164 |
+
token_type_ids = torch.zeros_like(input_ids)
|
| 165 |
+
|
| 166 |
+
words_embeddings = self.word_embeddings(input_ids)
|
| 167 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 168 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 169 |
+
|
| 170 |
+
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
| 171 |
+
embeddings = self.LayerNorm(embeddings)
|
| 172 |
+
embeddings = self.dropout(embeddings)
|
| 173 |
+
return embeddings
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class BertSelfAttention(nn.Module):
|
| 177 |
+
def __init__(self, config):
|
| 178 |
+
super(BertSelfAttention, self).__init__()
|
| 179 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 182 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
| 183 |
+
self.output_attentions = config.output_attentions
|
| 184 |
+
|
| 185 |
+
self.num_attention_heads = config.num_attention_heads
|
| 186 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 187 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 188 |
+
|
| 189 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 190 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 191 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 192 |
+
|
| 193 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 194 |
+
|
| 195 |
+
def transpose_for_scores(self, x):
|
| 196 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 197 |
+
x = x.view(*new_x_shape)
|
| 198 |
+
return x.permute(0, 2, 1, 3)
|
| 199 |
+
|
| 200 |
+
def forward(self, hidden_states, attention_mask, head_mask=None):
|
| 201 |
+
mixed_query_layer = self.query(hidden_states)
|
| 202 |
+
mixed_key_layer = self.key(hidden_states)
|
| 203 |
+
mixed_value_layer = self.value(hidden_states)
|
| 204 |
+
|
| 205 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 206 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 207 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
| 208 |
+
|
| 209 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 210 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 211 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 212 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 213 |
+
attention_scores = attention_scores + attention_mask
|
| 214 |
+
|
| 215 |
+
# Normalize the attention scores to probabilities.
|
| 216 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 217 |
+
|
| 218 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 219 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 220 |
+
attention_probs = self.dropout(attention_probs)
|
| 221 |
+
|
| 222 |
+
# Mask heads if we want to
|
| 223 |
+
if head_mask is not None:
|
| 224 |
+
attention_probs = attention_probs * head_mask
|
| 225 |
+
|
| 226 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 227 |
+
|
| 228 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 229 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 230 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 231 |
+
|
| 232 |
+
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
|
| 233 |
+
return outputs
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class BertSelfOutput(nn.Module):
|
| 237 |
+
def __init__(self, config):
|
| 238 |
+
super(BertSelfOutput, self).__init__()
|
| 239 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 240 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 241 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 242 |
+
|
| 243 |
+
def forward(self, hidden_states, input_tensor):
|
| 244 |
+
hidden_states = self.dense(hidden_states)
|
| 245 |
+
hidden_states = self.dropout(hidden_states)
|
| 246 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 247 |
+
return hidden_states
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class BertAttention(nn.Module):
|
| 251 |
+
def __init__(self, config):
|
| 252 |
+
super(BertAttention, self).__init__()
|
| 253 |
+
self.self = BertSelfAttention(config)
|
| 254 |
+
self.output = BertSelfOutput(config)
|
| 255 |
+
self.pruned_heads = set()
|
| 256 |
+
|
| 257 |
+
def prune_heads(self, heads):
|
| 258 |
+
if len(heads) == 0:
|
| 259 |
+
return
|
| 260 |
+
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
| 261 |
+
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
| 262 |
+
for head in heads:
|
| 263 |
+
# Compute how many pruned heads are before the head and move the index accordingly
|
| 264 |
+
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
| 265 |
+
mask[head] = 0
|
| 266 |
+
mask = mask.view(-1).contiguous().eq(1)
|
| 267 |
+
index = torch.arange(len(mask))[mask].long()
|
| 268 |
+
|
| 269 |
+
# Prune linear layers
|
| 270 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 271 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 272 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 273 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 274 |
+
|
| 275 |
+
# Update hyper params and store pruned heads
|
| 276 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 277 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 278 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 279 |
+
|
| 280 |
+
def forward(self, input_tensor, attention_mask, head_mask=None):
|
| 281 |
+
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
| 282 |
+
attention_output = self.output(self_outputs[0], input_tensor)
|
| 283 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 284 |
+
return outputs
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class BertIntermediate(nn.Module):
|
| 288 |
+
def __init__(self, config):
|
| 289 |
+
super(BertIntermediate, self).__init__()
|
| 290 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 291 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
| 292 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 293 |
+
else:
|
| 294 |
+
self.intermediate_act_fn = config.hidden_act
|
| 295 |
+
|
| 296 |
+
def forward(self, hidden_states):
|
| 297 |
+
hidden_states = self.dense(hidden_states)
|
| 298 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 299 |
+
return hidden_states
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class BertOutput(nn.Module):
|
| 303 |
+
def __init__(self, config):
|
| 304 |
+
super(BertOutput, self).__init__()
|
| 305 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 306 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 307 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 308 |
+
|
| 309 |
+
def forward(self, hidden_states, input_tensor):
|
| 310 |
+
hidden_states = self.dense(hidden_states)
|
| 311 |
+
hidden_states = self.dropout(hidden_states)
|
| 312 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 313 |
+
return hidden_states
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class BertLayer(nn.Module):
|
| 317 |
+
def __init__(self, config):
|
| 318 |
+
super(BertLayer, self).__init__()
|
| 319 |
+
self.attention = BertAttention(config)
|
| 320 |
+
self.intermediate = BertIntermediate(config)
|
| 321 |
+
self.output = BertOutput(config)
|
| 322 |
+
|
| 323 |
+
def forward(self, hidden_states, attention_mask, head_mask=None):
|
| 324 |
+
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
| 325 |
+
attention_output = attention_outputs[0]
|
| 326 |
+
intermediate_output = self.intermediate(attention_output)
|
| 327 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 328 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
| 329 |
+
return outputs
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class BertEncoder(nn.Module):
|
| 333 |
+
def __init__(self, config):
|
| 334 |
+
super(BertEncoder, self).__init__()
|
| 335 |
+
self.output_attentions = config.output_attentions
|
| 336 |
+
self.output_hidden_states = config.output_hidden_states
|
| 337 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
| 338 |
+
|
| 339 |
+
def forward(self, hidden_states, attention_mask, head_mask=None):
|
| 340 |
+
all_hidden_states = ()
|
| 341 |
+
all_attentions = ()
|
| 342 |
+
for i, layer_module in enumerate(self.layer):
|
| 343 |
+
if self.output_hidden_states:
|
| 344 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 345 |
+
|
| 346 |
+
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
| 347 |
+
hidden_states = layer_outputs[0]
|
| 348 |
+
|
| 349 |
+
if self.output_attentions:
|
| 350 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 351 |
+
|
| 352 |
+
# Add last layer
|
| 353 |
+
if self.output_hidden_states:
|
| 354 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 355 |
+
|
| 356 |
+
outputs = (hidden_states,)
|
| 357 |
+
if self.output_hidden_states:
|
| 358 |
+
outputs = outputs + (all_hidden_states,)
|
| 359 |
+
if self.output_attentions:
|
| 360 |
+
outputs = outputs + (all_attentions,)
|
| 361 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class BertPooler(nn.Module):
|
| 365 |
+
def __init__(self, config):
|
| 366 |
+
super(BertPooler, self).__init__()
|
| 367 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 368 |
+
self.activation = nn.Tanh()
|
| 369 |
+
|
| 370 |
+
def forward(self, hidden_states):
|
| 371 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 372 |
+
# to the first token.
|
| 373 |
+
first_token_tensor = hidden_states[:, 0]
|
| 374 |
+
pooled_output = self.dense(first_token_tensor)
|
| 375 |
+
pooled_output = self.activation(pooled_output)
|
| 376 |
+
return pooled_output
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 380 |
+
def __init__(self, config):
|
| 381 |
+
super(BertPredictionHeadTransform, self).__init__()
|
| 382 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 383 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
| 384 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 385 |
+
else:
|
| 386 |
+
self.transform_act_fn = config.hidden_act
|
| 387 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 388 |
+
|
| 389 |
+
def forward(self, hidden_states):
|
| 390 |
+
hidden_states = self.dense(hidden_states)
|
| 391 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 392 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 393 |
+
return hidden_states
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class BertLMPredictionHead(nn.Module):
|
| 397 |
+
def __init__(self, config):
|
| 398 |
+
super(BertLMPredictionHead, self).__init__()
|
| 399 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 400 |
+
|
| 401 |
+
# The output weights are the same as the input embeddings, but there is
|
| 402 |
+
# an output-only bias for each token.
|
| 403 |
+
self.decoder = nn.Linear(config.hidden_size,
|
| 404 |
+
config.vocab_size,
|
| 405 |
+
bias=False)
|
| 406 |
+
|
| 407 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 408 |
+
|
| 409 |
+
def forward(self, hidden_states):
|
| 410 |
+
hidden_states = self.transform(hidden_states)
|
| 411 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
| 412 |
+
return hidden_states
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class BertOnlyMLMHead(nn.Module):
|
| 416 |
+
def __init__(self, config):
|
| 417 |
+
super(BertOnlyMLMHead, self).__init__()
|
| 418 |
+
self.predictions = BertLMPredictionHead(config)
|
| 419 |
+
|
| 420 |
+
def forward(self, sequence_output):
|
| 421 |
+
prediction_scores = self.predictions(sequence_output)
|
| 422 |
+
return prediction_scores
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class BertOnlyNSPHead(nn.Module):
|
| 426 |
+
def __init__(self, config):
|
| 427 |
+
super(BertOnlyNSPHead, self).__init__()
|
| 428 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 429 |
+
|
| 430 |
+
def forward(self, pooled_output):
|
| 431 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 432 |
+
return seq_relationship_score
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class BertPreTrainingHeads(nn.Module):
|
| 436 |
+
def __init__(self, config):
|
| 437 |
+
super(BertPreTrainingHeads, self).__init__()
|
| 438 |
+
self.predictions = BertLMPredictionHead(config)
|
| 439 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 440 |
+
|
| 441 |
+
def forward(self, sequence_output, pooled_output):
|
| 442 |
+
prediction_scores = self.predictions(sequence_output)
|
| 443 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 444 |
+
return prediction_scores, seq_relationship_score
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 448 |
+
""" An abstract class to handle weights initialization and
|
| 449 |
+
a simple interface for dowloading and loading pretrained models.
|
| 450 |
+
"""
|
| 451 |
+
config_class = BertConfig
|
| 452 |
+
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
| 453 |
+
load_tf_weights = load_tf_weights_in_bert
|
| 454 |
+
base_model_prefix = "bert"
|
| 455 |
+
|
| 456 |
+
def _init_weights(self, module):
|
| 457 |
+
""" Initialize the weights """
|
| 458 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 459 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 460 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 461 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 462 |
+
elif isinstance(module, BertLayerNorm):
|
| 463 |
+
module.bias.data.zero_()
|
| 464 |
+
module.weight.data.fill_(1.0)
|
| 465 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 466 |
+
module.bias.data.zero_()
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
BERT_START_DOCSTRING = r""" The BERT model was proposed in
|
| 470 |
+
`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
|
| 471 |
+
by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
|
| 472 |
+
pre-trained using a combination of masked language modeling objective and next sentence prediction
|
| 473 |
+
on a large corpus comprising the Toronto Book Corpus and Wikipedia.
|
| 474 |
+
|
| 475 |
+
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
| 476 |
+
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
| 477 |
+
|
| 478 |
+
.. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
|
| 479 |
+
https://arxiv.org/abs/1810.04805
|
| 480 |
+
|
| 481 |
+
.. _`torch.nn.Module`:
|
| 482 |
+
https://pytorch.org/docs/stable/nn.html#module
|
| 483 |
+
|
| 484 |
+
Parameters:
|
| 485 |
+
config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
| 486 |
+
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
| 487 |
+
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
BERT_INPUTS_DOCSTRING = r"""
|
| 491 |
+
Inputs:
|
| 492 |
+
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 493 |
+
Indices of input sequence tokens in the vocabulary.
|
| 494 |
+
To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
| 495 |
+
|
| 496 |
+
(a) For sequence pairs:
|
| 497 |
+
|
| 498 |
+
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
|
| 499 |
+
|
| 500 |
+
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
| 501 |
+
|
| 502 |
+
(b) For single sequences:
|
| 503 |
+
|
| 504 |
+
``tokens: [CLS] the dog is hairy . [SEP]``
|
| 505 |
+
|
| 506 |
+
``token_type_ids: 0 0 0 0 0 0 0``
|
| 507 |
+
|
| 508 |
+
Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
| 509 |
+
the right rather than the left.
|
| 510 |
+
|
| 511 |
+
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
|
| 512 |
+
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
| 513 |
+
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
| 514 |
+
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
| 515 |
+
Mask to avoid performing attention on padding token indices.
|
| 516 |
+
Mask values selected in ``[0, 1]``:
|
| 517 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
| 518 |
+
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 519 |
+
Segment token indices to indicate first and second portions of the inputs.
|
| 520 |
+
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
| 521 |
+
corresponds to a `sentence B` token
|
| 522 |
+
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
|
| 523 |
+
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 524 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
| 525 |
+
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
| 526 |
+
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
| 527 |
+
Mask to nullify selected heads of the self-attention modules.
|
| 528 |
+
Mask values selected in ``[0, 1]``:
|
| 529 |
+
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
| 530 |
+
"""
|
| 531 |
+
|
| 532 |
+
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 533 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 534 |
+
class BertModel(BertPreTrainedModel):
|
| 535 |
+
r"""
|
| 536 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 537 |
+
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
| 538 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 539 |
+
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
| 540 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
| 541 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
| 542 |
+
layer weights are trained from the next sentence prediction (classification)
|
| 543 |
+
objective during Bert pretraining. This output is usually *not* a good summary
|
| 544 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
| 545 |
+
the sequence of hidden-states for the whole input sequence.
|
| 546 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 547 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 548 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 549 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 550 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 551 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 552 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 553 |
+
|
| 554 |
+
Examples::
|
| 555 |
+
|
| 556 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 557 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
| 558 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 559 |
+
outputs = model(input_ids)
|
| 560 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
| 561 |
+
|
| 562 |
+
"""
|
| 563 |
+
def __init__(self, config):
|
| 564 |
+
super(BertModel, self).__init__(config)
|
| 565 |
+
|
| 566 |
+
self.embeddings = BertEmbeddings(config)
|
| 567 |
+
self.encoder = BertEncoder(config)
|
| 568 |
+
self.pooler = BertPooler(config)
|
| 569 |
+
|
| 570 |
+
self.init_weights()
|
| 571 |
+
|
| 572 |
+
def _resize_token_embeddings(self, new_num_tokens):
|
| 573 |
+
old_embeddings = self.embeddings.word_embeddings
|
| 574 |
+
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
| 575 |
+
self.embeddings.word_embeddings = new_embeddings
|
| 576 |
+
return self.embeddings.word_embeddings
|
| 577 |
+
|
| 578 |
+
def _prune_heads(self, heads_to_prune):
|
| 579 |
+
""" Prunes heads of the model.
|
| 580 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 581 |
+
See base class PreTrainedModel
|
| 582 |
+
"""
|
| 583 |
+
for layer, heads in heads_to_prune.items():
|
| 584 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 585 |
+
|
| 586 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
| 587 |
+
if attention_mask is None:
|
| 588 |
+
attention_mask = torch.ones_like(input_ids)
|
| 589 |
+
if token_type_ids is None:
|
| 590 |
+
token_type_ids = torch.zeros_like(input_ids)
|
| 591 |
+
|
| 592 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 593 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 594 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 595 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 596 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 597 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 598 |
+
|
| 599 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 600 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 601 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 602 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 603 |
+
# effectively the same as removing these entirely.
|
| 604 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
| 605 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 606 |
+
|
| 607 |
+
# Prepare head mask if needed
|
| 608 |
+
# 1.0 in head_mask indicate we keep the head
|
| 609 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 610 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 611 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 612 |
+
if head_mask is not None:
|
| 613 |
+
if head_mask.dim() == 1:
|
| 614 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 615 |
+
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
| 616 |
+
elif head_mask.dim() == 2:
|
| 617 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
| 618 |
+
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
| 619 |
+
else:
|
| 620 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 621 |
+
|
| 622 |
+
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
| 623 |
+
encoder_outputs = self.encoder(embedding_output,
|
| 624 |
+
extended_attention_mask,
|
| 625 |
+
head_mask=head_mask)
|
| 626 |
+
sequence_output = encoder_outputs[0]
|
| 627 |
+
pooled_output = self.pooler(sequence_output)
|
| 628 |
+
|
| 629 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
| 630 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 638 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 639 |
+
class BertForLatentConnector(BertPreTrainedModel):
|
| 640 |
+
r"""
|
| 641 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 642 |
+
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
| 643 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 644 |
+
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
| 645 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
| 646 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
| 647 |
+
layer weights are trained from the next sentence prediction (classification)
|
| 648 |
+
objective during Bert pretraining. This output is usually *not* a good summary
|
| 649 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
| 650 |
+
the sequence of hidden-states for the whole input sequence.
|
| 651 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 652 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 653 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 654 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 655 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 656 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 657 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 658 |
+
|
| 659 |
+
Examples::
|
| 660 |
+
|
| 661 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 662 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
| 663 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 664 |
+
outputs = model(input_ids)
|
| 665 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
| 666 |
+
|
| 667 |
+
"""
|
| 668 |
+
def __init__(self, config, latent_size):
|
| 669 |
+
super(BertForLatentConnector, self).__init__(config)
|
| 670 |
+
|
| 671 |
+
self.embeddings = BertEmbeddings(config)
|
| 672 |
+
self.encoder = BertEncoder(config)
|
| 673 |
+
self.pooler = BertPooler(config)
|
| 674 |
+
|
| 675 |
+
self.linear = nn.Linear(config.hidden_size, 2 * latent_size, bias=False)
|
| 676 |
+
|
| 677 |
+
self.init_weights()
|
| 678 |
+
|
| 679 |
+
def _resize_token_embeddings(self, new_num_tokens):
|
| 680 |
+
old_embeddings = self.embeddings.word_embeddings
|
| 681 |
+
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
| 682 |
+
self.embeddings.word_embeddings = new_embeddings
|
| 683 |
+
return self.embeddings.word_embeddings
|
| 684 |
+
|
| 685 |
+
def _prune_heads(self, heads_to_prune):
|
| 686 |
+
""" Prunes heads of the model.
|
| 687 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 688 |
+
See base class PreTrainedModel
|
| 689 |
+
"""
|
| 690 |
+
for layer, heads in heads_to_prune.items():
|
| 691 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 692 |
+
|
| 693 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
| 694 |
+
if attention_mask is None:
|
| 695 |
+
attention_mask = torch.ones_like(input_ids)
|
| 696 |
+
if token_type_ids is None:
|
| 697 |
+
token_type_ids = torch.zeros_like(input_ids)
|
| 698 |
+
|
| 699 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 700 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 701 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 702 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 703 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 704 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 705 |
+
|
| 706 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 707 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 708 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 709 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 710 |
+
# effectively the same as removing these entirely.
|
| 711 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
| 712 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 713 |
+
|
| 714 |
+
# Prepare head mask if needed
|
| 715 |
+
# 1.0 in head_mask indicate we keep the head
|
| 716 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 717 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 718 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 719 |
+
if head_mask is not None:
|
| 720 |
+
if head_mask.dim() == 1:
|
| 721 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 722 |
+
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
| 723 |
+
elif head_mask.dim() == 2:
|
| 724 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
| 725 |
+
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
| 726 |
+
else:
|
| 727 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 728 |
+
|
| 729 |
+
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
| 730 |
+
encoder_outputs = self.encoder(embedding_output,
|
| 731 |
+
extended_attention_mask,
|
| 732 |
+
head_mask=head_mask)
|
| 733 |
+
sequence_output = encoder_outputs[0]
|
| 734 |
+
pooled_output = self.pooler(sequence_output)
|
| 735 |
+
|
| 736 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
| 737 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
| 742 |
+
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
| 743 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 744 |
+
class BertForPreTraining(BertPreTrainedModel):
|
| 745 |
+
r"""
|
| 746 |
+
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 747 |
+
Labels for computing the masked language modeling loss.
|
| 748 |
+
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
| 749 |
+
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
|
| 750 |
+
in ``[0, ..., config.vocab_size]``
|
| 751 |
+
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
| 752 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
|
| 753 |
+
Indices should be in ``[0, 1]``.
|
| 754 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
| 755 |
+
``1`` indicates sequence B is a random sequence.
|
| 756 |
+
|
| 757 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 758 |
+
**loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 759 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
|
| 760 |
+
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
| 761 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 762 |
+
**seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
|
| 763 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
|
| 764 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 765 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 766 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 767 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 768 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 769 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 770 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 771 |
+
|
| 772 |
+
Examples::
|
| 773 |
+
|
| 774 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 775 |
+
model = BertForPreTraining.from_pretrained('bert-base-uncased')
|
| 776 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 777 |
+
outputs = model(input_ids)
|
| 778 |
+
prediction_scores, seq_relationship_scores = outputs[:2]
|
| 779 |
+
|
| 780 |
+
"""
|
| 781 |
+
def __init__(self, config):
|
| 782 |
+
super(BertForPreTraining, self).__init__(config)
|
| 783 |
+
|
| 784 |
+
self.bert = BertModel(config)
|
| 785 |
+
self.cls = BertPreTrainingHeads(config)
|
| 786 |
+
|
| 787 |
+
self.init_weights()
|
| 788 |
+
self.tie_weights()
|
| 789 |
+
|
| 790 |
+
def tie_weights(self):
|
| 791 |
+
""" Make sure we are sharing the input and output embeddings.
|
| 792 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
| 793 |
+
"""
|
| 794 |
+
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
| 795 |
+
self.bert.embeddings.word_embeddings)
|
| 796 |
+
|
| 797 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
| 798 |
+
masked_lm_labels=None, next_sentence_label=None):
|
| 799 |
+
|
| 800 |
+
outputs = self.bert(input_ids,
|
| 801 |
+
attention_mask=attention_mask,
|
| 802 |
+
token_type_ids=token_type_ids,
|
| 803 |
+
position_ids=position_ids,
|
| 804 |
+
head_mask=head_mask)
|
| 805 |
+
|
| 806 |
+
sequence_output, pooled_output = outputs[:2]
|
| 807 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 808 |
+
|
| 809 |
+
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
| 810 |
+
|
| 811 |
+
if masked_lm_labels is not None and next_sentence_label is not None:
|
| 812 |
+
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
| 813 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
| 814 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
| 815 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
| 816 |
+
outputs = (total_loss,) + outputs
|
| 817 |
+
|
| 818 |
+
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
|
| 822 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 823 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
| 824 |
+
r"""
|
| 825 |
+
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 826 |
+
Labels for computing the masked language modeling loss.
|
| 827 |
+
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
| 828 |
+
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
|
| 829 |
+
in ``[0, ..., config.vocab_size]``
|
| 830 |
+
|
| 831 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 832 |
+
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 833 |
+
Masked language modeling loss.
|
| 834 |
+
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
| 835 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 836 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 837 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 838 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 839 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 840 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 841 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 842 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 843 |
+
|
| 844 |
+
Examples::
|
| 845 |
+
|
| 846 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 847 |
+
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
| 848 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 849 |
+
outputs = model(input_ids, masked_lm_labels=input_ids)
|
| 850 |
+
loss, prediction_scores = outputs[:2]
|
| 851 |
+
|
| 852 |
+
"""
|
| 853 |
+
def __init__(self, config):
|
| 854 |
+
super(BertForMaskedLM, self).__init__(config)
|
| 855 |
+
|
| 856 |
+
self.bert = BertModel(config)
|
| 857 |
+
self.cls = BertOnlyMLMHead(config)
|
| 858 |
+
|
| 859 |
+
self.init_weights()
|
| 860 |
+
self.tie_weights()
|
| 861 |
+
|
| 862 |
+
def tie_weights(self):
|
| 863 |
+
""" Make sure we are sharing the input and output embeddings.
|
| 864 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
| 865 |
+
"""
|
| 866 |
+
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
| 867 |
+
self.bert.embeddings.word_embeddings)
|
| 868 |
+
|
| 869 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
| 870 |
+
masked_lm_labels=None):
|
| 871 |
+
|
| 872 |
+
outputs = self.bert(input_ids,
|
| 873 |
+
attention_mask=attention_mask,
|
| 874 |
+
token_type_ids=token_type_ids,
|
| 875 |
+
position_ids=position_ids,
|
| 876 |
+
head_mask=head_mask)
|
| 877 |
+
|
| 878 |
+
sequence_output = outputs[0]
|
| 879 |
+
prediction_scores = self.cls(sequence_output)
|
| 880 |
+
|
| 881 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
| 882 |
+
if masked_lm_labels is not None:
|
| 883 |
+
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
| 884 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
| 885 |
+
outputs = (masked_lm_loss,) + outputs
|
| 886 |
+
|
| 887 |
+
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
| 891 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 892 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
| 893 |
+
r"""
|
| 894 |
+
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
| 895 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
|
| 896 |
+
Indices should be in ``[0, 1]``.
|
| 897 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
| 898 |
+
``1`` indicates sequence B is a random sequence.
|
| 899 |
+
|
| 900 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 901 |
+
**loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 902 |
+
Next sequence prediction (classification) loss.
|
| 903 |
+
**seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
|
| 904 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
|
| 905 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 906 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 907 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 908 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 909 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 910 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 911 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 912 |
+
|
| 913 |
+
Examples::
|
| 914 |
+
|
| 915 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 916 |
+
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
| 917 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 918 |
+
outputs = model(input_ids)
|
| 919 |
+
seq_relationship_scores = outputs[0]
|
| 920 |
+
|
| 921 |
+
"""
|
| 922 |
+
def __init__(self, config):
|
| 923 |
+
super(BertForNextSentencePrediction, self).__init__(config)
|
| 924 |
+
|
| 925 |
+
self.bert = BertModel(config)
|
| 926 |
+
self.cls = BertOnlyNSPHead(config)
|
| 927 |
+
|
| 928 |
+
self.init_weights()
|
| 929 |
+
|
| 930 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
| 931 |
+
next_sentence_label=None):
|
| 932 |
+
|
| 933 |
+
outputs = self.bert(input_ids,
|
| 934 |
+
attention_mask=attention_mask,
|
| 935 |
+
token_type_ids=token_type_ids,
|
| 936 |
+
position_ids=position_ids,
|
| 937 |
+
head_mask=head_mask)
|
| 938 |
+
|
| 939 |
+
pooled_output = outputs[1]
|
| 940 |
+
|
| 941 |
+
seq_relationship_score = self.cls(pooled_output)
|
| 942 |
+
|
| 943 |
+
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
| 944 |
+
if next_sentence_label is not None:
|
| 945 |
+
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
| 946 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
| 947 |
+
outputs = (next_sentence_loss,) + outputs
|
| 948 |
+
|
| 949 |
+
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
| 953 |
+
the pooled output) e.g. for GLUE tasks. """,
|
| 954 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 955 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
| 956 |
+
r"""
|
| 957 |
+
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
| 958 |
+
Labels for computing the sequence classification/regression loss.
|
| 959 |
+
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
| 960 |
+
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
| 961 |
+
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
| 962 |
+
|
| 963 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 964 |
+
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 965 |
+
Classification (or regression if config.num_labels==1) loss.
|
| 966 |
+
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
| 967 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
| 968 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 969 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 970 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 971 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 972 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 973 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 974 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 975 |
+
|
| 976 |
+
Examples::
|
| 977 |
+
|
| 978 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 979 |
+
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
| 980 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 981 |
+
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
| 982 |
+
outputs = model(input_ids, labels=labels)
|
| 983 |
+
loss, logits = outputs[:2]
|
| 984 |
+
|
| 985 |
+
"""
|
| 986 |
+
def __init__(self, config):
|
| 987 |
+
super(BertForSequenceClassification, self).__init__(config)
|
| 988 |
+
self.num_labels = config.num_labels
|
| 989 |
+
|
| 990 |
+
self.bert = BertModel(config)
|
| 991 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 992 |
+
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
| 993 |
+
self.use_freeze = False
|
| 994 |
+
|
| 995 |
+
self.init_weights()
|
| 996 |
+
|
| 997 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
| 998 |
+
position_ids=None, head_mask=None, labels=None):
|
| 999 |
+
|
| 1000 |
+
outputs = self.bert(input_ids,
|
| 1001 |
+
attention_mask=attention_mask,
|
| 1002 |
+
token_type_ids=token_type_ids,
|
| 1003 |
+
position_ids=position_ids,
|
| 1004 |
+
head_mask=head_mask)
|
| 1005 |
+
|
| 1006 |
+
pooled_output = outputs[1]
|
| 1007 |
+
|
| 1008 |
+
if self.use_freeze:
|
| 1009 |
+
pooled_output = pooled_output.detach()
|
| 1010 |
+
|
| 1011 |
+
pooled_output = self.dropout(pooled_output)
|
| 1012 |
+
logits = self.classifier(pooled_output)
|
| 1013 |
+
|
| 1014 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1015 |
+
|
| 1016 |
+
if labels is not None:
|
| 1017 |
+
if self.num_labels == 1:
|
| 1018 |
+
# We are doing regression
|
| 1019 |
+
loss_fct = MSELoss()
|
| 1020 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
| 1021 |
+
else:
|
| 1022 |
+
loss_fct = CrossEntropyLoss()
|
| 1023 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1024 |
+
outputs = (loss,) + outputs
|
| 1025 |
+
|
| 1026 |
+
# pdb.set_trace()
|
| 1027 |
+
return outputs, pooled_output # (loss), logits, (hidden_states), (attentions)
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
| 1031 |
+
the pooled output) e.g. for GLUE tasks. """,
|
| 1032 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 1033 |
+
class BertForSequenceClassificationLatentConnector(BertPreTrainedModel):
|
| 1034 |
+
r"""
|
| 1035 |
+
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
| 1036 |
+
Labels for computing the sequence classification/regression loss.
|
| 1037 |
+
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
| 1038 |
+
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
| 1039 |
+
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
| 1040 |
+
|
| 1041 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 1042 |
+
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 1043 |
+
Classification (or regression if config.num_labels==1) loss.
|
| 1044 |
+
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
| 1045 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
| 1046 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 1047 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 1048 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 1049 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1050 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 1051 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 1052 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 1053 |
+
|
| 1054 |
+
Examples::
|
| 1055 |
+
|
| 1056 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 1057 |
+
model = BertForSequenceClassificationLatentConnector.from_pretrained('bert-base-uncased')
|
| 1058 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 1059 |
+
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
| 1060 |
+
outputs = model(input_ids, labels=labels)
|
| 1061 |
+
loss, logits = outputs[:2]
|
| 1062 |
+
|
| 1063 |
+
"""
|
| 1064 |
+
def __init__(self, config, latent_size):
|
| 1065 |
+
super(BertForSequenceClassificationLatentConnector, self).__init__(config)
|
| 1066 |
+
self.num_labels = config.num_labels
|
| 1067 |
+
|
| 1068 |
+
self.bert = BertModel(config)
|
| 1069 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1070 |
+
|
| 1071 |
+
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
| 1072 |
+
self.linear = nn.Linear(config.hidden_size, 2 * latent_size, bias=False)
|
| 1073 |
+
self.use_freeze = False
|
| 1074 |
+
|
| 1075 |
+
self.init_weights()
|
| 1076 |
+
|
| 1077 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
| 1078 |
+
position_ids=None, head_mask=None, labels=None):
|
| 1079 |
+
|
| 1080 |
+
outputs = self.bert(input_ids,
|
| 1081 |
+
attention_mask=attention_mask,
|
| 1082 |
+
token_type_ids=token_type_ids,
|
| 1083 |
+
position_ids=position_ids,
|
| 1084 |
+
head_mask=head_mask)
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
pooled_output = outputs[1]
|
| 1088 |
+
# mean, logvar = self.linear(pooled_output).chunk(2, -1)
|
| 1089 |
+
|
| 1090 |
+
if self.use_freeze:
|
| 1091 |
+
pooled_output = pooled_output.detach()
|
| 1092 |
+
|
| 1093 |
+
pooled_output = self.dropout(pooled_output)
|
| 1094 |
+
logits = self.classifier(pooled_output)
|
| 1095 |
+
|
| 1096 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1097 |
+
|
| 1098 |
+
if labels is not None:
|
| 1099 |
+
if self.num_labels == 1:
|
| 1100 |
+
# We are doing regression
|
| 1101 |
+
loss_fct = MSELoss()
|
| 1102 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
| 1103 |
+
else:
|
| 1104 |
+
loss_fct = CrossEntropyLoss()
|
| 1105 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1106 |
+
outputs = (loss,) + outputs
|
| 1107 |
+
|
| 1108 |
+
return outputs, pooled_output # (loss), logits, (hidden_states), (attentions)
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
| 1112 |
+
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
| 1113 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 1114 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
| 1115 |
+
r"""
|
| 1116 |
+
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
| 1117 |
+
Labels for computing the multiple choice classification loss.
|
| 1118 |
+
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
| 1119 |
+
of the input tensors. (see `input_ids` above)
|
| 1120 |
+
|
| 1121 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 1122 |
+
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 1123 |
+
Classification loss.
|
| 1124 |
+
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
|
| 1125 |
+
of the input tensors. (see `input_ids` above).
|
| 1126 |
+
Classification scores (before SoftMax).
|
| 1127 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 1128 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 1129 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 1130 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1131 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 1132 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 1133 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 1134 |
+
|
| 1135 |
+
Examples::
|
| 1136 |
+
|
| 1137 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 1138 |
+
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
|
| 1139 |
+
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
| 1140 |
+
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
| 1141 |
+
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
| 1142 |
+
outputs = model(input_ids, labels=labels)
|
| 1143 |
+
loss, classification_scores = outputs[:2]
|
| 1144 |
+
|
| 1145 |
+
"""
|
| 1146 |
+
def __init__(self, config):
|
| 1147 |
+
super(BertForMultipleChoice, self).__init__(config)
|
| 1148 |
+
|
| 1149 |
+
self.bert = BertModel(config)
|
| 1150 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1151 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1152 |
+
|
| 1153 |
+
self.init_weights()
|
| 1154 |
+
|
| 1155 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
| 1156 |
+
position_ids=None, head_mask=None, labels=None):
|
| 1157 |
+
num_choices = input_ids.shape[1]
|
| 1158 |
+
|
| 1159 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
| 1160 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1161 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1162 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1163 |
+
|
| 1164 |
+
outputs = self.bert(input_ids,
|
| 1165 |
+
attention_mask=attention_mask,
|
| 1166 |
+
token_type_ids=token_type_ids,
|
| 1167 |
+
position_ids=position_ids,
|
| 1168 |
+
head_mask=head_mask)
|
| 1169 |
+
|
| 1170 |
+
pooled_output = outputs[1]
|
| 1171 |
+
|
| 1172 |
+
pooled_output = self.dropout(pooled_output)
|
| 1173 |
+
logits = self.classifier(pooled_output)
|
| 1174 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1175 |
+
|
| 1176 |
+
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1177 |
+
|
| 1178 |
+
if labels is not None:
|
| 1179 |
+
loss_fct = CrossEntropyLoss()
|
| 1180 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1181 |
+
outputs = (loss,) + outputs
|
| 1182 |
+
|
| 1183 |
+
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
| 1184 |
+
|
| 1185 |
+
|
| 1186 |
+
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
|
| 1187 |
+
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
| 1188 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 1189 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
| 1190 |
+
r"""
|
| 1191 |
+
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 1192 |
+
Labels for computing the token classification loss.
|
| 1193 |
+
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
| 1194 |
+
|
| 1195 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 1196 |
+
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 1197 |
+
Classification loss.
|
| 1198 |
+
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
| 1199 |
+
Classification scores (before SoftMax).
|
| 1200 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 1201 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 1202 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 1203 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1204 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 1205 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 1206 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 1207 |
+
|
| 1208 |
+
Examples::
|
| 1209 |
+
|
| 1210 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 1211 |
+
model = BertForTokenClassification.from_pretrained('bert-base-uncased')
|
| 1212 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 1213 |
+
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
| 1214 |
+
outputs = model(input_ids, labels=labels)
|
| 1215 |
+
loss, scores = outputs[:2]
|
| 1216 |
+
|
| 1217 |
+
"""
|
| 1218 |
+
def __init__(self, config):
|
| 1219 |
+
super(BertForTokenClassification, self).__init__(config)
|
| 1220 |
+
self.num_labels = config.num_labels
|
| 1221 |
+
|
| 1222 |
+
self.bert = BertModel(config)
|
| 1223 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1224 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1225 |
+
|
| 1226 |
+
self.init_weights()
|
| 1227 |
+
|
| 1228 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
| 1229 |
+
position_ids=None, head_mask=None, labels=None):
|
| 1230 |
+
|
| 1231 |
+
outputs = self.bert(input_ids,
|
| 1232 |
+
attention_mask=attention_mask,
|
| 1233 |
+
token_type_ids=token_type_ids,
|
| 1234 |
+
position_ids=position_ids,
|
| 1235 |
+
head_mask=head_mask)
|
| 1236 |
+
|
| 1237 |
+
sequence_output = outputs[0]
|
| 1238 |
+
|
| 1239 |
+
sequence_output = self.dropout(sequence_output)
|
| 1240 |
+
logits = self.classifier(sequence_output)
|
| 1241 |
+
|
| 1242 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1243 |
+
if labels is not None:
|
| 1244 |
+
loss_fct = CrossEntropyLoss()
|
| 1245 |
+
# Only keep active parts of the loss
|
| 1246 |
+
if attention_mask is not None:
|
| 1247 |
+
active_loss = attention_mask.view(-1) == 1
|
| 1248 |
+
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
| 1249 |
+
active_labels = labels.view(-1)[active_loss]
|
| 1250 |
+
loss = loss_fct(active_logits, active_labels)
|
| 1251 |
+
else:
|
| 1252 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1253 |
+
outputs = (loss,) + outputs
|
| 1254 |
+
|
| 1255 |
+
return outputs # (loss), scores, (hidden_states), (attentions)
|
| 1256 |
+
|
| 1257 |
+
|
| 1258 |
+
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
| 1259 |
+
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
| 1260 |
+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
| 1261 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
| 1262 |
+
r"""
|
| 1263 |
+
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
| 1264 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1265 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
| 1266 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
| 1267 |
+
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
| 1268 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1269 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
| 1270 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
| 1271 |
+
|
| 1272 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 1273 |
+
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 1274 |
+
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
| 1275 |
+
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
| 1276 |
+
Span-start scores (before SoftMax).
|
| 1277 |
+
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
| 1278 |
+
Span-end scores (before SoftMax).
|
| 1279 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 1280 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 1281 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 1282 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1283 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 1284 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 1285 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 1286 |
+
|
| 1287 |
+
Examples::
|
| 1288 |
+
|
| 1289 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 1290 |
+
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
|
| 1291 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 1292 |
+
start_positions = torch.tensor([1])
|
| 1293 |
+
end_positions = torch.tensor([3])
|
| 1294 |
+
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
| 1295 |
+
loss, start_scores, end_scores = outputs[:2]
|
| 1296 |
+
|
| 1297 |
+
"""
|
| 1298 |
+
def __init__(self, config):
|
| 1299 |
+
super(BertForQuestionAnswering, self).__init__(config)
|
| 1300 |
+
self.num_labels = config.num_labels
|
| 1301 |
+
|
| 1302 |
+
self.bert = BertModel(config)
|
| 1303 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1304 |
+
|
| 1305 |
+
self.init_weights()
|
| 1306 |
+
|
| 1307 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
| 1308 |
+
start_positions=None, end_positions=None):
|
| 1309 |
+
|
| 1310 |
+
outputs = self.bert(input_ids,
|
| 1311 |
+
attention_mask=attention_mask,
|
| 1312 |
+
token_type_ids=token_type_ids,
|
| 1313 |
+
position_ids=position_ids,
|
| 1314 |
+
head_mask=head_mask)
|
| 1315 |
+
|
| 1316 |
+
sequence_output = outputs[0]
|
| 1317 |
+
|
| 1318 |
+
logits = self.qa_outputs(sequence_output)
|
| 1319 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1320 |
+
start_logits = start_logits.squeeze(-1)
|
| 1321 |
+
end_logits = end_logits.squeeze(-1)
|
| 1322 |
+
|
| 1323 |
+
outputs = (start_logits, end_logits,) + outputs[2:]
|
| 1324 |
+
if start_positions is not None and end_positions is not None:
|
| 1325 |
+
# If we are on multi-GPU, split add a dimension
|
| 1326 |
+
if len(start_positions.size()) > 1:
|
| 1327 |
+
start_positions = start_positions.squeeze(-1)
|
| 1328 |
+
if len(end_positions.size()) > 1:
|
| 1329 |
+
end_positions = end_positions.squeeze(-1)
|
| 1330 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1331 |
+
ignored_index = start_logits.size(1)
|
| 1332 |
+
start_positions.clamp_(0, ignored_index)
|
| 1333 |
+
end_positions.clamp_(0, ignored_index)
|
| 1334 |
+
|
| 1335 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1336 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1337 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1338 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1339 |
+
outputs = (total_loss,) + outputs
|
| 1340 |
+
|
| 1341 |
+
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
| 1342 |
+
|
| 1343 |
+
|
| 1344 |
+
############
|
| 1345 |
+
# XX Added #
|
| 1346 |
+
############
|
| 1347 |
+
|
| 1348 |
+
class BertForLatentConnector_XX(nn.Module):
|
| 1349 |
+
def __init__(self, config, latent_size):
|
| 1350 |
+
super().__init__()
|
| 1351 |
+
self.config = config
|
| 1352 |
+
self.embeddings = BertEmbeddings(config)
|
| 1353 |
+
self.encoder = BertEncoder(config)
|
| 1354 |
+
self.pooler = BertPooler(config)
|
| 1355 |
+
self.linear = nn.Linear(config.hidden_size, 2 * latent_size, bias=False)
|
| 1356 |
+
self.init_weights()
|
| 1357 |
+
|
| 1358 |
+
def init_weights(self):
|
| 1359 |
+
""" Initialize and prunes weights if needed. """
|
| 1360 |
+
# Initialize weights
|
| 1361 |
+
self.apply(self._init_weights)
|
| 1362 |
+
|
| 1363 |
+
# Prune heads if needed
|
| 1364 |
+
if self.config.pruned_heads:
|
| 1365 |
+
self.prune_heads(self.config.pruned_heads)
|
| 1366 |
+
|
| 1367 |
+
def _init_weights(self, module):
|
| 1368 |
+
""" Initialize the weights """
|
| 1369 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 1370 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 1371 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 1372 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1373 |
+
elif isinstance(module, BertLayerNorm):
|
| 1374 |
+
module.bias.data.zero_()
|
| 1375 |
+
module.weight.data.fill_(1.0)
|
| 1376 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 1377 |
+
module.bias.data.zero_()
|
| 1378 |
+
|
| 1379 |
+
def _resize_token_embeddings(self, new_num_tokens):
|
| 1380 |
+
old_embeddings = self.embeddings.word_embeddings
|
| 1381 |
+
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
| 1382 |
+
self.embeddings.word_embeddings = new_embeddings
|
| 1383 |
+
return self.embeddings.word_embeddings
|
| 1384 |
+
|
| 1385 |
+
def _prune_heads(self, heads_to_prune):
|
| 1386 |
+
""" Prunes heads of the model.
|
| 1387 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 1388 |
+
See base class PreTrainedModel
|
| 1389 |
+
"""
|
| 1390 |
+
for layer, heads in heads_to_prune.items():
|
| 1391 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 1392 |
+
|
| 1393 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
| 1394 |
+
if attention_mask is None:
|
| 1395 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1396 |
+
if token_type_ids is None:
|
| 1397 |
+
token_type_ids = torch.zeros_like(input_ids)
|
| 1398 |
+
|
| 1399 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 1400 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 1401 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 1402 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 1403 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 1404 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 1405 |
+
|
| 1406 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 1407 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 1408 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 1409 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 1410 |
+
# effectively the same as removing these entirely.
|
| 1411 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
| 1412 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 1413 |
+
|
| 1414 |
+
# Prepare head mask if needed
|
| 1415 |
+
# 1.0 in head_mask indicate we keep the head
|
| 1416 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 1417 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 1418 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 1419 |
+
if head_mask is not None:
|
| 1420 |
+
if head_mask.dim() == 1:
|
| 1421 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 1422 |
+
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
| 1423 |
+
elif head_mask.dim() == 2:
|
| 1424 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
| 1425 |
+
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
| 1426 |
+
else:
|
| 1427 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 1428 |
+
|
| 1429 |
+
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
| 1430 |
+
encoder_outputs = self.encoder(embedding_output,
|
| 1431 |
+
extended_attention_mask,
|
| 1432 |
+
head_mask=head_mask)
|
| 1433 |
+
sequence_output = encoder_outputs[0]
|
| 1434 |
+
pooled_output = self.pooler(sequence_output)
|
| 1435 |
+
|
| 1436 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
| 1437 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
| 1438 |
+
|
| 1439 |
+
|
versatile_diffusion/lib/model_zoo/optimus_models/optimus_gpt2.py
ADDED
|
@@ -0,0 +1,1122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch OpenAI GPT-2 model."""
|
| 17 |
+
|
| 18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 19 |
+
|
| 20 |
+
import pdb
|
| 21 |
+
|
| 22 |
+
import collections
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
import math
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
from io import open
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
from torch.nn import CrossEntropyLoss
|
| 33 |
+
from torch.nn.parameter import Parameter
|
| 34 |
+
|
| 35 |
+
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
|
| 36 |
+
from .configuration_gpt2 import GPT2Config
|
| 37 |
+
from .file_utils import add_start_docstrings
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
| 42 |
+
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
|
| 43 |
+
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"}
|
| 44 |
+
|
| 45 |
+
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
| 46 |
+
""" Load tf checkpoints in a pytorch model
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
import re
|
| 50 |
+
import numpy as np
|
| 51 |
+
import tensorflow as tf
|
| 52 |
+
except ImportError:
|
| 53 |
+
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 54 |
+
"https://www.tensorflow.org/install/ for installation instructions.")
|
| 55 |
+
raise
|
| 56 |
+
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
| 57 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
| 58 |
+
# Load weights from TF model
|
| 59 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 60 |
+
names = []
|
| 61 |
+
arrays = []
|
| 62 |
+
for name, shape in init_vars:
|
| 63 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
| 64 |
+
array = tf.train.load_variable(tf_path, name)
|
| 65 |
+
names.append(name)
|
| 66 |
+
arrays.append(array.squeeze())
|
| 67 |
+
|
| 68 |
+
for name, array in zip(names, arrays):
|
| 69 |
+
name = name[6:] # skip "model/"
|
| 70 |
+
name = name.split('/')
|
| 71 |
+
pointer = model
|
| 72 |
+
for m_name in name:
|
| 73 |
+
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
| 74 |
+
l = re.split(r'(\d+)', m_name)
|
| 75 |
+
else:
|
| 76 |
+
l = [m_name]
|
| 77 |
+
if l[0] == 'w' or l[0] == 'g':
|
| 78 |
+
pointer = getattr(pointer, 'weight')
|
| 79 |
+
elif l[0] == 'b':
|
| 80 |
+
pointer = getattr(pointer, 'bias')
|
| 81 |
+
elif l[0] == 'wpe' or l[0] == 'wte':
|
| 82 |
+
pointer = getattr(pointer, l[0])
|
| 83 |
+
pointer = getattr(pointer, 'weight')
|
| 84 |
+
else:
|
| 85 |
+
pointer = getattr(pointer, l[0])
|
| 86 |
+
if len(l) >= 2:
|
| 87 |
+
num = int(l[1])
|
| 88 |
+
pointer = pointer[num]
|
| 89 |
+
try:
|
| 90 |
+
assert pointer.shape == array.shape
|
| 91 |
+
except AssertionError as e:
|
| 92 |
+
e.args += (pointer.shape, array.shape)
|
| 93 |
+
raise
|
| 94 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
| 95 |
+
pointer.data = torch.from_numpy(array)
|
| 96 |
+
return model
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def gelu(x):
|
| 100 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Attention(nn.Module):
|
| 104 |
+
def __init__(self, nx, n_ctx, config, scale=False):
|
| 105 |
+
super(Attention, self).__init__()
|
| 106 |
+
self.output_attentions = config.output_attentions
|
| 107 |
+
|
| 108 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
| 109 |
+
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
| 110 |
+
assert n_state % config.n_head == 0
|
| 111 |
+
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
| 112 |
+
self.n_head = config.n_head
|
| 113 |
+
self.split_size = n_state
|
| 114 |
+
self.scale = scale
|
| 115 |
+
|
| 116 |
+
self.c_attn = Conv1D(n_state * 3, nx)
|
| 117 |
+
self.c_proj = Conv1D(n_state, nx)
|
| 118 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 119 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 120 |
+
self.pruned_heads = set()
|
| 121 |
+
|
| 122 |
+
def prune_heads(self, heads):
|
| 123 |
+
if len(heads) == 0:
|
| 124 |
+
return
|
| 125 |
+
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
| 126 |
+
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
| 127 |
+
for head in heads:
|
| 128 |
+
# Compute how many pruned heads are before the head and move the index accordingly
|
| 129 |
+
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
| 130 |
+
mask[head] = 0
|
| 131 |
+
mask = mask.view(-1).contiguous().eq(1)
|
| 132 |
+
index = torch.arange(len(mask))[mask].long()
|
| 133 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
|
| 134 |
+
|
| 135 |
+
# Prune conv1d layers
|
| 136 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
| 137 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
| 138 |
+
|
| 139 |
+
# Update hyper params
|
| 140 |
+
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
| 141 |
+
self.n_head = self.n_head - len(heads)
|
| 142 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 143 |
+
|
| 144 |
+
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
|
| 145 |
+
w = torch.matmul(q, k)
|
| 146 |
+
if self.scale:
|
| 147 |
+
w = w / math.sqrt(v.size(-1))
|
| 148 |
+
nd, ns = w.size(-2), w.size(-1)
|
| 149 |
+
b = self.bias[:, :, ns-nd:ns, :ns]
|
| 150 |
+
w = w * b - 1e4 * (1 - b)
|
| 151 |
+
|
| 152 |
+
if attention_mask is not None:
|
| 153 |
+
# Apply the attention mask
|
| 154 |
+
w = w + attention_mask
|
| 155 |
+
|
| 156 |
+
w = nn.Softmax(dim=-1)(w)
|
| 157 |
+
w = self.attn_dropout(w)
|
| 158 |
+
|
| 159 |
+
# Mask heads if we want to
|
| 160 |
+
if head_mask is not None:
|
| 161 |
+
w = w * head_mask
|
| 162 |
+
|
| 163 |
+
outputs = [torch.matmul(w, v)]
|
| 164 |
+
if self.output_attentions:
|
| 165 |
+
outputs.append(w)
|
| 166 |
+
return outputs
|
| 167 |
+
|
| 168 |
+
def merge_heads(self, x):
|
| 169 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 170 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
| 171 |
+
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
| 172 |
+
|
| 173 |
+
def split_heads(self, x, k=False):
|
| 174 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
| 175 |
+
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
| 176 |
+
if k:
|
| 177 |
+
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
| 178 |
+
else:
|
| 179 |
+
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
| 180 |
+
|
| 181 |
+
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
|
| 182 |
+
x = self.c_attn(x)
|
| 183 |
+
query, key, value = x.split(self.split_size, dim=2)
|
| 184 |
+
query = self.split_heads(query)
|
| 185 |
+
key = self.split_heads(key, k=True)
|
| 186 |
+
value = self.split_heads(value)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if layer_past is not None:
|
| 190 |
+
past_key, past_value = layer_past[0], layer_past[1] # transpose back cf below
|
| 191 |
+
|
| 192 |
+
past_key = self.split_heads(past_key, k=True)
|
| 193 |
+
past_value = self.split_heads(past_value)
|
| 194 |
+
# pdb.set_trace()
|
| 195 |
+
key = torch.cat((past_key, key), dim=-1)
|
| 196 |
+
value = torch.cat((past_value, value), dim=-2)
|
| 197 |
+
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
| 198 |
+
|
| 199 |
+
attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
|
| 200 |
+
a = attn_outputs[0]
|
| 201 |
+
|
| 202 |
+
a = self.merge_heads(a)
|
| 203 |
+
a = self.c_proj(a)
|
| 204 |
+
a = self.resid_dropout(a)
|
| 205 |
+
|
| 206 |
+
outputs = [a, present] + attn_outputs[1:]
|
| 207 |
+
return outputs # a, present, (attentions)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class MLP(nn.Module):
|
| 211 |
+
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
| 212 |
+
super(MLP, self).__init__()
|
| 213 |
+
nx = config.n_embd
|
| 214 |
+
self.c_fc = Conv1D(n_state, nx)
|
| 215 |
+
self.c_proj = Conv1D(nx, n_state)
|
| 216 |
+
self.act = gelu
|
| 217 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
h = self.act(self.c_fc(x))
|
| 221 |
+
h2 = self.c_proj(h)
|
| 222 |
+
return self.dropout(h2)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class Block(nn.Module):
|
| 226 |
+
def __init__(self, n_ctx, config, scale=False):
|
| 227 |
+
super(Block, self).__init__()
|
| 228 |
+
nx = config.n_embd
|
| 229 |
+
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
| 230 |
+
self.attn = Attention(nx, n_ctx, config, scale)
|
| 231 |
+
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
| 232 |
+
self.mlp = MLP(4 * nx, config)
|
| 233 |
+
|
| 234 |
+
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
|
| 235 |
+
output_attn = self.attn(self.ln_1(x),
|
| 236 |
+
layer_past=layer_past,
|
| 237 |
+
attention_mask=attention_mask,
|
| 238 |
+
head_mask=head_mask)
|
| 239 |
+
a = output_attn[0] # output_attn: a, present, (attentions)
|
| 240 |
+
|
| 241 |
+
x = x + a
|
| 242 |
+
m = self.mlp(self.ln_2(x))
|
| 243 |
+
x = x + m
|
| 244 |
+
|
| 245 |
+
outputs = [x] + output_attn[1:]
|
| 246 |
+
return outputs # x, present, (attentions)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class GPT2PreTrainedModel(PreTrainedModel):
|
| 250 |
+
""" An abstract class to handle weights initialization and
|
| 251 |
+
a simple interface for dowloading and loading pretrained models.
|
| 252 |
+
"""
|
| 253 |
+
config_class = GPT2Config
|
| 254 |
+
pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
| 255 |
+
load_tf_weights = load_tf_weights_in_gpt2
|
| 256 |
+
base_model_prefix = "transformer"
|
| 257 |
+
|
| 258 |
+
def __init__(self, *inputs, **kwargs):
|
| 259 |
+
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
|
| 260 |
+
|
| 261 |
+
def _init_weights(self, module):
|
| 262 |
+
""" Initialize the weights.
|
| 263 |
+
"""
|
| 264 |
+
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
| 265 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 266 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 267 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 268 |
+
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
| 269 |
+
module.bias.data.zero_()
|
| 270 |
+
elif isinstance(module, nn.LayerNorm):
|
| 271 |
+
module.bias.data.zero_()
|
| 272 |
+
module.weight.data.fill_(1.0)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
|
| 276 |
+
`Language Models are Unsupervised Multitask Learners`_
|
| 277 |
+
by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
|
| 278 |
+
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
|
| 279 |
+
corpus of ~40 GB of text data.
|
| 280 |
+
|
| 281 |
+
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
| 282 |
+
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
| 283 |
+
|
| 284 |
+
.. _`Language Models are Unsupervised Multitask Learners`:
|
| 285 |
+
https://openai.com/blog/better-language-models/
|
| 286 |
+
|
| 287 |
+
.. _`torch.nn.Module`:
|
| 288 |
+
https://pytorch.org/docs/stable/nn.html#module
|
| 289 |
+
|
| 290 |
+
Parameters:
|
| 291 |
+
config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
| 292 |
+
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
| 293 |
+
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
| 297 |
+
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 298 |
+
Indices of input sequence tokens in the vocabulary.
|
| 299 |
+
GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
| 300 |
+
the right rather than the left.
|
| 301 |
+
Indices can be obtained using :class:`pytorch_transformers.GPT2Tokenizer`.
|
| 302 |
+
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
| 303 |
+
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
| 304 |
+
**past**:
|
| 305 |
+
list of ``torch.FloatTensor`` (one for each layer):
|
| 306 |
+
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
| 307 |
+
(see `past` output below). Can be used to speed up sequential decoding.
|
| 308 |
+
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
| 309 |
+
Mask to avoid performing attention on padding token indices.
|
| 310 |
+
Mask values selected in ``[0, 1]``:
|
| 311 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
| 312 |
+
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 313 |
+
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
| 314 |
+
The embeddings from these tokens will be summed with the respective token embeddings.
|
| 315 |
+
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
| 316 |
+
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 317 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
| 318 |
+
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
| 319 |
+
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
| 320 |
+
Mask to nullify selected heads of the self-attention modules.
|
| 321 |
+
Mask values selected in ``[0, 1]``:
|
| 322 |
+
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
| 326 |
+
GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
| 327 |
+
class GPT2Model(GPT2PreTrainedModel):
|
| 328 |
+
r"""
|
| 329 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 330 |
+
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
| 331 |
+
Sequence of hidden-states at the last layer of the model.
|
| 332 |
+
**past**:
|
| 333 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 334 |
+
that contains pre-computed hidden-states (key and values in the attention blocks).
|
| 335 |
+
Can be used (see `past` input) to speed up sequential decoding.
|
| 336 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 337 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 338 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 339 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 340 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 341 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 342 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 343 |
+
|
| 344 |
+
Examples::
|
| 345 |
+
|
| 346 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 347 |
+
model = GPT2Model.from_pretrained('gpt2')
|
| 348 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 349 |
+
outputs = model(input_ids)
|
| 350 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
| 351 |
+
|
| 352 |
+
"""
|
| 353 |
+
def __init__(self, config):
|
| 354 |
+
super(GPT2Model, self).__init__(config)
|
| 355 |
+
self.output_hidden_states = config.output_hidden_states
|
| 356 |
+
self.output_attentions = config.output_attentions
|
| 357 |
+
|
| 358 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 359 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 360 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 361 |
+
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
| 362 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 363 |
+
|
| 364 |
+
try:
|
| 365 |
+
self.latent_size = config.latent_size
|
| 366 |
+
except:
|
| 367 |
+
self.latent_size = 32 # default size is 32
|
| 368 |
+
|
| 369 |
+
self.linear = nn.Linear(self.latent_size, config.hidden_size * config.n_layer, bias=False) # different latent vector for each layer
|
| 370 |
+
self.linear_emb = nn.Linear(self.latent_size, config.hidden_size, bias=False) # share the same latent vector as the embeddings
|
| 371 |
+
|
| 372 |
+
self.config = config
|
| 373 |
+
self.init_weights()
|
| 374 |
+
|
| 375 |
+
def _resize_token_embeddings(self, new_num_tokens):
|
| 376 |
+
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
|
| 377 |
+
return self.wte
|
| 378 |
+
|
| 379 |
+
def _prune_heads(self, heads_to_prune):
|
| 380 |
+
""" Prunes heads of the model.
|
| 381 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 382 |
+
"""
|
| 383 |
+
for layer, heads in heads_to_prune.items():
|
| 384 |
+
self.h[layer].attn.prune_heads(heads)
|
| 385 |
+
|
| 386 |
+
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, latent_as_gpt_emb=False, latent_as_gpt_memory=True):
|
| 387 |
+
|
| 388 |
+
if past is None:
|
| 389 |
+
past_length = 0
|
| 390 |
+
past = [None] * len(self.h)
|
| 391 |
+
else:
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
if latent_as_gpt_emb:
|
| 395 |
+
past_emb = self.linear_emb(past) # used as embeddings to add on other three embeddings
|
| 396 |
+
|
| 397 |
+
if latent_as_gpt_memory:
|
| 398 |
+
past = self.linear(past)
|
| 399 |
+
share_latent = False
|
| 400 |
+
if share_latent:
|
| 401 |
+
# the same latent vector shared by all layers
|
| 402 |
+
past = [past.unsqueeze(-2), past.unsqueeze(-2)] # query, key
|
| 403 |
+
past = [past] * len(self.h)
|
| 404 |
+
past_length = past[0][0].size(-2)
|
| 405 |
+
else:
|
| 406 |
+
# different latent vectors for each layer
|
| 407 |
+
past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2)
|
| 408 |
+
past = list(zip(past_split,past_split))
|
| 409 |
+
|
| 410 |
+
# past = past.view(batch_size,len(self.h),-1)
|
| 411 |
+
# past = [[past[:,i,:].unsqueeze(-2), past[:,i,:].unsqueeze(-2) ] for i in range(len(self.h))]
|
| 412 |
+
past_length = 1 # past[0][0].size(-2)
|
| 413 |
+
else:
|
| 414 |
+
past_length = 0
|
| 415 |
+
past = [None] * len(self.h)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
if position_ids is None:
|
| 419 |
+
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
| 420 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
# Attention mask.
|
| 424 |
+
if attention_mask is not None:
|
| 425 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 426 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 427 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 428 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 429 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 430 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 431 |
+
|
| 432 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 433 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 434 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 435 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 436 |
+
# effectively the same as removing these entirely.
|
| 437 |
+
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
| 438 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
| 439 |
+
|
| 440 |
+
# Prepare head mask if needed
|
| 441 |
+
# 1.0 in head_mask indicate we keep the head
|
| 442 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 443 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
| 444 |
+
if head_mask is not None:
|
| 445 |
+
if head_mask.dim() == 1:
|
| 446 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 447 |
+
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
|
| 448 |
+
elif head_mask.dim() == 2:
|
| 449 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
| 450 |
+
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
| 451 |
+
else:
|
| 452 |
+
head_mask = [None] * self.config.n_layer
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
input_shape = input_ids.size()
|
| 456 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
| 457 |
+
position_ids = position_ids.view(-1, position_ids.size(-1))
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
inputs_embeds = self.wte(input_ids)
|
| 461 |
+
position_embeds = self.wpe(position_ids)
|
| 462 |
+
if token_type_ids is not None:
|
| 463 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
| 464 |
+
token_type_embeds = self.wte(token_type_ids)
|
| 465 |
+
else:
|
| 466 |
+
token_type_embeds = 0
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
| 470 |
+
if latent_as_gpt_emb:
|
| 471 |
+
# pdb.set_trace()
|
| 472 |
+
hidden_states = hidden_states + past_emb.unsqueeze(1)
|
| 473 |
+
|
| 474 |
+
hidden_states = self.drop(hidden_states)
|
| 475 |
+
|
| 476 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 477 |
+
|
| 478 |
+
presents = ()
|
| 479 |
+
all_attentions = []
|
| 480 |
+
all_hidden_states = ()
|
| 481 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
| 482 |
+
if self.output_hidden_states:
|
| 483 |
+
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
outputs = block(hidden_states,
|
| 487 |
+
layer_past=layer_past,
|
| 488 |
+
attention_mask=attention_mask,
|
| 489 |
+
head_mask=head_mask[i])
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
hidden_states, present = outputs[:2]
|
| 493 |
+
presents = presents + (present,)
|
| 494 |
+
|
| 495 |
+
if self.output_attentions:
|
| 496 |
+
all_attentions.append(outputs[2])
|
| 497 |
+
|
| 498 |
+
hidden_states = self.ln_f(hidden_states)
|
| 499 |
+
|
| 500 |
+
hidden_states = hidden_states.view(*output_shape)
|
| 501 |
+
# Add last hidden state
|
| 502 |
+
if self.output_hidden_states:
|
| 503 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 504 |
+
|
| 505 |
+
outputs = (hidden_states, presents)
|
| 506 |
+
if self.output_hidden_states:
|
| 507 |
+
outputs = outputs + (all_hidden_states,)
|
| 508 |
+
if self.output_attentions:
|
| 509 |
+
# let the number of heads free (-1) so we can extract attention even after head pruning
|
| 510 |
+
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
|
| 511 |
+
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
|
| 512 |
+
outputs = outputs + (all_attentions,)
|
| 513 |
+
return outputs # last hidden state, presents, (all hidden_states), (attentions)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
|
| 517 |
+
(linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
| 518 |
+
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
| 519 |
+
r"""
|
| 520 |
+
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 521 |
+
Labels for language modeling.
|
| 522 |
+
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
| 523 |
+
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
| 524 |
+
All labels set to ``-1`` are ignored (masked), the loss is only
|
| 525 |
+
computed for labels in ``[0, ..., config.vocab_size]``
|
| 526 |
+
|
| 527 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 528 |
+
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 529 |
+
Language modeling loss.
|
| 530 |
+
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
| 531 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 532 |
+
**past**:
|
| 533 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 534 |
+
that contains pre-computed hidden-states (key and values in the attention blocks).
|
| 535 |
+
Can be used (see `past` input) to speed up sequential decoding.
|
| 536 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 537 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 538 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 539 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 540 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 541 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 542 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 543 |
+
|
| 544 |
+
Examples::
|
| 545 |
+
|
| 546 |
+
import torch
|
| 547 |
+
from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 548 |
+
|
| 549 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 550 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 551 |
+
|
| 552 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 553 |
+
outputs = model(input_ids, labels=input_ids)
|
| 554 |
+
loss, logits = outputs[:2]
|
| 555 |
+
|
| 556 |
+
"""
|
| 557 |
+
def __init__(self, config):
|
| 558 |
+
super(GPT2LMHeadModel, self).__init__(config)
|
| 559 |
+
self.transformer = GPT2Model(config)
|
| 560 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 561 |
+
|
| 562 |
+
self.init_weights()
|
| 563 |
+
self.tie_weights()
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def tie_weights(self):
|
| 567 |
+
""" Make sure we are sharing the input and output embeddings.
|
| 568 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
| 569 |
+
"""
|
| 570 |
+
self._tie_or_clone_weights(self.lm_head,
|
| 571 |
+
self.transformer.wte)
|
| 572 |
+
|
| 573 |
+
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
| 574 |
+
labels=None, label_ignore=None):
|
| 575 |
+
transformer_outputs = self.transformer(input_ids,
|
| 576 |
+
past=past,
|
| 577 |
+
attention_mask=attention_mask,
|
| 578 |
+
token_type_ids=token_type_ids,
|
| 579 |
+
position_ids=position_ids,
|
| 580 |
+
head_mask=head_mask)
|
| 581 |
+
hidden_states = transformer_outputs[0]
|
| 582 |
+
|
| 583 |
+
lm_logits = self.lm_head(hidden_states)
|
| 584 |
+
|
| 585 |
+
outputs = (lm_logits,) + transformer_outputs[1:]
|
| 586 |
+
if labels is not None:
|
| 587 |
+
# Shift so that tokens < n predict n
|
| 588 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 589 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 590 |
+
# Flatten the tokens
|
| 591 |
+
loss_fct = CrossEntropyLoss(ignore_index=label_ignore, reduce=False) # 50258 is the padding id, otherwise -1 is used for masked LM.
|
| 592 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
| 593 |
+
shift_labels.view(-1))
|
| 594 |
+
loss = torch.sum(loss.view(-1, shift_labels.shape[-1]), -1)
|
| 595 |
+
outputs = (loss,) + outputs
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
|
| 603 |
+
(linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
| 604 |
+
class GPT2ForLatentConnector(GPT2PreTrainedModel):
|
| 605 |
+
r"""
|
| 606 |
+
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 607 |
+
Labels for language modeling.
|
| 608 |
+
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
| 609 |
+
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
| 610 |
+
All labels set to ``-1`` are ignored (masked), the loss is only
|
| 611 |
+
computed for labels in ``[0, ..., config.vocab_size]``
|
| 612 |
+
|
| 613 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 614 |
+
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 615 |
+
Language modeling loss.
|
| 616 |
+
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
| 617 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 618 |
+
**past**:
|
| 619 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 620 |
+
that contains pre-computed hidden-states (key and values in the attention blocks).
|
| 621 |
+
Can be used (see `past` input) to speed up sequential decoding.
|
| 622 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 623 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 624 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 625 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 626 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 627 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 628 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 629 |
+
|
| 630 |
+
Examples::
|
| 631 |
+
|
| 632 |
+
import torch
|
| 633 |
+
from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 634 |
+
|
| 635 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 636 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 637 |
+
|
| 638 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
| 639 |
+
outputs = model(input_ids, labels=input_ids)
|
| 640 |
+
loss, logits = outputs[:2]
|
| 641 |
+
|
| 642 |
+
"""
|
| 643 |
+
def __init__(self, config, latent_size=32, latent_as_gpt_emb=True, latent_as_gpt_memory=True):
|
| 644 |
+
|
| 645 |
+
super(GPT2ForLatentConnector, self).__init__(config)
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
self.transformer = GPT2Model(config)
|
| 649 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 650 |
+
|
| 651 |
+
self.init_weights()
|
| 652 |
+
self.tie_weights()
|
| 653 |
+
|
| 654 |
+
self.latent_as_gpt_emb = latent_as_gpt_emb
|
| 655 |
+
self.latent_as_gpt_memory = latent_as_gpt_memory
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def tie_weights(self):
|
| 660 |
+
""" Make sure we are sharing the input and output embeddings.
|
| 661 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
| 662 |
+
"""
|
| 663 |
+
self._tie_or_clone_weights(self.lm_head,
|
| 664 |
+
self.transformer.wte)
|
| 665 |
+
|
| 666 |
+
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
| 667 |
+
labels=None, label_ignore=None):
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
transformer_outputs = self.transformer(input_ids,
|
| 671 |
+
past=past,
|
| 672 |
+
attention_mask=attention_mask,
|
| 673 |
+
token_type_ids=token_type_ids,
|
| 674 |
+
position_ids=position_ids,
|
| 675 |
+
head_mask=head_mask,
|
| 676 |
+
latent_as_gpt_emb=self.latent_as_gpt_emb,
|
| 677 |
+
latent_as_gpt_memory=self.latent_as_gpt_memory)
|
| 678 |
+
hidden_states = transformer_outputs[0]
|
| 679 |
+
|
| 680 |
+
lm_logits = self.lm_head(hidden_states)
|
| 681 |
+
|
| 682 |
+
outputs = (lm_logits,) + transformer_outputs[1:]
|
| 683 |
+
if labels is not None:
|
| 684 |
+
# Shift so that tokens < n predict n
|
| 685 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 686 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 687 |
+
# Flatten the tokens
|
| 688 |
+
loss_fct = CrossEntropyLoss(ignore_index=label_ignore, reduce=False) # 50258 is the padding id, otherwise -1 is used for masked LM.
|
| 689 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
| 690 |
+
shift_labels.view(-1))
|
| 691 |
+
loss = torch.sum(loss.view(-1, shift_labels.shape[-1]), -1)
|
| 692 |
+
outputs = (loss,) + outputs
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
|
| 696 |
+
|
| 697 |
+
@add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
|
| 698 |
+
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
| 699 |
+
The language modeling head has its weights tied to the input embeddings,
|
| 700 |
+
the classification head takes as input the input of a specified classification token index in the input sequence).
|
| 701 |
+
""", GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
| 702 |
+
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
| 703 |
+
r"""
|
| 704 |
+
**mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
|
| 705 |
+
Index of the classification token in each input sequence.
|
| 706 |
+
Selected in the range ``[0, input_ids.size(-1) - 1[``.
|
| 707 |
+
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
| 708 |
+
Labels for language modeling.
|
| 709 |
+
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
| 710 |
+
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
| 711 |
+
All labels set to ``-1`` are ignored (masked), the loss is only
|
| 712 |
+
computed for labels in ``[0, ..., config.vocab_size]``
|
| 713 |
+
**mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
|
| 714 |
+
Labels for computing the multiple choice classification loss.
|
| 715 |
+
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
| 716 |
+
of the input tensors. (see `input_ids` above)
|
| 717 |
+
|
| 718 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 719 |
+
**lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 720 |
+
Language modeling loss.
|
| 721 |
+
**mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 722 |
+
Multiple choice classification loss.
|
| 723 |
+
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
|
| 724 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 725 |
+
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
|
| 726 |
+
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
|
| 727 |
+
**past**:
|
| 728 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 729 |
+
that contains pre-computed hidden-states (key and values in the attention blocks).
|
| 730 |
+
Can be used (see `past` input) to speed up sequential decoding.
|
| 731 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
| 732 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
| 733 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
| 734 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 735 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
| 736 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
| 737 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
| 738 |
+
|
| 739 |
+
Examples::
|
| 740 |
+
|
| 741 |
+
import torch
|
| 742 |
+
from pytorch_transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
|
| 743 |
+
|
| 744 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 745 |
+
model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
|
| 746 |
+
|
| 747 |
+
# Add a [CLS] to the vocabulary (we should train it also!)
|
| 748 |
+
tokenizer.add_special_tokens({'cls_token': '[CLS]'})
|
| 749 |
+
model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
|
| 750 |
+
print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
|
| 751 |
+
|
| 752 |
+
choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
| 753 |
+
encoded_choices = [tokenizer.encode(s) for s in choices]
|
| 754 |
+
cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
|
| 755 |
+
|
| 756 |
+
input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
|
| 757 |
+
mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
|
| 758 |
+
|
| 759 |
+
outputs = model(input_ids, mc_token_ids=mc_token_ids)
|
| 760 |
+
lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
| 761 |
+
|
| 762 |
+
"""
|
| 763 |
+
def __init__(self, config):
|
| 764 |
+
super(GPT2DoubleHeadsModel, self).__init__(config)
|
| 765 |
+
self.transformer = GPT2Model(config)
|
| 766 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 767 |
+
self.multiple_choice_head = SequenceSummary(config)
|
| 768 |
+
|
| 769 |
+
self.init_weights()
|
| 770 |
+
self.tie_weights()
|
| 771 |
+
|
| 772 |
+
def tie_weights(self):
|
| 773 |
+
""" Make sure we are sharing the input and output embeddings.
|
| 774 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
| 775 |
+
"""
|
| 776 |
+
self._tie_or_clone_weights(self.lm_head,
|
| 777 |
+
self.transformer.wte)
|
| 778 |
+
|
| 779 |
+
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
| 780 |
+
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
| 781 |
+
transformer_outputs = self.transformer(input_ids,
|
| 782 |
+
past=past,
|
| 783 |
+
attention_mask=attention_mask,
|
| 784 |
+
token_type_ids=token_type_ids,
|
| 785 |
+
position_ids=position_ids,
|
| 786 |
+
head_mask=head_mask)
|
| 787 |
+
|
| 788 |
+
hidden_states = transformer_outputs[0]
|
| 789 |
+
|
| 790 |
+
lm_logits = self.lm_head(hidden_states)
|
| 791 |
+
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
| 792 |
+
|
| 793 |
+
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
| 794 |
+
if mc_labels is not None:
|
| 795 |
+
loss_fct = CrossEntropyLoss()
|
| 796 |
+
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
|
| 797 |
+
mc_labels.view(-1))
|
| 798 |
+
outputs = (loss,) + outputs
|
| 799 |
+
if lm_labels is not None:
|
| 800 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 801 |
+
shift_labels = lm_labels[..., 1:].contiguous()
|
| 802 |
+
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
| 803 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
| 804 |
+
shift_labels.view(-1))
|
| 805 |
+
outputs = (loss,) + outputs
|
| 806 |
+
|
| 807 |
+
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
|
| 808 |
+
|
| 809 |
+
############
|
| 810 |
+
# XX Added #
|
| 811 |
+
############
|
| 812 |
+
|
| 813 |
+
class GPT2Model_XX(nn.Module):
|
| 814 |
+
def __init__(self, config):
|
| 815 |
+
super().__init__()
|
| 816 |
+
self.config = config
|
| 817 |
+
self.output_hidden_states = config.output_hidden_states
|
| 818 |
+
self.output_attentions = config.output_attentions
|
| 819 |
+
|
| 820 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 821 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 822 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 823 |
+
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
| 824 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 825 |
+
|
| 826 |
+
try:
|
| 827 |
+
self.latent_size = config.latent_size
|
| 828 |
+
except:
|
| 829 |
+
self.latent_size = 32 # default size is 32
|
| 830 |
+
|
| 831 |
+
self.linear = nn.Linear(self.latent_size, config.hidden_size * config.n_layer, bias=False) # different latent vector for each layer
|
| 832 |
+
self.linear_emb = nn.Linear(self.latent_size, config.hidden_size, bias=False) # share the same latent vector as the embeddings
|
| 833 |
+
|
| 834 |
+
self.config = config
|
| 835 |
+
self.init_weights()
|
| 836 |
+
|
| 837 |
+
def init_weights(self):
|
| 838 |
+
""" Initialize and prunes weights if needed. """
|
| 839 |
+
# Initialize weights
|
| 840 |
+
self.apply(self._init_weights)
|
| 841 |
+
|
| 842 |
+
# Prune heads if needed
|
| 843 |
+
if self.config.pruned_heads:
|
| 844 |
+
self.prune_heads(self.config.pruned_heads)
|
| 845 |
+
|
| 846 |
+
def _init_weights(self, module):
|
| 847 |
+
""" Initialize the weights.
|
| 848 |
+
"""
|
| 849 |
+
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
| 850 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 851 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 852 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 853 |
+
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
| 854 |
+
module.bias.data.zero_()
|
| 855 |
+
elif isinstance(module, nn.LayerNorm):
|
| 856 |
+
module.bias.data.zero_()
|
| 857 |
+
module.weight.data.fill_(1.0)
|
| 858 |
+
|
| 859 |
+
def _resize_token_embeddings(self, new_num_tokens):
|
| 860 |
+
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
|
| 861 |
+
return self.wte
|
| 862 |
+
|
| 863 |
+
def _prune_heads(self, heads_to_prune):
|
| 864 |
+
""" Prunes heads of the model.
|
| 865 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 866 |
+
"""
|
| 867 |
+
for layer, heads in heads_to_prune.items():
|
| 868 |
+
self.h[layer].attn.prune_heads(heads)
|
| 869 |
+
|
| 870 |
+
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, latent_as_gpt_emb=False, latent_as_gpt_memory=True):
|
| 871 |
+
if past is None:
|
| 872 |
+
past_length = 0
|
| 873 |
+
past = [None] * len(self.h)
|
| 874 |
+
else:
|
| 875 |
+
if latent_as_gpt_emb:
|
| 876 |
+
past_emb = self.linear_emb(past) # used as embeddings to add on other three embeddings
|
| 877 |
+
|
| 878 |
+
if latent_as_gpt_memory:
|
| 879 |
+
past = self.linear(past)
|
| 880 |
+
share_latent = False
|
| 881 |
+
if share_latent:
|
| 882 |
+
# the same latent vector shared by all layers
|
| 883 |
+
past = [past.unsqueeze(-2), past.unsqueeze(-2)] # query, key
|
| 884 |
+
past = [past] * len(self.h)
|
| 885 |
+
past_length = past[0][0].size(-2)
|
| 886 |
+
else:
|
| 887 |
+
# different latent vectors for each layer
|
| 888 |
+
past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2)
|
| 889 |
+
past = list(zip(past_split,past_split))
|
| 890 |
+
|
| 891 |
+
# past = past.view(batch_size,len(self.h),-1)
|
| 892 |
+
# past = [[past[:,i,:].unsqueeze(-2), past[:,i,:].unsqueeze(-2) ] for i in range(len(self.h))]
|
| 893 |
+
past_length = 1 # past[0][0].size(-2)
|
| 894 |
+
else:
|
| 895 |
+
past_length = 0
|
| 896 |
+
past = [None] * len(self.h)
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
if position_ids is None:
|
| 900 |
+
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
| 901 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
# Attention mask.
|
| 905 |
+
if attention_mask is not None:
|
| 906 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 907 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 908 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 909 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 910 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 911 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 912 |
+
|
| 913 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 914 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 915 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 916 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 917 |
+
# effectively the same as removing these entirely.
|
| 918 |
+
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
| 919 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
| 920 |
+
|
| 921 |
+
# Prepare head mask if needed
|
| 922 |
+
# 1.0 in head_mask indicate we keep the head
|
| 923 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 924 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
| 925 |
+
if head_mask is not None:
|
| 926 |
+
if head_mask.dim() == 1:
|
| 927 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 928 |
+
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
|
| 929 |
+
elif head_mask.dim() == 2:
|
| 930 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
| 931 |
+
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
| 932 |
+
else:
|
| 933 |
+
head_mask = [None] * self.config.n_layer
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
input_shape = input_ids.size()
|
| 937 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
| 938 |
+
position_ids = position_ids.view(-1, position_ids.size(-1))
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
inputs_embeds = self.wte(input_ids)
|
| 942 |
+
position_embeds = self.wpe(position_ids)
|
| 943 |
+
if token_type_ids is not None:
|
| 944 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
| 945 |
+
token_type_embeds = self.wte(token_type_ids)
|
| 946 |
+
else:
|
| 947 |
+
token_type_embeds = 0
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
| 951 |
+
if latent_as_gpt_emb:
|
| 952 |
+
# pdb.set_trace()
|
| 953 |
+
hidden_states = hidden_states + past_emb.unsqueeze(1)
|
| 954 |
+
|
| 955 |
+
hidden_states = self.drop(hidden_states)
|
| 956 |
+
|
| 957 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 958 |
+
|
| 959 |
+
presents = ()
|
| 960 |
+
all_attentions = []
|
| 961 |
+
all_hidden_states = ()
|
| 962 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
| 963 |
+
if self.output_hidden_states:
|
| 964 |
+
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
outputs = block(hidden_states,
|
| 968 |
+
layer_past=layer_past,
|
| 969 |
+
attention_mask=attention_mask,
|
| 970 |
+
head_mask=head_mask[i])
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
hidden_states, present = outputs[:2]
|
| 974 |
+
presents = presents + (present,)
|
| 975 |
+
|
| 976 |
+
if self.output_attentions:
|
| 977 |
+
all_attentions.append(outputs[2])
|
| 978 |
+
|
| 979 |
+
hidden_states = self.ln_f(hidden_states)
|
| 980 |
+
|
| 981 |
+
hidden_states = hidden_states.view(*output_shape)
|
| 982 |
+
# Add last hidden state
|
| 983 |
+
if self.output_hidden_states:
|
| 984 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 985 |
+
|
| 986 |
+
outputs = (hidden_states, presents)
|
| 987 |
+
if self.output_hidden_states:
|
| 988 |
+
outputs = outputs + (all_hidden_states,)
|
| 989 |
+
if self.output_attentions:
|
| 990 |
+
# let the number of heads free (-1) so we can extract attention even after head pruning
|
| 991 |
+
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
|
| 992 |
+
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
|
| 993 |
+
outputs = outputs + (all_attentions,)
|
| 994 |
+
return outputs # last hidden state, presents, (all hidden_states), (attentions)
|
| 995 |
+
|
| 996 |
+
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
| 997 |
+
""" Build a resized Embedding Module from a provided token Embedding Module.
|
| 998 |
+
Increasing the size will add newly initialized vectors at the end
|
| 999 |
+
Reducing the size will remove vectors from the end
|
| 1000 |
+
|
| 1001 |
+
Args:
|
| 1002 |
+
new_num_tokens: (`optional`) int
|
| 1003 |
+
New number of tokens in the embedding matrix.
|
| 1004 |
+
Increasing the size will add newly initialized vectors at the end
|
| 1005 |
+
Reducing the size will remove vectors from the end
|
| 1006 |
+
If not provided or None: return the provided token Embedding Module.
|
| 1007 |
+
Return: ``torch.nn.Embeddings``
|
| 1008 |
+
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
| 1009 |
+
"""
|
| 1010 |
+
if new_num_tokens is None:
|
| 1011 |
+
return old_embeddings
|
| 1012 |
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
| 1013 |
+
if old_num_tokens == new_num_tokens:
|
| 1014 |
+
return old_embeddings
|
| 1015 |
+
# Build new embeddings
|
| 1016 |
+
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
| 1017 |
+
new_embeddings.to(old_embeddings.weight.device)
|
| 1018 |
+
# initialize all new embeddings (in particular added tokens)
|
| 1019 |
+
self._init_weights(new_embeddings)
|
| 1020 |
+
# Copy word embeddings from the previous weights
|
| 1021 |
+
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
| 1022 |
+
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
| 1023 |
+
return new_embeddings
|
| 1024 |
+
|
| 1025 |
+
class GPT2ForLatentConnector_XX(nn.Module):
|
| 1026 |
+
def __init__(self,
|
| 1027 |
+
config,
|
| 1028 |
+
latent_size=32,
|
| 1029 |
+
latent_as_gpt_emb=True,
|
| 1030 |
+
latent_as_gpt_memory=True):
|
| 1031 |
+
|
| 1032 |
+
super().__init__()
|
| 1033 |
+
self.config = config
|
| 1034 |
+
self.transformer = GPT2Model_XX(config)
|
| 1035 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 1036 |
+
self.init_weights()
|
| 1037 |
+
self.tie_weights()
|
| 1038 |
+
self.latent_as_gpt_emb = latent_as_gpt_emb
|
| 1039 |
+
self.latent_as_gpt_memory = latent_as_gpt_memory
|
| 1040 |
+
|
| 1041 |
+
def init_weights(self):
|
| 1042 |
+
""" Initialize and prunes weights if needed. """
|
| 1043 |
+
# Initialize weights
|
| 1044 |
+
self.apply(self._init_weights)
|
| 1045 |
+
|
| 1046 |
+
# Prune heads if needed
|
| 1047 |
+
if self.config.pruned_heads:
|
| 1048 |
+
self.prune_heads(self.config.pruned_heads)
|
| 1049 |
+
|
| 1050 |
+
def _init_weights(self, module):
|
| 1051 |
+
""" Initialize the weights.
|
| 1052 |
+
"""
|
| 1053 |
+
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
| 1054 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 1055 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 1056 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1057 |
+
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
| 1058 |
+
module.bias.data.zero_()
|
| 1059 |
+
elif isinstance(module, nn.LayerNorm):
|
| 1060 |
+
module.bias.data.zero_()
|
| 1061 |
+
module.weight.data.fill_(1.0)
|
| 1062 |
+
|
| 1063 |
+
def _tie_or_clone_weights(self, first_module, second_module):
|
| 1064 |
+
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
| 1065 |
+
"""
|
| 1066 |
+
if self.config.torchscript:
|
| 1067 |
+
first_module.weight = nn.Parameter(second_module.weight.clone())
|
| 1068 |
+
else:
|
| 1069 |
+
first_module.weight = second_module.weight
|
| 1070 |
+
|
| 1071 |
+
if hasattr(first_module, 'bias') and first_module.bias is not None:
|
| 1072 |
+
first_module.bias.data = torch.nn.functional.pad(
|
| 1073 |
+
first_module.bias.data,
|
| 1074 |
+
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
|
| 1075 |
+
'constant', 0,)
|
| 1076 |
+
|
| 1077 |
+
def tie_weights(self):
|
| 1078 |
+
""" Make sure we are sharing the input and output embeddings.
|
| 1079 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
| 1080 |
+
"""
|
| 1081 |
+
self._tie_or_clone_weights(self.lm_head,
|
| 1082 |
+
self.transformer.wte)
|
| 1083 |
+
|
| 1084 |
+
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
| 1085 |
+
labels=None, label_ignore=None):
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
transformer_outputs = self.transformer(input_ids,
|
| 1089 |
+
past=past,
|
| 1090 |
+
attention_mask=attention_mask,
|
| 1091 |
+
token_type_ids=token_type_ids,
|
| 1092 |
+
position_ids=position_ids,
|
| 1093 |
+
head_mask=head_mask,
|
| 1094 |
+
latent_as_gpt_emb=self.latent_as_gpt_emb,
|
| 1095 |
+
latent_as_gpt_memory=self.latent_as_gpt_memory)
|
| 1096 |
+
hidden_states = transformer_outputs[0]
|
| 1097 |
+
|
| 1098 |
+
lm_logits = self.lm_head(hidden_states)
|
| 1099 |
+
|
| 1100 |
+
outputs = (lm_logits,) + transformer_outputs[1:]
|
| 1101 |
+
if labels is not None:
|
| 1102 |
+
# Shift so that tokens < n predict n
|
| 1103 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 1104 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1105 |
+
# Flatten the tokens
|
| 1106 |
+
loss_fct = CrossEntropyLoss(ignore_index=label_ignore, reduce=False) # 50258 is the padding id, otherwise -1 is used for masked LM.
|
| 1107 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
| 1108 |
+
shift_labels.view(-1))
|
| 1109 |
+
loss = torch.sum(loss.view(-1, shift_labels.shape[-1]), -1)
|
| 1110 |
+
outputs = (loss,) + outputs
|
| 1111 |
+
|
| 1112 |
+
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
|
| 1113 |
+
|
| 1114 |
+
def resize_token_embeddings(self, new_num_tokens=None):
|
| 1115 |
+
model_embeds = self.transformer._resize_token_embeddings(new_num_tokens)
|
| 1116 |
+
if new_num_tokens is None:
|
| 1117 |
+
return model_embeds
|
| 1118 |
+
self.config.vocab_size = new_num_tokens
|
| 1119 |
+
self.transformer.vocab_size = new_num_tokens
|
| 1120 |
+
if hasattr(self, 'tie_weights'):
|
| 1121 |
+
self.tie_weights()
|
| 1122 |
+
return model_embeds
|
versatile_diffusion/lib/model_zoo/optimus_models/tokenization_bert.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes."""
|
| 16 |
+
|
| 17 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 18 |
+
|
| 19 |
+
import collections
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import unicodedata
|
| 23 |
+
from io import open
|
| 24 |
+
|
| 25 |
+
from .tokenization_utils import PreTrainedTokenizer
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
|
| 30 |
+
|
| 31 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
| 32 |
+
'vocab_file':
|
| 33 |
+
{
|
| 34 |
+
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
| 35 |
+
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
| 36 |
+
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
| 37 |
+
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
| 38 |
+
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
| 39 |
+
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
| 40 |
+
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
| 41 |
+
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
|
| 42 |
+
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
|
| 43 |
+
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
|
| 44 |
+
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
|
| 45 |
+
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
|
| 46 |
+
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
| 51 |
+
'bert-base-uncased': 512,
|
| 52 |
+
'bert-large-uncased': 512,
|
| 53 |
+
'bert-base-cased': 512,
|
| 54 |
+
'bert-large-cased': 512,
|
| 55 |
+
'bert-base-multilingual-uncased': 512,
|
| 56 |
+
'bert-base-multilingual-cased': 512,
|
| 57 |
+
'bert-base-chinese': 512,
|
| 58 |
+
'bert-base-german-cased': 512,
|
| 59 |
+
'bert-large-uncased-whole-word-masking': 512,
|
| 60 |
+
'bert-large-cased-whole-word-masking': 512,
|
| 61 |
+
'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
|
| 62 |
+
'bert-large-cased-whole-word-masking-finetuned-squad': 512,
|
| 63 |
+
'bert-base-cased-finetuned-mrpc': 512,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
PRETRAINED_INIT_CONFIGURATION = {
|
| 67 |
+
'bert-base-uncased': {'do_lower_case': True},
|
| 68 |
+
'bert-large-uncased': {'do_lower_case': True},
|
| 69 |
+
'bert-base-cased': {'do_lower_case': False},
|
| 70 |
+
'bert-large-cased': {'do_lower_case': False},
|
| 71 |
+
'bert-base-multilingual-uncased': {'do_lower_case': True},
|
| 72 |
+
'bert-base-multilingual-cased': {'do_lower_case': False},
|
| 73 |
+
'bert-base-chinese': {'do_lower_case': False},
|
| 74 |
+
'bert-base-german-cased': {'do_lower_case': False},
|
| 75 |
+
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
|
| 76 |
+
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
|
| 77 |
+
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
|
| 78 |
+
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
|
| 79 |
+
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_vocab(vocab_file):
|
| 84 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 85 |
+
vocab = collections.OrderedDict()
|
| 86 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 87 |
+
tokens = reader.readlines()
|
| 88 |
+
for index, token in enumerate(tokens):
|
| 89 |
+
token = token.rstrip('\n')
|
| 90 |
+
vocab[token] = index
|
| 91 |
+
return vocab
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def whitespace_tokenize(text):
|
| 95 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
| 96 |
+
text = text.strip()
|
| 97 |
+
if not text:
|
| 98 |
+
return []
|
| 99 |
+
tokens = text.split()
|
| 100 |
+
return tokens
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class BertTokenizer(PreTrainedTokenizer):
|
| 104 |
+
r"""
|
| 105 |
+
Constructs a BertTokenizer.
|
| 106 |
+
:class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
| 110 |
+
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
|
| 111 |
+
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
| 112 |
+
max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
|
| 113 |
+
minimum of this value (if specified) and the underlying BERT model's sequence length.
|
| 114 |
+
never_split: List of tokens which will never be split during tokenization. Only has an effect when
|
| 115 |
+
do_wordpiece_only=False
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 119 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
| 120 |
+
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
| 121 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
| 122 |
+
|
| 123 |
+
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
| 124 |
+
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
| 125 |
+
mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
|
| 126 |
+
"""Constructs a BertTokenizer.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
**vocab_file**: Path to a one-wordpiece-per-line vocabulary file
|
| 130 |
+
**do_lower_case**: (`optional`) boolean (default True)
|
| 131 |
+
Whether to lower case the input
|
| 132 |
+
Only has an effect when do_basic_tokenize=True
|
| 133 |
+
**do_basic_tokenize**: (`optional`) boolean (default True)
|
| 134 |
+
Whether to do basic tokenization before wordpiece.
|
| 135 |
+
**never_split**: (`optional`) list of string
|
| 136 |
+
List of tokens which will never be split during tokenization.
|
| 137 |
+
Only has an effect when do_basic_tokenize=True
|
| 138 |
+
**tokenize_chinese_chars**: (`optional`) boolean (default True)
|
| 139 |
+
Whether to tokenize Chinese characters.
|
| 140 |
+
This should likely be deactivated for Japanese:
|
| 141 |
+
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
|
| 142 |
+
"""
|
| 143 |
+
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
| 144 |
+
pad_token=pad_token, cls_token=cls_token,
|
| 145 |
+
mask_token=mask_token, **kwargs)
|
| 146 |
+
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
| 147 |
+
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
| 148 |
+
|
| 149 |
+
if not os.path.isfile(vocab_file):
|
| 150 |
+
raise ValueError(
|
| 151 |
+
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
| 152 |
+
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
| 153 |
+
self.vocab = load_vocab(vocab_file)
|
| 154 |
+
self.ids_to_tokens = collections.OrderedDict(
|
| 155 |
+
[(ids, tok) for tok, ids in self.vocab.items()])
|
| 156 |
+
self.do_basic_tokenize = do_basic_tokenize
|
| 157 |
+
if do_basic_tokenize:
|
| 158 |
+
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
| 159 |
+
never_split=never_split,
|
| 160 |
+
tokenize_chinese_chars=tokenize_chinese_chars)
|
| 161 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def vocab_size(self):
|
| 165 |
+
return len(self.vocab)
|
| 166 |
+
|
| 167 |
+
def _tokenize(self, text):
|
| 168 |
+
split_tokens = []
|
| 169 |
+
if self.do_basic_tokenize:
|
| 170 |
+
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
| 171 |
+
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
| 172 |
+
split_tokens.append(sub_token)
|
| 173 |
+
else:
|
| 174 |
+
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
| 175 |
+
return split_tokens
|
| 176 |
+
|
| 177 |
+
def _convert_token_to_id(self, token):
|
| 178 |
+
""" Converts a token (str/unicode) in an id using the vocab. """
|
| 179 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 180 |
+
|
| 181 |
+
def _convert_id_to_token(self, index):
|
| 182 |
+
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
| 183 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 184 |
+
|
| 185 |
+
def convert_tokens_to_string(self, tokens):
|
| 186 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 187 |
+
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
| 188 |
+
return out_string
|
| 189 |
+
|
| 190 |
+
def add_special_tokens_single_sentence(self, token_ids):
|
| 191 |
+
"""
|
| 192 |
+
Adds special tokens to the a sequence for sequence classification tasks.
|
| 193 |
+
A BERT sequence has the following format: [CLS] X [SEP]
|
| 194 |
+
"""
|
| 195 |
+
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
| 196 |
+
|
| 197 |
+
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
|
| 198 |
+
"""
|
| 199 |
+
Adds special tokens to a sequence pair for sequence classification tasks.
|
| 200 |
+
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
| 201 |
+
"""
|
| 202 |
+
sep = [self.sep_token_id]
|
| 203 |
+
cls = [self.cls_token_id]
|
| 204 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 205 |
+
|
| 206 |
+
def save_vocabulary(self, vocab_path):
|
| 207 |
+
"""Save the tokenizer vocabulary to a directory or file."""
|
| 208 |
+
index = 0
|
| 209 |
+
if os.path.isdir(vocab_path):
|
| 210 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
| 211 |
+
else:
|
| 212 |
+
vocab_file = vocab_path
|
| 213 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 214 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 215 |
+
if index != token_index:
|
| 216 |
+
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 217 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
| 218 |
+
index = token_index
|
| 219 |
+
writer.write(token + u'\n')
|
| 220 |
+
index += 1
|
| 221 |
+
return (vocab_file,)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class BasicTokenizer(object):
|
| 225 |
+
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
| 226 |
+
|
| 227 |
+
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
|
| 228 |
+
""" Constructs a BasicTokenizer.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
**do_lower_case**: Whether to lower case the input.
|
| 232 |
+
**never_split**: (`optional`) list of str
|
| 233 |
+
Kept for backward compatibility purposes.
|
| 234 |
+
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
|
| 235 |
+
List of token not to split.
|
| 236 |
+
**tokenize_chinese_chars**: (`optional`) boolean (default True)
|
| 237 |
+
Whether to tokenize Chinese characters.
|
| 238 |
+
This should likely be deactivated for Japanese:
|
| 239 |
+
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
|
| 240 |
+
"""
|
| 241 |
+
if never_split is None:
|
| 242 |
+
never_split = []
|
| 243 |
+
self.do_lower_case = do_lower_case
|
| 244 |
+
self.never_split = never_split
|
| 245 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
| 246 |
+
|
| 247 |
+
def tokenize(self, text, never_split=None):
|
| 248 |
+
""" Basic Tokenization of a piece of text.
|
| 249 |
+
Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
**never_split**: (`optional`) list of str
|
| 253 |
+
Kept for backward compatibility purposes.
|
| 254 |
+
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
|
| 255 |
+
List of token not to split.
|
| 256 |
+
"""
|
| 257 |
+
never_split = self.never_split + (never_split if never_split is not None else [])
|
| 258 |
+
text = self._clean_text(text)
|
| 259 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
| 260 |
+
# models. This is also applied to the English models now, but it doesn't
|
| 261 |
+
# matter since the English models were not trained on any Chinese data
|
| 262 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
| 263 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
| 264 |
+
# words in the English Wikipedia.).
|
| 265 |
+
if self.tokenize_chinese_chars:
|
| 266 |
+
text = self._tokenize_chinese_chars(text)
|
| 267 |
+
orig_tokens = whitespace_tokenize(text)
|
| 268 |
+
split_tokens = []
|
| 269 |
+
for token in orig_tokens:
|
| 270 |
+
if self.do_lower_case and token not in never_split:
|
| 271 |
+
token = token.lower()
|
| 272 |
+
token = self._run_strip_accents(token)
|
| 273 |
+
split_tokens.extend(self._run_split_on_punc(token))
|
| 274 |
+
|
| 275 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
| 276 |
+
return output_tokens
|
| 277 |
+
|
| 278 |
+
def _run_strip_accents(self, text):
|
| 279 |
+
"""Strips accents from a piece of text."""
|
| 280 |
+
text = unicodedata.normalize("NFD", text)
|
| 281 |
+
output = []
|
| 282 |
+
for char in text:
|
| 283 |
+
cat = unicodedata.category(char)
|
| 284 |
+
if cat == "Mn":
|
| 285 |
+
continue
|
| 286 |
+
output.append(char)
|
| 287 |
+
return "".join(output)
|
| 288 |
+
|
| 289 |
+
def _run_split_on_punc(self, text, never_split=None):
|
| 290 |
+
"""Splits punctuation on a piece of text."""
|
| 291 |
+
if never_split is not None and text in never_split:
|
| 292 |
+
return [text]
|
| 293 |
+
chars = list(text)
|
| 294 |
+
i = 0
|
| 295 |
+
start_new_word = True
|
| 296 |
+
output = []
|
| 297 |
+
while i < len(chars):
|
| 298 |
+
char = chars[i]
|
| 299 |
+
if _is_punctuation(char):
|
| 300 |
+
output.append([char])
|
| 301 |
+
start_new_word = True
|
| 302 |
+
else:
|
| 303 |
+
if start_new_word:
|
| 304 |
+
output.append([])
|
| 305 |
+
start_new_word = False
|
| 306 |
+
output[-1].append(char)
|
| 307 |
+
i += 1
|
| 308 |
+
|
| 309 |
+
return ["".join(x) for x in output]
|
| 310 |
+
|
| 311 |
+
def _tokenize_chinese_chars(self, text):
|
| 312 |
+
"""Adds whitespace around any CJK character."""
|
| 313 |
+
output = []
|
| 314 |
+
for char in text:
|
| 315 |
+
cp = ord(char)
|
| 316 |
+
if self._is_chinese_char(cp):
|
| 317 |
+
output.append(" ")
|
| 318 |
+
output.append(char)
|
| 319 |
+
output.append(" ")
|
| 320 |
+
else:
|
| 321 |
+
output.append(char)
|
| 322 |
+
return "".join(output)
|
| 323 |
+
|
| 324 |
+
def _is_chinese_char(self, cp):
|
| 325 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 326 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 327 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 328 |
+
#
|
| 329 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 330 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 331 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 332 |
+
# space-separated words, so they are not treated specially and handled
|
| 333 |
+
# like the all of the other languages.
|
| 334 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
| 335 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
| 336 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
| 337 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
| 338 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
| 339 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
| 340 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
| 341 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
| 342 |
+
return True
|
| 343 |
+
|
| 344 |
+
return False
|
| 345 |
+
|
| 346 |
+
def _clean_text(self, text):
|
| 347 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 348 |
+
output = []
|
| 349 |
+
for char in text:
|
| 350 |
+
cp = ord(char)
|
| 351 |
+
if cp == 0 or cp == 0xfffd or _is_control(char):
|
| 352 |
+
continue
|
| 353 |
+
if _is_whitespace(char):
|
| 354 |
+
output.append(" ")
|
| 355 |
+
else:
|
| 356 |
+
output.append(char)
|
| 357 |
+
return "".join(output)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class WordpieceTokenizer(object):
|
| 361 |
+
"""Runs WordPiece tokenization."""
|
| 362 |
+
|
| 363 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
| 364 |
+
self.vocab = vocab
|
| 365 |
+
self.unk_token = unk_token
|
| 366 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 367 |
+
|
| 368 |
+
def tokenize(self, text):
|
| 369 |
+
"""Tokenizes a piece of text into its word pieces.
|
| 370 |
+
|
| 371 |
+
This uses a greedy longest-match-first algorithm to perform tokenization
|
| 372 |
+
using the given vocabulary.
|
| 373 |
+
|
| 374 |
+
For example:
|
| 375 |
+
input = "unaffable"
|
| 376 |
+
output = ["un", "##aff", "##able"]
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
text: A single token or whitespace separated tokens. This should have
|
| 380 |
+
already been passed through `BasicTokenizer`.
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
A list of wordpiece tokens.
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
output_tokens = []
|
| 387 |
+
for token in whitespace_tokenize(text):
|
| 388 |
+
chars = list(token)
|
| 389 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 390 |
+
output_tokens.append(self.unk_token)
|
| 391 |
+
continue
|
| 392 |
+
|
| 393 |
+
is_bad = False
|
| 394 |
+
start = 0
|
| 395 |
+
sub_tokens = []
|
| 396 |
+
while start < len(chars):
|
| 397 |
+
end = len(chars)
|
| 398 |
+
cur_substr = None
|
| 399 |
+
while start < end:
|
| 400 |
+
substr = "".join(chars[start:end])
|
| 401 |
+
if start > 0:
|
| 402 |
+
substr = "##" + substr
|
| 403 |
+
if substr in self.vocab:
|
| 404 |
+
cur_substr = substr
|
| 405 |
+
break
|
| 406 |
+
end -= 1
|
| 407 |
+
if cur_substr is None:
|
| 408 |
+
is_bad = True
|
| 409 |
+
break
|
| 410 |
+
sub_tokens.append(cur_substr)
|
| 411 |
+
start = end
|
| 412 |
+
|
| 413 |
+
if is_bad:
|
| 414 |
+
output_tokens.append(self.unk_token)
|
| 415 |
+
else:
|
| 416 |
+
output_tokens.extend(sub_tokens)
|
| 417 |
+
return output_tokens
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def _is_whitespace(char):
|
| 421 |
+
"""Checks whether `chars` is a whitespace character."""
|
| 422 |
+
# \t, \n, and \r are technically contorl characters but we treat them
|
| 423 |
+
# as whitespace since they are generally considered as such.
|
| 424 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
| 425 |
+
return True
|
| 426 |
+
cat = unicodedata.category(char)
|
| 427 |
+
if cat == "Zs":
|
| 428 |
+
return True
|
| 429 |
+
return False
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _is_control(char):
|
| 433 |
+
"""Checks whether `chars` is a control character."""
|
| 434 |
+
# These are technically control characters but we count them as whitespace
|
| 435 |
+
# characters.
|
| 436 |
+
if char == "\t" or char == "\n" or char == "\r":
|
| 437 |
+
return False
|
| 438 |
+
cat = unicodedata.category(char)
|
| 439 |
+
if cat.startswith("C"):
|
| 440 |
+
return True
|
| 441 |
+
return False
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _is_punctuation(char):
|
| 445 |
+
"""Checks whether `chars` is a punctuation character."""
|
| 446 |
+
cp = ord(char)
|
| 447 |
+
# We treat all non-letter/number ASCII as punctuation.
|
| 448 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
| 449 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
| 450 |
+
# consistency.
|
| 451 |
+
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
| 452 |
+
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
| 453 |
+
return True
|
| 454 |
+
cat = unicodedata.category(char)
|
| 455 |
+
if cat.startswith("P"):
|
| 456 |
+
return True
|
| 457 |
+
return False
|