PyTorch
normwear2
custom_code
normwear2 / latent_bayesian.py
yunfeiluo's picture
Upload latent_bayesian.py
b1e5e3a verified
import torch
import numpy as np
from sklearn.cluster import KMeans
################## Bayesian Functions Start ########################################################################
# helper function for determining state based on transit matrix
def get_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps,
top_k=-1, temperature=1, future_action_enc_out=None,
embed_dim=768,
**kwargs):
# last_s: 1, embed_dim*bn*nvar
# centroids: num_centroids, embed_dim*bn*nvar
# future_action_enc_out: sample_steps, embed_dim*bn*action_nvar
# action_nvar < nvar, (action_nvar+phyio_nvar = nvar)
# currently, bn is always 1.
# init
temperature = min(max(1e-6, temperature), 2)
prev_ci = np.argmin(np.sqrt(np.sum((centroids - last_s.cpu().numpy())**2, axis=1)))
result_embeds = torch.zeros(1, sample_steps, last_s.shape[-1])
traj_log = 0
# generate across target steps
for ss in range(sample_steps):
# raw sampling
p = transit_p[prev_ci]
# up-weight the transition where the transited state representation is closer to the next step of the action.
if future_action_enc_out is not None:
action_emb = future_action_enc_out[ss] # embed_dim*bn*action_nvar
action_emb = action_emb.cpu().numpy()
centroids_action_emb = centroids[:, -future_action_enc_out.shape[-1]:] # num_centroids, embed_dim*bn*action_nvar
# compute distance, then apply min-max normalization.
action_distance = np.linalg.norm(centroids_action_emb - action_emb[None, :], axis=-1) # num_centroids
action_distance = (action_distance - action_distance.min()) / (action_distance.max() - action_distance.min() + 1e-8) # minmax norm
p = p * (1 - action_distance) # upweight the transition to states whose representation is more similar to the future action.
# apply temperature
p = p ** (1.0 / temperature)
# use top k token
if top_k > 0:
topk_idx = np.argsort(p)[-top_k:]
topk_p = p[topk_idx]
else: # use all token
topk_idx = np.arange(len(p))
topk_p = p
topk_p = topk_p / topk_p.sum() # make sure p sum to 1
# sampling
new_cidx = np.random.choice(np.arange(len(topk_idx)), p=topk_p) # sampling step
new_ci = topk_idx[new_cidx]
# update
# print(centroids.shape, centroid_std.keys())
# exit()
traj_log += np.log(topk_p[new_cidx] + 1e-12)
# result_embeds[:, ss, :] = torch.from_numpy(centroids[new_ci])
curr_scale = 0 if centroid_std.get(new_ci) is None else centroid_std[new_ci] # means there are less number of clusters than actual desired number of clusters
result_embeds[:, ss, :] = torch.from_numpy(
np.random.normal(loc=centroids[new_ci], scale=curr_scale)
)
prev_ci = new_ci
return result_embeds.float().to(last_s.device), traj_log # 1, 2048, dim
def quantile_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps,
top_k=-1, temperature=1, num_traj=20, future_action_enc_out=None):
# initialize traj list
num_traj = int(min(100, max(0, num_traj)))
result_embeds_traj_log = list() # result_embeds, traj_log
# repeat for num_traj times
for _ in range(num_traj):
result_embeds, traj_log = get_traj_of_state(last_s, transit_p, centroids, centroid_std,
sample_steps, top_k=top_k, temperature=temperature,
future_action_enc_out=future_action_enc_out)
result_embeds_traj_log.append((result_embeds, traj_log))
result_embeds_traj_log.sort(key=lambda x: x[1], reverse=True)
# return result_embeds_traj_log[0][0] # 1, 2048, dim
# fuse each sampled traj, weighted by their total energy
total_p = torch.tensor([t[1] for t in result_embeds_traj_log]).float().to(result_embeds_traj_log[0][0].device)
total_p = torch.softmax(total_p, 0)[:, None, None, None]
total_traj = torch.stack([t[0] for t in result_embeds_traj_log]) # num_traj, 1, 2048, dim
return (total_traj * total_p).sum(dim=0) # 1, 2048, dim
# return total_traj.mean(dim=0) # 1, 2048, dim
# Helper function to fit new bayesian
def fit_observed_bayesian(observed_emebds, num_states=16,
original_knowledge=None, post_w=1.0,
):
# observed_emebds: N, embed_dim
# original_knowledge: (original_transit, original_centroids), ((3600, 3600), (3600, 768))
# return: regularized_transit_p, regularized_centroids
# only cluster based on physio channels, ignore action channels.
# when physio_channels are introduced, fit only on physio channels
# because we want to regularize the transit matrix based on physio states,
# and action channels may introduce extra noise for clustering.
reg_km = KMeans(
n_clusters=num_states,
random_state=42,
# n_init=10
n_init=1,
algorithm="elkan",
).fit(observed_emebds)
regularized_centroids = reg_km.cluster_centers_ # num_states, 768
observed_centroids = reg_km.labels_ # N
centroid_std = {
observed_centroids[-1]: [
(observed_emebds[-1] - regularized_centroids[observed_centroids[-1]])**2,
1 # counter
]
}
# identify prior transit
if original_knowledge is not None:
original_transit, original_centroids = original_knowledge
closest_prior_centroids = np.sum((regularized_centroids[:, None, :]-original_centroids[None, :, :])**2, axis=-1)
closest_prior_centroids = np.argmin(closest_prior_centroids, axis=-1) # num_states
prior_transit = original_transit[closest_prior_centroids, :][:, closest_prior_centroids] # num_states, num_states
prior_transit_p = (prior_transit+1e-8) / ((prior_transit+1e-8).sum(axis=1, keepdims=True))
else:
prior_transit_p, post_w = 0, 1.0
# fit expected bayesian transit matrix
posterior_transit = np.zeros((num_states, num_states))
for c_i in range(len(observed_centroids)-1):
curr_centoids_id = observed_centroids[c_i]
# update transit matrix
posterior_transit[observed_centroids[c_i], observed_centroids[c_i+1]] += 1
# update std stats
if centroid_std.get(curr_centoids_id) is None:
centroid_std[curr_centoids_id] = [0, 0]
centroid_std[curr_centoids_id][0] += ((observed_emebds[c_i] - regularized_centroids[curr_centoids_id])**2)
centroid_std[curr_centoids_id][1] += 1
# compute posterior probability
posterior_transit_p = (posterior_transit+1e-8) / ((posterior_transit+1e-8).sum(axis=-1, keepdims=True))
# clean up std
for std_k in centroid_std:
accum_centroids, centroid_num = centroid_std[std_k]
centroid_std[std_k] = np.sqrt(accum_centroids / centroid_num)
# aggregate
regularized_transit_p = (post_w*posterior_transit_p) + ((1-post_w)*prior_transit_p)
regularized_transit_p = (regularized_transit_p+1e-8) / ((regularized_transit_p+1e-8).sum(axis=-1, keepdims=True))
return regularized_transit_p, regularized_centroids, centroid_std
def bayesian_forecast(in_tensor, n_channels, physio_channels,
context_length=2048-16, pred_length=2048+16,
num_states=16, action_channels=[],
condition_bayes=False, num_traj_sampled=1,
latent_encoder=None, return_transit_matrix=False):
# in_tensor: 1, nvar, length
end_idx = in_tensor.shape[-1] if len(action_channels) < 1 else context_length
with torch.no_grad():
enc_out, ids_restore, masked_patches = latent_encoder.forward_encoder(in_tensor.clone()[:, :, :end_idx], masking=False)
# regularize transit and centroid then forecast
embed_dim = enc_out.shape[-1]
enc_out = enc_out.permute(1, 2, 0).flatten(start_dim=1) # L//patch_size + 1, embed_dim*bn*nvar
# adjust num state
curr_num_state = min(num_states, len(enc_out)-1)
# fit bayesian
bayesian_outpack = fit_observed_bayesian(
# enc_out[0, 1:, :].cpu().numpy(),
# enc_out[1:, :].cpu().numpy(),
enc_out[:, :].cpu().numpy(),
num_states=curr_num_state,
# num_states=(context_length // 16) // 2,
# original_knowledge=(transit_matrix, centroids),
post_w=1.0,
)
# extract core info
regularized_transit_p, regularized_centroids, centroid_std = bayesian_outpack[:3]
# regularized_transit_p, regularized_centroids = regularize_transit_centroids(enc_out[0, 1:, :].cpu().numpy(), transit_matrix, centroids)
future_action_enc_out = None
if len(action_channels) > 0:
with torch.no_grad():
future_action_enc_out, _, _ = latent_encoder.forward_encoder(in_tensor.clone()[:, action_channels, :], masking=False)
context_npatchs = context_length // latent_encoder.patch_size
future_action_enc_out = future_action_enc_out.permute(1, 2, 0).flatten(start_dim=1)[context_npatchs:context_npatchs+(pred_length//latent_encoder.patch_size)+1, :] # sample_steps, embed_dim*bn*action_nvar
appended_embeds = quantile_traj_of_state(
# enc_out[0, -1, :],
enc_out[-1:, :],
regularized_transit_p,
regularized_centroids,
centroid_std,
(pred_length // latent_encoder.patch_size)+1,
top_k=curr_num_state,
temperature=1.0,
num_traj=num_traj_sampled, # maybe increase this later
future_action_enc_out=future_action_enc_out if condition_bayes else None,
) # 1, pred_length//patch_size, dim
# decoding
# enc_with_append = torch.concatenate((enc_out, appended_embeds), dim=1) # bn*nvar, L//patch_size + 1 + pred_length//patch_size, embed_dim
enc_with_append = torch.concatenate((enc_out, appended_embeds[0]), dim=0) # L//patch_size + 1 + pred_length//patch_size, embed_dim*bn*nvar
enc_with_append = enc_with_append.reshape(enc_with_append.shape[0], embed_dim, -1).permute(2, 0, 1) # bn*nvar, L//patch_size + 1 + pred_length//patch_size, embed_dim
dec_out = latent_encoder.forward_decoder(enc_with_append, ids_restore, masked_patches) # bn*nvar, L
# dec_out = torch.concatenate((enc_out, appended_embeds), dim=1)
dec_out = dec_out.reshape(dec_out.shape[0], -1)
bn_nvar, total_L = dec_out.shape
bayesian_out = dec_out.reshape(1, n_channels, total_L)[0, physio_channels, context_length:context_length+pred_length] # bn, nvar, pred_length
# return
if return_transit_matrix:
return bayesian_out, regularized_transit_p, regularized_centroids, centroid_std, enc_out
return bayesian_out # bn, nvar, pred_length
################## Bayesian Functions End ########################################################################