| import utils |
| import hydra |
| import torch |
| import einops |
| import numpy as np |
| from workspaces import base |
| from accelerate import Accelerator |
| from utils import get_split_idx |
|
|
| OBS_ELEMENT_INDICES = { |
| "block_translation": np.array([0, 1]), |
| "block2_translation": np.array([2, 3]), |
| "effector_translation": np.array([4, 5]), |
| "target_translation": np.array([6, 7]), |
| "target2_translation": np.array([8, 9]), |
| } |
|
|
| accelerator = Accelerator() |
|
|
|
|
| def calc_state_dist(a, b): |
| result = {} |
| for k, v in OBS_ELEMENT_INDICES.items(): |
| idx = torch.Tensor(v).long() |
| result[k] = ((a[idx] - b[idx]) ** 2).mean() |
| result["total"] = ((a - b) ** 2).mean() |
| return result |
|
|
|
|
| def mean_dicts(dicts): |
| result = {} |
| for k in dicts[0].keys(): |
| result[k] = np.mean([x[k] for x in dicts]) |
| return result |
|
|
|
|
| class BlockPushMultiviewWorkspace(base.Workspace): |
| def __init__(self, cfg, work_dir): |
| super().__init__(cfg, work_dir) |
|
|
| def _report_result_upon_completion(self, goal_idx=None): |
| return { |
| "entered": self.env.entered, |
| "moved": self.env.moved, |
| } |
|
|
| def run_offline_eval(self): |
| train_idx, val_idx = get_split_idx( |
| len(self.dataset), |
| self.cfg.seed, |
| train_fraction=self.cfg.train_fraction, |
| ) |
| embeddings = utils.inference.embed_trajectory_dataset( |
| self.encoder, self.dataset |
| ) |
| embeddings = [ |
| einops.rearrange(x, "T V E -> T (V E)") for x in embeddings |
| ] |
| states = [] |
| |
| state_subset_idx = [0, 1, 3, 4, 6, 7, 10, 11, 13, 14] |
| if self.accelerator.is_main_process: |
| states = [] |
| actions = [] |
| for i in range(len(self.dataset)): |
| T = self.dataset.get_seq_length(i) |
| state = self.dataset.states[i, :T] |
| state = state[:, state_subset_idx] |
| states.append(state) |
| actions.append(self.dataset.actions[i, :T]) |
| embd_state_linear_probe_results = ( |
| utils.inference.linear_probe_with_trajectory_split( |
| embeddings, |
| states, |
| train_idx, |
| val_idx, |
| ) |
| ) |
| |
| embd_state_linear_probe_results = { |
| f"embd_state_{k}": v for k, v in embd_state_linear_probe_results.items() |
| } |
| embd_action_linear_probe_results = ( |
| utils.inference.linear_probe_with_trajectory_split( |
| embeddings, |
| actions, |
| train_idx, |
| val_idx, |
| ) |
| ) |
| embd_action_linear_probe_results = { |
| f"embd_action_{k}": v |
| for k, v in embd_action_linear_probe_results.items() |
| } |
|
|
| state_dists = [] |
| N = 200 |
| rng = np.random.default_rng(self.cfg.seed) |
| for i in range(N): |
| query_traj_idx = rng.choice(len(self.dataset)) |
| query_frame_idx = rng.choice( |
| range(10, self.dataset.get_seq_length(query_traj_idx)) |
| ) |
| query_embedding = embeddings[query_traj_idx][query_frame_idx] |
| query_frame_state = self.dataset.states[ |
| query_traj_idx, query_frame_idx, state_subset_idx |
| ] |
|
|
| pool_embeddings = torch.cat( |
| [x for i, x in enumerate(embeddings) if i != query_traj_idx] |
| ) |
| pool_states = torch.cat( |
| [x for i, x in enumerate(states) if i != query_traj_idx] |
| ) |
| _, nn_idx = utils.inference.batch_knn( |
| query_embedding.unsqueeze(0), |
| pool_embeddings, |
| metric=utils.inference.mse, |
| k=1, |
| batch_size=1, |
| ) |
| closest_frame_state = pool_states[nn_idx[0, 0]] |
| state_dist = calc_state_dist(query_frame_state, closest_frame_state) |
| state_dists.append(state_dist) |
| mean_state_dist = mean_dicts(state_dists) |
| return { |
| **embd_state_linear_probe_results, |
| **embd_action_linear_probe_results, |
| **mean_state_dist, |
| } |
| else: |
| return None |
|
|