| import math |
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from collections import OrderedDict |
|
|
| from transformers import PreTrainedModel |
|
|
| from .configuration_dalle imoprt DallEConfig |
|
|
|
|
| class Conv2d(nn.Module): |
| def __init__(self, n_in, n_out, kw, config, use_float16=True): |
| super().__init__() |
| |
| assert n_in >= 1 |
| assert n_out >= 1 |
| assert kw >= 1 and kw % 2 == 1 |
| |
| self.n_in = n_in |
| self.n_out = n_out |
| self.kw = kw |
| self.config = config |
| self.use_float16 = use_float16 |
| w = torch.empty( |
| (n_out, n_in, kw, kw), |
| dtype=torch.float32, |
| device=config.device, |
| requires_grad=config.requires_grad, |
| ) |
| w.normal_(std=1 / math.sqrt(n_in * kw ** 2)) |
| |
| b = torch.zeros( |
| (n_out,), |
| dtype=torch.float32, |
| device=config.device, |
| requires_grad=config.requires_grad, |
| ) |
| |
| self.w = nn.Parameter(w) |
| self.b = nn.Parameter(b) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.use_float16 and 'cuda' in self.w.device.type: |
| if x.dtype != torch.float16: |
| x = x.half() |
| w, b = self.w.half(), self.b.half() |
| else: |
| if x.dtype != torch.float32: |
| x = x.float() |
| w, b = self.w, self.b |
| return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) |
| |
| def extra_repr(self): |
| inner_repr = f"n_in={self.n_in}, n_out={self.n_out}, kw={self.kw}, " |
| inner_repr += f"use_float16={self.use_float16}, " |
| inner_repr += f"device={self.config.device}, " |
| inner_repr += f"requires_grad={self.config.requires_grad}" |
| return inner_repr |
| |
| |
| class EncoderBlock(nn.Module): |
| def __init__(self, n_in, n_out, n_layers, config): |
| super().__init__() |
| |
| assert n_in >= 1 |
| assert n_out >= 1 and n_out % 4 == 0 |
| assert n_layers >= 1 |
| |
| self.n_in = n_in |
| self.n_out = n_out |
| self.n_hid = n_out // 4 |
| self.post_gain = 1 / (n_layers ** 2) |
| |
| if self.n_in != self.n_out: |
| self.id_path = Conv2d(self.n_in, self.n_out, 1, config) |
| else: |
| self.id_path = nn.Identity() |
| |
| self.res_path = nn.Sequential(OrderedDict([ |
| ('relu_1', nn.ReLU()), |
| ('conv_1', Conv2d(self.n_in, self.n_hid, 3, config)), |
| ('relu_2', nn.ReLU()), |
| ('conv_2', Conv2d(self.n_hid, self.n_hid, 3, config)), |
| ('relu_3', nn.ReLU()), |
| ('conv_3', Conv2d(self.n_hid, self.n_hid, 3, config)), |
| ('relu_4', nn.ReLU()), |
| ('conv_4', Conv2d(self.n_hid, self.n_out, 1, config)), |
| ])) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.id_path(x) + self.post_gain * self.res_path(x) |
|
|
|
|
| class DallEPreTrainedModel(PreTrainedModel): |
| config_class = DallEConfig |
| base_model_prefix="dalle" |
| |
| |
| class DallEEncoder(DallEPreTrainedModel): |
| |
| def __init__(self, config): |
| super().__init__(config) |
| blk_range = range(config.n_blk_per_group) |
| n_layers = config.group_count * config.n_blk_per_group |
| |
| in_channels = config.input_channels |
| n_hid = config.n_hid |
| |
| self.blocks = nn.Sequential(OrderedDict([ |
| ('input', Conv2d(in_channels, n_hid, 7, config)), |
| ('group_1', nn.Sequential(OrderedDict([ |
| *[(f'block_{i + 1}', |
| EncoderBlock(n_hid, n_hid, n_layers, config)) |
| for i in blk_range], |
| ('pool', nn.MaxPool2d(kernel_size=2)), |
| ]))), |
| ('group_2', nn.Sequential(OrderedDict([ |
| *[(f'block_{i + 1}', |
| EncoderBlock( |
| n_hid if i == 0 else 2 * n_hid, |
| 2 * n_hid, n_layers, config)) |
| for i in blk_range], |
| ('pool', nn.MaxPool2d(kernel_size=2)), |
| ]))), |
| ('group_3', nn.Sequential(OrderedDict([ |
| *[(f'block_{i + 1}', |
| EncoderBlock( |
| 2 * n_hid if i == 0 else 4 * n_hid, |
| 4 * n_hid, n_layers, config)) |
| for i in blk_range], |
| ('pool', nn.MaxPool2d(kernel_size=2)), |
| ]))), |
| ('group_4', nn.Sequential(OrderedDict([ |
| *[(f'block_{i + 1}', |
| EncoderBlock( |
| 4 * n_hid if i == 0 else 8 * n_hid, |
| 8 * n_hid, n_layers, config)) |
| for i in blk_range], |
| ]))), |
| ('output', nn.Sequential(OrderedDict([ |
| ('relu', nn.ReLU()), |
| ('conv', Conv2d( |
| 8 * n_hid, config.vocab_size, |
| 1, config, use_float16=False)), |
| ]))), |
| ])) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if len(x.shape) != 4: |
| raise ValueError(f'input shape {x.shape} is not 4d') |
| if x.shape[1] != self.input_channels: |
| raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}') |
| if x.dtype != torch.float32: |
| raise ValueError('input must have dtype torch.float32') |
| |
| return self.blocks(x) |