import torch import torch.nn as nn from typing import List from einops import * from torch import einsum as ein from .unit_nn import SinusoidalPosEmb import numpy as np # ── helpers ────────────────────────────────────────────────────────────────── def sinusoidal_pos_emb(t: torch.Tensor, dim: int) -> torch.Tensor: """t: (B,) -> (B, dim)""" device = t.device half = dim // 2 freqs = torch.exp(-torch.arange(half, device=device) * (np.log(10000) / (half - 1))) args = t[:, None] * freqs[None] return torch.cat([args.sin(), args.cos()], dim=-1) class SinusoidalPosEmb(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, t: torch.Tensor) -> torch.Tensor: # accept (B,), (B,1), or (B,1,1) — always return (B, dim) t = t.view(t.shape[0]) return sinusoidal_pos_emb(t, self.dim) # ── MLP block ───────────────────────────────────────────────────────────────── # %% class MLP(nn.Module): def __init__(self, in_c, hidden_c, out_c, time_emb_dim): super().__init__() self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish()) self.net1 = nn.Sequential(nn.Conv1d(in_c, hidden_c, 1), nn.ReLU()) self.net2 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU()) self.net3 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU()) self.out = nn.Conv1d(hidden_c, out_c, 1) def forward(self, x, time_emb): h = self.net1(x) h = h + self.time_net(time_emb).unsqueeze(-1) h = self.net2(h) h = self.net3(h) return self.out(h) # class MLP(nn.Module): # def __init__(self, in_c: int, hidden_c: int, out_c: int, time_emb_dim: int): # super().__init__() # self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish()) # self.net1 = nn.Sequential(nn.Linear(in_c, hidden_c), nn.ReLU()) # self.net2 = nn.Linear(hidden_c, out_c) # def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor: # # x : (B, in_c, L) # # time_emb : (B, time_emb_dim) # x_t = x.transpose(1, 2) # (B, L, in_c) for Linear # out = self.net1(x_t) # (B, L, hidden_c) # out = out + self.time_net(time_emb).unsqueeze(1) # broadcast over L # out = self.net2(out) # (B, L, out_c) # return out.transpose(1, 2) # (B, out_c, L) # %% # ── Decoder ─────────────────────────────────────────────────────────────────── class Decoder(nn.Module): """ Lightweight MLP velocity estimator for toy 2-D flow-matching. Tensor contract --------------- forward(x, mu, t) -> vel x : (B, feat_dim, L) mu : (B, feat_dim, L) t : (B,) | (B,1) | (B,1,1) # all accepted vel : (B, feat_dim, L) """ def __init__( self, in_c: int = 2, hidden_dim: int = 128, out_c: int = 2, time_emb_dim: int = 64, cond_dim: int = 0, ): super().__init__() self.time_emb = SinusoidalPosEmb(time_emb_dim) self.time_mlp = nn.Sequential( nn.Linear(time_emb_dim, time_emb_dim), ) # concat(x, mu) along channel dim -> 2*feat_dim channels self.net = MLP( in_c=in_c * 2, hidden_c=hidden_dim, out_c=out_c, time_emb_dim=time_emb_dim ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0.0, 0.02) if m.bias is not None: nn.init.zeros_(m.bias) def forward( self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, cond=None, ) -> torch.Tensor: # normalise t to (B,) regardless of input shape t_flat = t.reshape(x.shape[0]) # (B,) t_emb = self.time_mlp(self.time_emb(t_flat)) # (B, time_emb_dim) # concat along channel axis (B, 2*feat_dim, L) xmu = torch.cat([x, mu], dim=1) return self.net(xmu, t_emb) # (B, feat_dim, L)