| """Utils for training/fine-tuning scripts.""" |
|
|
| import torch |
|
|
| from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX |
|
|
|
|
| def get_current_action_mask(token_ids): |
| |
| newline_positions = token_ids != IGNORE_INDEX |
|
|
| |
| cumsum = torch.cumsum(newline_positions, dim=1) |
|
|
| |
| mask = (1 <= cumsum) & (cumsum <= ACTION_DIM) |
|
|
| |
| action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX |
| mask = action_tokens_only_mask * mask |
|
|
| return mask |
|
|
|
|
| def get_next_actions_mask(token_ids): |
| |
| newline_positions = token_ids != IGNORE_INDEX |
|
|
| |
| cumsum = torch.cumsum(newline_positions, dim=1) |
|
|
| |
| mask = cumsum > ACTION_DIM |
|
|
| |
| action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX |
| mask = action_tokens_only_mask * mask |
|
|
| return mask |
|
|
|
|
| def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask): |
| correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask |
| accuracy = correct_preds.sum().float() / mask.sum().float() |
| return accuracy |
|
|
|
|
| def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask): |
| pred_continuous_actions = torch.tensor( |
| action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy()) |
| ) |
| true_continuous_actions = torch.tensor( |
| action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy()) |
| ) |
| l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions) |
| return l1_loss |
|
|