| import torch |
| import torch.nn as nn |
| from typing import List |
| from einops import * |
| from torch import einsum as ein |
| from .unit_nn import SinusoidalPosEmb |
| import numpy as np |
| |
|
|
|
|
| def sinusoidal_pos_emb(t: torch.Tensor, dim: int) -> torch.Tensor: |
| """t: (B,) -> (B, dim)""" |
| device = t.device |
| half = dim // 2 |
| freqs = torch.exp(-torch.arange(half, device=device) * (np.log(10000) / (half - 1))) |
| args = t[:, None] * freqs[None] |
| return torch.cat([args.sin(), args.cos()], dim=-1) |
|
|
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| |
| t = t.view(t.shape[0]) |
| return sinusoidal_pos_emb(t, self.dim) |
|
|
|
|
| |
|
|
|
|
| |
| class MLP(nn.Module): |
| def __init__(self, in_c, hidden_c, out_c, time_emb_dim): |
| super().__init__() |
| self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish()) |
| self.net1 = nn.Sequential(nn.Conv1d(in_c, hidden_c, 1), nn.ReLU()) |
| self.net2 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU()) |
| self.net3 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU()) |
| self.out = nn.Conv1d(hidden_c, out_c, 1) |
|
|
| def forward(self, x, time_emb): |
| h = self.net1(x) |
| h = h + self.time_net(time_emb).unsqueeze(-1) |
| h = self.net2(h) |
| h = self.net3(h) |
| return self.out(h) |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
|
|
|
|
| class Decoder(nn.Module): |
| """ |
| Lightweight MLP velocity estimator for toy 2-D flow-matching. |
| |
| Tensor contract |
| --------------- |
| forward(x, mu, t) -> vel |
| x : (B, feat_dim, L) |
| mu : (B, feat_dim, L) |
| t : (B,) | (B,1) | (B,1,1) # all accepted |
| vel : (B, feat_dim, L) |
| """ |
|
|
| def __init__( |
| self, |
| in_c: int = 2, |
| hidden_dim: int = 128, |
| out_c: int = 2, |
| time_emb_dim: int = 64, |
| cond_dim: int = 0, |
| ): |
| super().__init__() |
| self.time_emb = SinusoidalPosEmb(time_emb_dim) |
| self.time_mlp = nn.Sequential( |
| nn.Linear(time_emb_dim, time_emb_dim), |
| ) |
| |
| self.net = MLP( |
| in_c=in_c * 2, hidden_c=hidden_dim, out_c=out_c, time_emb_dim=time_emb_dim |
| ) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, 0.0, 0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| mu: torch.Tensor, |
| t: torch.Tensor, |
| cond=None, |
| ) -> torch.Tensor: |
| |
| t_flat = t.reshape(x.shape[0]) |
| t_emb = self.time_mlp(self.time_emb(t_flat)) |
|
|
| |
| xmu = torch.cat([x, mu], dim=1) |
|
|
| return self.net(xmu, t_emb) |
|
|