| import math
|
| import torch
|
|
|
| from .base import BaseModule
|
| from .modules import Mish, Upsample, Downsample, Rezero, Block, ResnetBlock
|
| from .modules import LinearAttention, Residual, Timesteps, TimbreBlock, PitchPosEmb
|
|
|
| from einops import rearrange
|
|
|
|
|
| class UNetPitcher(BaseModule):
|
| def __init__(self,
|
| dim_base,
|
| dim_cond,
|
| use_ref_t,
|
| use_embed,
|
| dim_embed=256,
|
| dim_mults=(1, 2, 4),
|
| pitch_type='bins'):
|
|
|
| super(UNetPitcher, self).__init__()
|
| self.use_ref_t = use_ref_t
|
| self.use_embed = use_embed
|
| self.pitch_type = pitch_type
|
|
|
| dim_in = 2
|
|
|
|
|
| self.time_pos_emb = Timesteps(num_channels=dim_base,
|
| flip_sin_to_cos=True,
|
| downscale_freq_shift=0)
|
|
|
| self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_base, dim_base * 4),
|
| Mish(), torch.nn.Linear(dim_base * 4, dim_base))
|
|
|
|
|
| timbre_total = 0
|
| if use_ref_t:
|
| self.ref_block = TimbreBlock(out_dim=dim_cond)
|
| timbre_total += dim_cond
|
| if use_embed:
|
| timbre_total += dim_embed
|
|
|
| if timbre_total != 0:
|
| self.timbre_block = torch.nn.Sequential(
|
| torch.nn.Linear(timbre_total, 4 * dim_cond),
|
| Mish(),
|
| torch.nn.Linear(4 * dim_cond, dim_cond))
|
|
|
| if use_embed or use_ref_t:
|
| dim_in += dim_cond
|
|
|
| self.pitch_pos_emb = PitchPosEmb(dim_cond)
|
| self.pitch_mlp = torch.nn.Sequential(
|
| torch.nn.Conv1d(dim_cond, dim_cond * 4, 1, stride=1),
|
| Mish(),
|
| torch.nn.Conv1d(dim_cond * 4, dim_cond, 1, stride=1), )
|
| dim_in += dim_cond
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| dims = [dim_in, *map(lambda m: dim_base * m, dim_mults)]
|
| in_out = list(zip(dims[:-1], dims[1:]))
|
|
|
| self.downs = torch.nn.ModuleList([])
|
| self.ups = torch.nn.ModuleList([])
|
| num_resolutions = len(in_out)
|
|
|
| for ind, (dim_in, dim_out) in enumerate(in_out):
|
| is_last = ind >= (num_resolutions - 1)
|
| self.downs.append(torch.nn.ModuleList([
|
| ResnetBlock(dim_in, dim_out, time_emb_dim=dim_base),
|
| ResnetBlock(dim_out, dim_out, time_emb_dim=dim_base),
|
| Residual(Rezero(LinearAttention(dim_out))),
|
| Downsample(dim_out) if not is_last else torch.nn.Identity()]))
|
|
|
| mid_dim = dims[-1]
|
| self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base)
|
| self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
| self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base)
|
|
|
| for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
| self.ups.append(torch.nn.ModuleList([
|
| ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim_base),
|
| ResnetBlock(dim_in, dim_in, time_emb_dim=dim_base),
|
| Residual(Rezero(LinearAttention(dim_in))),
|
| Upsample(dim_in)]))
|
| self.final_block = Block(dim_base, dim_base)
|
| self.final_conv = torch.nn.Conv2d(dim_base, 1, 1)
|
|
|
| def forward(self, x, mean, f0, t, ref=None, embed=None):
|
| if not torch.is_tensor(t):
|
| t = torch.tensor([t], dtype=torch.long, device=x.device)
|
| if len(t.shape) == 0:
|
| t = t * torch.ones(x.shape[0], dtype=t.dtype, device=x.device)
|
|
|
| t = self.time_pos_emb(t)
|
| t = self.mlp(t)
|
|
|
| x = torch.stack([x, mean], 1)
|
|
|
| f0 = self.pitch_pos_emb(f0)
|
| f0 = self.pitch_mlp(f0)
|
| f0 = f0.unsqueeze(2)
|
| f0 = torch.cat(x.shape[2] * [f0], 2)
|
|
|
| timbre = None
|
| if self.use_ref_t:
|
| ref = torch.stack([ref], 1)
|
| timbre = self.ref_block(ref)
|
| if self.use_embed:
|
| if timbre is not None:
|
| timbre = torch.cat([timbre, embed], 1)
|
| else:
|
| timbre = embed
|
| if timbre is None:
|
|
|
| condition = f0
|
| else:
|
| timbre = self.timbre_block(timbre).unsqueeze(-1).unsqueeze(-1)
|
| timbre = torch.cat(x.shape[2] * [timbre], 2)
|
| timbre = torch.cat(x.shape[3] * [timbre], 3)
|
| condition = torch.cat([f0, timbre], 1)
|
|
|
| x = torch.cat([x, condition], 1)
|
|
|
| hiddens = []
|
| for resnet1, resnet2, attn, downsample in self.downs:
|
| x = resnet1(x, t)
|
| x = resnet2(x, t)
|
| x = attn(x)
|
| hiddens.append(x)
|
| x = downsample(x)
|
|
|
| x = self.mid_block1(x, t)
|
| x = self.mid_attn(x)
|
| x = self.mid_block2(x, t)
|
|
|
| for resnet1, resnet2, attn, upsample in self.ups:
|
| x = torch.cat((x, hiddens.pop()), dim=1)
|
| x = resnet1(x, t)
|
| x = resnet2(x, t)
|
| x = attn(x)
|
| x = upsample(x)
|
|
|
| x = self.final_block(x)
|
| output = self.final_conv(x)
|
|
|
| return output.squeeze(1) |