File size: 1,927 Bytes
cf587f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """Utils for training/fine-tuning scripts."""
import torch
from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX
def get_current_action_mask(token_ids):
# Create a tensor marking positions of IGNORE_INDEX
newline_positions = token_ids != IGNORE_INDEX
# Calculate cumulative sum to identify regions between newlines
cumsum = torch.cumsum(newline_positions, dim=1)
# Create the mask
mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
# Extract the action part only
action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
mask = action_tokens_only_mask * mask
return mask
def get_next_actions_mask(token_ids):
# Create a tensor marking positions of IGNORE_INDEX
newline_positions = token_ids != IGNORE_INDEX
# Calculate cumulative sum to identify regions between newlines
cumsum = torch.cumsum(newline_positions, dim=1)
# Create the mask
mask = cumsum > ACTION_DIM
# Extract the action part only
action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
mask = action_tokens_only_mask * mask
return mask
def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):
correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask
accuracy = correct_preds.sum().float() / mask.sum().float()
return accuracy
def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):
pred_continuous_actions = torch.tensor(
action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())
)
true_continuous_actions = torch.tensor(
action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())
)
l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)
return l1_loss
|