Alan123 commited on
Commit
c563c71
·
verified ·
1 Parent(s): d1d5bcb

Upload grp_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. grp_model.py +356 -0
grp_model.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def get_patches_fast(images, cfg):
8
+ from einops import rearrange
9
+ batch_size, height, width, channels = images.shape
10
+ patch_size = cfg.patch_size ## n_patches = 8
11
+
12
+ patches = rearrange(images[:,:,:,:3], 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)
13
+ if channels > 3:
14
+ ## History stacking in the channel dimension for observations only, not goal images.
15
+ patches = rearrange(images, 'b (h p1) (w p2) (c hs) -> b (h w hs) (p1 p2 c)', p1 = patch_size, p2 = patch_size, hs=cfg.policy.obs_stacking) ## Stack the history in the channel dimension
16
+ return patches
17
+
18
+
19
+ def calc_positional_embeddings(sequence_length, d):
20
+ result = torch.ones(sequence_length, d)
21
+ for i in range(sequence_length):
22
+ for j in range(d):
23
+ result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
24
+ return result
25
+
26
+
27
+ class Head(nn.Module):
28
+ """ one head of self-attention """
29
+
30
+ def __init__(self, head_size, n_embd, dropout):
31
+ super().__init__()
32
+ self.key = nn.Linear(n_embd, head_size, bias=False)
33
+ self.query = nn.Linear(n_embd, head_size, bias=False)
34
+ self.value = nn.Linear(n_embd, head_size, bias=False)
35
+ self.dropout = nn.Dropout(dropout)
36
+
37
+ def forward(self, x, mask=None):
38
+ B,T,C = x.shape
39
+ # TODO:
40
+ ## Provide the block masking logic for the attention head
41
+ k = self.key(x)
42
+ q = self.query(x)
43
+ wei = q @ k.transpose(-2,-1) * C**-0.5
44
+ #wei = wei.masked_fill(mask == 0, float('-inf'))
45
+ if mask is not None:
46
+ wei = wei.masked_fill(mask == 0, float('-inf'))
47
+ wei = F.softmax(wei, dim=-1)
48
+ wei = self.dropout(wei)
49
+ v = self.value(x)
50
+ out = wei @ v
51
+ return out
52
+
53
+
54
+ class MultiHeadAttention(nn.Module):
55
+ def __init__(self, num_heads, head_size, n_embd, dropout):
56
+ super().__init__()
57
+ self.heads = nn.ModuleList([Head(head_size, n_embd=n_embd, dropout=dropout) for _ in range(num_heads)])
58
+ self.proj = nn.Linear(n_embd, n_embd)
59
+ self.dropout = nn.Dropout(dropout)
60
+
61
+ def forward(self, x, mask=None):
62
+ with torch.profiler.record_function("Self-Attention"):
63
+ out = torch.cat([h(x, mask) for h in self.heads], dim=-1)
64
+ out = self.dropout(self.proj(out))
65
+ return out
66
+
67
+
68
+ class FeedFoward(nn.Module):
69
+ def __init__(self, n_embd, dropout):
70
+ super().__init__()
71
+ self.net = nn.Sequential(
72
+ nn.Linear(n_embd, 4 * n_embd),
73
+ nn.ReLU(),
74
+ nn.Linear(4 * n_embd, n_embd),
75
+ nn.Dropout(dropout),
76
+ )
77
+
78
+ def forward(self, x):
79
+ return self.net(x)
80
+
81
+
82
+ class Block(nn.Module):
83
+ def __init__(self, n_embd, n_head, dropout):
84
+ super().__init__()
85
+ head_size = n_embd // n_head
86
+ self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)
87
+ self.ffwd = FeedFoward(n_embd, dropout)
88
+ self.ln1 = nn.LayerNorm(n_embd)
89
+ self.ln2 = nn.LayerNorm(n_embd)
90
+
91
+ def forward(self, x, mask=None):
92
+ x = x + self.sa(self.ln1(x), mask)
93
+ x = x + self.ffwd(self.ln2(x))
94
+ return x
95
+
96
+
97
+ class GRP(nn.Module):
98
+ def __init__(self, cfg, mlp_ratio=4):
99
+ super(GRP, self).__init__()
100
+ self._cfg = cfg
101
+ chars = cfg.dataset.chars_list
102
+ cfg.vocab_size = len(chars)
103
+ # TODO:
104
+ ## Provide the logic for the GRP network
105
+
106
+ # 4) Transformer encoder blocks
107
+
108
+ # 5) Classification MLPk
109
+
110
+ # 1) Embeddings
111
+ # Calculate patch dimension: patch_size * patch_size * 3 (RGB)
112
+ patch_dim = cfg.patch_size * cfg.patch_size * 3
113
+ self.patch_embedding = nn.Linear(patch_dim, cfg.n_embd)
114
+
115
+ # Token embedding for text goals (if not using T5)
116
+ # Check if dataset config exists and has encode_with_t5, else assume False or handle safely
117
+ use_t5 = False
118
+ if hasattr(cfg, 'dataset') and hasattr(cfg.dataset, 'encode_with_t5'):
119
+ use_t5 = cfg.dataset.encode_with_t5
120
+
121
+ if not use_t5:
122
+ self.token_embedding_table = nn.Embedding(cfg.vocab_size, cfg.n_embd)
123
+
124
+ # Learnable positional embeddings or fixed buffer
125
+ # Calculate maximum sequence length: CLS + text tokens + separator + goal img patches + obs patches
126
+ num_patches_per_image = cfg.n_patches * cfg.n_patches
127
+ max_seq_len = 1 + cfg.max_block_size + 1 + num_patches_per_image + num_patches_per_image * cfg.policy.obs_stacking
128
+ # Using a fixed buffer as per the helper function provided
129
+ pos_emb = calc_positional_embeddings(max_seq_len, cfg.n_embd)
130
+ self.register_buffer('pos_embedding', pos_emb)
131
+
132
+ # Special Tokens
133
+ self.cls_token = nn.Parameter(torch.randn(1, 1, cfg.n_embd))
134
+ self.goal_token = nn.Parameter(torch.randn(1, 1, cfg.n_embd))
135
+
136
+ # Dropout
137
+ self.dropout = nn.Dropout(cfg.dropout)
138
+
139
+ # 2) Transformer encoder blocks
140
+ # Corrected: using cfg.n_blocks instead of cfg.n_layer
141
+ self.blocks = nn.Sequential(*[
142
+ Block(cfg.n_embd, n_head=cfg.n_head, dropout=cfg.dropout)
143
+ for _ in range(cfg.n_blocks)
144
+ ])
145
+
146
+ self.ln_f = nn.LayerNorm(cfg.n_embd) # Final layer norm
147
+
148
+ # 3) Classification MLP / Action Head
149
+ # Corrected: using cfg.action_dim (root) instead of cfg.env.action_dim
150
+ output_dim = cfg.action_dim * cfg.policy.action_stacking
151
+ self.lm_head = nn.Linear(cfg.n_embd, output_dim)
152
+
153
+ self.apply(self._init_weights)
154
+
155
+ def _init_weights(self, module):
156
+ if isinstance(module, nn.Linear):
157
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
158
+ if module.bias is not None:
159
+ torch.nn.init.zeros_(module.bias)
160
+ elif isinstance(module, nn.Embedding):
161
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
162
+
163
+ def forward(self, images, goals_txt, goal_imgs, targets=None, pose=None, mask_=False):
164
+ #device = images.device
165
+ n, c, h, w = images.shape
166
+ obs_patches = get_patches_fast(images, self._cfg)
167
+ patches_g = get_patches_fast(goal_imgs, self._cfg)
168
+ if self._cfg.dataset.encode_with_t5:
169
+ goals_e = goals_txt
170
+ B, T, E = goals_txt.shape
171
+ else:
172
+ goals_e = self.token_embedding_table(goals_txt)
173
+ B, E = goals_txt.shape
174
+ T = self._cfg.max_block_size
175
+
176
+ # TODO:
177
+ ## Provide the logic to produce the output and loss for the GRP
178
+
179
+ # Map the vector corresponding to each patch to the hidden size dimension
180
+
181
+ # Adding classification and goal_img tokens to the tokens
182
+
183
+ # Adding positional embedding
184
+
185
+ # Compute blocked masks
186
+
187
+ # Transformer Blocks
188
+
189
+ # Getting the classification token only
190
+
191
+ # Compute output and loss
192
+ obs_emb = self.patch_embedding(obs_patches)
193
+ goal_img_emb = self.patch_embedding(patches_g)
194
+ cls_tokens = self.cls_token.expand(B, -1, -1)
195
+ sep_tokens = self.goal_token.expand(B, -1, -1)
196
+
197
+ # Concatenate everything into one sequence:
198
+ # [CLS] + [Goal Text] + [Goal Separator] + [Goal Img] + [Observation Patches]
199
+ x = torch.cat((cls_tokens, goals_e, sep_tokens, goal_img_emb, obs_emb), dim=1)
200
+
201
+ # Adding positional embedding
202
+ # We slice the pre-calculated buffer to the current sequence length
203
+ seq_len = x.shape[1]
204
+ x = x + self.pos_embedding[:seq_len, :].to(self._cfg.device)
205
+ x = self.dropout(x)
206
+
207
+ # Compute blocked masks
208
+ mask = None
209
+ if mask_:
210
+ # Create a causal mask (lower triangular)
211
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
212
+
213
+ # Transformer Blocks
214
+ for block in self.blocks:
215
+ x = block(x, mask)
216
+
217
+ x = self.ln_f(x)
218
+
219
+ # Getting the classification token only (index 0)
220
+ # This token aggregates information from the entire sequence
221
+ cls_out = x[:, 0, :]
222
+
223
+ # Compute output
224
+ logits = self.lm_head(cls_out) # Shape: (B, action_dim * action_stacking)
225
+
226
+ # Compute loss
227
+ loss = None
228
+ if targets is not None:
229
+ # Typically MSE loss for continuous action regression
230
+ loss = F.mse_loss(logits, targets)
231
+
232
+ return (logits, loss)
233
+
234
+ def resize_image(self, image):
235
+ """
236
+ Docstring for resize_image
237
+
238
+ :param self: Description
239
+ :param image: Description
240
+ self._resize_state = lambda sf: cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1])) # resize state
241
+ """
242
+ import cv2
243
+ import numpy as _np
244
+ img = _np.array(image, dtype=_np.float32)
245
+ img = cv2.resize(img, (self._cfg.image_shape[0], self._cfg.image_shape[1]))
246
+ return img
247
+
248
+ def normalize_state(self, image):
249
+ """
250
+ Docstring for preprocess_state
251
+
252
+ :param self: Description
253
+ :param image: Description
254
+ self._encode_state = lambda af: ((af/(255.0)*2.0)-1.0) # encoder: take a float, output an integer
255
+ self._resize_state = lambda sf: cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1])) # resize state
256
+ """
257
+ # img = _np.array(image, dtype=_np.float32)
258
+ # img = cv2.resize(img, (self._cfg.image_shape[0], self._cfg.image_shape[1]))
259
+ enc = ((image / 255.0) * 2.0) - 1.0
260
+ # t = _torch.tensor(enc, dtype=_torch.float32, device=self._cfg.device)
261
+ return enc
262
+
263
+ def preprocess_state(self, image):
264
+ img = self.resize_image(image)
265
+ img = self.normalize_state(img)
266
+ return img
267
+
268
+ def preprocess_goal_image(self, image):
269
+ return self.preprocess_state(image)
270
+
271
+ def encode_text_goal(self, goal, tokenizer=None, text_model=None):
272
+ import numpy as _np
273
+ import torch as _torch
274
+ if self._cfg.dataset.encode_with_t5:
275
+ if tokenizer is None or text_model is None:
276
+ raise ValueError("tokenizer and text_model must be provided when using T5 encoding")
277
+ # TODO:
278
+ ## Provide the logic converting text goal to T5 embedding tensor
279
+ with _torch.no_grad():
280
+ # Tokenize the goal text
281
+ inputs = tokenizer(goal, return_tensors="pt", padding=True, truncation=True, max_length=self._cfg.max_block_size)
282
+ # Move inputs to the correct device
283
+ inputs = {k: v.to(self._cfg.device) for k, v in inputs.items()}
284
+ # Get embeddings from the text model (Encoder)
285
+ outputs = text_model(**inputs)
286
+ # Use last hidden state: (Batch, Seq_Len, Hidden_Dim)
287
+ embeddings = outputs.last_hidden_state
288
+
289
+ return embeddings
290
+ else:
291
+ pad = " " * self._cfg.max_block_size
292
+ goal_ = goal[:self._cfg.max_block_size] + pad[len(goal):self._cfg.max_block_size]
293
+ try:
294
+ stoi = {c: i for i, c in enumerate(self._cfg.dataset.chars_list)}
295
+ ids = [stoi.get(c, 0) for c in goal_]
296
+ except Exception:
297
+ ids = [0] * self._cfg.max_block_size
298
+ return _torch.tensor(_np.expand_dims(_np.array(ids, dtype=_np.int64), axis=0), dtype=_torch.long, device=self._cfg.device)
299
+
300
+ def process_text_embedding_for_buffer(self, goal, tokenizer=None, text_model=None):
301
+ """
302
+ Process text goal embedding for storing in the circular buffer.
303
+ Returns a numpy array of shape (max_block_size, n_embd) without batch dimension.
304
+ """
305
+ import numpy as _np
306
+ if tokenizer is None or text_model is None:
307
+ raise ValueError("tokenizer and text_model must be provided when using T5 encoding")
308
+
309
+ goal_ = _np.zeros((self._cfg.max_block_size, self._cfg.n_embd), dtype=_np.float32)
310
+ input_ids = tokenizer(goal, return_tensors="pt").input_ids
311
+ goal_t = text_model.encoder(input_ids).last_hidden_state.detach().cpu().numpy()
312
+ goal_[:len(goal_t[0]), :] = goal_t[0][:self._cfg.max_block_size]
313
+ return goal_
314
+
315
+ def decode_action(self, action_tensor):
316
+
317
+ """
318
+ Docstring for decode_action
319
+
320
+ :param self: Description
321
+ :param action_tensor: Description
322
+ self._decode_action = lambda binN: (binN * action_std) + action_mean # Undo mapping to [-1, 1]
323
+ """
324
+ import torch as _torch
325
+ ## The action tensor is of shape (batch_size, action_dim * action_stacking) so we need to repeat the mean and std per action stacking
326
+ action_mean = _torch.tensor(np.repeat(self._cfg.dataset.action_mean, self._cfg.policy.action_stacking), dtype=action_tensor.dtype, device=action_tensor.device)
327
+ action_std = _torch.tensor(np.repeat(self._cfg.dataset.action_std, self._cfg.policy.action_stacking), dtype=action_tensor.dtype, device=action_tensor.device)
328
+ return (action_tensor * action_std) + action_mean
329
+
330
+ def encode_action(self, action_float):
331
+ """
332
+ Docstring for encode_action
333
+
334
+ :param self: Description
335
+ :param action_float: Description
336
+ self._encode_action = lambda af: (af - action_mean)/(action_std) # encoder: take a float, output an integer
337
+ """
338
+ import torch as _torch
339
+ action_mean = _torch.tensor(self._cfg.dataset.action_mean, dtype=action_float.dtype, device=action_float.device)
340
+ action_std = _torch.tensor(self._cfg.dataset.action_std, dtype=action_float.dtype, device=action_float.device)
341
+ return (action_float - action_mean) / action_std
342
+
343
+
344
+ @torch.no_grad()
345
+ def estimate_loss(model, dataset):
346
+ out = {}
347
+ model.eval()
348
+ for split in ['train', 'val']:
349
+ losses = torch.zeros(model._cfg.eval_iters)
350
+ for k in range(model._cfg.eval_iters):
351
+ X, x_pose, x_goal, x_goal_img, Y = dataset.get_batch_grp(split, model._cfg, model._cfg.batch_size)
352
+ logits, loss = model(X, x_goal, x_goal_img, Y, pose=x_pose)
353
+ losses[k] = loss.item()
354
+ out[split] = losses.mean()
355
+ model.train()
356
+ return out