| import torch |
| import numpy as np |
| from sklearn.cluster import KMeans |
|
|
| |
|
|
| |
| def get_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps, |
| top_k=-1, temperature=1, future_action_enc_out=None, |
| embed_dim=768, |
| **kwargs): |
| |
| |
| |
| |
| |
|
|
| |
| temperature = min(max(1e-6, temperature), 2) |
| prev_ci = np.argmin(np.sqrt(np.sum((centroids - last_s.cpu().numpy())**2, axis=1))) |
| result_embeds = torch.zeros(1, sample_steps, last_s.shape[-1]) |
| traj_log = 0 |
|
|
| |
| for ss in range(sample_steps): |
| |
| p = transit_p[prev_ci] |
| |
| |
| if future_action_enc_out is not None: |
| action_emb = future_action_enc_out[ss] |
| action_emb = action_emb.cpu().numpy() |
| centroids_action_emb = centroids[:, -future_action_enc_out.shape[-1]:] |
|
|
| |
| action_distance = np.linalg.norm(centroids_action_emb - action_emb[None, :], axis=-1) |
| action_distance = (action_distance - action_distance.min()) / (action_distance.max() - action_distance.min() + 1e-8) |
| p = p * (1 - action_distance) |
|
|
| |
| p = p ** (1.0 / temperature) |
|
|
| |
| if top_k > 0: |
| topk_idx = np.argsort(p)[-top_k:] |
| topk_p = p[topk_idx] |
| else: |
| topk_idx = np.arange(len(p)) |
| topk_p = p |
| topk_p = topk_p / topk_p.sum() |
|
|
| |
| new_cidx = np.random.choice(np.arange(len(topk_idx)), p=topk_p) |
| new_ci = topk_idx[new_cidx] |
|
|
| |
| |
| |
| traj_log += np.log(topk_p[new_cidx] + 1e-12) |
| |
| curr_scale = 0 if centroid_std.get(new_ci) is None else centroid_std[new_ci] |
| result_embeds[:, ss, :] = torch.from_numpy( |
| np.random.normal(loc=centroids[new_ci], scale=curr_scale) |
| ) |
| prev_ci = new_ci |
|
|
| return result_embeds.float().to(last_s.device), traj_log |
|
|
| def quantile_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps, |
| top_k=-1, temperature=1, num_traj=20, future_action_enc_out=None): |
| |
| num_traj = int(min(100, max(0, num_traj))) |
| result_embeds_traj_log = list() |
|
|
| |
| for _ in range(num_traj): |
| result_embeds, traj_log = get_traj_of_state(last_s, transit_p, centroids, centroid_std, |
| sample_steps, top_k=top_k, temperature=temperature, |
| future_action_enc_out=future_action_enc_out) |
| result_embeds_traj_log.append((result_embeds, traj_log)) |
| result_embeds_traj_log.sort(key=lambda x: x[1], reverse=True) |
| |
|
|
| |
| total_p = torch.tensor([t[1] for t in result_embeds_traj_log]).float().to(result_embeds_traj_log[0][0].device) |
| total_p = torch.softmax(total_p, 0)[:, None, None, None] |
| total_traj = torch.stack([t[0] for t in result_embeds_traj_log]) |
| return (total_traj * total_p).sum(dim=0) |
| |
|
|
| |
| def fit_observed_bayesian(observed_emebds, num_states=16, |
| original_knowledge=None, post_w=1.0, |
| ): |
| |
| |
| |
|
|
| |
| |
| |
| |
| reg_km = KMeans( |
| n_clusters=num_states, |
| random_state=42, |
| |
| n_init=1, |
| algorithm="elkan", |
| ).fit(observed_emebds) |
|
|
| regularized_centroids = reg_km.cluster_centers_ |
| observed_centroids = reg_km.labels_ |
| centroid_std = { |
| observed_centroids[-1]: [ |
| (observed_emebds[-1] - regularized_centroids[observed_centroids[-1]])**2, |
| 1 |
| ] |
| } |
|
|
| |
| if original_knowledge is not None: |
| original_transit, original_centroids = original_knowledge |
| closest_prior_centroids = np.sum((regularized_centroids[:, None, :]-original_centroids[None, :, :])**2, axis=-1) |
| closest_prior_centroids = np.argmin(closest_prior_centroids, axis=-1) |
| prior_transit = original_transit[closest_prior_centroids, :][:, closest_prior_centroids] |
| prior_transit_p = (prior_transit+1e-8) / ((prior_transit+1e-8).sum(axis=1, keepdims=True)) |
| else: |
| prior_transit_p, post_w = 0, 1.0 |
|
|
| |
| posterior_transit = np.zeros((num_states, num_states)) |
| for c_i in range(len(observed_centroids)-1): |
| curr_centoids_id = observed_centroids[c_i] |
|
|
| |
| posterior_transit[observed_centroids[c_i], observed_centroids[c_i+1]] += 1 |
|
|
| |
| if centroid_std.get(curr_centoids_id) is None: |
| centroid_std[curr_centoids_id] = [0, 0] |
| centroid_std[curr_centoids_id][0] += ((observed_emebds[c_i] - regularized_centroids[curr_centoids_id])**2) |
| centroid_std[curr_centoids_id][1] += 1 |
|
|
| |
| posterior_transit_p = (posterior_transit+1e-8) / ((posterior_transit+1e-8).sum(axis=-1, keepdims=True)) |
|
|
| |
| for std_k in centroid_std: |
| accum_centroids, centroid_num = centroid_std[std_k] |
| centroid_std[std_k] = np.sqrt(accum_centroids / centroid_num) |
|
|
| |
| regularized_transit_p = (post_w*posterior_transit_p) + ((1-post_w)*prior_transit_p) |
| regularized_transit_p = (regularized_transit_p+1e-8) / ((regularized_transit_p+1e-8).sum(axis=-1, keepdims=True)) |
|
|
|
|
| return regularized_transit_p, regularized_centroids, centroid_std |
|
|
|
|
| def bayesian_forecast(in_tensor, n_channels, physio_channels, |
| context_length=2048-16, pred_length=2048+16, |
| num_states=16, action_channels=[], |
| condition_bayes=False, num_traj_sampled=1, |
| latent_encoder=None, return_transit_matrix=False): |
|
|
| |
| end_idx = in_tensor.shape[-1] if len(action_channels) < 1 else context_length |
| with torch.no_grad(): |
| enc_out, ids_restore, masked_patches = latent_encoder.forward_encoder(in_tensor.clone()[:, :, :end_idx], masking=False) |
|
|
| |
| |
| |
| embed_dim = enc_out.shape[-1] |
| enc_out = enc_out.permute(1, 2, 0).flatten(start_dim=1) |
|
|
| |
| curr_num_state = min(num_states, len(enc_out)-1) |
|
|
| |
| bayesian_outpack = fit_observed_bayesian( |
| |
| |
| enc_out[:, :].cpu().numpy(), |
| num_states=curr_num_state, |
| |
| |
| post_w=1.0, |
| ) |
|
|
| |
| regularized_transit_p, regularized_centroids, centroid_std = bayesian_outpack[:3] |
|
|
|
|
| |
| future_action_enc_out = None |
| if len(action_channels) > 0: |
| with torch.no_grad(): |
| future_action_enc_out, _, _ = latent_encoder.forward_encoder(in_tensor.clone()[:, action_channels, :], masking=False) |
| context_npatchs = context_length // latent_encoder.patch_size |
| future_action_enc_out = future_action_enc_out.permute(1, 2, 0).flatten(start_dim=1)[context_npatchs:context_npatchs+(pred_length//latent_encoder.patch_size)+1, :] |
| appended_embeds = quantile_traj_of_state( |
| |
| enc_out[-1:, :], |
| regularized_transit_p, |
| regularized_centroids, |
| centroid_std, |
| (pred_length // latent_encoder.patch_size)+1, |
| top_k=curr_num_state, |
| temperature=1.0, |
| num_traj=num_traj_sampled, |
| future_action_enc_out=future_action_enc_out if condition_bayes else None, |
| ) |
|
|
| |
| |
| enc_with_append = torch.concatenate((enc_out, appended_embeds[0]), dim=0) |
| enc_with_append = enc_with_append.reshape(enc_with_append.shape[0], embed_dim, -1).permute(2, 0, 1) |
| dec_out = latent_encoder.forward_decoder(enc_with_append, ids_restore, masked_patches) |
| |
| dec_out = dec_out.reshape(dec_out.shape[0], -1) |
| bn_nvar, total_L = dec_out.shape |
| bayesian_out = dec_out.reshape(1, n_channels, total_L)[0, physio_channels, context_length:context_length+pred_length] |
|
|
| |
| if return_transit_matrix: |
| return bayesian_out, regularized_transit_p, regularized_centroids, centroid_std, enc_out |
| return bayesian_out |
|
|
|
|
| |
|
|