| """Utils for training/fine-tuning scripts."""
|
|
|
| import torch
|
|
|
| from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX
|
|
|
|
|
| def get_current_action_mask(token_ids):
|
|
|
| newline_positions = token_ids != IGNORE_INDEX
|
|
|
|
|
| cumsum = torch.cumsum(newline_positions, dim=1)
|
|
|
|
|
| mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
|
|
|
|
|
| action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
|
| mask = action_tokens_only_mask * mask
|
|
|
| return mask
|
|
|
|
|
| def get_next_actions_mask(token_ids):
|
|
|
| newline_positions = token_ids != IGNORE_INDEX
|
|
|
|
|
| cumsum = torch.cumsum(newline_positions, dim=1)
|
|
|
|
|
| mask = cumsum > ACTION_DIM
|
|
|
|
|
| action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
|
| mask = action_tokens_only_mask * mask
|
|
|
| return mask
|
|
|
|
|
| def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):
|
| correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask
|
| accuracy = correct_preds.sum().float() / mask.sum().float()
|
| return accuracy
|
|
|
|
|
| def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):
|
| pred_continuous_actions = torch.tensor(
|
| action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())
|
| )
|
| true_continuous_actions = torch.tensor(
|
| action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())
|
| )
|
| l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)
|
| return l1_loss
|
|
|