File size: 4,564 Bytes
4edc9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
from typing import List
from einops import *
from torch import einsum as ein
from .unit_nn import SinusoidalPosEmb
import numpy as np
# ── helpers ──────────────────────────────────────────────────────────────────


def sinusoidal_pos_emb(t: torch.Tensor, dim: int) -> torch.Tensor:
    """t: (B,) -> (B, dim)"""
    device = t.device
    half = dim // 2
    freqs = torch.exp(-torch.arange(half, device=device) * (np.log(10000) / (half - 1)))
    args = t[:, None] * freqs[None]
    return torch.cat([args.sin(), args.cos()], dim=-1)


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        # accept (B,), (B,1), or (B,1,1) β€” always return (B, dim)
        t = t.view(t.shape[0])
        return sinusoidal_pos_emb(t, self.dim)


# ── MLP block ─────────────────────────────────────────────────────────────────


# %%
class MLP(nn.Module):
    def __init__(self, in_c, hidden_c, out_c, time_emb_dim):
        super().__init__()
        self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish())
        self.net1 = nn.Sequential(nn.Conv1d(in_c, hidden_c, 1), nn.ReLU())
        self.net2 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU())
        self.net3 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU())
        self.out = nn.Conv1d(hidden_c, out_c, 1)

    def forward(self, x, time_emb):
        h = self.net1(x)
        h = h + self.time_net(time_emb).unsqueeze(-1)
        h = self.net2(h)
        h = self.net3(h)
        return self.out(h)


# class MLP(nn.Module):
#     def __init__(self, in_c: int, hidden_c: int, out_c: int, time_emb_dim: int):
#         super().__init__()
#         self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish())
#         self.net1 = nn.Sequential(nn.Linear(in_c, hidden_c), nn.ReLU())
#         self.net2 = nn.Linear(hidden_c, out_c)

#     def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
#         # x         : (B, in_c, L)
#         # time_emb  : (B, time_emb_dim)
#         x_t = x.transpose(1, 2)  # (B, L, in_c) for Linear
#         out = self.net1(x_t)  # (B, L, hidden_c)
#         out = out + self.time_net(time_emb).unsqueeze(1)  # broadcast over L
#         out = self.net2(out)  # (B, L, out_c)
#         return out.transpose(1, 2)  # (B, out_c, L)


# %%
# ── Decoder ───────────────────────────────────────────────────────────────────


class Decoder(nn.Module):
    """
    Lightweight MLP velocity estimator for toy 2-D flow-matching.

    Tensor contract
    ---------------
    forward(x, mu, t) -> vel
      x   : (B, feat_dim, L)
      mu  : (B, feat_dim, L)
      t   : (B,) | (B,1) | (B,1,1)   # all accepted
      vel : (B, feat_dim, L)
    """

    def __init__(
        self,
        in_c: int = 2,
        hidden_dim: int = 128,
        out_c: int = 2,
        time_emb_dim: int = 64,
        cond_dim: int = 0,
    ):
        super().__init__()
        self.time_emb = SinusoidalPosEmb(time_emb_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
        )
        # concat(x, mu) along channel dim -> 2*feat_dim channels
        self.net = MLP(
            in_c=in_c * 2, hidden_c=hidden_dim, out_c=out_c, time_emb_dim=time_emb_dim
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(
        self,
        x: torch.Tensor,
        mu: torch.Tensor,
        t: torch.Tensor,
        cond=None,
    ) -> torch.Tensor:
        # normalise t to (B,) regardless of input shape
        t_flat = t.reshape(x.shape[0])  # (B,)
        t_emb = self.time_mlp(self.time_emb(t_flat))  # (B, time_emb_dim)

        # concat along channel axis  (B, 2*feat_dim, L)
        xmu = torch.cat([x, mu], dim=1)

        return self.net(xmu, t_emb)  # (B, feat_dim, L)