Timsty commited on
Commit
c58cdb9
·
verified ·
1 Parent(s): d3de13c

Upload pi0_moh.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pi0_moh.py +578 -0
pi0_moh.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+
4
+ import einops
5
+ import flax.nnx as nnx
6
+ import flax.nnx.bridge as nnx_bridge
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from typing_extensions import override
10
+ from typing import List, Optional
11
+
12
+ from openpi.models import model as _model
13
+ import openpi.models.gemma as _gemma
14
+ import openpi.models.siglip as _siglip
15
+ from openpi.shared import array_typing as at
16
+ import openpi.shared.nnx_utils as nnx_utils
17
+ from openpi.models.pi0moh_config import Pi0GatedConfig
18
+
19
+
20
+ def make_attn_mask(input_mask, mask_ar):
21
+ """Adapted from big_vision.
22
+
23
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
24
+ smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
25
+ setup several types of attention, for example:
26
+
27
+ [[1 1 1 1 1 1]]: pure causal attention.
28
+
29
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
30
+ themselves and the last 3 tokens have a causal attention. The first
31
+ entry could also be a 1 without changing behaviour.
32
+
33
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
34
+ block can attend all previous blocks and all tokens on the same block.
35
+
36
+ Args:
37
+ input_mask: bool[B, N] true if its part of the input, false if padding.
38
+ mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
39
+ it and false where it shares the same attention mask as the previous token.
40
+ """
41
+ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
42
+ cumsum = jnp.cumsum(mask_ar, axis=1)
43
+ attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
44
+ valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
45
+ return jnp.logical_and(attn_mask, valid_mask)
46
+
47
+
48
+ # Copied from pi0.py
49
+ @at.typecheck
50
+ def posemb_sincos(
51
+ pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
52
+ ) -> at.Float[at.Array, "b {embedding_dim}"]:
53
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
54
+ if embedding_dim % 2 != 0:
55
+ raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
56
+
57
+ fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
58
+ period = min_period * (max_period / min_period) ** fraction
59
+ sinusoid_input = jnp.einsum(
60
+ "i,j->ij",
61
+ pos,
62
+ 1.0 / period * 2 * jnp.pi,
63
+ precision=jax.lax.Precision.HIGHEST,
64
+ )
65
+ return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
66
+
67
+
68
+
69
+ class Pi0Gated(_model.BaseModel):
70
+
71
+ def __init__(self, config: Pi0GatedConfig, rngs: nnx.Rngs):
72
+ # Initialize base model with max_horizon.
73
+ super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
74
+ self.config = config
75
+ self.pi05 = config.pi05
76
+
77
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
78
+ action_expert_config = _gemma.get_config(config.action_expert_variant)
79
+
80
+ # TODO: rewrite gemma in NNX. For now, use bridge.
81
+ llm = nnx_bridge.ToNNX(
82
+ _gemma.Module(
83
+ configs=[paligemma_config, action_expert_config],
84
+ embed_dtype=config.dtype,
85
+ adarms=config.pi05,
86
+ )
87
+ )
88
+ llm.lazy_init(
89
+ rngs=rngs,
90
+ method="init",
91
+ use_adarms=[False, True] if config.pi05 else [False, False],
92
+ )
93
+
94
+ img = nnx_bridge.ToNNX(
95
+ _siglip.Module(
96
+ num_classes=paligemma_config.width,
97
+ variant="So400m/14",
98
+ pool_type="none",
99
+ scan=True,
100
+ dtype_mm=config.dtype,
101
+ )
102
+ )
103
+ img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
104
+
105
+ # Attribute names must match pi0.py for weight loading.
106
+ self.PaliGemma = nnx.Dict(llm=llm, img=img)
107
+
108
+ # Shared action input projection.
109
+ self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
110
+
111
+ if config.pi05:
112
+ # Pi0.5-style: adaRMS conditioning on timestep.
113
+ self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
114
+ self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
115
+ else:
116
+ # Pi0-style: state token + action-time MLP (no adaRMS).
117
+ self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
118
+ self.action_time_mlp_in = nnx.Linear(
119
+ 2 * action_expert_config.width,
120
+ action_expert_config.width,
121
+ rngs=rngs,
122
+ )
123
+ self.action_time_mlp_out = nnx.Linear(
124
+ action_expert_config.width,
125
+ action_expert_config.width,
126
+ rngs=rngs,
127
+ )
128
+
129
+ self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
130
+
131
+ # Extra gating head for Mixture-of-Horizons.
132
+ self.gate_out_proj = nnx.Linear(action_expert_config.width, 1, rngs=rngs)
133
+
134
+ @at.typecheck
135
+ def embed_prefix(
136
+ self, obs: _model.Observation
137
+ ) -> tuple[
138
+ at.Float[at.Array, "b s emb"],
139
+ at.Bool[at.Array, "b s"],
140
+ at.Bool[at.Array, " s"],
141
+ ]:
142
+ """Unchanged from pi0.py"""
143
+ input_mask = []
144
+ ar_mask = []
145
+ tokens = []
146
+ # embed images
147
+ for name in obs.images:
148
+ image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
149
+ tokens.append(image_tokens)
150
+ input_mask.append(
151
+ einops.repeat(
152
+ obs.image_masks[name],
153
+ "b -> b s",
154
+ s=image_tokens.shape[1],
155
+ )
156
+ )
157
+ # image tokens attend to each other
158
+ ar_mask += [False] * image_tokens.shape[1]
159
+
160
+ # add language (aka tokenized inputs)
161
+ if obs.tokenized_prompt is not None:
162
+ tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
163
+ tokens.append(tokenized_inputs)
164
+ input_mask.append(obs.tokenized_prompt_mask)
165
+ # full attention between image and language inputs
166
+ ar_mask += [False] * tokenized_inputs.shape[1]
167
+
168
+ tokens = jnp.concatenate(tokens, axis=1)
169
+ input_mask = jnp.concatenate(input_mask, axis=1)
170
+ ar_mask = jnp.array(ar_mask)
171
+ return tokens, input_mask, ar_mask
172
+
173
+ @at.typecheck
174
+ def embed_suffix(
175
+ self, state,
176
+ noisy_actions: _model.Actions,
177
+ timestep: at.Float[at.Array, " b"],
178
+ action_pad_mask,
179
+ ) -> tuple[
180
+ at.Float[at.Array, "b s emb"],
181
+ at.Bool[at.Array, "b s"],
182
+ at.Bool[at.Array, " s"],
183
+ at.Float[at.Array, "b emb"] | None,
184
+ ]:
185
+ """
186
+ Pi0 / Pi0.5 compatible suffix embedding.
187
+
188
+ Mirrors :class:`Pi0`'s ``embed_suffix`` (including adaRMS conditioning
189
+ when ``pi05=True``) but takes ``state`` and ``action_pad_mask``
190
+ explicitly to support batched horizon processing used by MoH.
191
+ """
192
+ input_mask = []
193
+ ar_mask: list[bool] = []
194
+ tokens = []
195
+
196
+ adarms_cond = None
197
+
198
+ # Optional Pi0-style state token (no state token for Pi0.5 / adaRMS).
199
+ if not self.pi05:
200
+ state_token = self.state_proj(state)[:, None, :]
201
+ tokens.append(state_token)
202
+ input_mask.append(jnp.ones((state.shape[0], 1), dtype=jnp.bool_))
203
+ # image/language inputs do not attend to state or actions
204
+ ar_mask += [True]
205
+
206
+ # Timestep embedding.
207
+ time_emb = posemb_sincos(
208
+ timestep,
209
+ self.action_in_proj.out_features,
210
+ min_period=4e-3,
211
+ max_period=4.0,
212
+ )
213
+
214
+ # Project actions.
215
+ action_tokens = self.action_in_proj(noisy_actions)
216
+
217
+ if self.pi05:
218
+ # Pi0.5: adaRMS on time embedding, actions unchanged.
219
+ time_emb = self.time_mlp_in(time_emb)
220
+ time_emb = nnx.swish(time_emb)
221
+ time_emb = self.time_mlp_out(time_emb)
222
+ time_emb = nnx.swish(time_emb)
223
+ adarms_cond = time_emb
224
+ action_expert_tokens = action_tokens
225
+ else:
226
+ # Pi0: concatenate time + actions and pass through MLP.
227
+ time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=noisy_actions.shape[1])
228
+ action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
229
+ action_time_tokens = self.action_time_mlp_in(action_time_tokens)
230
+ action_time_tokens = nnx.swish(action_time_tokens)
231
+ action_time_tokens = self.action_time_mlp_out(action_time_tokens)
232
+ action_expert_tokens = action_time_tokens
233
+
234
+ if action_pad_mask is None:
235
+ action_pad_mask = jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_)
236
+ input_mask.append(action_pad_mask)
237
+
238
+ tokens.append(action_expert_tokens)
239
+
240
+ # image/language/state inputs do not attend to action tokens
241
+ ar_mask += [True] + ([False] * (action_expert_tokens.shape[1] - 1))
242
+
243
+ tokens = jnp.concatenate(tokens, axis=1)
244
+ input_mask_arr = jnp.concatenate(input_mask, axis=1)
245
+ ar_mask_arr = jnp.array(ar_mask)
246
+
247
+ return tokens, input_mask_arr, ar_mask_arr, adarms_cond
248
+
249
+ def cv_squared(self, x: at.Array, eps: float = 1e-10) -> at.Array:
250
+ """Computes the squared coefficient of variation. (From pi0_pytorch_moh.py)"""
251
+
252
+ def compute_cv():
253
+ mean = jnp.mean(x, dtype=jnp.float32)
254
+ var = jnp.var(x, dtype=jnp.float32)
255
+ return var / (jnp.square(mean) + eps)
256
+
257
+ # Handle num_experts = 1 case
258
+ return jax.lax.cond(
259
+ x.shape[0] == 1,
260
+ lambda: jnp.array(0.0, dtype=jnp.float32),
261
+ compute_cv
262
+ )
263
+
264
+ @override
265
+ def compute_loss(
266
+ self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
267
+ ) -> at.Float[at.Array, "*b ah"]:
268
+ # def compute_loss(
269
+ # self,
270
+ # rng: at.KeyArrayLike,
271
+ # observation: _model.Observation,
272
+ # actions: at.Float[at.Array, "b s action_dim"],
273
+ # ) -> tuple[at.Float[at.Array, ""], dict[str, at.Float[at.Array, ""]]]:
274
+ preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
275
+ observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
276
+
277
+ batch_size, max_horizon, action_dim = actions.shape
278
+ num_horizons = len(self.config.horizons)
279
+ horizons_arr = jnp.array(self.config.horizons)
280
+
281
+ # Sample noise and time
282
+ noise = jax.random.normal(noise_rng, actions.shape)
283
+ time_scalar = jax.random.beta(time_rng, 1.5, 1, (batch_size,)) * 0.999 + 0.001
284
+
285
+ # Expand time and actions for each horizon
286
+ # time shape: (H, B)
287
+ time = einops.repeat(time_scalar, "b -> h b", h=num_horizons)
288
+ # x_t shape: (H, B, max_H, D)
289
+ x_t = time[..., None, None] * noise[None, ...] + (1 - time[..., None, None]) * actions[None, ...]
290
+ # u_t (target) shape: (B, max_H, D)
291
+ u_t = noise - actions
292
+
293
+ # STAGE 1: VLM Prefix Pass (Compute KV cache once)
294
+ prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
295
+ prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
296
+ prefix_positions = jnp.cumsum(prefix_mask, axis=1) - 1
297
+ (_, prefix_out), prefix_past_key_values = self.PaliGemma.llm(
298
+ [prefix_tokens, None],
299
+ mask=prefix_attn_mask,
300
+ positions=prefix_positions
301
+ )
302
+
303
+ # STAGE 2: Action Head Suffix Passes (Parallelized via batching)
304
+
305
+ # Repeat prefix masks and KV cache for each horizon
306
+ # New batch size is (B * H)
307
+ batched_prefix_mask = jnp.repeat(prefix_mask, num_horizons, axis=0)
308
+ batched_past_key_values = jax.tree_map(
309
+ lambda x: jnp.repeat(x, num_horizons, axis=1),
310
+ prefix_past_key_values
311
+ )
312
+ batched_state = jnp.repeat(observation.state, num_horizons, axis=0)
313
+
314
+ # Reshape x_t and time to align with the new batch dimension
315
+ # (H, B, max_H, D) -> (B*H, max_H, D)
316
+ batched_x_t = jnp.transpose(x_t, (1, 0, 2, 3)).reshape(batch_size * num_horizons, max_horizon, -1)
317
+ # (H, B) -> (B*H,)
318
+ batched_time = jnp.transpose(time, (1, 0)).reshape(-1)
319
+
320
+ # Create a padding mask for actions based on valid horizon length
321
+ # (H, max_H)
322
+ action_pad_mask = jnp.arange(max_horizon)[None, :] < horizons_arr[:, None]
323
+ # (B*H, max_H)
324
+ action_pad_mask_expanded = jnp.broadcast_to(
325
+ action_pad_mask[None, :, :],
326
+ (batch_size, num_horizons, max_horizon)
327
+ )
328
+ # (B, H, max_H) -> (B*H, max_H)
329
+ batched_action_pad_mask = action_pad_mask_expanded.reshape(batch_size * num_horizons, max_horizon)
330
+
331
+ # Embed the batched suffix inputs
332
+ suffix_tokens, suffix_pad_masks, suffix_ar_mask, adarms_cond = self.embed_suffix(
333
+ batched_state, batched_x_t, batched_time, action_pad_mask=batched_action_pad_mask
334
+ )
335
+
336
+ # Combine prefix and suffix masks for cross-attention
337
+ pad_masks = jnp.concatenate([batched_prefix_mask, suffix_pad_masks], axis=1)
338
+ ar_masks = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
339
+ full_att_2d_masks = make_attn_mask(pad_masks, ar_masks)
340
+
341
+ prefix_len = prefix_mask.shape[1]
342
+ suffix_len = suffix_tokens.shape[1]
343
+
344
+ # Create position IDs and attention mask for the suffix part only
345
+ suffix_position_ids = jnp.arange(prefix_len, prefix_len + suffix_len)[None, :]
346
+ suffix_att_2d_masks = full_att_2d_masks[:, -suffix_len:, :]
347
+
348
+ b = suffix_att_2d_masks.shape[0]
349
+ suffix_position_ids = jnp.broadcast_to(suffix_position_ids, (b, suffix_len))
350
+
351
+ adarms = [None, adarms_cond] if self.pi05 else [None, None]
352
+ (_, suffix_out), _ = self.PaliGemma.llm(
353
+ [None, suffix_tokens],
354
+ mask=suffix_att_2d_masks,
355
+ positions=suffix_position_ids,
356
+ kv_cache=batched_past_key_values,
357
+ adarms_cond=adarms,
358
+ )
359
+
360
+ action_start_index = 0 if self.pi05 else 1 # pi0.5 has no state token
361
+ v_t_batched = self.action_out_proj(suffix_out)
362
+ v_t_actions_padded = v_t_batched[:, action_start_index: action_start_index + max_horizon, :]
363
+ # (H, B, max_H, D_action)
364
+ all_v_t_preds = v_t_actions_padded.reshape(
365
+ batch_size, num_horizons, max_horizon, -1
366
+ ).transpose(1, 0, 2, 3)
367
+
368
+ # 1. Primary Loss: Ensures each expert head is trained well.
369
+ all_head_losses = []
370
+ for i, h in enumerate(self.config.horizons):
371
+ v_t_head = all_v_t_preds[i, :, :h, :]
372
+ target_v_t = u_t[:, :h, :]
373
+ # Mean over batch, horizon, and action dim
374
+ head_loss = jnp.mean(jnp.square(v_t_head - target_v_t))
375
+ all_head_losses.append(head_loss)
376
+
377
+ individual_loss = jnp.sum(jnp.stack(all_head_losses))
378
+
379
+ # 2. Auxiliary Loss: Trains the gating network
380
+ # (B*H, S_suffix, 1)
381
+ gate_logits_batched = self.gate_out_proj(suffix_out)
382
+ # (B*H, max_H, 1)
383
+ gate_logits_padded = gate_logits_batched[:, action_start_index: action_start_index + max_horizon, :]
384
+ # (B, max_H, H)
385
+ # gate_logits = einops.rearrange(gate_logits_padded, "(b h) s 1 -> b s h", b=batch_size, h=num_horizons)
386
+ gate_logits_reshaped = gate_logits_padded.reshape(batch_size, num_horizons, max_horizon, 1)
387
+ gate_logits = jnp.transpose(gate_logits_reshaped, (0, 2, 1, 3)).squeeze(-1)
388
+
389
+ # Create mask for softmax
390
+ # (max_H, H)
391
+ valid_heads_mask = jnp.arange(max_horizon)[:, None] < horizons_arr[None, :]
392
+ # (B, max_H, H) - broadcast batch dim
393
+ valid_heads_mask = jnp.broadcast_to(valid_heads_mask, gate_logits.shape)
394
+
395
+ masked_gate_logits = jnp.where(valid_heads_mask, gate_logits, jnp.finfo(gate_logits.dtype).min)
396
+ gate_weights = nnx.softmax(masked_gate_logits, axis=-1)
397
+
398
+ # Combine predictions using the dynamic weights
399
+ # all_v_t_preds: (H, B, max_H, D) -> (B, H, max_H, D)
400
+ all_v_t_preds_permuted = einops.rearrange(all_v_t_preds, "h b s d -> b h s d")
401
+ # gate_weights: (B, max_H, H) -> (B, H, max_H, 1)
402
+ gate_weights_expanded = jnp.transpose(gate_weights, (0, 2, 1))[:, :, :, None]
403
+
404
+ # (B, H, max_H, D) * (B, H, max_H, 1) -> sum over H -> (B, max_H, D)
405
+ v_t_combined = jnp.sum(all_v_t_preds_permuted * gate_weights_expanded, axis=1)
406
+
407
+ auxiliary_loss = jnp.mean(jnp.square(v_t_combined - u_t)) # Mean over B, H, D
408
+
409
+ # 3. Balance Loss: Encourage the gate layer to output weights flexibly
410
+ loss_components = []
411
+ boundaries = sorted(list(set([0] + self.config.horizons)))
412
+ for i in range(len(boundaries) - 1):
413
+ start_step, end_step = boundaries[i], boundaries[i + 1]
414
+ active_expert_indices = [idx for idx, h in enumerate(self.config.horizons) if h > start_step]
415
+
416
+ if len(active_expert_indices) > 1:
417
+ # (B, S_segment, H_total)
418
+ segment_gate_weights = gate_weights[:, start_step:end_step, :]
419
+ # (B, S_segment, H_active)
420
+ active_expert_weights = segment_gate_weights[:, :, jnp.array(active_expert_indices)]
421
+ # (H_active,)
422
+ avg_expert_prob_in_segment = jnp.mean(active_expert_weights, axis=(0, 1))
423
+ segment_loss = self.cv_squared(avg_expert_prob_in_segment)
424
+ loss_components.append(segment_loss)
425
+
426
+ load_balancing_loss = jnp.mean(jnp.stack(loss_components)) if loss_components else 0.0
427
+
428
+ total_loss = (
429
+ individual_loss +
430
+ self.config.aux_weight * auxiliary_loss +
431
+ self.config.balance_weight * load_balancing_loss
432
+ )
433
+
434
+ return total_loss
435
+
436
+ @override
437
+ def sample_actions(
438
+ self,
439
+ rng: at.KeyArrayLike,
440
+ observation: _model.Observation,
441
+ *,
442
+ num_steps: int | at.Int[at.Array, ""] = 10,
443
+ noise: at.Float[at.Array, "b ah ad"] | None = None,
444
+ ) -> _model.Actions:
445
+ """
446
+ Samples actions using the gated fusion mechanism during denoising.
447
+ """
448
+ observation = _model.preprocess_observation(None, observation, train=False)
449
+ dt = -1.0 / num_steps
450
+ batch_size = observation.state.shape[0]
451
+ max_horizon = self.action_horizon
452
+ num_horizons = len(self.config.horizons)
453
+ horizons_arr = jnp.array(self.config.horizons)
454
+
455
+ noise = jax.random.normal(rng, (batch_size, max_horizon, self.action_dim))
456
+
457
+ # First fill KV cache with a forward pass of the prefix
458
+ prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
459
+ prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
460
+ positions = jnp.cumsum(prefix_mask, axis=1) - 1
461
+ (_, prefix_out), kv_cache = self.PaliGemma.llm(
462
+ [prefix_tokens, None],
463
+ mask=prefix_attn_mask,
464
+ positions=positions
465
+ )
466
+
467
+ # Prepare static batched inputs (these don't change in the loop)
468
+ batched_prefix_mask = jnp.repeat(prefix_mask, num_horizons, axis=0)
469
+ batched_kv_cache = jax.tree_map(
470
+ lambda x: jnp.repeat(x, num_horizons, axis=1),
471
+ kv_cache
472
+ )
473
+ batched_state = jnp.repeat(observation.state, num_horizons, axis=0)
474
+
475
+ # Create static masks for padding actions in the loop
476
+ # (H, max_H)
477
+ steps_arr = jnp.arange(max_horizon)
478
+ action_pad_mask_per_h = steps_arr[None, :] < horizons_arr[:, None]
479
+ # (B*H, max_H)
480
+ batched_action_pad_mask = jnp.broadcast_to(
481
+ action_pad_mask_per_h[None, :, :],
482
+ (batch_size, num_horizons, max_horizon)
483
+ )
484
+ batched_action_pad_mask = einops.rearrange(batched_action_pad_mask, "b h s -> (b h) s")
485
+ # batched_action_pad_mask = einops.repeat(action_pad_mask_per_h, "h s -> (b h) s", b=batch_size)
486
+ # (H, max_H, 1)
487
+ action_mask_for_padding_x_t = (steps_arr[None, :, None] < horizons_arr[:, None, None])
488
+
489
+ # Create static mask for gate softmax
490
+ # (max_H, H)
491
+ valid_heads_mask = steps_arr[:, None] < horizons_arr[None, :]
492
+ # (B, max_H, H) - for broadcasting
493
+ valid_heads_mask = valid_heads_mask[None, :, :]
494
+
495
+ action_start_index = 0 if self.pi05 else 1 # pi0.5 has no state token
496
+
497
+ prefix_len = prefix_mask.shape[1]
498
+
499
+ def step_fn(carry):
500
+ x_t, time = carry
501
+
502
+ # --- Prepare Batched Inputs for this step ---
503
+ expanded_time = jnp.broadcast_to(time, (batch_size * num_horizons,))
504
+
505
+ # Pad x_t for each horizon
506
+ # (1, B, max_H, D)
507
+ x_t_expanded = x_t[None, ...]
508
+ # (H, B, max_H, D)
509
+ padded_x_t_batched = jnp.where(action_mask_for_padding_x_t, x_t_expanded, 0.0)
510
+ # (B*H, max_H, D)
511
+ batched_x_t = padded_x_t_batched.transpose(1, 0, 2, 3)
512
+ batched_x_t = einops.rearrange(batched_x_t, "b h s d -> (b h) s d")
513
+
514
+ # --- Run Batched Suffix Pass ---
515
+ suffix_tokens, suffix_pad_masks, suffix_ar_mask, adarms_cond = self.embed_suffix(
516
+ batched_state, batched_x_t, expanded_time, action_pad_mask=batched_action_pad_mask
517
+ )
518
+
519
+ pad_masks = jnp.concatenate([batched_prefix_mask, suffix_pad_masks], axis=1)
520
+ ar_masks = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
521
+ full_att_2d_masks = make_attn_mask(pad_masks, ar_masks)
522
+
523
+ suffix_len = suffix_tokens.shape[1]
524
+ suffix_position_ids = jnp.arange(prefix_len, prefix_len + suffix_len)[None, :]
525
+ suffix_att_2d_masks = full_att_2d_masks[:, -suffix_len:, :]
526
+
527
+ b = suffix_att_2d_masks.shape[0]
528
+ suffix_position_ids = jnp.broadcast_to(suffix_position_ids, (b, suffix_len))
529
+
530
+ adarms = [None, adarms_cond] if self.pi05 else [None, None]
531
+ (_, suffix_out), _ = self.PaliGemma.llm(
532
+ [None, suffix_tokens],
533
+ mask=suffix_att_2d_masks,
534
+ positions=suffix_position_ids,
535
+ kv_cache=batched_kv_cache,
536
+ adarms_cond=adarms,
537
+ )
538
+
539
+ # --- Gating and Fusion ---
540
+ # (B*H, S_suffix, 1)
541
+ gate_logits_batched = self.gate_out_proj(suffix_out)
542
+ # (B*H, max_H, 1)
543
+ gate_logits_padded = gate_logits_batched[:, action_start_index: action_start_index + max_horizon, :]
544
+ # (B, max_H, H)
545
+ gate_logits_reshaped = gate_logits_padded.reshape(batch_size, num_horizons, max_horizon, 1)
546
+ gate_logits = jnp.transpose(gate_logits_reshaped, (0, 2, 1, 3)).squeeze(-1)
547
+ masked_gate_logits = jnp.where(valid_heads_mask, gate_logits, jnp.finfo(gate_logits.dtype).min)
548
+ gate_weights = nnx.softmax(masked_gate_logits, axis=-1)
549
+
550
+ # Get all predictions
551
+ # (B*H, S_suffix, D_action)
552
+ v_t_batched = self.action_out_proj(suffix_out)
553
+ # (B*H, max_H, D_action)
554
+ v_t_actions_padded = v_t_batched[:, action_start_index: action_start_index + max_horizon, :]
555
+ # (B, H, max_H, D_action)
556
+ all_v_t_preds = v_t_actions_padded.reshape(batch_size, num_horizons, max_horizon, -1)
557
+
558
+ # Combine predictions
559
+ # gate_weights: (B, max_H, H) -> (B, H, max_H, 1)
560
+ gate_weights_expanded = jnp.transpose(gate_weights, (0, 2, 1))[:, :, :, None]
561
+
562
+ # (B, H, max_H, D) * (B, H, max_H, 1) -> sum over H -> (B, max_H, D)
563
+ v_t = jnp.sum(all_v_t_preds * gate_weights_expanded, axis=1)
564
+
565
+ # --- Euler Step ---
566
+ x_t_new = x_t + dt * v_t
567
+ time_new = time + dt
568
+
569
+ return (x_t_new, time_new)
570
+
571
+ def cond_fn(carry):
572
+ x_t, time = carry
573
+ # robust to floating-point error
574
+ return time >= -dt / 2
575
+
576
+ x_0, _ = jax.lax.while_loop(cond_fn, step_fn, (noise, 1.0))
577
+ return x_0
578
+