PyTorch
normwear2
custom_code
yunfeiluo commited on
Commit
b1e5e3a
·
verified ·
1 Parent(s): 27f677b

Upload latent_bayesian.py

Browse files
Files changed (1) hide show
  1. latent_bayesian.py +2 -1
latent_bayesian.py CHANGED
@@ -197,7 +197,8 @@ def bayesian_forecast(in_tensor, n_channels, physio_channels,
197
  if len(action_channels) > 0:
198
  with torch.no_grad():
199
  future_action_enc_out, _, _ = latent_encoder.forward_encoder(in_tensor.clone()[:, action_channels, :], masking=False)
200
- future_action_enc_out = future_action_enc_out.permute(1, 2, 0).flatten(start_dim=1)[1:1+(pred_length//latent_encoder.patch_size)+1, :] # sample_steps, embed_dim*bn*action_nvar
 
201
  appended_embeds = quantile_traj_of_state(
202
  # enc_out[0, -1, :],
203
  enc_out[-1:, :],
 
197
  if len(action_channels) > 0:
198
  with torch.no_grad():
199
  future_action_enc_out, _, _ = latent_encoder.forward_encoder(in_tensor.clone()[:, action_channels, :], masking=False)
200
+ context_npatchs = context_length // latent_encoder.patch_size
201
+ future_action_enc_out = future_action_enc_out.permute(1, 2, 0).flatten(start_dim=1)[context_npatchs:context_npatchs+(pred_length//latent_encoder.patch_size)+1, :] # sample_steps, embed_dim*bn*action_nvar
202
  appended_embeds = quantile_traj_of_state(
203
  # enc_out[0, -1, :],
204
  enc_out[-1:, :],