import torch import numpy as np from sklearn.cluster import KMeans ################## Bayesian Functions Start ######################################################################## # helper function for determining state based on transit matrix def get_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps, top_k=-1, temperature=1, future_action_enc_out=None, embed_dim=768, **kwargs): # last_s: 1, embed_dim*bn*nvar # centroids: num_centroids, embed_dim*bn*nvar # future_action_enc_out: sample_steps, embed_dim*bn*action_nvar # action_nvar < nvar, (action_nvar+phyio_nvar = nvar) # currently, bn is always 1. # init temperature = min(max(1e-6, temperature), 2) prev_ci = np.argmin(np.sqrt(np.sum((centroids - last_s.cpu().numpy())**2, axis=1))) result_embeds = torch.zeros(1, sample_steps, last_s.shape[-1]) traj_log = 0 # generate across target steps for ss in range(sample_steps): # raw sampling p = transit_p[prev_ci] # up-weight the transition where the transited state representation is closer to the next step of the action. if future_action_enc_out is not None: action_emb = future_action_enc_out[ss] # embed_dim*bn*action_nvar action_emb = action_emb.cpu().numpy() centroids_action_emb = centroids[:, -future_action_enc_out.shape[-1]:] # num_centroids, embed_dim*bn*action_nvar # compute distance, then apply min-max normalization. action_distance = np.linalg.norm(centroids_action_emb - action_emb[None, :], axis=-1) # num_centroids action_distance = (action_distance - action_distance.min()) / (action_distance.max() - action_distance.min() + 1e-8) # minmax norm p = p * (1 - action_distance) # upweight the transition to states whose representation is more similar to the future action. # apply temperature p = p ** (1.0 / temperature) # use top k token if top_k > 0: topk_idx = np.argsort(p)[-top_k:] topk_p = p[topk_idx] else: # use all token topk_idx = np.arange(len(p)) topk_p = p topk_p = topk_p / topk_p.sum() # make sure p sum to 1 # sampling new_cidx = np.random.choice(np.arange(len(topk_idx)), p=topk_p) # sampling step new_ci = topk_idx[new_cidx] # update # print(centroids.shape, centroid_std.keys()) # exit() traj_log += np.log(topk_p[new_cidx] + 1e-12) # result_embeds[:, ss, :] = torch.from_numpy(centroids[new_ci]) curr_scale = 0 if centroid_std.get(new_ci) is None else centroid_std[new_ci] # means there are less number of clusters than actual desired number of clusters result_embeds[:, ss, :] = torch.from_numpy( np.random.normal(loc=centroids[new_ci], scale=curr_scale) ) prev_ci = new_ci return result_embeds.float().to(last_s.device), traj_log # 1, 2048, dim def quantile_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps, top_k=-1, temperature=1, num_traj=20, future_action_enc_out=None): # initialize traj list num_traj = int(min(100, max(0, num_traj))) result_embeds_traj_log = list() # result_embeds, traj_log # repeat for num_traj times for _ in range(num_traj): result_embeds, traj_log = get_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps, top_k=top_k, temperature=temperature, future_action_enc_out=future_action_enc_out) result_embeds_traj_log.append((result_embeds, traj_log)) result_embeds_traj_log.sort(key=lambda x: x[1], reverse=True) # return result_embeds_traj_log[0][0] # 1, 2048, dim # fuse each sampled traj, weighted by their total energy total_p = torch.tensor([t[1] for t in result_embeds_traj_log]).float().to(result_embeds_traj_log[0][0].device) total_p = torch.softmax(total_p, 0)[:, None, None, None] total_traj = torch.stack([t[0] for t in result_embeds_traj_log]) # num_traj, 1, 2048, dim return (total_traj * total_p).sum(dim=0) # 1, 2048, dim # return total_traj.mean(dim=0) # 1, 2048, dim # Helper function to fit new bayesian def fit_observed_bayesian(observed_emebds, num_states=16, original_knowledge=None, post_w=1.0, ): # observed_emebds: N, embed_dim # original_knowledge: (original_transit, original_centroids), ((3600, 3600), (3600, 768)) # return: regularized_transit_p, regularized_centroids # only cluster based on physio channels, ignore action channels. # when physio_channels are introduced, fit only on physio channels # because we want to regularize the transit matrix based on physio states, # and action channels may introduce extra noise for clustering. reg_km = KMeans( n_clusters=num_states, random_state=42, # n_init=10 n_init=1, algorithm="elkan", ).fit(observed_emebds) regularized_centroids = reg_km.cluster_centers_ # num_states, 768 observed_centroids = reg_km.labels_ # N centroid_std = { observed_centroids[-1]: [ (observed_emebds[-1] - regularized_centroids[observed_centroids[-1]])**2, 1 # counter ] } # identify prior transit if original_knowledge is not None: original_transit, original_centroids = original_knowledge closest_prior_centroids = np.sum((regularized_centroids[:, None, :]-original_centroids[None, :, :])**2, axis=-1) closest_prior_centroids = np.argmin(closest_prior_centroids, axis=-1) # num_states prior_transit = original_transit[closest_prior_centroids, :][:, closest_prior_centroids] # num_states, num_states prior_transit_p = (prior_transit+1e-8) / ((prior_transit+1e-8).sum(axis=1, keepdims=True)) else: prior_transit_p, post_w = 0, 1.0 # fit expected bayesian transit matrix posterior_transit = np.zeros((num_states, num_states)) for c_i in range(len(observed_centroids)-1): curr_centoids_id = observed_centroids[c_i] # update transit matrix posterior_transit[observed_centroids[c_i], observed_centroids[c_i+1]] += 1 # update std stats if centroid_std.get(curr_centoids_id) is None: centroid_std[curr_centoids_id] = [0, 0] centroid_std[curr_centoids_id][0] += ((observed_emebds[c_i] - regularized_centroids[curr_centoids_id])**2) centroid_std[curr_centoids_id][1] += 1 # compute posterior probability posterior_transit_p = (posterior_transit+1e-8) / ((posterior_transit+1e-8).sum(axis=-1, keepdims=True)) # clean up std for std_k in centroid_std: accum_centroids, centroid_num = centroid_std[std_k] centroid_std[std_k] = np.sqrt(accum_centroids / centroid_num) # aggregate regularized_transit_p = (post_w*posterior_transit_p) + ((1-post_w)*prior_transit_p) regularized_transit_p = (regularized_transit_p+1e-8) / ((regularized_transit_p+1e-8).sum(axis=-1, keepdims=True)) return regularized_transit_p, regularized_centroids, centroid_std def bayesian_forecast(in_tensor, n_channels, physio_channels, context_length=2048-16, pred_length=2048+16, num_states=16, action_channels=[], condition_bayes=False, num_traj_sampled=1, latent_encoder=None, return_transit_matrix=False): # in_tensor: 1, nvar, length end_idx = in_tensor.shape[-1] if len(action_channels) < 1 else context_length with torch.no_grad(): enc_out, ids_restore, masked_patches = latent_encoder.forward_encoder(in_tensor.clone()[:, :, :end_idx], masking=False) # regularize transit and centroid then forecast embed_dim = enc_out.shape[-1] enc_out = enc_out.permute(1, 2, 0).flatten(start_dim=1) # L//patch_size + 1, embed_dim*bn*nvar # adjust num state curr_num_state = min(num_states, len(enc_out)-1) # fit bayesian bayesian_outpack = fit_observed_bayesian( # enc_out[0, 1:, :].cpu().numpy(), # enc_out[1:, :].cpu().numpy(), enc_out[:, :].cpu().numpy(), num_states=curr_num_state, # num_states=(context_length // 16) // 2, # original_knowledge=(transit_matrix, centroids), post_w=1.0, ) # extract core info regularized_transit_p, regularized_centroids, centroid_std = bayesian_outpack[:3] # regularized_transit_p, regularized_centroids = regularize_transit_centroids(enc_out[0, 1:, :].cpu().numpy(), transit_matrix, centroids) future_action_enc_out = None if len(action_channels) > 0: with torch.no_grad(): future_action_enc_out, _, _ = latent_encoder.forward_encoder(in_tensor.clone()[:, action_channels, :], masking=False) context_npatchs = context_length // latent_encoder.patch_size future_action_enc_out = future_action_enc_out.permute(1, 2, 0).flatten(start_dim=1)[context_npatchs:context_npatchs+(pred_length//latent_encoder.patch_size)+1, :] # sample_steps, embed_dim*bn*action_nvar appended_embeds = quantile_traj_of_state( # enc_out[0, -1, :], enc_out[-1:, :], regularized_transit_p, regularized_centroids, centroid_std, (pred_length // latent_encoder.patch_size)+1, top_k=curr_num_state, temperature=1.0, num_traj=num_traj_sampled, # maybe increase this later future_action_enc_out=future_action_enc_out if condition_bayes else None, ) # 1, pred_length//patch_size, dim # decoding # enc_with_append = torch.concatenate((enc_out, appended_embeds), dim=1) # bn*nvar, L//patch_size + 1 + pred_length//patch_size, embed_dim enc_with_append = torch.concatenate((enc_out, appended_embeds[0]), dim=0) # L//patch_size + 1 + pred_length//patch_size, embed_dim*bn*nvar enc_with_append = enc_with_append.reshape(enc_with_append.shape[0], embed_dim, -1).permute(2, 0, 1) # bn*nvar, L//patch_size + 1 + pred_length//patch_size, embed_dim dec_out = latent_encoder.forward_decoder(enc_with_append, ids_restore, masked_patches) # bn*nvar, L # dec_out = torch.concatenate((enc_out, appended_embeds), dim=1) dec_out = dec_out.reshape(dec_out.shape[0], -1) bn_nvar, total_L = dec_out.shape bayesian_out = dec_out.reshape(1, n_channels, total_L)[0, physio_channels, context_length:context_length+pred_length] # bn, nvar, pred_length # return if return_transit_matrix: return bayesian_out, regularized_transit_p, regularized_centroids, centroid_std, enc_out return bayesian_out # bn, nvar, pred_length ################## Bayesian Functions End ########################################################################