Upload train_jit.py with huggingface_hub
Browse files- train_jit.py +5 -20
train_jit.py
CHANGED
|
@@ -185,16 +185,13 @@ def build_jit(cfg: JiTTrainConfig) -> JiT:
|
|
| 185 |
|
| 186 |
class CFMFlowWrapper(nn.Module):
|
| 187 |
"""
|
| 188 |
-
torchdyn NeuralODE expects f(t, x) with
|
| 189 |
-
|
| 190 |
-
- x_pred mode (JiT denoiser): JiT predicts x1 (clean); v = (x_pred - x) / (1-t).
|
| 191 |
"""
|
| 192 |
|
| 193 |
-
def __init__(self, model: JiT
|
| 194 |
super().__init__()
|
| 195 |
self.model = model
|
| 196 |
-
self.prediction_mode = "x_pred"
|
| 197 |
-
self.t_eps = t_eps
|
| 198 |
|
| 199 |
def forward(self, t: torch.Tensor, x: torch.Tensor, y=None, *args, **kwargs) -> torch.Tensor:
|
| 200 |
batch = x.shape[0]
|
|
@@ -203,12 +200,6 @@ class CFMFlowWrapper(nn.Module):
|
|
| 203 |
t_flat = t_flat.expand(batch)
|
| 204 |
elif t_flat.shape[0] != batch:
|
| 205 |
t_flat = t_flat[:batch]
|
| 206 |
-
|
| 207 |
-
if self.prediction_mode == "x_pred":
|
| 208 |
-
x_pred = self.model(x, t_flat)
|
| 209 |
-
one_minus_t = (1.0 - t_flat).clamp(min=self.t_eps)
|
| 210 |
-
t_bc = one_minus_t.reshape(-1, *([1] * (x.dim() - 1)))
|
| 211 |
-
return (x_pred - x) / t_bc
|
| 212 |
return self.model(x, t_flat)
|
| 213 |
|
| 214 |
|
|
@@ -277,14 +268,8 @@ def main() -> None:
|
|
| 277 |
x0 = torch.randn_like(x1)
|
| 278 |
t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1)
|
| 279 |
t_b = t.reshape(-1).float()
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
x_pred = net_model(xt, t_b)
|
| 283 |
-
one_minus_t = (1.0 - t_b).clamp(min=1.0e-5)
|
| 284 |
-
t_bc = one_minus_t.reshape(-1, *([1] * (xt.dim() - 1)))
|
| 285 |
-
v_pred = (x_pred - xt) / t_bc
|
| 286 |
-
v_target = (x1 - xt) / t_bc
|
| 287 |
-
loss = torch.mean((v_target - v_pred) ** 2)
|
| 288 |
|
| 289 |
optim.zero_grad(set_to_none=True)
|
| 290 |
loss.backward()
|
|
|
|
| 185 |
|
| 186 |
class CFMFlowWrapper(nn.Module):
|
| 187 |
"""
|
| 188 |
+
torchdyn NeuralODE expects f(t, x) with same batch as x.
|
| 189 |
+
JiT is forward(x, t) with t shape (N,).
|
|
|
|
| 190 |
"""
|
| 191 |
|
| 192 |
+
def __init__(self, model: JiT):
|
| 193 |
super().__init__()
|
| 194 |
self.model = model
|
|
|
|
|
|
|
| 195 |
|
| 196 |
def forward(self, t: torch.Tensor, x: torch.Tensor, y=None, *args, **kwargs) -> torch.Tensor:
|
| 197 |
batch = x.shape[0]
|
|
|
|
| 200 |
t_flat = t_flat.expand(batch)
|
| 201 |
elif t_flat.shape[0] != batch:
|
| 202 |
t_flat = t_flat[:batch]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
return self.model(x, t_flat)
|
| 204 |
|
| 205 |
|
|
|
|
| 268 |
x0 = torch.randn_like(x1)
|
| 269 |
t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1)
|
| 270 |
t_b = t.reshape(-1).float()
|
| 271 |
+
vt = net_model(xt, t_b)
|
| 272 |
+
loss = torch.mean((vt - ut) ** 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
optim.zero_grad(set_to_none=True)
|
| 275 |
loss.backward()
|