| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| import math |
|
|
| |
| |
| class PositionEncodingSine2D(nn.Module): |
| """ |
| This is a more standard version of the position embedding, very similar to the one |
| used by the Attention is all you need paper, generalized to work on images. |
| """ |
| def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): |
| super(PositionEncodingSine2D, self).__init__() |
| self.num_pos_feats = num_pos_feats |
| self.temperature = temperature |
| self.normalize = normalize |
| if scale is not None and normalize is False: |
| raise ValueError("normalize should be True if scale is passed") |
| if scale is None: |
| scale = 2 * math.pi |
| self.scale = scale |
|
|
| def forward(self, x, isTarget = False): |
| ''' |
| input x: B, C, H, W |
| return pos: B, C, H, W |
| |
| ''' |
| not_mask = torch.ones(x.size()[0], x.size()[2], x.size()[3]).to(x.device) |
| y_embed = not_mask.cumsum(1, dtype=torch.float32) |
| x_embed = not_mask.cumsum(2, dtype=torch.float32) |
| |
| if self.normalize: |
| eps = 1e-6 |
| |
| y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
| x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
|
| dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
| dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
|
| pos_x = x_embed[:, :, :, None] / dim_t |
| pos_y = y_embed[:, :, :, None] / dim_t |
| pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
| pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
| pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
| return pos |
|
|
| class EncoderLayerInnerAttention(nn.Module): |
| """ |
| Transformer encoder with all paramters |
| """ |
| def __init__(self, d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight): |
| super(EncoderLayerInnerAttention, self).__init__() |
| |
| |
| self.pos_weight = pos_weight |
| self.feat_weight = feat_weight |
| self.inner_encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation = activation) |
| self.posEncoder = PositionEncodingSine2D(d_model // 2) |
| |
| self.cross_encoder_layer = EncoderLayerCrossAttention(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, |
| dropout=dropout, activation=activation) |
| |
| def forward(self, x, y, featmask, x_mask = None, y_mask = None): |
| ''' |
| input x: B, C, H, W |
| input y: B, C, H, W |
| input x_mask: B, 1, H, W, mask == True will be ignored |
| input y_mask: B, 1, H, W, mask == True will be ignored |
| ''' |
| |
| |
| x = self.cross_encoder_layer(x, featmask, None)[0] |
|
|
| bx, cx, hx, wx = x.size() |
| |
| by, cy, hy, wy = y.size() |
| |
| posx = self.posEncoder(x) |
| posy = self.posEncoder(y) |
| |
| |
| featx = self.feat_weight * x + self.pos_weight * posx |
| featy = self.feat_weight * y + self.pos_weight * posy |
| |
| |
| |
| featx = featx.flatten(2).permute(2, 0, 1) |
| featy = featy.flatten(2).permute(2, 0, 1) |
| x_mask = x_mask.flatten(2).squeeze(1) if x_mask is not None else torch.cuda.BoolTensor(bx, hx * wx).fill_(False) |
| y_mask = y_mask.flatten(2).squeeze(1) if y_mask is not None else torch.cuda.BoolTensor(by, hy * wy).fill_(False) |
| |
| |
| len_seq_x, len_seq_y = featx.size()[0], featy.size()[0] |
| |
| output = torch.cat([featx, featy], dim=0) |
| src_key_padding_mask = torch.cat((x_mask, y_mask), dim=1) |
| with torch.no_grad() : |
| src_mask = torch.cuda.BoolTensor(hx * wx + hy * wy, hx * wx + hy * wy).fill_(True) |
| src_mask[:hx * wx, :hx * wx] = False |
| src_mask[hx * wx :, hx * wx:] = False |
| |
| output = self.inner_encoder_layer(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask) |
|
|
| outx, outy = output.narrow(0, 0, len_seq_x), output.narrow(0, len_seq_x, len_seq_y) |
| outx, outy = outx.permute(1, 2, 0).view(bx, cx, hx, wx), outy.permute(1, 2, 0).view(by, cy, hy, wy) |
| x_mask, y_mask = x_mask.view(bx, 1, hx, wx), y_mask.view(bx, 1, hy, wy) |
| |
| return outx, outy, x_mask, y_mask |
| |
| class EncoderLayerCrossAttention(nn.Module): |
| """ |
| Transformer encoder with all paramters |
| """ |
| def __init__(self, d_model, nhead, dim_feedforward, dropout, activation): |
| super(EncoderLayerCrossAttention, self).__init__() |
| |
| self.cross_encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation = activation) |
| |
| def forward(self, featx, featy, featmask, x_mask = None, y_mask = None): |
| ''' |
| input x: B, C, H, W |
| input y: B, C, H, W |
| input x_mask: B, 1, H, W, mask == True will be ignored |
| input y_mask: B, 1, H, W, mask == True will be ignored |
| ''' |
| |
| bx, cx, hx, wx = featx.size() |
| by, cy, hy, wy = featy.size() |
| |
| |
| featx = featx.flatten(2).permute(2, 0, 1) |
| featy = featy.flatten(2).permute(2, 0, 1) |
| x_mask = x_mask.flatten(2).squeeze(1) if x_mask is not None else torch.cuda.BoolTensor(bx, hx * wx).fill_(False) |
| y_mask = y_mask.flatten(2).squeeze(1) if y_mask is not None else torch.cuda.BoolTensor(by, hy * wy).fill_(False) |
| |
| |
| len_seq_x, len_seq_y = featx.size()[0], featy.size()[0] |
| |
| output = torch.cat([featx, featy], dim=0) |
| src_key_padding_mask = torch.cat((x_mask, y_mask), dim=1) |
| with torch.no_grad() : |
| src_mask = torch.cuda.BoolTensor(hx * wx + hy * wy, hx * wx + hy * wy).fill_(False) |
| src_mask[:hx * wx, :hx * wx] = True |
| src_mask[hx * wx :, hx * wx:] = True |
| |
| output = self.cross_encoder_layer(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask) |
| |
| outx, outy = output.narrow(0, 0, len_seq_x), output.narrow(0, len_seq_x, len_seq_y) |
| outx, outy = outx.permute(1, 2, 0).view(bx, cx, hx, wx), outy.permute(1, 2, 0).view(by, cy, hy, wy) |
| x_mask, y_mask = x_mask.view(bx, 1, hx, wx), y_mask.view(bx, 1, hy, wy) |
| |
| return outx, outy, x_mask, y_mask |
| |
| class EncoderLayerEmpty(nn.Module): |
| """ |
| Transformer encoder with all paramters |
| """ |
| def __init__(self): |
| super(EncoderLayerEmpty, self).__init__() |
| |
| def forward(self, featx, featy, featmask, x_mask = None, y_mask = None): |
| ''' |
| input x: B, C, H, W |
| input y: B, C, H, W |
| input x_mask: B, 1, H, W, mask == True will be ignored |
| input y_mask: B, 1, H, W, mask == True will be ignored |
| ''' |
| return featx, featy, x_mask, y_mask |
| |
| class EncoderLayerBlock(nn.Module): |
| """ |
| Transformer encoder with all paramters |
| """ |
| def __init__(self, d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight, layer_type) : |
| super(EncoderLayerBlock, self).__init__() |
| |
| cross_encoder_layer = EncoderLayerCrossAttention(d_model, nhead, dim_feedforward, dropout, activation) |
| att_encoder_layer = EncoderLayerInnerAttention(d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight) |
| |
| if layer_type[0] == 'C' : |
| self.layer1 = cross_encoder_layer |
| elif layer_type[0] == 'I' : |
| self.layer1 = att_encoder_layer |
| elif layer_type[0] == 'N' : |
| self.layer1 = EncoderLayerEmpty() |
| |
| if layer_type[1] == 'C' : |
| self.layer2 = cross_encoder_layer |
| elif layer_type[1] == 'I' : |
| self.layer2 = att_encoder_layer |
| elif layer_type[1] == 'N' : |
| self.layer2 = EncoderLayerEmpty() |
|
|
| def forward(self, featx, featy, featmask, x_mask = None, y_mask = None): |
| ''' |
| input x: B, C, H, W |
| input y: B, C, H, W |
| input x_mask: B, 1, H, W, mask == True will be ignored |
| input y_mask: B, 1, H, W, mask == True will be ignored |
| ''' |
| |
| featx, featy, x_mask, y_mask = self.layer1(featx, featy, featmask, x_mask, y_mask) |
| featx, featy, x_mask, y_mask = self.layer2(featx, featy, featmask, x_mask, y_mask) |
|
|
| return featx, featy, x_mask, y_mask |
| |
| |
| class Decoder(nn.Module): |
|
|
| def __init__(self, in_channels=256, out_channels=3): |
| super(Decoder, self).__init__() |
|
|
| self.deconv1 = nn.ConvTranspose2d(in_channels, 128, 2, stride=2) |
| self.relu1 = nn.ReLU() |
|
|
| self.deconv2 = nn.ConvTranspose2d(128, out_channels, 2, stride=2) |
|
|
| def forward(self, x): |
|
|
| x = self.deconv1(x) |
| x = self.relu1(x) |
|
|
| x = self.deconv2(x) |
|
|
| x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False) |
|
|
| return x |
| |
| class ClsBranch(nn.Module): |
| |
| def __init__(self, in_dim): |
| super(ClsBranch, self).__init__() |
|
|
| self.conv = nn.Conv2d(in_dim, 1, 3) |
| self.relu = nn.ReLU() |
|
|
| self.mlp = nn.Sequential(*[nn.Linear(28*28, 32), |
| nn.ReLU(), |
| nn.Linear(32, 1), |
| nn.Sigmoid()]) |
| |
| def forward(self, x): |
|
|
| x = self.conv(x) |
| x = self.relu(x) |
| x = torch.flatten(x, start_dim=1) |
| x = self.mlp(x) |
| return x |
| |
| class Encoder(nn.Module): |
| """ |
| Transformer encoder with all paramters |
| """ |
| def __init__(self, feat_dim, pos_weight = 0.1, feat_weight=1, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, activation='relu', layer_type = ['I', 'C', 'I', 'C', 'I', 'C'], drop_feat = 0.1): |
| super(Encoder, self).__init__() |
| |
| self.num_layers = num_layers |
| self.feat_proj = nn.Conv2d(feat_dim, d_model, kernel_size=1) |
| self.drop_feat = nn.Dropout2d(p=drop_feat) |
| self.encoder_blocks = nn.ModuleList([EncoderLayerBlock(d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight, layer_type[i * 2 : i * 2 + 2]) for i in range(num_layers)]) |
|
|
| self.decoder = Decoder(d_model, 3) |
| self.cls_branch = ClsBranch(in_dim=256) |
| self.sigmoid = nn.Sigmoid() |
| self.eps = 1e-7 |
| |
| |
| def forward(self, x, y, fmask, x_mask = None, y_mask = None): |
| ''' |
| input x: B, C, H, W |
| input y: B, C, H, W |
| input x_mask: B, 1, H, W, mask == True will be ignored |
| input y_mask: B, 1, H, W, mask == True will be ignored |
| ''' |
| featx = self.feat_proj (x) |
| featx = self.drop_feat(featx) |
| |
| bx, cx, hx, wx = featx.size() |
| |
| featy = self.feat_proj (y) |
| featy = self.drop_feat(featy) |
| |
| by, cy, hy, wy = featy.size() |
|
|
| featmask = self.feat_proj(fmask) |
|
|
| for i in range(self.num_layers) : |
| featx, featy, x_mask, y_mask = self.encoder_blocks[i](featx, featy, featmask, x_mask, y_mask) |
| |
| out_cls = self.cls_branch(featy) |
| outx = self.sigmoid(self.decoder(featx)) |
| outy = self.sigmoid(self.decoder(featy)) |
|
|
| outx = torch.clamp(outx, min=self.eps, max=1-self.eps) |
| outy = torch.clamp(outy, min=self.eps, max=1-self.eps) |
| |
| return outx, outy, out_cls |
| |
| |
| class TransEncoder(nn.Module): |
| """ |
| Transformer encoder: small and large variants |
| """ |
| def __init__(self, feat_dim=1024, pos_weight = 0.1, feat_weight = 1, dropout=0.1, activation='relu', mode='small', layer_type=['I', 'C', 'I', 'C', 'I', 'N'], drop_feat=0.1): |
| super(TransEncoder, self).__init__() |
| |
| if mode == 'tiny' : |
| d_model=128 |
| nhead=2 |
| num_layers=3 |
| dim_feedforward=256 |
| |
| elif mode == 'small' : |
| d_model=256 |
| nhead=2 |
| num_layers=3 |
| dim_feedforward=256 |
| |
| elif mode == 'base' : |
| d_model=512 |
| nhead=8 |
| num_layers=3 |
| dim_feedforward=2048 |
| |
| elif mode == 'large' : |
| d_model=512 |
| nhead=8 |
| num_layers=6 |
| dim_feedforward=2048 |
| |
| self.net = Encoder(feat_dim, pos_weight, feat_weight, d_model, nhead, num_layers, dim_feedforward, dropout, activation, layer_type, drop_feat) |
| |
| def forward(self, x, y, fmask, x_mask = None, y_mask = None): |
| ''' |
| input x: B, C, H, W |
| input y: B, C, H, W |
| |
| ''' |
| outx, outy, out_cls = self.net(x, y, fmask, x_mask, y_mask) |
|
|
| return outx, outy, out_cls |
| |
| if __name__ == '__main__' : |
| |
| feat_dim = 256 |
| mode = 'small' |
| x = torch.cuda.FloatTensor(2, feat_dim, 10, 10) |
| x_mask = torch.cuda.BoolTensor(2, 1, 10, 10) |
| |
| net = TransEncoder() |
| |
| print (net) |