BadVLA / raw_data /prismatic /training /train_utils.py
czxlovesu03's picture
Upload folder using huggingface_hub
c075fba verified
"""Utils for training/fine-tuning scripts."""
import torch
from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX
def get_current_action_mask(token_ids):
# Create a tensor marking positions of IGNORE_INDEX
newline_positions = token_ids != IGNORE_INDEX
# Calculate cumulative sum to identify regions between newlines
cumsum = torch.cumsum(newline_positions, dim=1)
# Create the mask
mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
# Extract the action part only
action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
mask = action_tokens_only_mask * mask
return mask
def get_next_actions_mask(token_ids):
# Create a tensor marking positions of IGNORE_INDEX
newline_positions = token_ids != IGNORE_INDEX
# Calculate cumulative sum to identify regions between newlines
cumsum = torch.cumsum(newline_positions, dim=1)
# Create the mask
mask = cumsum > ACTION_DIM
# Extract the action part only
action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
mask = action_tokens_only_mask * mask
return mask
def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):
correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask
accuracy = correct_preds.sum().float() / mask.sum().float()
return accuracy
def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):
pred_continuous_actions = torch.tensor(
action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())
)
true_continuous_actions = torch.tensor(
action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())
)
l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)
return l1_loss