| from glob import glob |
| from tqdm import tqdm, trange |
| import numpy as np |
| import mmcv |
| import pickle |
| from sklearn.cluster import KMeans |
|
|
| def load_and_process_traj(ann_file): |
|
|
| data_infos = mmcv.load(ann_file, file_format='pkl') |
| ego_trajs = [] |
| map_locs = [] |
| |
| for data in tqdm(data_infos): |
| if np.sum(np.array(data["gt_fut_bbox_sdc_mask"][0, :8],dtype=np.float32))==8: |
| traj = data["gt_fut_bbox_sdc_lidar"][0, :8, :2] |
| ego_trajs.append(traj) |
| map_locs.append(data['map_location']) |
|
|
| |
| with open('/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/full_traj.pkl', 'wb') as writer: |
| pickle.dump([ego_trajs, map_locs], writer) |
|
|
| def process_kmeans_vocab(traj_file, anchors=4096, loc=''): |
| with open(traj_file, 'rb') as reader: |
| [traj_data, locs] = pickle.load(reader) |
| print(traj_data[0].shape) |
| if loc=='': |
| end_p = np.array([traj[-1,:2] for traj in traj_data]) |
| else: |
| end_p = [] |
| for traj,l in zip(traj_data, locs): |
| if l==loc: |
| end_p.append(traj[-1,:2]) |
| end_p = np.array(end_p) |
|
|
| print(end_p.shape) |
|
|
| kmeans = KMeans(n_clusters=anchors, verbose=True) |
| kmeans.fit(end_p) |
| print('fit end') |
| centroids = kmeans.cluster_centers_ |
| |
| with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_{anchors}.pkl', 'wb') as writer: |
| pickle.dump(centroids, writer) |
| |
| print('processing the representative trajs...') |
| rep_traj = [] |
| rep_loc = [] |
| for i in trange(centroids.shape[0]): |
| centroid = centroids[i] |
| dist_arg = np.argmin(np.linalg.norm(end_p - centroid[np.newaxis, :2], axis=1)) |
| rep_traj.append(traj_data[dist_arg]) |
| rep_loc.append(locs[dist_arg]) |
| |
| with open(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_{anchors}.pkl', 'wb') as writer: |
| pickle.dump([rep_traj, rep_loc], writer) |
|
|
| import matplotlib.pyplot as plt |
| def visualization(traj_file, kmeans_files, suffix='4096'): |
| with open(traj_file, 'rb') as reader: |
| [traj_data, loc_data] = pickle.load(reader) |
| |
| with open(kmeans_files, 'rb') as reader: |
| k_means_data = pickle.load(reader) |
| |
| sing_mask = np.array(loc_data)=='singapore' |
| plt.figure() |
| plt.scatter(k_means_data[:, 0], k_means_data[:, 1], c='orange', |
| marker='*', s=5, zorder=3) |
| plt.scatter(k_means_data[sing_mask, 0], k_means_data[sing_mask, 1], c='red', |
| marker='*', s=5, zorder=4) |
| |
| for traj in tqdm(traj_data): |
| plt.plot(traj[:, 0], traj[:, 1], color='navy', |
| alpha=0.6, zorder=1) |
| |
| plt.savefig(f'/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_{suffix}.png') |
|
|
|
|
|
|
| if __name__ == '__main__': |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| visualization( |
| '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_8192.pkl', |
| '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_8192.pkl', |
| '8192' |
| ) |
|
|
| visualization( |
| '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_traj_4096.pkl', |
| '/cpfs01/user/liuhaochen/AlgEngine_nuplan/traj_vocab/vocab_center_4096.pkl', |
| '4096' |
| ) |
| |
| |
|
|