| import math |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import NamedTuple, Tuple |
|
|
| import torch |
| from choices import * |
| from config_base import BaseConfig |
| from torch import nn |
| from torch.nn import init |
|
|
| from .blocks import * |
| from .nn import timestep_embedding |
| from .unet import * |
|
|
|
|
| class LatentNetType(Enum): |
| none = 'none' |
| |
| skip = 'skip' |
|
|
|
|
| class LatentNetReturn(NamedTuple): |
| pred: torch.Tensor = None |
|
|
|
|
| @dataclass |
| class MLPSkipNetConfig(BaseConfig): |
| """ |
| default MLP for the latent DPM in the paper! |
| """ |
| num_channels: int |
| skip_layers: Tuple[int] |
| num_hid_channels: int |
| num_layers: int |
| num_time_emb_channels: int = 64 |
| activation: Activation = Activation.silu |
| use_norm: bool = True |
| condition_bias: float = 1 |
| dropout: float = 0 |
| last_act: Activation = Activation.none |
| num_time_layers: int = 2 |
| time_last_act: bool = False |
|
|
| def make_model(self): |
| return MLPSkipNet(self) |
|
|
|
|
| class MLPSkipNet(nn.Module): |
| """ |
| concat x to hidden layers |
| |
| default MLP for the latent DPM in the paper! |
| """ |
| def __init__(self, conf: MLPSkipNetConfig): |
| super().__init__() |
| self.conf = conf |
|
|
| layers = [] |
| for i in range(conf.num_time_layers): |
| if i == 0: |
| a = conf.num_time_emb_channels |
| b = conf.num_channels |
| else: |
| a = conf.num_channels |
| b = conf.num_channels |
| layers.append(nn.Linear(a, b)) |
| if i < conf.num_time_layers - 1 or conf.time_last_act: |
| layers.append(conf.activation.get_act()) |
| self.time_embed = nn.Sequential(*layers) |
|
|
| self.layers = nn.ModuleList([]) |
| for i in range(conf.num_layers): |
| if i == 0: |
| act = conf.activation |
| norm = conf.use_norm |
| cond = True |
| a, b = conf.num_channels, conf.num_hid_channels |
| dropout = conf.dropout |
| elif i == conf.num_layers - 1: |
| act = Activation.none |
| norm = False |
| cond = False |
| a, b = conf.num_hid_channels, conf.num_channels |
| dropout = 0 |
| else: |
| act = conf.activation |
| norm = conf.use_norm |
| cond = True |
| a, b = conf.num_hid_channels, conf.num_hid_channels |
| dropout = conf.dropout |
|
|
| if i in conf.skip_layers: |
| a += conf.num_channels |
|
|
| self.layers.append( |
| MLPLNAct( |
| a, |
| b, |
| norm=norm, |
| activation=act, |
| cond_channels=conf.num_channels, |
| use_cond=cond, |
| condition_bias=conf.condition_bias, |
| dropout=dropout, |
| )) |
| self.last_act = conf.last_act.get_act() |
|
|
| def forward(self, x, t, **kwargs): |
| t = timestep_embedding(t, self.conf.num_time_emb_channels) |
| cond = self.time_embed(t) |
| h = x |
| for i in range(len(self.layers)): |
| if i in self.conf.skip_layers: |
| |
| h = torch.cat([h, x], dim=1) |
| h = self.layers[i].forward(x=h, cond=cond) |
| h = self.last_act(h) |
| return LatentNetReturn(h) |
|
|
|
|
| class MLPLNAct(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| norm: bool, |
| use_cond: bool, |
| activation: Activation, |
| cond_channels: int, |
| condition_bias: float = 0, |
| dropout: float = 0, |
| ): |
| super().__init__() |
| self.activation = activation |
| self.condition_bias = condition_bias |
| self.use_cond = use_cond |
|
|
| self.linear = nn.Linear(in_channels, out_channels) |
| self.act = activation.get_act() |
| if self.use_cond: |
| self.linear_emb = nn.Linear(cond_channels, out_channels) |
| self.cond_layers = nn.Sequential(self.act, self.linear_emb) |
| if norm: |
| self.norm = nn.LayerNorm(out_channels) |
| else: |
| self.norm = nn.Identity() |
|
|
| if dropout > 0: |
| self.dropout = nn.Dropout(p=dropout) |
| else: |
| self.dropout = nn.Identity() |
|
|
| self.init_weights() |
|
|
| def init_weights(self): |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| if self.activation == Activation.relu: |
| init.kaiming_normal_(module.weight, |
| a=0, |
| nonlinearity='relu') |
| elif self.activation == Activation.lrelu: |
| init.kaiming_normal_(module.weight, |
| a=0.2, |
| nonlinearity='leaky_relu') |
| elif self.activation == Activation.silu: |
| init.kaiming_normal_(module.weight, |
| a=0, |
| nonlinearity='relu') |
| else: |
| |
| pass |
|
|
| def forward(self, x, cond=None): |
| x = self.linear(x) |
| if self.use_cond: |
| |
| cond = self.cond_layers(cond) |
| cond = (cond, None) |
|
|
| |
| x = x * (self.condition_bias + cond[0]) |
| if cond[1] is not None: |
| x = x + cond[1] |
| |
| x = self.norm(x) |
| else: |
| |
| x = self.norm(x) |
| x = self.act(x) |
| x = self.dropout(x) |
| return x |