haiphamcse commited on
Commit
f74d123
·
verified ·
1 Parent(s): b7338c6

Upload train_jit.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_jit.py +5 -20
train_jit.py CHANGED
@@ -185,16 +185,13 @@ def build_jit(cfg: JiTTrainConfig) -> JiT:
185
 
186
  class CFMFlowWrapper(nn.Module):
187
  """
188
- torchdyn NeuralODE expects f(t, x) with dx/dt returned.
189
- - velocity mode: JiT predicts v directly return model(x, t).
190
- - x_pred mode (JiT denoiser): JiT predicts x1 (clean); v = (x_pred - x) / (1-t).
191
  """
192
 
193
- def __init__(self, model: JiT, prediction_mode: str = "velocity", t_eps: float = 1e-5):
194
  super().__init__()
195
  self.model = model
196
- self.prediction_mode = "x_pred"
197
- self.t_eps = t_eps
198
 
199
  def forward(self, t: torch.Tensor, x: torch.Tensor, y=None, *args, **kwargs) -> torch.Tensor:
200
  batch = x.shape[0]
@@ -203,12 +200,6 @@ class CFMFlowWrapper(nn.Module):
203
  t_flat = t_flat.expand(batch)
204
  elif t_flat.shape[0] != batch:
205
  t_flat = t_flat[:batch]
206
-
207
- if self.prediction_mode == "x_pred":
208
- x_pred = self.model(x, t_flat)
209
- one_minus_t = (1.0 - t_flat).clamp(min=self.t_eps)
210
- t_bc = one_minus_t.reshape(-1, *([1] * (x.dim() - 1)))
211
- return (x_pred - x) / t_bc
212
  return self.model(x, t_flat)
213
 
214
 
@@ -277,14 +268,8 @@ def main() -> None:
277
  x0 = torch.randn_like(x1)
278
  t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1)
279
  t_b = t.reshape(-1).float()
280
- # vt = net_model(xt, t_b)
281
- # loss = torch.mean((vt - ut) ** 2)
282
- x_pred = net_model(xt, t_b)
283
- one_minus_t = (1.0 - t_b).clamp(min=1.0e-5)
284
- t_bc = one_minus_t.reshape(-1, *([1] * (xt.dim() - 1)))
285
- v_pred = (x_pred - xt) / t_bc
286
- v_target = (x1 - xt) / t_bc
287
- loss = torch.mean((v_target - v_pred) ** 2)
288
 
289
  optim.zero_grad(set_to_none=True)
290
  loss.backward()
 
185
 
186
  class CFMFlowWrapper(nn.Module):
187
  """
188
+ torchdyn NeuralODE expects f(t, x) with same batch as x.
189
+ JiT is forward(x, t) with t shape (N,).
 
190
  """
191
 
192
+ def __init__(self, model: JiT):
193
  super().__init__()
194
  self.model = model
 
 
195
 
196
  def forward(self, t: torch.Tensor, x: torch.Tensor, y=None, *args, **kwargs) -> torch.Tensor:
197
  batch = x.shape[0]
 
200
  t_flat = t_flat.expand(batch)
201
  elif t_flat.shape[0] != batch:
202
  t_flat = t_flat[:batch]
 
 
 
 
 
 
203
  return self.model(x, t_flat)
204
 
205
 
 
268
  x0 = torch.randn_like(x1)
269
  t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1)
270
  t_b = t.reshape(-1).float()
271
+ vt = net_model(xt, t_b)
272
+ loss = torch.mean((vt - ut) ** 2)
 
 
 
 
 
 
273
 
274
  optim.zero_grad(set_to_none=True)
275
  loss.backward()