| import yaml |
| import matplotlib.pyplot as plt |
| import torch |
| import numpy as np |
| import os |
| from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
| from PIL import Image |
| from torchvision import transforms |
| import torchvision.transforms.functional as TF |
|
|
|
|
| IMAGE_ASPECT_RATIO = (4 / 3) |
|
|
| with open("config/data_config.yaml", "r") as f: |
| data_config = yaml.safe_load(f) |
|
|
|
|
| def get_action_torch(diffusion_output, action_stats): |
| ndeltas = diffusion_output |
| ndeltas = ndeltas.reshape(ndeltas.shape[0], -1, 2) |
| ndeltas = unnormalize_data(ndeltas, action_stats) |
| actions = torch.cumsum(ndeltas, dim=1) |
| return actions.to(ndeltas) |
|
|
| def log_viz_single(dataset_name, obs_image, goal_image, preds, deltas, loss, min_idx, actions, action_stats, plan_iter=0, output_dir='plot.png'): |
| ''' |
| Visualize a single instance |
| actions is gt actions |
| ''' |
| viz_obs_image = unnormalize(obs_image.detach().cpu())[-1] |
| viz_goal_image = unnormalize(goal_image.detach().cpu()) |
| deltas = deltas.detach().cpu() |
| loss = loss.detach().cpu() |
| actions = actions.detach().cpu() |
| pred_actions = get_action_torch(deltas[:, :, :2], action_stats) |
| plot_array = plot_images_and_actions(dataset_name, viz_obs_image, viz_goal_image, pred_actions, actions, min_idx, loss=loss) |
|
|
| plt.imshow(plot_array) |
| plt.axis('off') |
|
|
| |
| plt.savefig(output_dir, format='png', dpi=300, bbox_inches='tight') |
|
|
| def plot_images_and_actions(dataset_name, curr_viz_obs_image, curr_viz_goal_image, curr_viz_pred_actions, curr_viz_actions, min_idx, loss): |
| curr_viz_obs_image = curr_viz_obs_image.permute(1, 2, 0).cpu().numpy() |
| curr_viz_goal_image = curr_viz_goal_image.permute(1, 2, 0).cpu().numpy() |
|
|
| |
| curr_viz_pred_actions = curr_viz_pred_actions * data_config[dataset_name]['metric_waypoint_spacing'] |
| curr_viz_actions = curr_viz_actions * data_config[dataset_name]['metric_waypoint_spacing'] |
| |
| |
| fig, axs = plt.subplots(1, 3, figsize=(9, 3)) |
|
|
| |
| axs[0].imshow(curr_viz_obs_image) |
| axs[0].set_title("Condition Image", fontsize=13) |
| axs[0].axis("off") |
|
|
| |
| axs[1].imshow(curr_viz_goal_image) |
| axs[1].set_title("Goal Image", fontsize=13) |
| axs[1].axis("off") |
|
|
| colors = ['red', 'orange', 'cyan'] |
| for i in range(1, curr_viz_pred_actions.shape[0]): |
| color = colors[(i - 1) % len(colors)] |
| label = f"Sample {i} Min Loss" if i == min_idx.item() else f"{i}" |
|
|
| if i != min_idx.item(): |
| axs[2].plot(-curr_viz_pred_actions[i, :, 1], curr_viz_pred_actions[i, :, 0], |
| color=color, marker="o", markersize=5, label=label) |
| axs[2].text(-curr_viz_pred_actions[i, -1, 1], |
| curr_viz_pred_actions[i, -1, 0], |
| round(loss[i].item(), 3), |
| color='black', |
| fontsize=10, |
| ha='left', va='bottom') |
|
|
| |
| axs[2].plot(-curr_viz_pred_actions[min_idx.item(), :, 1], curr_viz_pred_actions[min_idx.item(), :, 0], |
| color='green', marker="o", markersize=5, label=f"{min_idx.item()}") |
| axs[2].text(-curr_viz_pred_actions[min_idx.item(), -1, 1], |
| curr_viz_pred_actions[min_idx.item(), -1, 0], |
| round(loss[min_idx.item()].item(), 3), |
| color='black', |
| fontsize=10, |
| ha='left', va='bottom') |
|
|
| |
| axs[2].plot(-curr_viz_actions[:, 1], curr_viz_actions[:, 0], color='blue', marker="o", label="GT") |
|
|
| |
| axs[2].set_title(" ", fontsize=13) |
| axs[2].set_xlabel("X (m)", fontsize=11) |
| axs[2].set_ylabel("Y (m)", fontsize=11) |
|
|
| |
| axs[2].set_aspect('equal', adjustable='box') |
| x_min, x_max = axs[2].get_xlim() |
| y_min, y_max = axs[2].get_ylim() |
| axis_range = max(x_max - x_min, y_max - y_min) / 2 |
| x_mid = (x_max + x_min) / 2 |
| y_mid = (y_max + y_min) / 2 |
| axs[2].set_xlim(x_mid - axis_range, x_mid + axis_range) |
| axs[2].set_ylim(y_mid - axis_range, y_mid + axis_range) |
|
|
| axs[2].legend(loc='lower left', fontsize=10, frameon=True, bbox_to_anchor=(0, 0)) |
| plt.tight_layout() |
|
|
| canvas = FigureCanvas(fig) |
| canvas.draw() |
| plot_array = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') |
| plot_array = plot_array.reshape(canvas.get_width_height()[::-1] + (3,)) |
| plt.close(fig) |
| return plot_array |
|
|
|
|
| def normalize_data(data, stats): |
| |
| ndata = (data - stats['min']) / (stats['max'] - stats['min']) |
| |
| ndata = ndata * 2 - 1 |
| return ndata |
|
|
| def unnormalize_data(ndata, stats): |
| ndata = (ndata + 1) / 2 |
| data = ndata * (stats['max'].to(ndata) - stats['min'].to(ndata)) + stats['min'].to(ndata) |
| return data |
|
|
| def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"): |
| data_ext = { |
| "image": ".jpg", |
| "audio": ".wav" |
| |
| } |
| return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}") |
|
|
| def yaw_rotmat(yaw: float) -> np.ndarray: |
| return np.array( |
| [ |
| [np.cos(yaw), -np.sin(yaw), 0.0], |
| [np.sin(yaw), np.cos(yaw), 0.0], |
| [0.0, 0.0, 1.0], |
| ], |
| ) |
|
|
| def angle_difference(theta1, theta2): |
| delta_theta = theta2 - theta1 |
| delta_theta = delta_theta - 2 * np.pi * np.floor((delta_theta + np.pi) / (2 * np.pi)) |
| return delta_theta |
|
|
| def get_delta_np(actions): |
| |
| ex_actions = np.concatenate((np.zeros((1, actions.shape[1])), actions), axis=0) |
| delta = ex_actions[1:] - ex_actions[:-1] |
| |
| return delta |
|
|
| def to_local_coords( |
| positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float |
| ) -> np.ndarray: |
| """ |
| Convert positions to local coordinates |
| |
| Args: |
| positions (np.ndarray): positions to convert |
| curr_pos (np.ndarray): current position |
| curr_yaw (float): current yaw |
| Returns: |
| np.ndarray: positions in local coordinates |
| """ |
| rotmat = yaw_rotmat(curr_yaw) |
| if positions.shape[-1] == 2: |
| rotmat = rotmat[:2, :2] |
| elif positions.shape[-1] == 3: |
| pass |
| else: |
| raise ValueError |
|
|
| return (positions - curr_pos).dot(rotmat) |
|
|
| def calculate_delta_yaw(unnorm_actions): |
| x = unnorm_actions[..., 0] |
| y = unnorm_actions[..., 1] |
| |
| yaw = torch.atan2(y, x).unsqueeze(-1) |
| delta_yaw = torch.cat((torch.zeros(yaw.shape[0], 1, yaw.shape[2]).to(yaw.device), yaw), dim=1) |
| delta_yaw = delta_yaw[:, 1:, :] - delta_yaw[:, :-1, :] |
| |
| return delta_yaw |
|
|
| def save_planning_pred(dataset_save_output_dir, B, idxs, obs_image, goal_image, preds, deltas, loss, gt_actions, plan_iter=0): |
| for batch_idx, idx in enumerate(idxs.flatten()): |
| sample_idx = int(idx) |
| sample_folder = os.path.join(dataset_save_output_dir, f'id_{sample_idx}') |
| os.makedirs(sample_folder, exist_ok=True) |
| |
| preds_save = { |
| 'obs_image': obs_image[batch_idx], |
| 'goal_image': goal_image[batch_idx], |
| 'preds': preds[batch_idx], |
| 'deltas': deltas[batch_idx], |
| 'loss': loss[batch_idx], |
| 'gt_actions': gt_actions[batch_idx], |
| } |
| preds_file = os.path.join(sample_folder, f"preds_{plan_iter}.pth") |
| torch.save(preds_save, preds_file) |
| |
| class CenterCropAR: |
| def __init__(self, ar: float = IMAGE_ASPECT_RATIO): |
| self.ar = ar |
|
|
| def __call__(self, img: Image.Image): |
| w, h = img.size |
| if w > h: |
| img = TF.center_crop(img, (h, int(h * self.ar))) |
| else: |
| img = TF.center_crop(img, (int(w / self.ar), w)) |
| return img |
|
|
| transform = transforms.Compose([ |
| CenterCropAR(), |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| ]) |
|
|
| unnormalize = transforms.Normalize( |
| mean=[-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5], |
| std=[1 / 0.5, 1 / 0.5, 1 / 0.5] |
| ) |
|
|