File size: 4,564 Bytes
4edc9aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import torch
import torch.nn as nn
from typing import List
from einops import *
from torch import einsum as ein
from .unit_nn import SinusoidalPosEmb
import numpy as np
# ββ helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def sinusoidal_pos_emb(t: torch.Tensor, dim: int) -> torch.Tensor:
"""t: (B,) -> (B, dim)"""
device = t.device
half = dim // 2
freqs = torch.exp(-torch.arange(half, device=device) * (np.log(10000) / (half - 1)))
args = t[:, None] * freqs[None]
return torch.cat([args.sin(), args.cos()], dim=-1)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
# accept (B,), (B,1), or (B,1,1) β always return (B, dim)
t = t.view(t.shape[0])
return sinusoidal_pos_emb(t, self.dim)
# ββ MLP block βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# %%
class MLP(nn.Module):
def __init__(self, in_c, hidden_c, out_c, time_emb_dim):
super().__init__()
self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish())
self.net1 = nn.Sequential(nn.Conv1d(in_c, hidden_c, 1), nn.ReLU())
self.net2 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU())
self.net3 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU())
self.out = nn.Conv1d(hidden_c, out_c, 1)
def forward(self, x, time_emb):
h = self.net1(x)
h = h + self.time_net(time_emb).unsqueeze(-1)
h = self.net2(h)
h = self.net3(h)
return self.out(h)
# class MLP(nn.Module):
# def __init__(self, in_c: int, hidden_c: int, out_c: int, time_emb_dim: int):
# super().__init__()
# self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish())
# self.net1 = nn.Sequential(nn.Linear(in_c, hidden_c), nn.ReLU())
# self.net2 = nn.Linear(hidden_c, out_c)
# def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
# # x : (B, in_c, L)
# # time_emb : (B, time_emb_dim)
# x_t = x.transpose(1, 2) # (B, L, in_c) for Linear
# out = self.net1(x_t) # (B, L, hidden_c)
# out = out + self.time_net(time_emb).unsqueeze(1) # broadcast over L
# out = self.net2(out) # (B, L, out_c)
# return out.transpose(1, 2) # (B, out_c, L)
# %%
# ββ Decoder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Decoder(nn.Module):
"""
Lightweight MLP velocity estimator for toy 2-D flow-matching.
Tensor contract
---------------
forward(x, mu, t) -> vel
x : (B, feat_dim, L)
mu : (B, feat_dim, L)
t : (B,) | (B,1) | (B,1,1) # all accepted
vel : (B, feat_dim, L)
"""
def __init__(
self,
in_c: int = 2,
hidden_dim: int = 128,
out_c: int = 2,
time_emb_dim: int = 64,
cond_dim: int = 0,
):
super().__init__()
self.time_emb = SinusoidalPosEmb(time_emb_dim)
self.time_mlp = nn.Sequential(
nn.Linear(time_emb_dim, time_emb_dim),
)
# concat(x, mu) along channel dim -> 2*feat_dim channels
self.net = MLP(
in_c=in_c * 2, hidden_c=hidden_dim, out_c=out_c, time_emb_dim=time_emb_dim
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(
self,
x: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
cond=None,
) -> torch.Tensor:
# normalise t to (B,) regardless of input shape
t_flat = t.reshape(x.shape[0]) # (B,)
t_emb = self.time_mlp(self.time_emb(t_flat)) # (B, time_emb_dim)
# concat along channel axis (B, 2*feat_dim, L)
xmu = torch.cat([x, mu], dim=1)
return self.net(xmu, t_emb) # (B, feat_dim, L)
|