PyTorch
normwear2
custom_code
yunfeiluo commited on
Commit
5f7b8bf
·
verified ·
1 Parent(s): 438a66f

Upload 9 files

Browse files
Files changed (9) hide show
  1. __init__.py +2 -0
  2. config.json +31 -0
  3. configuration_normwear.py +54 -0
  4. latent_bayesian.py +265 -0
  5. layers.py +540 -0
  6. modeling_normwear.py +45 -0
  7. normwear2.py +706 -0
  8. pytorch_model.bin +3 -0
  9. utils.py +26 -0
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modeling_normwear import NormWear2Model
2
+ from .configuration_normwear import NormWear2Config
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "normwear2",
3
+ "architectures": ["NormWear2Model"],
4
+
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_normwear.NormWear2Config",
7
+ "AutoModel": "modeling_normwear.NormWear2Model"
8
+ },
9
+
10
+ "patch_size" : 16,
11
+ "mlp_ratio" : 4.0,
12
+ "fuse_freq" : 2,
13
+ "drop_p" : 0.0,
14
+
15
+ "max_in_length" : 256,
16
+ "trainable_pe" : true,
17
+
18
+ "embed_dim" : 768,
19
+ "num_heads" : 12,
20
+ "depth" : 12,
21
+
22
+ "decoder_embed_dim" : 512,
23
+ "decoder_num_head" : 8,
24
+ "decoder_depth" : 2,
25
+
26
+ "token_level_fuse" : true,
27
+ "use_casual" : true,
28
+ "use_cls" : false,
29
+ "jepa" : false,
30
+ "jepa_post_decoder_train" : false
31
+ }
configuration_normwear.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class NormWear2Config(PretrainedConfig):
5
+ model_type = "normwear2"
6
+
7
+ def __init__(
8
+ self,
9
+ patch_size=16,
10
+ embed_dim=768, decoder_embed_dim=512,
11
+ depth=4, decoder_depth=2,
12
+ num_heads=12,decoder_num_head=8,
13
+ mlp_ratio=4.0, drop_p=0.0,
14
+ fuse_freq=2, # channel attn every 2 block
15
+ # layer type
16
+ # absolute position embedding
17
+ max_in_length=256, # NOTE: actual is total seq_length // patch_size
18
+ trainable_pe=True,
19
+ # mechanism wise config
20
+ token_level_fuse=True,
21
+ use_casual=True,
22
+ use_cls=False,
23
+ # jepa
24
+ jepa=False, jepa_post_decoder_train=False,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ # basics
30
+ self.patch_size = patch_size
31
+ self.mlp_ratio = mlp_ratio
32
+ self.fuse_freq = fuse_freq
33
+ self.drop_p = drop_p
34
+
35
+ # position
36
+ self.max_in_length = max_in_length
37
+ self.trainable_pe = trainable_pe
38
+
39
+ # encoder
40
+ self.embed_dim = embed_dim
41
+ self.num_heads = num_heads
42
+ self.depth = depth
43
+
44
+ # decoder
45
+ self.decoder_embed_dim = decoder_embed_dim
46
+ self.decoder_num_head = decoder_num_head
47
+ self.decoder_depth = decoder_depth
48
+
49
+ # others
50
+ self.token_level_fuse = token_level_fuse
51
+ self.use_casual = use_casual
52
+ self.use_cls = use_cls
53
+ self.jepa = jepa
54
+ self.jepa_post_decoder_train = jepa_post_decoder_train
latent_bayesian.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from sklearn.cluster import KMeans
4
+
5
+ from .normwear2 import NormWear2
6
+
7
+ ################## Bayesian Functions Start ########################################################################
8
+
9
+ # helper function for determining state based on transit matrix
10
+ def get_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps,
11
+ top_k=-1, temperature=1, future_action_enc_out=None,
12
+ embed_dim=768,
13
+ **kwargs):
14
+ # last_s: 1, embed_dim*bn*nvar
15
+ # centroids: num_centroids, embed_dim*bn*nvar
16
+ # future_action_enc_out: sample_steps, embed_dim*bn*action_nvar
17
+ # action_nvar < nvar, (action_nvar+phyio_nvar = nvar)
18
+ # currently, bn is always 1.
19
+
20
+ # init
21
+ temperature = min(max(1e-6, temperature), 2)
22
+ prev_ci = np.argmin(np.sqrt(np.sum((centroids - last_s.cpu().numpy())**2, axis=1)))
23
+ result_embeds = torch.zeros(1, sample_steps, last_s.shape[-1])
24
+ traj_log = 0
25
+
26
+ # generate across target steps
27
+ for ss in range(sample_steps):
28
+ # raw sampling
29
+ p = transit_p[prev_ci]
30
+
31
+ # up-weight the transition where the transited state representation is closer to the next step of the action.
32
+ if future_action_enc_out is not None:
33
+ action_emb = future_action_enc_out[ss] # embed_dim*bn*action_nvar
34
+ action_emb = action_emb.cpu().numpy()
35
+ centroids_action_emb = centroids[:, -future_action_enc_out.shape[-1]:] # num_centroids, embed_dim*bn*action_nvar
36
+
37
+ # compute distance, then apply min-max normalization.
38
+ action_distance = np.linalg.norm(centroids_action_emb - action_emb[None, :], axis=-1) # num_centroids
39
+ action_distance = (action_distance - action_distance.min()) / (action_distance.max() - action_distance.min() + 1e-8) # minmax norm
40
+ p = p * (1 - action_distance) # upweight the transition to states whose representation is more similar to the future action.
41
+
42
+ # apply temperature
43
+ p = p ** (1.0 / temperature)
44
+
45
+ # use top k token
46
+ if top_k > 0:
47
+ topk_idx = np.argsort(p)[-top_k:]
48
+ topk_p = p[topk_idx]
49
+ else: # use all token
50
+ topk_idx = np.arange(len(p))
51
+ topk_p = p
52
+ topk_p = topk_p / topk_p.sum() # make sure p sum to 1
53
+
54
+ # sampling
55
+ new_cidx = np.random.choice(np.arange(len(topk_idx)), p=topk_p) # sampling step
56
+ new_ci = topk_idx[new_cidx]
57
+
58
+ # update
59
+ # print(centroids.shape, centroid_std.keys())
60
+ # exit()
61
+ traj_log += np.log(topk_p[new_cidx] + 1e-12)
62
+ # result_embeds[:, ss, :] = torch.from_numpy(centroids[new_ci])
63
+ curr_scale = 0 if centroid_std.get(new_ci) is None else centroid_std[new_ci] # means there are less number of clusters than actual desired number of clusters
64
+ result_embeds[:, ss, :] = torch.from_numpy(
65
+ np.random.normal(loc=centroids[new_ci], scale=curr_scale)
66
+ )
67
+ prev_ci = new_ci
68
+
69
+ return result_embeds.float().to(last_s.device), traj_log # 1, 2048, dim
70
+
71
+ def quantile_traj_of_state(last_s, transit_p, centroids, centroid_std, sample_steps,
72
+ top_k=-1, temperature=1, num_traj=20, future_action_enc_out=None):
73
+ # initialize traj list
74
+ num_traj = int(min(100, max(0, num_traj)))
75
+ result_embeds_traj_log = list() # result_embeds, traj_log
76
+
77
+ # repeat for num_traj times
78
+ for _ in range(num_traj):
79
+ result_embeds, traj_log = get_traj_of_state(last_s, transit_p, centroids, centroid_std,
80
+ sample_steps, top_k=top_k, temperature=temperature,
81
+ future_action_enc_out=future_action_enc_out)
82
+ result_embeds_traj_log.append((result_embeds, traj_log))
83
+ result_embeds_traj_log.sort(key=lambda x: x[1], reverse=True)
84
+ # return result_embeds_traj_log[0][0] # 1, 2048, dim
85
+
86
+ # fuse each sampled traj, weighted by their total energy
87
+ total_p = torch.tensor([t[1] for t in result_embeds_traj_log]).float().to(result_embeds_traj_log[0][0].device)
88
+ total_p = torch.softmax(total_p, 0)[:, None, None, None]
89
+ total_traj = torch.stack([t[0] for t in result_embeds_traj_log]) # num_traj, 1, 2048, dim
90
+ return (total_traj * total_p).sum(dim=0) # 1, 2048, dim
91
+ # return total_traj.mean(dim=0) # 1, 2048, dim
92
+
93
+ # Helper function to fit new bayesian
94
+ def fit_observed_bayesian(observed_emebds, num_states=16,
95
+ original_knowledge=None, post_w=1.0,
96
+ ):
97
+ # observed_emebds: N, embed_dim
98
+ # original_knowledge: (original_transit, original_centroids), ((3600, 3600), (3600, 768))
99
+ # return: regularized_transit_p, regularized_centroids
100
+
101
+ # only cluster based on physio channels, ignore action channels.
102
+ # when physio_channels are introduced, fit only on physio channels
103
+ # because we want to regularize the transit matrix based on physio states,
104
+ # and action channels may introduce extra noise for clustering.
105
+ reg_km = KMeans(
106
+ n_clusters=num_states,
107
+ random_state=42,
108
+ # n_init=10
109
+ n_init=1,
110
+ algorithm="elkan",
111
+ ).fit(observed_emebds)
112
+
113
+ regularized_centroids = reg_km.cluster_centers_ # num_states, 768
114
+ observed_centroids = reg_km.labels_ # N
115
+ centroid_std = {
116
+ observed_centroids[-1]: [
117
+ (observed_emebds[-1] - regularized_centroids[observed_centroids[-1]])**2,
118
+ 1 # counter
119
+ ]
120
+ }
121
+
122
+ # identify prior transit
123
+ if original_knowledge is not None:
124
+ original_transit, original_centroids = original_knowledge
125
+ closest_prior_centroids = np.sum((regularized_centroids[:, None, :]-original_centroids[None, :, :])**2, axis=-1)
126
+ closest_prior_centroids = np.argmin(closest_prior_centroids, axis=-1) # num_states
127
+ prior_transit = original_transit[closest_prior_centroids, :][:, closest_prior_centroids] # num_states, num_states
128
+ prior_transit_p = (prior_transit+1e-8) / ((prior_transit+1e-8).sum(axis=1, keepdims=True))
129
+ else:
130
+ prior_transit_p, post_w = 0, 1.0
131
+
132
+ # fit expected bayesian transit matrix
133
+ posterior_transit = np.zeros((num_states, num_states))
134
+ for c_i in range(len(observed_centroids)-1):
135
+ curr_centoids_id = observed_centroids[c_i]
136
+
137
+ # update transit matrix
138
+ posterior_transit[observed_centroids[c_i], observed_centroids[c_i+1]] += 1
139
+
140
+ # update std stats
141
+ if centroid_std.get(curr_centoids_id) is None:
142
+ centroid_std[curr_centoids_id] = [0, 0]
143
+ centroid_std[curr_centoids_id][0] += ((observed_emebds[c_i] - regularized_centroids[curr_centoids_id])**2)
144
+ centroid_std[curr_centoids_id][1] += 1
145
+
146
+ # compute posterior probability
147
+ posterior_transit_p = (posterior_transit+1e-8) / ((posterior_transit+1e-8).sum(axis=-1, keepdims=True))
148
+
149
+ # clean up std
150
+ for std_k in centroid_std:
151
+ accum_centroids, centroid_num = centroid_std[std_k]
152
+ centroid_std[std_k] = np.sqrt(accum_centroids / centroid_num)
153
+
154
+ # aggregate
155
+ regularized_transit_p = (post_w*posterior_transit_p) + ((1-post_w)*prior_transit_p)
156
+ regularized_transit_p = (regularized_transit_p+1e-8) / ((regularized_transit_p+1e-8).sum(axis=-1, keepdims=True))
157
+
158
+
159
+ return regularized_transit_p, regularized_centroids, centroid_std
160
+
161
+
162
+ def bayesian_forecast(in_tensor, n_channels, physio_channels,
163
+ context_length=2048-16, pred_length=2048+16,
164
+ num_states=16, action_channels=[],
165
+ condition_bayes=False, num_traj_sampled=1,
166
+ latent_encoder=None):
167
+
168
+ # in_tensor: 1, nvar, length
169
+ end_idx = in_tensor.shape[-1] if len(action_channels) < 1 else context_length
170
+ with torch.no_grad():
171
+ enc_out, ids_restore, masked_patches = latent_encoder.forward_encoder(in_tensor.clone()[:, :, :end_idx], masking=False)
172
+
173
+
174
+
175
+ # regularize transit and centroid then forecast
176
+ embed_dim = enc_out.shape[-1]
177
+ enc_out = enc_out.permute(1, 2, 0).flatten(start_dim=1) # L//patch_size + 1, embed_dim*bn*nvar
178
+
179
+ # adjust num state
180
+ curr_num_state = min(num_states, len(enc_out)-1)
181
+
182
+ # fit bayesian
183
+ bayesian_outpack = fit_observed_bayesian(
184
+ # enc_out[0, 1:, :].cpu().numpy(),
185
+ enc_out[1:, :].cpu().numpy(),
186
+ num_states=curr_num_state,
187
+ # num_states=(context_length // 16) // 2,
188
+ # original_knowledge=(transit_matrix, centroids),
189
+ post_w=1.0,
190
+ )
191
+
192
+ # extract core info
193
+ regularized_transit_p, regularized_centroids, centroid_std = bayesian_outpack[:3]
194
+
195
+
196
+ # regularized_transit_p, regularized_centroids = regularize_transit_centroids(enc_out[0, 1:, :].cpu().numpy(), transit_matrix, centroids)
197
+ future_action_enc_out = None
198
+ if len(action_channels) > 0:
199
+ with torch.no_grad():
200
+ future_action_enc_out, _, _ = latent_encoder.forward_encoder(in_tensor.clone()[:, action_channels, :], masking=False)
201
+ future_action_enc_out = future_action_enc_out.permute(1, 2, 0).flatten(start_dim=1)[1:1+(pred_length//latent_encoder.patch_size)+1, :] # sample_steps, embed_dim*bn*action_nvar
202
+ appended_embeds = quantile_traj_of_state(
203
+ # enc_out[0, -1, :],
204
+ enc_out[-1:, :],
205
+ regularized_transit_p,
206
+ regularized_centroids,
207
+ centroid_std,
208
+ (pred_length // latent_encoder.patch_size)+1,
209
+ top_k=curr_num_state,
210
+ temperature=1.0,
211
+ num_traj=num_traj_sampled, # maybe increase this later
212
+ future_action_enc_out=future_action_enc_out if condition_bayes else None,
213
+ ) # 1, pred_length//patch_size, dim
214
+
215
+ # decoding
216
+ # enc_with_append = torch.concatenate((enc_out, appended_embeds), dim=1) # bn*nvar, L//patch_size + 1 + pred_length//patch_size, embed_dim
217
+ enc_with_append = torch.concatenate((enc_out, appended_embeds[0]), dim=0) # L//patch_size + 1 + pred_length//patch_size, embed_dim*bn*nvar
218
+ enc_with_append = enc_with_append.reshape(enc_with_append.shape[0], embed_dim, -1).permute(2, 0, 1) # bn*nvar, L//patch_size + 1 + pred_length//patch_size, embed_dim
219
+ dec_out = latent_encoder.forward_decoder(enc_with_append, ids_restore, masked_patches) # bn*nvar, L
220
+ # dec_out = torch.concatenate((enc_out, appended_embeds), dim=1)
221
+ dec_out = dec_out.reshape(dec_out.shape[0], -1)
222
+ bn_nvar, total_L = dec_out.shape
223
+ bayesian_out = dec_out.reshape(1, n_channels, total_L)[0, physio_channels, context_length:context_length+pred_length] # bn, nvar, pred_length
224
+
225
+ return bayesian_out # bn, nvar, pred_length
226
+
227
+
228
+ ################## Bayesian Functions End ########################################################################
229
+
230
+
231
+
232
+
233
+ ################## Base Models Start ########################################################################
234
+ def load_normwear2_model(weight_path='../train_results/ckpts/from_k8s/normwear2_fix_pos_checkpoint-19.pth'):
235
+ model = NormWear2(
236
+ # basics
237
+ patch_size=16,
238
+ mlp_ratio=4.0,
239
+ # encoder configuration
240
+ embed_dim=768,
241
+ num_heads=12,
242
+ depth=12,
243
+ # decoder configuration
244
+ decoder_embed_dim=512,
245
+ decoder_num_head=8,
246
+ decoder_depth=2,
247
+ # position embedding
248
+ trainable_pe=True,
249
+ max_in_length=4096 // 16,
250
+ # others
251
+ mask_prob=0.0, # 0.5
252
+ use_casual=True,
253
+ token_level_fuse=True,
254
+ use_cls=False,
255
+ jepa=False,
256
+ )
257
+
258
+ # load ckpt
259
+ state_dict = torch.load(weight_path, weights_only=False)
260
+ if state_dict.get('model') is not None:
261
+ state_dict = state_dict['model']
262
+ model.load_state_dict(state_dict, strict=True)
263
+ print("Model Load Success!")
264
+
265
+ return model
layers.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import Optional, Tuple
4
+
5
+ # import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.jit import Final
9
+
10
+ from itertools import repeat
11
+ import collections.abc
12
+
13
+ from .utils import *
14
+
15
+ def _ntuple(n):
16
+ def parse(x):
17
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
18
+ return tuple(x)
19
+ return tuple(repeat(x, n))
20
+ return parse
21
+
22
+
23
+ to_1tuple = _ntuple(1)
24
+ to_2tuple = _ntuple(2)
25
+ to_3tuple = _ntuple(3)
26
+ to_4tuple = _ntuple(4)
27
+ to_ntuple = _ntuple
28
+
29
+ class CheckShape(nn.Module):
30
+ def __init__(self, remark, key=None):
31
+ super().__init__()
32
+ self.remark = remark
33
+ self.key = key
34
+ def forward(self, x, **kwargs):
35
+ if self.remark is not None:
36
+ print(self.remark, x.shape)
37
+
38
+ out = x
39
+ if self.key is not None:
40
+ out = self.key(x)
41
+ return out
42
+
43
+ # fix time position embedding
44
+ class tAPE(nn.Module):
45
+ def __init__(self, d_model, dropout=0.1, max_len=2048, scale_factor=1.0, trainable=False):
46
+ super(tAPE, self).__init__()
47
+ self.max_len = max_len
48
+ self.trainable = trainable
49
+ self.dropout = nn.Dropout(p=dropout)
50
+ pe = torch.zeros(max_len, d_model) # positional encoding
51
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
52
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
53
+
54
+ pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len))
55
+ pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len))
56
+ pe = scale_factor * pe.unsqueeze(0)
57
+ self.register_buffer('pe', pe) # this stores the variable in the state_dict (used for non-trainable variables)
58
+
59
+ # trainable parameter
60
+ if self.trainable:
61
+ self.trainable_pe = nn.Parameter(torch.zeros(pe.shape))
62
+
63
+ def interpolate_pe(self, original_pe, target_len):
64
+ # original_pe: (1, original_length, embedding_size)
65
+ # return interpolated_pe: (1, target_len, embedding_size)
66
+ # fetch required info
67
+ original_len = original_pe.size(1)
68
+ if target_len <= original_len: # if shorted then just clip
69
+ # return original_pe.unfold(dimension=1, size=target_len, step=1).mean(dim=1).permute(0, 2, 1)
70
+ return original_pe[:, :target_len, :]
71
+
72
+ # interpolate
73
+ pe_reshaped = original_pe.permute(0, 2, 1) # 1, embedding_size, original_length
74
+ pe_interpolated = F.interpolate(
75
+ pe_reshaped,
76
+ size=target_len, # target length
77
+ mode='nearest-exact',
78
+ # align_corners=True # casual scenario is recommended to be true
79
+ )
80
+ interpolated_pe = pe_interpolated.permute(0, 2, 1) # 1, original_length, embedding_size
81
+ return interpolated_pe
82
+
83
+ def cyclic_pe(self, original_pe, target_len):
84
+ # original_pe: (1, original_length, embedding_size)
85
+ # return interpolated_pe: (1, target_len, embedding_size)
86
+
87
+ # cycling
88
+ # pe_reshaped = original_pe.permute(0, 2, 1) # 1, embedding_size, original_length
89
+ cyclic_pe = torch.concat((original_pe, original_pe), dim=1) # 1, original_length*2, embedding_size
90
+ while cyclic_pe.shape[-1] < target_len:
91
+ cyclic_pe = torch.concat((cyclic_pe, original_pe), dim=1)
92
+ # cyclic_pe = pe_reshaped.permute(0, 2, 1) # 1, original_length, embedding_size
93
+
94
+ # clip
95
+ if target_len <= cyclic_pe.shape[1]: # if shorted then just clip
96
+ return cyclic_pe[:, :target_len, :]
97
+ return cyclic_pe
98
+
99
+ def duplicate_pretrained_pe(self, pretrained_end_idx=256-16):
100
+ # self.pe shape: [1, max_length, embedding_size]
101
+ # self.trainable_pe shape: [1, max_length, embedding_size]
102
+ # NOTE: This function will be called after pretrained pe get loaded
103
+ # TODO: The index from 0 to pretrained_end_idx are well-pretrained, and the rest remain randomly initialized.
104
+ # when this function get called, duplicate the parameters values from 0 to pretrained_end_idx to all the later indeces, do for both pe and trainable pe
105
+ with torch.no_grad():
106
+ for param in [self.pe, self.trainable_pe]:
107
+ # param shape: [1, max_length, embedding_size]
108
+ max_len = param.shape[1]
109
+
110
+ pretrained = param[:, :pretrained_end_idx, :].clone()
111
+
112
+ remaining = max_len - pretrained_end_idx
113
+ if remaining <= 0:
114
+ continue
115
+
116
+ # repeat pretrained block enough times
117
+ repeat_factor = int(((remaining + pretrained_end_idx - 1) / pretrained_end_idx)+1)
118
+ tiled = pretrained.repeat(1, repeat_factor, 1) # 1, repeat_factor*pretrained_len, embedding_size
119
+
120
+ # fill the remaining positions
121
+ param[:, pretrained_end_idx:, :] = tiled[:, :remaining, :]
122
+
123
+
124
+ def forward(self, x): # N, L, C
125
+ has_four_dim = False
126
+ if len(x.shape) == 4:
127
+ has_four_dim = True
128
+ bn, nvar, L, C = x.shape
129
+ x = x.reshape(bn*nvar, L, C)
130
+
131
+ # adjust pe function
132
+ pe_adjust = self.interpolate_pe # seems work better than cyclic
133
+ # pe_adjust = self.cyclic_pe
134
+
135
+ # NOTE: this is just because the very 1st version has false length, remove this afterward
136
+ curr_max_len = self.max_len if self.max_len < 1024 else 256-16
137
+
138
+ # add position embeddings
139
+ x = x + pe_adjust(self.pe[:, :curr_max_len, :], x.shape[1])
140
+ # x = x + pe_adjust(self.pe[:, :, :], x.shape[1])
141
+ # x = x + self.pe[:, pe_start_idx:pe_start_idx+x.shape[1], :]
142
+ if self.trainable:
143
+ x = x + pe_adjust(self.trainable_pe[:, :curr_max_len, :], x.shape[1])
144
+ # x = x + self.trainable_pe[:, pe_start_idx:pe_start_idx+x.shape[1], :]
145
+ x = self.dropout(x)
146
+
147
+ if has_four_dim:
148
+ x = x.reshape(bn, nvar, L, C)
149
+ return x
150
+
151
+ class VAE_Latent(nn.Module):
152
+ def __init__(self, emb_size, out_size, bias=None):
153
+ super().__init__()
154
+
155
+ self.mu = nn.Linear(emb_size, out_size, bias=bias)
156
+ self.var = nn.Sequential(
157
+ nn.Linear(emb_size, out_size, bias=bias),
158
+ nn.Softplus()
159
+ )
160
+
161
+ def forward(self, x):
162
+ if not self.training:
163
+ # during inference, just return the mean
164
+ return self.mu(x)
165
+
166
+ # generate mean and variance
167
+ mu, var = self.mu(x), self.var(x)
168
+
169
+ # reparametrization trick
170
+ eps = torch.randn_like(var)
171
+ z = mu + var*eps
172
+ return z
173
+
174
+ class Mlp(nn.Module):
175
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
176
+ """
177
+ def __init__(
178
+ self,
179
+ in_features,
180
+ hidden_features=None,
181
+ out_features=None,
182
+ act_layer=nn.GELU,
183
+ norm_layer=None,
184
+ bias=True,
185
+ drop=0.,
186
+ use_conv=False,
187
+ vae_out=False,
188
+ ):
189
+ super().__init__()
190
+ out_features = out_features or in_features
191
+ hidden_features = hidden_features or in_features
192
+ bias = to_2tuple(bias)
193
+ drop_probs = to_2tuple(drop)
194
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
195
+
196
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
197
+ self.act = act_layer()
198
+ self.drop1 = nn.Dropout(drop_probs[0])
199
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
200
+
201
+ # final out linear
202
+ if not vae_out:
203
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
204
+ else:
205
+ self.fc2 = VAE_Latent(hidden_features, out_features, bias=bias[1])
206
+
207
+
208
+ self.drop2 = nn.Dropout(drop_probs[1])
209
+
210
+ def forward(self, x):
211
+ x = self.fc1(x)
212
+ x = self.act(x)
213
+ x = self.drop1(x)
214
+ x = self.norm(x)
215
+ x = self.fc2(x)
216
+ x = self.drop2(x)
217
+ return x
218
+
219
+ class SwiGLU_Mlp(nn.Module):
220
+ """
221
+ SwiGLU MLP block used in modern transformers (LLaMA, Qwen).
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ in_features,
227
+ hidden_features=None,
228
+ out_features=None,
229
+ norm_layer=None,
230
+ act_layer=None,
231
+ bias=True,
232
+ drop=0.,
233
+ use_conv=False,
234
+ vae_out=False,
235
+ ):
236
+ super().__init__()
237
+
238
+ out_features = out_features or in_features
239
+ hidden_features = hidden_features or int(in_features * 4) # typical MLP ratio
240
+
241
+ bias = to_2tuple(bias)
242
+ drop_probs = to_2tuple(drop)
243
+
244
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
245
+
246
+ # SwiGLU uses TWO projections
247
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
248
+ self.fc2 = linear_layer(in_features, hidden_features, bias=bias[0])
249
+
250
+ self.norm = norm_layer(hidden_features, eps=1e-06) if norm_layer is not None else nn.Identity()
251
+
252
+ # final projection
253
+ if not vae_out:
254
+ self.fc3 = linear_layer(hidden_features, out_features, bias=bias[1])
255
+ else:
256
+ self.fc3 = VAE_Latent(hidden_features, out_features, bias=bias[1])
257
+
258
+ self.drop2 = nn.Dropout(drop_probs[1])
259
+
260
+ def forward(self, x):
261
+
262
+ gate = F.silu(self.fc1(x)) # SiLU activation
263
+ value = self.fc2(x)
264
+
265
+ x = gate * value # SwiGLU gating
266
+
267
+ x = self.norm(x)
268
+
269
+ x = self.fc3(x)
270
+
271
+ x = self.drop2(x)
272
+
273
+ return x
274
+
275
+ class Attention(nn.Module):
276
+ fused_attn: Final[bool]
277
+
278
+ def __init__(
279
+ self,
280
+ dim: int,
281
+ num_heads: int = 8,
282
+ qkv_bias: bool = False,
283
+ qk_norm: bool = False,
284
+ attn_drop: float = 0.,
285
+ proj_drop: float = 0.,
286
+ norm_layer: nn.Module = nn.LayerNorm,
287
+ use_casual: bool = False,
288
+ ) -> None:
289
+ super().__init__()
290
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
291
+ self.num_heads = num_heads
292
+ self.head_dim = dim // num_heads
293
+ self.scale = self.head_dim ** -0.5
294
+ # self.fused_attn = use_fused_attn()
295
+ self.fused_attn = True
296
+ self.use_casual = use_casual
297
+
298
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
299
+ self.q_norm = norm_layer(self.head_dim, eps=1e-06) if qk_norm else nn.Identity()
300
+ self.k_norm = norm_layer(self.head_dim, eps=1e-06) if qk_norm else nn.Identity()
301
+ self.attn_drop = nn.Dropout(attn_drop)
302
+ self.proj = nn.Linear(dim, dim)
303
+ self.proj_drop = nn.Dropout(proj_drop)
304
+
305
+ # reservor adjacency matrix
306
+ self.rc_attn = None
307
+
308
+ def forward(
309
+ self,
310
+ x: torch.Tensor,
311
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
312
+ ) -> torch.Tensor:
313
+ B, N, C = x.shape
314
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
315
+ q, k, v = qkv.unbind(0)
316
+ q, k = self.q_norm(q), self.k_norm(k)
317
+
318
+ # kv cache
319
+ if past_kv is not None:
320
+ past_k, past_v = past_kv
321
+ k = torch.cat([past_k, k], dim=2) # [B, h, past+N, d]
322
+ v = torch.cat([past_v, v], dim=2)
323
+
324
+ # whether to use scaled attn or raw attn
325
+ if self.fused_attn:
326
+ x = F.scaled_dot_product_attention(
327
+ q, k, v,
328
+ dropout_p=self.attn_drop.p if self.training else 0.,
329
+ is_causal=self.use_casual
330
+ )
331
+ else:
332
+ q = q * self.scale
333
+ attn = q @ k.transpose(-2, -1)
334
+ attn = attn.softmax(dim=-1)
335
+ attn = self.attn_drop(attn)
336
+ x = attn @ v
337
+
338
+ # mlp layers
339
+ x = x.transpose(1, 2).reshape(B, N, C)
340
+ x = self.proj(x)
341
+ x = self.proj_drop(x)
342
+ return x
343
+
344
+ def scaled_dot_product_attention_kvcache(query, key, value, attn_mask=None, dropout_p=0.0,
345
+ is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
346
+ L, S = query.size(-2), key.size(-2)
347
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
348
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
349
+ if is_causal:
350
+ assert attn_mask is None
351
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
352
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
353
+ attn_bias.to(query.dtype)
354
+
355
+ if attn_mask is not None:
356
+ if attn_mask.dtype == torch.bool:
357
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
358
+ else:
359
+ attn_bias = attn_mask + attn_bias
360
+
361
+ if enable_gqa:
362
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
363
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
364
+
365
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
366
+ attn_weight += attn_bias
367
+ attn_weight = torch.softmax(attn_weight, dim=-1)
368
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
369
+ return attn_weight @ value
370
+
371
+ class LayerScale(nn.Module):
372
+ def __init__(
373
+ self,
374
+ dim: int,
375
+ init_values: float = 1e-5,
376
+ inplace: bool = False,
377
+ ) -> None:
378
+ super().__init__()
379
+ self.inplace = inplace
380
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
381
+
382
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
383
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
384
+
385
+
386
+ class Block(nn.Module):
387
+ def __init__(
388
+ self,
389
+ dim: int,
390
+ num_heads: int,
391
+ mlp_ratio: float = 4.,
392
+ qkv_bias: bool = False,
393
+ qk_norm: bool = False,
394
+ proj_drop: float = 0.,
395
+ attn_drop: float = 0.,
396
+ init_values: Optional[float] = None,
397
+ drop_path: float = 0.,
398
+ act_layer: nn.Module = nn.GELU,
399
+ norm_layer: nn.Module = nn.LayerNorm,
400
+ mlp_layer: nn.Module = Mlp,
401
+ use_casual: bool = False,
402
+ vae_out: bool = False,
403
+ ) -> None:
404
+ super().__init__()
405
+ self.norm1 = norm_layer(dim, eps=1e-06)
406
+ self.attn = Attention(
407
+ dim,
408
+ num_heads=num_heads,
409
+ qkv_bias=qkv_bias,
410
+ qk_norm=qk_norm,
411
+ attn_drop=attn_drop,
412
+ proj_drop=proj_drop,
413
+ norm_layer=norm_layer,
414
+ use_casual=use_casual,
415
+ )
416
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
417
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
418
+
419
+ self.norm2 = norm_layer(dim, eps=1e-06)
420
+ self.mlp = mlp_layer(
421
+ in_features=dim,
422
+ hidden_features=int(dim * mlp_ratio),
423
+ act_layer=act_layer,
424
+ drop=proj_drop,
425
+ vae_out=vae_out,
426
+ )
427
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
428
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
429
+
430
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
431
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
432
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
433
+ return x
434
+
435
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
436
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
437
+
438
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
439
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
440
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
441
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
442
+ 'survival rate' as the argument.
443
+
444
+ """
445
+ if drop_prob == 0. or not training:
446
+ return x
447
+ keep_prob = 1 - drop_prob
448
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
449
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
450
+ if keep_prob > 0.0 and scale_by_keep:
451
+ random_tensor.div_(keep_prob)
452
+ return x * random_tensor
453
+
454
+ class DropPath(nn.Module):
455
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
456
+ """
457
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
458
+ super(DropPath, self).__init__()
459
+ self.drop_prob = drop_prob
460
+ self.scale_by_keep = scale_by_keep
461
+
462
+ def forward(self, x):
463
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
464
+
465
+ def extra_repr(self):
466
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
467
+
468
+ class PatchTSTKernelEmbeddingLocal(nn.Module):
469
+ def __init__(self, poly_degrees=2, num_poly_feats=120, patch_length=16, rff_scale=1.0, num_rff=256, rff_trainable=False, d_feat=512, d_out=512):
470
+ super().__init__()
471
+ poly_degrees_lst = range(2, 2 + poly_degrees)
472
+
473
+ self.num_poly_feats = num_poly_feats
474
+ self.patch_indices = [
475
+ torch.randint(
476
+ high=patch_length,
477
+ size=(self.num_poly_feats, d),
478
+ requires_grad=False,
479
+ )
480
+ for d in poly_degrees_lst
481
+ ]
482
+ self.freq_weights = nn.Parameter(
483
+ rff_scale * torch.randn(patch_length, num_rff // 2),
484
+ requires_grad=rff_trainable,
485
+ )
486
+ self.freq_biases = nn.Parameter(
487
+ torch.randn(1, 1, 1, num_rff // 2),
488
+ requires_grad=rff_trainable,
489
+ )
490
+ self.projection = nn.Linear(d_feat, d_out, bias=False)
491
+
492
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
493
+ """
494
+ Parameters:
495
+ x (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
496
+ Patch input for embedding
497
+ return:
498
+ `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
499
+ """
500
+
501
+ poly_feats = [x[..., pis].prod(dim=-1) for pis in self.patch_indices]
502
+
503
+ weighted_x = x @ self.freq_weights + self.freq_biases
504
+ rff_feats = torch.cat([torch.sin(weighted_x), torch.cos(weighted_x)], dim=-1)
505
+
506
+ # features = torch.cat([cdiff_feats, *poly_feats, rff_feats], dim=-1)
507
+ features = torch.cat([x, *poly_feats, rff_feats], dim=-1)
508
+ # print(features.shape)
509
+ # exit()
510
+ features = self.projection(features)
511
+ return features
512
+
513
+
514
+ class SIGReg(torch.nn.Module):
515
+ """Sketch Isotropic Gaussian Regularizer (single-GPU!)"""
516
+
517
+ def __init__(self, knots=17, num_proj=1024):
518
+ super().__init__()
519
+ self.num_proj = num_proj
520
+ t = torch.linspace(0, 3, knots, dtype=torch.float32)
521
+ dt = 3 / (knots - 1)
522
+ weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
523
+ weights[[0, -1]] = dt
524
+ window = torch.exp(-t.square() / 2.0)
525
+ self.register_buffer("t", t)
526
+ self.register_buffer("phi", window)
527
+ self.register_buffer("weights", weights * window)
528
+
529
+ def forward(self, proj):
530
+ """
531
+ proj: (T, B, D)
532
+ """
533
+ # sample random projections
534
+ A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
535
+ A = A.div_(A.norm(p=2, dim=0))
536
+ # compute the epps-pulley statistic
537
+ x_t = (proj @ A).unsqueeze(-1) * self.t
538
+ err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
539
+ statistic = (err @ self.weights) * proj.size(-2)
540
+ return statistic.mean() # average over projections and time
modeling_normwear.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel
4
+
5
+ from .configuration_normwear import NormWear2Config
6
+ from .normwear2 import NormWear2
7
+
8
+
9
+ class NormWear2Model(PreTrainedModel):
10
+
11
+ config_class = NormWear2Config
12
+ base_model_prefix = "normwear"
13
+
14
+ def __init__(self, config: NormWear2Config):
15
+ super().__init__(config)
16
+
17
+ self.normwear = NormWear2(
18
+ patch_size=config.patch_size,
19
+ embed_dim=config.embed_dim, decoder_embed_dim=config.decoder_embed_dim,
20
+ depth=config.depth, decoder_depth=config.decoder_depth,
21
+ num_heads=config.num_heads,decoder_num_head=config.decoder_num_head,
22
+ mlp_ratio=config.mlp_ratio, drop_p=config.drop_p,
23
+ fuse_freq=config.fuse_freq, # channel attn every 2 block
24
+ # layer type
25
+ # absolute position embedding
26
+ max_in_length=config.max_in_length, # NOTE: actual is total seq_length // patch_size
27
+ trainable_pe=config.trainable_pe,
28
+ # mechanism wise config
29
+ token_level_fuse=config.token_level_fuse,
30
+ use_casual=config.use_casual,
31
+ use_cls=config.use_cls,
32
+ # jepa
33
+ jepa=config.jepa, jepa_post_decoder_train=config.jepa_post_decoder_train,
34
+ )
35
+
36
+ self.post_init()
37
+
38
+ def forward(self, *args, **kwargs):
39
+ return self.normwear(*args, **kwargs)
40
+
41
+ def predict(self, *args, **kwargs):
42
+ return self.normwear.predict(*args, **kwargs)
43
+
44
+ def simulate(self, *args, **kwargs):
45
+ return self.normwear.simulate(*args, **kwargs)
normwear2.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) School of Computing, Information, and Data Science, University of California San Diego.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ # import torch
13
+ # import torch.nn as nn
14
+ # import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ from .layers import *
18
+
19
+ class EncoderLayer(nn.Module):
20
+ def __init__(self,embed_dim = 768,
21
+ norm_layer=nn.RMSNorm,
22
+ mlp_layer=SwiGLU_Mlp,
23
+ num_heads=12,
24
+ mlp_ratio=4.0,
25
+ qkv_bias=True,
26
+ drop_p=0.0,
27
+ fuse_frequency=2,
28
+ curr_layer = 0,
29
+ # fusion scheme
30
+ no_fusion=False,
31
+ mean_fuse=False,
32
+ use_casual=False,
33
+ prepend_cls=True,
34
+ token_level_fuse=False, # True: will follow Panda's idea, where each token themselves are info exchange laision intead of single cls representative.
35
+ vae_out=False,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.no_fusion = no_fusion
40
+ self.mean_fuse = mean_fuse
41
+ self.prepend_cls = prepend_cls
42
+ self.token_level_fuse = token_level_fuse
43
+
44
+ self.curr_layer = curr_layer
45
+ self.fuse_frequency = fuse_frequency
46
+
47
+ #self.self_attn = self_attn_model.transformer.blocks[curr_layer].eval()
48
+ self.variate_encoder = Block(
49
+ mlp_layer=mlp_layer,
50
+ dim=embed_dim,
51
+ num_heads=num_heads,
52
+ mlp_ratio=mlp_ratio,
53
+ qkv_bias=qkv_bias,
54
+ norm_layer=norm_layer,
55
+ use_casual=use_casual,
56
+ vae_out=vae_out
57
+ )
58
+
59
+ if self.curr_layer%self.fuse_frequency==0:
60
+ self.cls_fusion = Block(
61
+ mlp_layer=mlp_layer,
62
+ dim=embed_dim,
63
+ num_heads=num_heads,
64
+ mlp_ratio=mlp_ratio,
65
+ qkv_bias=qkv_bias,
66
+ use_casual=False
67
+ # proj_drop=drop # comment out for low version on jetson nano
68
+ )
69
+
70
+ def forward(self,x, nvar=5):
71
+ '''
72
+ input: x: bs*n_vars x L+1 x E
73
+ '''
74
+ _, N, E = x.shape
75
+
76
+ x_out = self.variate_encoder(x) # bs * nvars, L+1, E
77
+
78
+ # cls fusion
79
+ if self.curr_layer%self.fuse_frequency==0 and not self.no_fusion:
80
+ if not self.token_level_fuse: # [CLS] laision fusion
81
+ x_out = torch.reshape(x_out, (-1,nvar, N, E)) # z: [bs x nvars x num_patch x E]
82
+ if self.prepend_cls:
83
+ patch_tokens = x_out[:,:,1:,:] # if cls was prepended
84
+ else:
85
+ patch_tokens = x_out[:,:,:-1,:] # if cls was appended
86
+
87
+ # fetch token
88
+ if self.mean_fuse:
89
+ cls = x_out.mean(dim=2)
90
+ else:
91
+ if self.prepend_cls:
92
+ cls = x_out[:,:,0,:] # bs x n_vars x E, if cls was prepended
93
+ else:
94
+ cls = x_out[:,:,-1,:] # bs x n_vars x E, if cls was appended
95
+
96
+ # forward and replace
97
+ cls = self.cls_fusion(cls).unsqueeze(2) # bs x n_vars x 1 x E
98
+
99
+ if self.prepend_cls:
100
+ x_out = torch.cat((cls,patch_tokens),dim=2) # prepend cls
101
+ else:
102
+ x_out = torch.cat((patch_tokens, cls),dim=2) # append cls
103
+ bs, n_vars, N, E = x_out.shape
104
+ x_out = torch.reshape(x_out,(bs*n_vars,N,E)) #bs * nvars, L+1, E
105
+ else: # token level laision fusion (Following guidance from Panda's logic)
106
+ # x_out input shape: bs * nvars, L+1, E
107
+ x_out = torch.reshape(x_out, (-1,nvar, N, E)) # z: [bs x nvars x num_patch x E]
108
+ x_out = x_out.permute(0, 2, 1, 3) # z: [bs x num_patch x nvars x E]
109
+ bs, N, n_vars, E = x_out.shape
110
+ x_out = torch.reshape(x_out, (x_out.shape[0]*N, n_vars, E)) # combine the 1st 2 dimensions, prepare for attn
111
+
112
+ # cross channel forward
113
+ x_out = self.cls_fusion(x_out) # bs*num_patch, nvars, E
114
+ x_out = torch.reshape(x_out, (bs, N, n_vars, E)).permute(0, 2, 1, 3) # bs, nvars, num_patch, E
115
+ x_out = torch.reshape(x_out, (bs*n_vars, N, E)) # bs*nvars, num_patch, E
116
+
117
+ return x_out
118
+
119
+
120
+ class NormWear2(nn.Module):
121
+ """ Masked Autoencoder
122
+ """
123
+ def __init__(self, patch_size=16,
124
+ embed_dim=768, decoder_embed_dim=512,
125
+ depth=4, decoder_depth=2,
126
+ num_heads=12,decoder_num_head=8,
127
+ mlp_ratio=4.0, drop_p=0.0,
128
+ fuse_freq=2, # channel attn every 2 block
129
+ # layer type
130
+ norm_layer=nn.RMSNorm,
131
+ mlp_layer=SwiGLU_Mlp,
132
+ # absolute position embedding
133
+ max_in_length=2048, # NOTE: actual is total seq_length // patch_size
134
+ trainable_pe=True,
135
+ # mechanism wise config
136
+ token_level_fuse=False,
137
+ use_casual=False,
138
+ use_cls=True,
139
+ # to be deprecated
140
+ mask_prob=0.5, # 0.4, 0.5, deprecated after leverage dynamic mask ratio
141
+ max_pred_length=64, # deprecated
142
+ prepend_cls=True,
143
+ vae_out=False,
144
+ # jepa
145
+ jepa=False, jepa_post_decoder_train=False,
146
+ ):
147
+ super().__init__()
148
+
149
+ self.patch_size = patch_size
150
+ self.use_cls = use_cls
151
+ self.max_in_length = max_in_length
152
+
153
+ self.mask_prob = mask_prob # deprecated
154
+ self.prepend_cls = prepend_cls # deprecated
155
+ self.max_pred_length = max_pred_length # deprecated
156
+
157
+ self.jepa = jepa
158
+ self.jepa_post_decoder_train = jepa_post_decoder_train
159
+ if jepa:
160
+ self.SIGReg = SIGReg()
161
+
162
+ # --------------------------------------------------------------------------
163
+ # MAE encoder specifics
164
+ self.init_embed = nn.Sequential( # in bn*nvar, L
165
+ CheckShape(None, key=lambda x: x.unsqueeze(1)), # bn*nvar, 1, L
166
+ )
167
+
168
+ self.patch_embed = nn.Sequential( # in: bn*nvar, init_embed_size=1, L
169
+ nn.Conv1d(in_channels=1,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size), # bn*nvar, embed_dim, L//patch_size
170
+ CheckShape(None, key=lambda x: x.permute(0, 2, 1)) # bn*nvar, L//patch_size, embed_dim
171
+ )
172
+
173
+ if self.use_cls:
174
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
175
+
176
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
177
+ self.pos_embed = tAPE(embed_dim, max_len=max_in_length, trainable=trainable_pe, dropout=0.1)
178
+ self.encoder_blocks = [
179
+ EncoderLayer(embed_dim = embed_dim,
180
+ norm_layer = norm_layer,
181
+ mlp_layer = mlp_layer,
182
+ num_heads=num_heads,
183
+ mlp_ratio=mlp_ratio,
184
+ drop_p=drop_p,
185
+ fuse_frequency=fuse_freq,
186
+ curr_layer = i,
187
+ # fusion scheme
188
+ no_fusion=False, # False
189
+ mean_fuse=False, # False
190
+ use_casual=use_casual,
191
+ prepend_cls=prepend_cls,
192
+ token_level_fuse=token_level_fuse
193
+ )
194
+ for i in range(depth-1)]
195
+
196
+ # add last encoder layer
197
+ self.encoder_blocks.append(
198
+ EncoderLayer(embed_dim = embed_dim,
199
+ norm_layer = norm_layer,
200
+ mlp_layer = mlp_layer,
201
+ num_heads=num_heads,
202
+ mlp_ratio=mlp_ratio,
203
+ drop_p=drop_p,
204
+ fuse_frequency=fuse_freq,
205
+ curr_layer = depth,
206
+ # fusion scheme
207
+ no_fusion=False, # False
208
+ mean_fuse=False, # False
209
+ use_casual=use_casual,
210
+ prepend_cls=prepend_cls,
211
+ token_level_fuse=token_level_fuse,
212
+ vae_out=vae_out
213
+ )
214
+ )
215
+
216
+ self.encoder_blocks = nn.ModuleList(self.encoder_blocks)
217
+
218
+
219
+
220
+
221
+ # --------------------------------------------------------------------------
222
+ # --------------------------------------------------------------------------
223
+ # MAE decoder specifics
224
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
225
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
226
+ self.decoder_pos_embed = tAPE(decoder_embed_dim, max_len=max_in_length, trainable=trainable_pe)
227
+
228
+
229
+ self.decoder_blocks = nn.ModuleList([
230
+ Block(dim=decoder_embed_dim,num_heads=decoder_num_head,
231
+ mlp_ratio=mlp_ratio,norm_layer=norm_layer, use_casual=use_casual)
232
+ for i in range(decoder_depth)]) # bn*nvar, L//patch_size, decoder_embed_dim
233
+
234
+ # reshape layer after the linear map
235
+ if self.use_cls:
236
+ if self.prepend_cls:
237
+ decoder_reshape_layer = CheckShape(None, key=lambda x: x.flatten(start_dim=1)[:, self.patch_size:]) # bn*nvar, L
238
+ else:
239
+ decoder_reshape_layer = CheckShape(None, key=lambda x: x.flatten(start_dim=1)[:, :-self.patch_size]) # bn*nvar, L
240
+ else:
241
+ decoder_reshape_layer = CheckShape(None, key=lambda x: x.flatten(start_dim=1)) # bn*nvar, L
242
+
243
+ # regular output (same kernel for all step)
244
+ self.decoder_out = nn.Sequential(
245
+ nn.Linear(decoder_embed_dim, decoder_embed_dim//2), # bn*nvar, L//patch_size
246
+ nn.GELU(),
247
+ nn.Linear(decoder_embed_dim//2, patch_size), # bn*nvar, L//patch_size, patch_size
248
+ decoder_reshape_layer, # bn*nvar, L
249
+ # deconvolution/smoothing
250
+ CheckShape(None, key=lambda x: x.unsqueeze(1)), # bn*nvar, 1, L
251
+ nn.Conv1d(1, decoder_embed_dim//2, self.patch_size, padding='same'),
252
+ nn.GELU(),
253
+ nn.Conv1d(decoder_embed_dim//2, 1, self.patch_size, padding='same'),
254
+ CheckShape(None, key=lambda x: x.squeeze(1)), # bn*nvar, L
255
+ # # linear out
256
+ # nn.Linear(decoder_embed_dim, patch_size),
257
+ # CheckShape(None, key=lambda x: x.flatten(start_dim=1)[:, self.patch_size:])
258
+ )
259
+
260
+ def forward_encoder(self, x, masking=True, context_length=None, kv_cache=None, all_visible_length=None, non_visible_channel=list()):
261
+ '''Input
262
+ X:bn, nvar, L
263
+
264
+ '''
265
+ # embed patches
266
+ bn, nvar, L = x.shape
267
+ x = self.init_embed(x.flatten(end_dim=-2)) # bn*nvar, 1, L
268
+ x = self.patch_embed(x) # bn*nvar, L//patch_size, embed_dim
269
+ # x = self.pos_embed(x) # bn*nvar, L//patch_size, embed_dim
270
+
271
+ ####### MASK PART START ########################################################
272
+ # masking:
273
+ if masking:
274
+ # mask_prob = self.mask_prob
275
+ mask_prob = np.random.uniform(low=0.3, high=0.7) # varied mask ratio
276
+ else:
277
+ mask_prob = 0
278
+
279
+ # randomly masked out the patches
280
+ masked_patches = torch.ones(x.shape[0], x.shape[1], self.patch_size).to(x.device) # init
281
+ # use_unstructured = np.random.rand() < 0.5 # interpolation or forecasting
282
+ for x_i in range(len(x)):
283
+ # if use_unstructured:
284
+ # random unstructured masking
285
+ mask_patches_idx = torch.randperm(x.shape[1]) # shuffle idx
286
+ ids_restore = mask_patches_idx[torch.rand(mask_patches_idx.shape) < mask_prob].flatten().sort().values # idxs to mask
287
+ # else:
288
+ # # masking only the later part
289
+ # mask_patches_idx = torch.arange(x.shape[1]) # regular idx
290
+ # if mask_prob > 0:
291
+ # start_idx = np.random.choice(np.arange(int(0.3*x.shape[1]), x.shape[1]-1))
292
+ # ids_restore = mask_patches_idx[start_idx:].flatten().sort().values
293
+ # else:
294
+ # ids_restore = mask_patches_idx[torch.rand(mask_patches_idx.shape) < mask_prob].flatten().sort().values # idxs to mask
295
+
296
+ # x = x.float() # dtype adjust
297
+
298
+ # replace those token with mask token
299
+ x[x_i, ids_restore, :] = self.mask_token[0].expand(len(ids_restore), x.shape[2]).to(x.dtype)
300
+ masked_patches[x_i, ids_restore, :] *= 2 # scaling up the mask position (for loss)
301
+
302
+ # replace token after context_length as mask token
303
+ if context_length is not None:
304
+ end_patch_idx = context_length // self.patch_size
305
+ x[:, end_patch_idx:, :] = self.mask_token.expand(x.shape[0], x.shape[1]-end_patch_idx, x.shape[2]).to(x.dtype) # replace those with mask token
306
+
307
+ # replace specific channel part with mask token
308
+ if all_visible_length is not None:
309
+ end_patch_idx = all_visible_length // self.patch_size
310
+ x = x.reshape(bn, nvar, x.shape[1], x.shape[2]) # bn, nvar, L//patch_size, embed_dim
311
+ x[:, non_visible_channel, end_patch_idx:, :] = self.mask_token.unsqueeze(0).expand(x.shape[0], len(non_visible_channel), x.shape[2]-end_patch_idx, x.shape[3]) # replace those with mask token
312
+ x = x.reshape(bn*nvar, x.shape[2], x.shape[3]) # reshape back to # bn*nvar, L//patch_size, embed_dim
313
+
314
+ ####### MASK PART END ###############################################################
315
+
316
+ ##### add position embedding #######
317
+ x = self.pos_embed(x) # bn*nvar, L//patch_size, embed_dim, add pos-embed after masking
318
+
319
+ ##### append cls token #######
320
+ if self.use_cls:
321
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
322
+ if self.prepend_cls:
323
+ x = torch.cat((cls_tokens, x), dim=1) # prepend cls token
324
+ else:
325
+ x = torch.cat((x, cls_tokens), dim=1) # append cls token
326
+
327
+ # apply Encoder blocks
328
+ for blk in self.encoder_blocks:
329
+ x = blk(x, nvar=nvar) # bn*nvar, L//patch_size + 1, embed_dim
330
+
331
+ return x, ids_restore, masked_patches
332
+
333
+ def forward_decoder(self, x, ids_restore, masked_patches, kv_cache=None):
334
+ # embed tokens
335
+ # x: # bn*nvar, L//patch_size+1, embed_dim
336
+
337
+ # add pos embed
338
+ x_ = self.decoder_pos_embed(self.decoder_embed(x)) # bn*nvar, L//patch_size, decoder_embed_dim
339
+
340
+ # decode
341
+ for blk in self.decoder_blocks:
342
+ x_ = blk(x_) # bn*nvar, L//patch_size, embed_dim
343
+
344
+ # predictor projection
345
+ x_ = self.decoder_out(x_) # bn*nvar, L
346
+
347
+ return x_
348
+
349
+ def forward_loss(self,target_tss, pred, masked_patches=None):
350
+ """
351
+ target_tss: bn, nvar, L
352
+ pred: bn, nvar, L
353
+ masked_patches: bn*nvar, L//patch_size, patch_size
354
+ """
355
+
356
+ # cosim_scores = self.cosim(target_tss,pred)
357
+ # loss = 1 - cosim_scores
358
+ # cos_loss = loss.mean()
359
+
360
+ loss_function = F.mse_loss
361
+ # loss_function = F.l1_loss
362
+
363
+
364
+ # compute loss
365
+ recon_loss = loss_function(pred, target_tss, reduction='none')
366
+
367
+ # scale up masked area
368
+ if masked_patches is not None:
369
+ masked_patches = masked_patches.flatten(start_dim=1) # bn*nvar, L
370
+ recon_loss = recon_loss*(masked_patches.reshape(recon_loss.shape))
371
+
372
+ # reduce
373
+ recon_loss = recon_loss.mean()
374
+
375
+ loss = recon_loss
376
+
377
+ return loss
378
+
379
+ def forward(self, data_pack, output_latent=False, masking=True):
380
+ '''Input
381
+ sample: bn, nvar, L
382
+ target_tss: bn, nvar, L
383
+ '''
384
+ # de-pack
385
+ # data_pack['sample'] = torch.sign(data_pack['sample'])*torch.log1p(torch.abs(data_pack['sample']))
386
+ imgs = data_pack['sample'] # bn, nvar, L
387
+ target_tss = data_pack['sample'] # bn, nvar, L
388
+
389
+ # if have noise
390
+ if data_pack.get('noise_sample') is not None:
391
+ imgs = data_pack['noise_sample']
392
+
393
+ # print("Check:", imgs.shape, target_tss.shape)
394
+ # exit()
395
+
396
+
397
+ ## ----------- JEPA forward ----------------------
398
+ if self.jepa: # forward function for jepa
399
+ return self.forward_jepa(imgs, target_tss, lambd=0.1)
400
+
401
+
402
+ ## ----------- Regular MAE forward ----------------------
403
+ # encoder forward
404
+ latent, ids_restore, masked_patches = self.forward_encoder(imgs, masking=masking)
405
+
406
+ # decoder forward
407
+ pred = self.forward_decoder(latent, ids_restore, masked_patches) # bs*nvar, L
408
+ pred = pred.reshape(target_tss.shape) # bs,nvar, L
409
+
410
+ # calculate loss
411
+ # loss = self.forward_loss(target_tss, pred, loss_mask=data_pack['awake_mask'], masked_patches=masked_patches, reduce=(not output_latent))
412
+ loss = self.forward_loss(target_tss, pred, masked_patches=masked_patches)
413
+
414
+ # intermediate return
415
+ if output_latent:
416
+ return latent, pred, masked_patches, loss
417
+
418
+ # return loss, pred, mask
419
+ return loss
420
+
421
+
422
+
423
+
424
+ def forward_jepa(self, in_context, target_context, lambd=0.1):
425
+ '''Input
426
+ in_context: bn, nvar, L
427
+ target_context: bn, nvar, L
428
+ '''
429
+ if not self.jepa_post_decoder_train:
430
+ # encoder forward
431
+ masked_latent, ids_restore, masked_patches = self.forward_encoder(in_context, masking=True)
432
+ target_latent, _, _ = self.forward_encoder(in_context, masking=False)
433
+
434
+ # latent shape: # bn*nvar, L//patch_size + 1, embed_dim
435
+ # masked_patches: bs*nvar, L, patch_size
436
+
437
+ # reconstruction loss
438
+ recon_loss = F.mse_loss(masked_latent, target_latent, reduction='none') # bs_nvar, num_patches, embed_size
439
+
440
+ # scale up masked area
441
+ if masked_patches is not None:
442
+ latent_masked_patches = masked_patches.mean(dim=-1)[:, :, None] # bs*nvar, L, 1
443
+ recon_loss = recon_loss*(latent_masked_patches)
444
+
445
+ # reduce
446
+ recon_loss = recon_loss.mean()
447
+
448
+
449
+ # step-wise sigreg (anti-collapse)
450
+ # NOTE: SIGReg take proj: (T, B, D) as input (= seq_length, batch_size, embed_dim)
451
+ sigreg_loss = self.SIGReg(masked_latent.permute(1, 0, 2)) # SIGReg already take mean
452
+
453
+
454
+ # aggregate loss
455
+ loss = recon_loss + (lambd * sigreg_loss)
456
+
457
+
458
+ # if integrate the decoder loss
459
+ # decoder forward
460
+ pred = self.forward_decoder(masked_latent, ids_restore, masked_patches) # bs*nvar, L
461
+ pred = pred.reshape(target_context.shape) # bs,nvar, L
462
+ raw_recon_loss = self.forward_loss(target_context, pred, masked_patches=masked_patches)
463
+ loss = (0.5*loss) + raw_recon_loss
464
+
465
+
466
+ # # check
467
+ # print(loss)
468
+ # exit()
469
+
470
+ return loss
471
+ else: # for training the decoder only
472
+ # encoder forward
473
+ with torch.no_grad():
474
+ masked_latent, ids_restore, masked_patches = self.forward_encoder(in_context, masking=False)
475
+
476
+ # decoder forward
477
+ pred = self.forward_decoder(masked_latent, ids_restore, masked_patches) # bs*nvar, L
478
+ pred = pred.reshape(target_context.shape) # bs,nvar, L
479
+
480
+ # regular loss
481
+ # print("Reconstruct loss here!")
482
+ # exit()
483
+ return self.forward_loss(target_context, pred, masked_patches=masked_patches)
484
+
485
+
486
+
487
+ def predict(self, context_tensor, prediction_length, max_pred_length=None,
488
+ lookback_window=None, **kwargs):
489
+ # context_tensor: 1, L, nvar
490
+ # output: 1, pred_length, nvar
491
+
492
+ # determine the auto-regressive steps
493
+ if max_pred_length is None:
494
+ max_pred_length = min(128, max(
495
+ self.patch_size,
496
+ mean_centroid(context_tensor[0].T, patch_size=self.patch_size), # this function take (nvar, L) as input
497
+ ))
498
+ # if lookback_window is None:
499
+ # lookback_window = 4*max_pred_length
500
+
501
+
502
+ # determine the observed context length
503
+ max_observed_context_length = min(
504
+ context_tensor.shape[1],
505
+ int(2*(self.max_in_length * self.patch_size))
506
+ ) # note really mater after use averge-or-interpolate mechanism
507
+
508
+ if context_tensor.shape[1] > max_observed_context_length:
509
+ context_tensor = context_tensor[:, -max_observed_context_length:, :]
510
+
511
+
512
+ # z-normalize context tensor
513
+ loc = context_tensor.mean(dim=1, keepdims=True)
514
+ scale = context_tensor.std(dim=1, keepdims=True)
515
+ scale[scale == 0] = 1.0
516
+ scale += 1e-8
517
+ context_tensor = (context_tensor - loc) / scale
518
+
519
+
520
+ # recursively generate
521
+ forecasted_tensor, kv_cache = self.generate(context_tensor, max_pred_length,
522
+ kv_cache=None, lookback_window=lookback_window) # 1, Lf, nvar
523
+ all_forecast = forecasted_tensor
524
+ while all_forecast.shape[1] < prediction_length:
525
+ # concat forecasted part from previous round
526
+ context_tensor = torch.concatenate((context_tensor, forecasted_tensor), dim=1) # 1, L+Lf, nvar
527
+
528
+ # clip observed context
529
+ if context_tensor.shape[1] > max_observed_context_length:
530
+ context_tensor = context_tensor[:, -max_observed_context_length:, :]
531
+
532
+ # forecast
533
+ forecasted_tensor, kv_cache = self.generate(context_tensor, max_pred_length,
534
+ kv_cache=kv_cache, lookback_window=lookback_window) # 1, Lf, nvar
535
+
536
+ # update all forecast
537
+ all_forecast = torch.concatenate((all_forecast, forecasted_tensor), dim=1)
538
+
539
+
540
+ # wrap up final output
541
+ all_forecast = all_forecast[:, :prediction_length, :] # clip
542
+ all_forecast = (all_forecast * scale) + loc # de-normalize back
543
+
544
+ return all_forecast
545
+
546
+ def generate(self, context_tensor, prediction_length, kv_cache=None,
547
+ lookback_window=None,**kwargs):
548
+ # context_tensor: 1, L, nvar
549
+ # output: 1, pred_length, nvar
550
+
551
+ # # z-normalize context tensor
552
+ # loc = context_tensor.mean(dim=1, keepdims=True)
553
+ # scale = context_tensor.std(dim=1, keepdims=True)
554
+ # scale[scale == 0] = 1.0
555
+ # scale += 1e-8
556
+ # context_tensor = (context_tensor - loc) / scale
557
+
558
+ # reshape
559
+ context_tensor = context_tensor.permute(0, 2, 1) # 1, nvar, L
560
+
561
+ if lookback_window is not None:
562
+ lookback_window = min(lookback_window, context_tensor.shape[2])
563
+ context_tensor = context_tensor[:, :, -lookback_window:]
564
+
565
+ # pad context tensor
566
+ bn, nvar, context_length = context_tensor.shape
567
+ total_len = context_length+prediction_length
568
+ total_len = total_len + (self.patch_size-(total_len%self.patch_size)) # need to be multiple of patch_size=16
569
+ pad_context_tensor = torch.zeros(bn, nvar, total_len).to(context_tensor.device)
570
+ pad_context_tensor[:, :, :context_length] = context_tensor
571
+
572
+ with torch.no_grad():
573
+ # forward
574
+ enc_out, ids_restore, masked_patches = self.forward_encoder(pad_context_tensor, masking=False, context_length=context_length, kv_cache=kv_cache)
575
+ # enc_out shape: bn*nvar, L//patch_size + 1, embed_dim
576
+ dec_out = self.forward_decoder(enc_out, ids_restore, masked_patches, kv_cache=kv_cache) # bn*nvar, L
577
+
578
+ # wrap-up predicted out
579
+ bn_nvar, total_L = dec_out.shape
580
+ pred_out = dec_out.reshape(bn, nvar, total_L)[:, :, context_length:context_length+prediction_length]
581
+ pred_out = pred_out.permute(0, 2, 1) # bn, L, nvar
582
+
583
+ # de-normalize
584
+ # pred_out = (pred_out * scale) + loc
585
+
586
+ return pred_out.detach(), kv_cache # 1, L, nvar
587
+
588
+ def simulate(self, context_tensor, all_visible_length=512,
589
+ non_visible_channel=list(), ar_step=None, **kwargs):
590
+ # context_tensor: 1, L, nvar
591
+ # all_visible_length: length where all channel are observed
592
+ # non_visible_channel: [ch0. ch1, ...]
593
+ # output, 1, L, nvar
594
+
595
+ # mask
596
+ context_tensor[:, all_visible_length:, non_visible_channel] = 0
597
+
598
+ # adjust shape for successive operations
599
+ context_tensor = context_tensor.permute(0, 2, 1) # 1, nvar, L
600
+
601
+ # determine the optimal auto-regressive step size
602
+ if ar_step is None:
603
+ # ar_step = self.patch_size
604
+ ar_step = min(128, max(
605
+ self.patch_size,
606
+ mean_centroid(context_tensor[0, non_visible_channel, :all_visible_length], patch_size=self.patch_size), # this function take (nvar, L) as input
607
+ ))
608
+ print(f"{ar_step=}")
609
+
610
+ # normalize
611
+ loc = context_tensor.mean(dim=2, keepdims=True)
612
+ scale = context_tensor.std(dim=2, keepdims=True)
613
+ scale[scale == 0] = 1.0
614
+ scale += 1e-8
615
+
616
+ # calculate loc and scale for non visible channel separately
617
+ loc[:, non_visible_channel, :all_visible_length] = context_tensor[:, non_visible_channel, :all_visible_length].mean(dim=2, keepdims=True)
618
+ scale[:, non_visible_channel, :all_visible_length] = context_tensor[:, non_visible_channel, :all_visible_length].std(dim=2, keepdims=True)
619
+
620
+ # normalize
621
+ context_tensor = (context_tensor - loc) / scale
622
+
623
+ # make sure nonvisible part stay 0
624
+ context_tensor[:, non_visible_channel, all_visible_length:] = 0
625
+
626
+ # pad context tensor
627
+ bn, nvar, context_length = context_tensor.shape
628
+ total_len = context_length
629
+ total_len = total_len + (self.patch_size-(total_len%self.patch_size)) # need to be multiple of patch_size=16
630
+ pad_context_tensor = torch.zeros(bn, nvar, total_len).to(context_tensor.device)
631
+ pad_context_tensor[:, :, :context_length] = context_tensor
632
+
633
+
634
+ # auto-regressive simulate
635
+ with torch.no_grad():
636
+ for end_idx in range(all_visible_length+ar_step, context_length+1, ar_step):
637
+ # forward
638
+ enc_out, ids_restore, masked_patches = self.forward_encoder(pad_context_tensor[:, :, :end_idx],
639
+ masking=False, all_visible_length=end_idx-ar_step,
640
+ non_visible_channel=non_visible_channel)
641
+ # enc_out shape: bn*nvar, L//patch_size + 1, embed_dim
642
+ dec_out = self.forward_decoder(enc_out, ids_restore, masked_patches).reshape(bn, nvar, end_idx) # bn*nvar, L (end_idx)
643
+
644
+ # update the stored global tensor
645
+ curr_max_possible_length = min(end_idx, pad_context_tensor.shape[-1])
646
+ pad_context_tensor[:, :, all_visible_length:curr_max_possible_length] = dec_out[:, :, all_visible_length:curr_max_possible_length]
647
+
648
+ pred_out = (pad_context_tensor * scale) + loc # bn, nvar, L
649
+
650
+ # # direct simulate
651
+ # with torch.no_grad():
652
+ # # forward
653
+ # enc_out, ids_restore, masked_patches = self.forward_encoder(pad_context_tensor, masking=False, context_length=context_length, all_visible_length=all_visible_length, non_visible_channel=non_visible_channel)
654
+ # # enc_out shape: bn*nvar, L//patch_size + 1, embed_dim
655
+ # dec_out = self.forward_decoder(enc_out, ids_restore, masked_patches) # bn*nvar, L
656
+ # bn_nvar, total_L = dec_out.shape
657
+
658
+ # # predicted out
659
+ # pred_out = dec_out.reshape(bn, nvar, total_L)[:, :, :context_length]
660
+
661
+ # # de-normalize back
662
+ # pred_out = (pred_out * scale) + loc
663
+
664
+ return pred_out.detach().permute(0, 2, 1) # bn, L, nvar
665
+
666
+ def get_embedding(self, sample_data, criteria='mean'): # default: criteria='mean'
667
+ # sample_data: (nvar, L) or (bn, nvar, L)
668
+ if len(sample_data.shape) == 2:
669
+ sample_data = sample_data.unsqueeze(0).float()
670
+ bn, nvar, L = sample_data.shape
671
+
672
+ # forward
673
+ out, _, _ = self.forward_encoder(sample_data, masking=False) # bn*nvar, P, E
674
+ bn_nvar, P, E = out.shape
675
+ out = out.reshape(bn, nvar, P, E)
676
+
677
+ # aggregate
678
+ if criteria == 'mean':
679
+ out = out.mean(dim=2) # bn, nvar, E
680
+ elif criteria == 'last':
681
+ out = out[:, :, -1, :] # bn, nvar, E
682
+ else:
683
+ raise ValueError("Unsupported aggregation criteria:", criteria)
684
+
685
+ return out.flatten(start_dim=1) # bn, nvar*E
686
+
687
+
688
+ if __name__ == '__main__':
689
+ # python3 -m normwear_on_chaotic.normwear_opt
690
+ model = NormWear2(
691
+ patch_size=16,
692
+ depth=12,
693
+ mask_prob=0.0,
694
+ max_in_length=4096, # 2048 for all ckpts before
695
+ use_casual=True, # False for all ckpts before
696
+ prepend_cls=True,
697
+ token_level_fuse=True,
698
+ )
699
+
700
+ # construct random data of shape bn, L, nvar
701
+ # test_x = torch.rand(2, 32, 3)
702
+ test_x = torch.rand(2, 64, 3)
703
+ out_y = model.predict(test_x, 32, max_pred_length=16)
704
+
705
+ # verbose
706
+ print("Output shape:", out_y.shape)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ef2770ccdca3f9c27ba4cc3220501620eeaa4b765e234dcad2882d7530924c9
3
+ size 748287646
utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def mean_centroid(x, sr=1.0, patch_size=16):
4
+ # x: nvar, L
5
+ f = torch.fft.rfft(x, dim=-1).abs()
6
+ freqs = torch.fft.rfftfreq(x.size(-1), 1/sr).to(x.device)
7
+ return int(((1 / (((f * freqs).sum(-1) / f.sum(-1)).mean())) // patch_size) * patch_size)
8
+
9
+ def generate_reservoir_matrix(n, sparsity=0.05, spectral_radius=0.9, seed=None):
10
+ if seed is not None:
11
+ torch.manual_seed(seed)
12
+
13
+ # Step 1: Random matrix with values in [-1, 1]
14
+ W = torch.rand(n, n) * 2 - 1
15
+
16
+ # Step 2: Apply sparsity mask
17
+ mask = (torch.rand(n, n) < sparsity).float()
18
+ W *= mask
19
+
20
+ # Step 3: Normalize to desired spectral radius
21
+ eigenvalues = torch.linalg.eigvals(W).abs()
22
+ max_eigenvalue = torch.max(eigenvalues)
23
+ if max_eigenvalue > 0:
24
+ W *= spectral_radius / max_eigenvalue
25
+
26
+ return W