jasonzhango commited on
Commit
24443be
·
verified ·
1 Parent(s): be1e979

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. checkpoint-300000/added_tokens.json +45 -0
  3. checkpoint-300000/chat_template.jinja +6 -0
  4. checkpoint-300000/config.json +28 -0
  5. checkpoint-300000/configuration_eo1_internvl.py +79 -0
  6. checkpoint-300000/merges.txt +0 -0
  7. checkpoint-300000/model.safetensors +3 -0
  8. checkpoint-300000/modeling_eo1_internvl.py +1284 -0
  9. checkpoint-300000/preprocessor_config.json +11 -0
  10. checkpoint-300000/processing_eo1_internvl.py +106 -0
  11. checkpoint-300000/processor_config.json +0 -0
  12. checkpoint-300000/rng_state_0.pth +3 -0
  13. checkpoint-300000/rng_state_1.pth +3 -0
  14. checkpoint-300000/rng_state_10.pth +3 -0
  15. checkpoint-300000/rng_state_11.pth +3 -0
  16. checkpoint-300000/rng_state_12.pth +3 -0
  17. checkpoint-300000/rng_state_13.pth +3 -0
  18. checkpoint-300000/rng_state_14.pth +3 -0
  19. checkpoint-300000/rng_state_15.pth +3 -0
  20. checkpoint-300000/rng_state_16.pth +3 -0
  21. checkpoint-300000/rng_state_17.pth +3 -0
  22. checkpoint-300000/rng_state_18.pth +3 -0
  23. checkpoint-300000/rng_state_19.pth +3 -0
  24. checkpoint-300000/rng_state_2.pth +3 -0
  25. checkpoint-300000/rng_state_20.pth +3 -0
  26. checkpoint-300000/rng_state_21.pth +3 -0
  27. checkpoint-300000/rng_state_22.pth +3 -0
  28. checkpoint-300000/rng_state_23.pth +3 -0
  29. checkpoint-300000/rng_state_24.pth +3 -0
  30. checkpoint-300000/rng_state_25.pth +3 -0
  31. checkpoint-300000/rng_state_26.pth +3 -0
  32. checkpoint-300000/rng_state_27.pth +3 -0
  33. checkpoint-300000/rng_state_28.pth +3 -0
  34. checkpoint-300000/rng_state_29.pth +3 -0
  35. checkpoint-300000/rng_state_3.pth +3 -0
  36. checkpoint-300000/rng_state_30.pth +3 -0
  37. checkpoint-300000/rng_state_31.pth +3 -0
  38. checkpoint-300000/rng_state_32.pth +3 -0
  39. checkpoint-300000/rng_state_33.pth +3 -0
  40. checkpoint-300000/rng_state_34.pth +3 -0
  41. checkpoint-300000/rng_state_35.pth +3 -0
  42. checkpoint-300000/rng_state_36.pth +3 -0
  43. checkpoint-300000/rng_state_37.pth +3 -0
  44. checkpoint-300000/rng_state_38.pth +3 -0
  45. checkpoint-300000/rng_state_39.pth +3 -0
  46. checkpoint-300000/rng_state_4.pth +3 -0
  47. checkpoint-300000/rng_state_40.pth +3 -0
  48. checkpoint-300000/rng_state_41.pth +3 -0
  49. checkpoint-300000/rng_state_42.pth +3 -0
  50. checkpoint-300000/rng_state_43.pth +3 -0
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  checkpoint-200000/trainer_state.json filter=lfs diff=lfs merge=lfs -text
37
  checkpoint-250000/trainer_state.json filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  checkpoint-200000/trainer_state.json filter=lfs diff=lfs merge=lfs -text
37
  checkpoint-250000/trainer_state.json filter=lfs diff=lfs merge=lfs -text
38
+ checkpoint-300000/trainer_state.json filter=lfs diff=lfs merge=lfs -text
checkpoint-300000/added_tokens.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</box>": 151677,
3
+ "</img>": 151670,
4
+ "</quad>": 151673,
5
+ "</ref>": 151675,
6
+ "</think>": 151668,
7
+ "</tool_call>": 151658,
8
+ "</tool_response>": 151666,
9
+ "<IMG_CONTEXT>": 151671,
10
+ "<box>": 151676,
11
+ "<img>": 151669,
12
+ "<quad>": 151672,
13
+ "<ref>": 151674,
14
+ "<think>": 151667,
15
+ "<tool_call>": 151657,
16
+ "<tool_response>": 151665,
17
+ "<|action_end|>": 151680,
18
+ "<|action_pad|>": 151679,
19
+ "<|action_pass|>": 151681,
20
+ "<|action_start|>": 151678,
21
+ "<|box_end|>": 151649,
22
+ "<|box_start|>": 151648,
23
+ "<|endoftext|>": 151643,
24
+ "<|file_sep|>": 151664,
25
+ "<|fim_middle|>": 151660,
26
+ "<|fim_pad|>": 151662,
27
+ "<|fim_prefix|>": 151659,
28
+ "<|fim_suffix|>": 151661,
29
+ "<|im_end|>": 151645,
30
+ "<|im_start|>": 151644,
31
+ "<|image_pad|>": 151655,
32
+ "<|object_ref_end|>": 151647,
33
+ "<|object_ref_start|>": 151646,
34
+ "<|quad_end|>": 151651,
35
+ "<|quad_start|>": 151650,
36
+ "<|repo_name|>": 151663,
37
+ "<|state_end|>": 151684,
38
+ "<|state_pad|>": 151683,
39
+ "<|state_start|>": 151682,
40
+ "<|video_pad|>": 151656,
41
+ "<|vision_end|>": 151653,
42
+ "<|vision_pad|>": 151654,
43
+ "<|vision_start|>": 151652,
44
+ "<|vla|>": 151685
45
+ }
checkpoint-300000/chat_template.jinja ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {% for message in messages %}{{'<|im_start|>' + message['role'] + '
2
+ '}}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<image>
3
+ ' }}{% elif content['type'] == 'video' %}{{ '<video>
4
+ ' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{'<|im_end|>
5
+ '}}{% endfor %}{% if add_generation_prompt %}{{'<|im_start|>assistant
6
+ ' }}{% endif %}
checkpoint-300000/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_chunk_size": 30,
3
+ "action_pass_id": 151681,
4
+ "action_token_id": 151679,
5
+ "architectures": [
6
+ "EO1InternVLPiFlowMatchingModel"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_eo1_internvl.EO1InternVLPiFlowMatchingConfig",
10
+ "AutoModel": "modeling_eo1_internvl.EO1InternVLPiFlowMatchingModel"
11
+ },
12
+ "backbone_name_or_path": "hugg_model/InternVL3_5-1B",
13
+ "dtype": "bfloat16",
14
+ "eos_token_id": 151645,
15
+ "expert_hidden_size": 1024,
16
+ "expert_init_from_backbone": false,
17
+ "expert_intermediate_size": 3072,
18
+ "expert_layer_mapping": "last",
19
+ "expert_num_attention_heads": 16,
20
+ "expert_num_hidden_layers": 18,
21
+ "ignore_index": -100,
22
+ "img_context_token_id": 151671,
23
+ "max_action_dim": 32,
24
+ "model_type": "eo1_internvl_pi",
25
+ "num_denoise_steps": 10,
26
+ "pad_token_id": 151643,
27
+ "transformers_version": "4.56.0"
28
+ }
checkpoint-300000/configuration_eo1_internvl.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 EO-Robotics Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+
19
+
20
+ class EO1InternVLPiFlowMatchingConfig(PretrainedConfig):
21
+ """
22
+ EO1 Flow-Matching wrapper for InternVL backbone + Pi05-style action expert.
23
+
24
+ Pi05 key properties (mirrors `openpi.models.pi0` with `pi05=True`):
25
+ - Prefix uses standard *causal* LM forward (flash-attn friendly) to build a per-layer KV cache.
26
+ - Action block is bidirectional within itself and can attend to the cached prefix KV.
27
+ - Flow-matching timestep is injected via AdaRMSNorm in the action expert (not concatenated into embeddings).
28
+ - Continuous state token in suffix is *disabled* (state should be encoded in text if needed).
29
+ """
30
+
31
+ model_type = "eo1_internvl_pi"
32
+ keys_to_ignore_at_inference = ["past_key_values"]
33
+
34
+ def __init__(
35
+ self,
36
+ backbone_name_or_path: str | None = None,
37
+ # Flow matching
38
+ action_chunk_size: int = 16,
39
+ max_action_dim: int = 32,
40
+ num_denoise_steps: int = 10,
41
+ rtc_delay: int = 5,
42
+ # Tokens
43
+ action_token_id: int | None = None,
44
+ action_pass_id: int | None = None,
45
+ img_context_token_id: int | None = None,
46
+ ignore_index: int = -100,
47
+ # Expert init
48
+ expert_init_from_backbone: bool = False,
49
+ # Expert architecture (Pi05-style: smaller action expert than VLM)
50
+ expert_num_hidden_layers: int | None = 18,
51
+ expert_hidden_size: int | None = 1024,
52
+ expert_intermediate_size: int | None = 3072,
53
+ expert_num_attention_heads: int | None = 16,
54
+ expert_layer_mapping: str = "last",
55
+ **kwargs,
56
+ ):
57
+ self.backbone_name_or_path = backbone_name_or_path
58
+
59
+ self.action_chunk_size = int(action_chunk_size)
60
+ self.max_action_dim = int(max_action_dim)
61
+ self.num_denoise_steps = int(num_denoise_steps)
62
+ self.rtc_delay = int(rtc_delay)
63
+
64
+ self.action_token_id = action_token_id
65
+ self.action_pass_id = action_pass_id
66
+ self.img_context_token_id = img_context_token_id
67
+ self.ignore_index = int(ignore_index)
68
+
69
+ self.expert_init_from_backbone = bool(expert_init_from_backbone)
70
+ self.expert_num_hidden_layers = None if expert_num_hidden_layers is None else int(expert_num_hidden_layers)
71
+ self.expert_hidden_size = None if expert_hidden_size is None else int(expert_hidden_size)
72
+ self.expert_intermediate_size = None if expert_intermediate_size is None else int(expert_intermediate_size)
73
+ self.expert_num_attention_heads = None if expert_num_attention_heads is None else int(expert_num_attention_heads)
74
+ self.expert_layer_mapping = str(expert_layer_mapping)
75
+
76
+ super().__init__(**kwargs)
77
+
78
+
79
+ EO1InternVLPiFlowMatchingConfig.register_for_auto_class()
checkpoint-300000/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-300000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:317388e105ef997dcb0f3ec00c76bc4a8a77767d21c93197206b8494ecbbbbd6
3
+ size 3726487744
checkpoint-300000/modeling_eo1_internvl.py ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 EO-Robotics Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import copy
18
+ import math
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import Any
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F # noqa: N812
26
+ from torch import Tensor
27
+ from transformers.modeling_outputs import ModelOutput
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.utils import logging
30
+
31
+ from .configuration_eo1_internvl import EO1InternVLPiFlowMatchingConfig
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ def create_sinusoidal_pos_embedding(
38
+ time: torch.tensor,
39
+ dimension: int,
40
+ min_period: float = 4e-3,
41
+ max_period: float = 4.0,
42
+ device: str | torch.device = "cpu",
43
+ ) -> Tensor:
44
+ """Sine-cosine embedding for scalar time in [0,1]. Matches openpi `posemb_sincos` sensitivity."""
45
+ if dimension % 2 != 0:
46
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
47
+ if time.ndim != 1:
48
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
49
+
50
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, device=device)
51
+ period = min_period * (max_period / min_period) ** fraction
52
+ scaling_factor = 1.0 / period * 2 * math.pi
53
+ sin_input = scaling_factor[None, :] * time[:, None]
54
+ return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
55
+
56
+
57
+ def _masked_fill_min(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
58
+ """Fill with dtype-min where mask is False. `mask` is broadcastable to `x`."""
59
+ return x.masked_fill(~mask, torch.finfo(x.dtype).min)
60
+
61
+
62
+ class AdaRMSNorm(nn.Module):
63
+ """
64
+ Pi05-style AdaRMSNorm (openpi `gemma.RMSNorm` with `cond!=None`):
65
+ - RMS normalize in float32
66
+ - per-layer modulation = Linear(cond -> 3*D) initialized to zeros
67
+ - output = normed * (1 + scale) + shift
68
+ - returns gate for gated residual.
69
+ """
70
+
71
+ def __init__(self, dim: int, *, eps: float = 1e-6):
72
+ super().__init__()
73
+ self.eps = float(eps)
74
+ self.modulation = nn.Linear(dim, dim * 3, bias=True)
75
+ nn.init.zeros_(self.modulation.weight)
76
+ nn.init.zeros_(self.modulation.bias)
77
+
78
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
79
+ if cond is None:
80
+ raise ValueError("AdaRMSNorm requires `cond` (Pi05 mode).")
81
+ if cond.ndim not in (2, 3):
82
+ raise ValueError(f"cond must be (B,D) or (B,T,D), got {tuple(cond.shape)}")
83
+ if x.ndim != 3:
84
+ raise ValueError(f"x must be (B,T,D), got {tuple(x.shape)}")
85
+ if x.shape[0] != cond.shape[0]:
86
+ raise ValueError(f"Batch mismatch: {x.shape[0]=} vs {cond.shape[0]=}")
87
+ if cond.shape[-1] != x.shape[-1]:
88
+ raise ValueError(f"Dim mismatch: {x.shape[-1]=} vs {cond.shape[-1]=}")
89
+ if cond.ndim == 3 and x.shape[:2] != cond.shape[:2]:
90
+ raise ValueError(f"Token mismatch: {x.shape[:2]=} vs {cond.shape[:2]=}")
91
+
92
+ x_dtype = x.dtype
93
+ x_f32 = x.float()
94
+ var = x_f32.pow(2).mean(dim=-1, keepdim=True)
95
+ normed = x_f32 * torch.rsqrt(var + self.eps)
96
+
97
+ mod = self.modulation(cond).to(dtype=x_f32.dtype)
98
+ scale, shift, gate = mod.chunk(3, dim=-1)
99
+ if cond.ndim == 2:
100
+ scale = scale[:, None, :]
101
+ shift = shift[:, None, :]
102
+ gate = gate[:, None, :]
103
+ out = normed * (1 + scale) + shift
104
+ return out.to(dtype=x_dtype), gate.to(dtype=x_dtype)
105
+
106
+
107
+ class Qwen2PiSelfAttention(nn.Module):
108
+ """
109
+ Qwen2 attention variant for Pi05 action expert:
110
+ - queries from suffix tokens (action tokens)
111
+ - keys/values from concat(prefix_kv_cache, suffix_kv)
112
+ - uses full (non-causal) attention mask provided by caller.
113
+ """
114
+
115
+ def __init__(self, qwen_config: Any, layer_idx: int):
116
+ super().__init__()
117
+ try:
118
+ from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
119
+ except Exception as e: # pragma: no cover
120
+ raise ImportError("transformers qwen2 internals are required for eo_pi_internvl.") from e
121
+
122
+ self._apply_rotary_pos_emb = apply_rotary_pos_emb
123
+ self._repeat_kv = repeat_kv
124
+
125
+ self.layer_idx = int(layer_idx)
126
+ self.hidden_size = int(qwen_config.hidden_size)
127
+ self.num_heads = int(qwen_config.num_attention_heads)
128
+ self.num_kv_heads = int(qwen_config.num_key_value_heads)
129
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
130
+ self.head_dim = int(getattr(qwen_config, "head_dim", self.hidden_size // self.num_heads))
131
+ self.scaling = self.head_dim**-0.5
132
+
133
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
134
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True)
135
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True)
136
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
137
+
138
+ def forward(
139
+ self,
140
+ hidden_states: torch.Tensor,
141
+ *,
142
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
143
+ attention_mask: torch.Tensor | None,
144
+ prefix_k: torch.Tensor,
145
+ prefix_v: torch.Tensor,
146
+ ) -> torch.Tensor:
147
+ # hidden_states: (B, S, D)
148
+ bsz, seqlen, _ = hidden_states.shape
149
+ q = self.q_proj(hidden_states).view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2)
150
+ k = self.k_proj(hidden_states).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
151
+ v = self.v_proj(hidden_states).view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
152
+
153
+ cos, sin = position_embeddings
154
+ q, k = self._apply_rotary_pos_emb(q, k, cos, sin)
155
+
156
+ if prefix_k.ndim != 4 or prefix_v.ndim != 4:
157
+ raise ValueError(f"prefix_k/v must be (B, n_kv, P, hd), got {tuple(prefix_k.shape)}, {tuple(prefix_v.shape)}")
158
+ if int(prefix_k.shape[0]) != bsz or int(prefix_v.shape[0]) != bsz:
159
+ raise ValueError("prefix_k/v batch mismatch.")
160
+ if int(prefix_k.shape[1]) != self.num_kv_heads or int(prefix_v.shape[1]) != self.num_kv_heads:
161
+ raise ValueError(
162
+ "prefix_k/v num_kv_heads mismatch: "
163
+ f"{int(prefix_k.shape[1])=} {int(prefix_v.shape[1])=} vs {self.num_kv_heads=}"
164
+ )
165
+ if int(prefix_k.shape[-1]) != self.head_dim or int(prefix_v.shape[-1]) != self.head_dim:
166
+ raise ValueError("prefix_k/v head_dim mismatch.")
167
+
168
+ k_all = torch.cat([prefix_k, k], dim=2) # (B, n_kv, P+S, hd)
169
+ v_all = torch.cat([prefix_v, v], dim=2)
170
+
171
+ k_all = self._repeat_kv(k_all, self.num_kv_groups) # (B, n_heads, P+S, hd)
172
+ v_all = self._repeat_kv(v_all, self.num_kv_groups)
173
+
174
+ # attention_mask: (B, 1, S, P+S) additive (0 or -inf), broadcast to heads
175
+ if attention_mask is not None:
176
+ if attention_mask.ndim != 4:
177
+ raise ValueError(f"attention_mask must be 4D (B,1,S,K), got {tuple(attention_mask.shape)}")
178
+ attn_mask = attention_mask.expand(bsz, self.num_heads, seqlen, k_all.shape[-2])
179
+ else:
180
+ attn_mask = None
181
+
182
+ attn_out = torch.nn.functional.scaled_dot_product_attention(
183
+ q, k_all, v_all, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
184
+ ) # (B, n_heads, S, hd)
185
+ attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seqlen, self.num_heads * self.head_dim)
186
+ return self.o_proj(attn_out)
187
+
188
+
189
+ class Qwen2PiMLP(nn.Module):
190
+ def __init__(self, qwen_config: Any):
191
+ super().__init__()
192
+ hidden = int(qwen_config.hidden_size)
193
+ inter = int(qwen_config.intermediate_size)
194
+ self.gate_proj = nn.Linear(hidden, inter, bias=False)
195
+ self.up_proj = nn.Linear(hidden, inter, bias=False)
196
+ self.down_proj = nn.Linear(inter, hidden, bias=False)
197
+ act_name = str(getattr(qwen_config, "hidden_act", "silu"))
198
+ if act_name != "silu":
199
+ logger.warning_once("EO Pi action expert: forcing SiLU hidden_act for MLP (got %s).", act_name)
200
+ self.act = nn.SiLU()
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
204
+
205
+
206
+ class Qwen2PiDecoderLayer(nn.Module):
207
+ def __init__(self, qwen_config: Any, layer_idx: int):
208
+ super().__init__()
209
+ eps = float(getattr(qwen_config, "rms_norm_eps", 1e-6))
210
+ self.input_layernorm = AdaRMSNorm(int(qwen_config.hidden_size), eps=eps)
211
+ self.self_attn = Qwen2PiSelfAttention(qwen_config=qwen_config, layer_idx=layer_idx)
212
+ self.post_attention_layernorm = AdaRMSNorm(int(qwen_config.hidden_size), eps=eps)
213
+ self.mlp = Qwen2PiMLP(qwen_config=qwen_config)
214
+
215
+ def forward(
216
+ self,
217
+ hidden_states: torch.Tensor,
218
+ *,
219
+ adarms_cond: torch.Tensor,
220
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
221
+ attention_mask: torch.Tensor | None,
222
+ prefix_k: torch.Tensor,
223
+ prefix_v: torch.Tensor,
224
+ ) -> torch.Tensor:
225
+ residual = hidden_states
226
+ x, gate = self.input_layernorm(hidden_states, adarms_cond)
227
+ x = self.self_attn(
228
+ x,
229
+ position_embeddings=position_embeddings,
230
+ attention_mask=attention_mask,
231
+ prefix_k=prefix_k,
232
+ prefix_v=prefix_v,
233
+ )
234
+ hidden_states = residual + x * gate
235
+
236
+ residual = hidden_states
237
+ x, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
238
+ x = self.mlp(x)
239
+ hidden_states = residual + x * gate
240
+ return hidden_states
241
+
242
+
243
+ class Qwen2PiActionExpert(nn.Module):
244
+ def __init__(self, qwen_config: Any, *, init_from_qwen2_lm: nn.Module | None = None):
245
+ super().__init__()
246
+ try:
247
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding
248
+ except Exception as e: # pragma: no cover
249
+ raise ImportError("transformers qwen2 internals are required for eo_pi_internvl.") from e
250
+
251
+ self.config = qwen_config
252
+ self.num_layers = int(qwen_config.num_hidden_layers)
253
+ self.layers = nn.ModuleList([Qwen2PiDecoderLayer(qwen_config, i) for i in range(self.num_layers)])
254
+ self.rotary_emb = Qwen2RotaryEmbedding(qwen_config)
255
+ self.final_norm = AdaRMSNorm(int(qwen_config.hidden_size), eps=float(getattr(qwen_config, "rms_norm_eps", 1e-6)))
256
+
257
+ if init_from_qwen2_lm is not None:
258
+ self._init_from_qwen2_lm(init_from_qwen2_lm)
259
+
260
+ def _init_from_qwen2_lm(self, qwen2_lm: nn.Module):
261
+ """
262
+ Copy attention/MLP weights from a Qwen2ForCausalLM (or Qwen2Model) into this expert.
263
+ AdaRMSNorm modulation stays zero-init to match Pi05.
264
+ """
265
+ src_layers = None
266
+ if hasattr(qwen2_lm, "model") and hasattr(qwen2_lm.model, "layers"):
267
+ src_layers = qwen2_lm.model.layers
268
+ elif hasattr(qwen2_lm, "layers"):
269
+ src_layers = qwen2_lm.layers
270
+ if src_layers is None:
271
+ raise ValueError("Unsupported qwen2_lm: cannot locate `.model.layers`.")
272
+
273
+ if len(src_layers) != len(self.layers):
274
+ raise ValueError(f"Layer count mismatch: {len(src_layers)=} vs {len(self.layers)=}")
275
+
276
+ for dst, src in zip(self.layers, src_layers, strict=True):
277
+ # attention
278
+ dst.self_attn.q_proj.load_state_dict(src.self_attn.q_proj.state_dict())
279
+ dst.self_attn.k_proj.load_state_dict(src.self_attn.k_proj.state_dict())
280
+ dst.self_attn.v_proj.load_state_dict(src.self_attn.v_proj.state_dict())
281
+ dst.self_attn.o_proj.load_state_dict(src.self_attn.o_proj.state_dict())
282
+ # mlp
283
+ dst.mlp.gate_proj.load_state_dict(src.mlp.gate_proj.state_dict())
284
+ dst.mlp.up_proj.load_state_dict(src.mlp.up_proj.state_dict())
285
+ dst.mlp.down_proj.load_state_dict(src.mlp.down_proj.state_dict())
286
+
287
+ def forward(
288
+ self,
289
+ action_tokens: torch.Tensor,
290
+ *,
291
+ prefix_kv_cache: list[tuple[torch.Tensor, torch.Tensor]],
292
+ prefix_key_mask: torch.Tensor,
293
+ position_ids: torch.Tensor,
294
+ adarms_cond: torch.Tensor,
295
+ suffix_key_mask: torch.Tensor | None = None,
296
+ ) -> torch.Tensor:
297
+ """
298
+ Args:
299
+ action_tokens: (B, S, D)
300
+ prefix_kv_cache: list[(k,v)] each (B, n_kv, P, hd) from InternVL prefix expert.
301
+ prefix_key_mask: (B, P) bool, True = valid prefix token.
302
+ position_ids: (B, S) positions for action tokens (prefix_len + [0..S-1]).
303
+ adarms_cond: (B, D) time conditioning vector.
304
+ suffix_key_mask: (B, S) bool, True = valid suffix token (optional; for padding).
305
+ """
306
+ if action_tokens.ndim != 3:
307
+ raise ValueError(f"action_tokens must be (B,S,D), got {tuple(action_tokens.shape)}")
308
+ bsz, s_len, _ = action_tokens.shape
309
+ if prefix_key_mask.ndim != 2 or int(prefix_key_mask.shape[0]) != bsz:
310
+ raise ValueError(f"prefix_key_mask must be (B,P), got {tuple(prefix_key_mask.shape)}")
311
+ if position_ids.shape != (bsz, s_len):
312
+ raise ValueError(f"position_ids must be (B,S)={bsz,s_len}, got {tuple(position_ids.shape)}")
313
+ if len(prefix_kv_cache) == 0:
314
+ raise ValueError("prefix_kv_cache is empty.")
315
+
316
+ # (cos,sin) for suffix tokens only (RoPE positions already baked into prefix cache).
317
+ position_embeddings = self.rotary_emb(action_tokens, position_ids)
318
+
319
+ if suffix_key_mask is None:
320
+ suffix_key_mask = torch.ones((bsz, s_len), device=action_tokens.device, dtype=torch.bool)
321
+ if suffix_key_mask.shape != (bsz, s_len):
322
+ raise ValueError(f"suffix_key_mask must be (B,S), got {tuple(suffix_key_mask.shape)}")
323
+
324
+ # Build Pi05 action-block attention mask: suffix queries attend to (valid prefix keys) + (valid suffix keys) fully.
325
+ prefix_part = (suffix_key_mask[:, None, :, None] & prefix_key_mask[:, None, None, :]) # (B,1,S,P)
326
+ suffix_part = (suffix_key_mask[:, None, :, None] & suffix_key_mask[:, None, None, :]) # (B,1,S,S)
327
+ allow = torch.cat([prefix_part, suffix_part], dim=-1) # (B,1,S,P+S)
328
+ attn_mask = torch.zeros(
329
+ (bsz, 1, s_len, int(prefix_key_mask.shape[1]) + s_len),
330
+ device=action_tokens.device,
331
+ dtype=action_tokens.dtype,
332
+ )
333
+ attn_mask = _masked_fill_min(attn_mask, allow)
334
+
335
+ x = action_tokens
336
+ for layer_idx, layer in enumerate(self.layers):
337
+ if layer_idx >= len(prefix_kv_cache):
338
+ raise ValueError(
339
+ "prefix_kv_cache has fewer layers than action expert. "
340
+ f"{len(prefix_kv_cache)=} < {len(self.layers)=}"
341
+ )
342
+ pk, pv = prefix_kv_cache[layer_idx]
343
+ x = layer(
344
+ x,
345
+ adarms_cond=adarms_cond,
346
+ position_embeddings=position_embeddings,
347
+ attention_mask=attn_mask,
348
+ prefix_k=pk,
349
+ prefix_v=pv,
350
+ )
351
+
352
+ x, _ = self.final_norm(x, adarms_cond)
353
+ return x
354
+
355
+
356
+ class Qwen3HeadRMSNorm(nn.Module):
357
+ """Qwen3-style RMSNorm used for `q_norm`/`k_norm` on per-head dim."""
358
+
359
+ def __init__(self, dim: int, *, eps: float = 1e-6):
360
+ super().__init__()
361
+ self.weight = nn.Parameter(torch.ones(dim))
362
+ self.eps = float(eps)
363
+
364
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
365
+ dtype = x.dtype
366
+ x_f32 = x.float()
367
+ var = x_f32.pow(2).mean(dim=-1, keepdim=True)
368
+ x_norm = x_f32 * torch.rsqrt(var + self.eps)
369
+ return (self.weight * x_norm).to(dtype=dtype)
370
+
371
+
372
+ class Qwen3PiSelfAttention(nn.Module):
373
+ """
374
+ Qwen3 attention variant for Pi05 action expert:
375
+ - queries from suffix tokens (action tokens)
376
+ - keys/values from concat(prefix_kv_cache, suffix_kv)
377
+ - uses full (non-causal) attention mask provided by caller.
378
+ """
379
+
380
+ def __init__(self, qwen_config: Any, layer_idx: int):
381
+ super().__init__()
382
+ try:
383
+ from transformers.models.qwen3.modeling_qwen3 import apply_rotary_pos_emb, repeat_kv
384
+ except Exception as e: # pragma: no cover
385
+ raise ImportError("transformers qwen3 internals are required for eo_pi_internvl.") from e
386
+
387
+ self._apply_rotary_pos_emb = apply_rotary_pos_emb
388
+ self._repeat_kv = repeat_kv
389
+
390
+ self.layer_idx = int(layer_idx)
391
+ self.hidden_size = int(qwen_config.hidden_size)
392
+ self.num_heads = int(qwen_config.num_attention_heads)
393
+ self.num_kv_heads = int(qwen_config.num_key_value_heads)
394
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
395
+ self.head_dim = int(getattr(qwen_config, "head_dim", self.hidden_size // self.num_heads))
396
+
397
+ attn_bias = bool(getattr(qwen_config, "attention_bias", False))
398
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attn_bias)
399
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias)
400
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias)
401
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attn_bias)
402
+
403
+ eps = float(getattr(qwen_config, "rms_norm_eps", 1e-6))
404
+ self.q_norm = Qwen3HeadRMSNorm(self.head_dim, eps=eps)
405
+ self.k_norm = Qwen3HeadRMSNorm(self.head_dim, eps=eps)
406
+
407
+ def forward(
408
+ self,
409
+ hidden_states: torch.Tensor,
410
+ *,
411
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
412
+ attention_mask: torch.Tensor | None,
413
+ prefix_k: torch.Tensor,
414
+ prefix_v: torch.Tensor,
415
+ ) -> torch.Tensor:
416
+ bsz, seqlen, _ = hidden_states.shape
417
+ hidden_shape = (bsz, seqlen, -1, self.head_dim)
418
+
419
+ q = self.q_proj(hidden_states).view(hidden_shape)
420
+ k = self.k_proj(hidden_states).view(hidden_shape)
421
+ v = self.v_proj(hidden_states).view(hidden_shape)
422
+ q = self.q_norm(q).transpose(1, 2) # (B,n_heads,S,hd)
423
+ k = self.k_norm(k).transpose(1, 2) # (B,n_kv,S,hd)
424
+ v = v.transpose(1, 2) # (B,n_kv,S,hd)
425
+
426
+ cos, sin = position_embeddings
427
+ q, k = self._apply_rotary_pos_emb(q, k, cos, sin)
428
+
429
+ if prefix_k.ndim != 4 or prefix_v.ndim != 4:
430
+ raise ValueError(
431
+ f"prefix_k/v must be (B, n_kv, P, hd), got {tuple(prefix_k.shape)}, {tuple(prefix_v.shape)}"
432
+ )
433
+ if int(prefix_k.shape[0]) != bsz or int(prefix_v.shape[0]) != bsz:
434
+ raise ValueError("prefix_k/v batch mismatch.")
435
+ if int(prefix_k.shape[1]) != self.num_kv_heads or int(prefix_v.shape[1]) != self.num_kv_heads:
436
+ raise ValueError(
437
+ "prefix_k/v num_kv_heads mismatch: "
438
+ f"{int(prefix_k.shape[1])=} {int(prefix_v.shape[1])=} vs {self.num_kv_heads=}"
439
+ )
440
+ if int(prefix_k.shape[-1]) != self.head_dim or int(prefix_v.shape[-1]) != self.head_dim:
441
+ raise ValueError("prefix_k/v head_dim mismatch.")
442
+
443
+ k_all = torch.cat([prefix_k, k], dim=2) # (B, n_kv, P+S, hd)
444
+ v_all = torch.cat([prefix_v, v], dim=2)
445
+ k_all = self._repeat_kv(k_all, self.num_kv_groups) # (B, n_heads, P+S, hd)
446
+ v_all = self._repeat_kv(v_all, self.num_kv_groups)
447
+
448
+ if attention_mask is not None:
449
+ if attention_mask.ndim != 4:
450
+ raise ValueError(f"attention_mask must be 4D (B,1,S,K), got {tuple(attention_mask.shape)}")
451
+ attn_mask = attention_mask.expand(bsz, self.num_heads, seqlen, k_all.shape[-2])
452
+ else:
453
+ attn_mask = None
454
+
455
+ attn_out = torch.nn.functional.scaled_dot_product_attention(
456
+ q, k_all, v_all, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
457
+ )
458
+ attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seqlen, self.num_heads * self.head_dim)
459
+ return self.o_proj(attn_out)
460
+
461
+
462
+ class Qwen3PiMLP(nn.Module):
463
+ def __init__(self, qwen_config: Any):
464
+ super().__init__()
465
+ hidden = int(qwen_config.hidden_size)
466
+ inter = int(qwen_config.intermediate_size)
467
+ self.gate_proj = nn.Linear(hidden, inter, bias=False)
468
+ self.up_proj = nn.Linear(hidden, inter, bias=False)
469
+ self.down_proj = nn.Linear(inter, hidden, bias=False)
470
+ act_name = str(getattr(qwen_config, "hidden_act", "silu"))
471
+ if act_name != "silu":
472
+ logger.warning_once("EO Pi action expert: forcing SiLU hidden_act for MLP (got %s).", act_name)
473
+ self.act = nn.SiLU()
474
+
475
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
476
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
477
+
478
+
479
+ class Qwen3PiDecoderLayer(nn.Module):
480
+ def __init__(self, qwen_config: Any, layer_idx: int):
481
+ super().__init__()
482
+ eps = float(getattr(qwen_config, "rms_norm_eps", 1e-6))
483
+ self.input_layernorm = AdaRMSNorm(int(qwen_config.hidden_size), eps=eps)
484
+ self.self_attn = Qwen3PiSelfAttention(qwen_config=qwen_config, layer_idx=layer_idx)
485
+ self.post_attention_layernorm = AdaRMSNorm(int(qwen_config.hidden_size), eps=eps)
486
+ self.mlp = Qwen3PiMLP(qwen_config=qwen_config)
487
+
488
+ def forward(
489
+ self,
490
+ hidden_states: torch.Tensor,
491
+ *,
492
+ adarms_cond: torch.Tensor,
493
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
494
+ attention_mask: torch.Tensor | None,
495
+ prefix_k: torch.Tensor,
496
+ prefix_v: torch.Tensor,
497
+ ) -> torch.Tensor:
498
+ residual = hidden_states
499
+ x, gate = self.input_layernorm(hidden_states, adarms_cond)
500
+ x = self.self_attn(
501
+ x,
502
+ position_embeddings=position_embeddings,
503
+ attention_mask=attention_mask,
504
+ prefix_k=prefix_k,
505
+ prefix_v=prefix_v,
506
+ )
507
+ hidden_states = residual + x * gate
508
+
509
+ residual = hidden_states
510
+ x, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
511
+ x = self.mlp(x)
512
+ hidden_states = residual + x * gate
513
+ return hidden_states
514
+
515
+
516
+ class Qwen3PiActionExpert(nn.Module):
517
+ def __init__(self, qwen_config: Any, *, init_from_qwen3_lm: nn.Module | None = None):
518
+ super().__init__()
519
+ try:
520
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding
521
+ except Exception as e: # pragma: no cover
522
+ raise ImportError("transformers qwen3 internals are required for eo_pi_internvl.") from e
523
+
524
+ self.config = qwen_config
525
+ self.num_layers = int(qwen_config.num_hidden_layers)
526
+ self.layers = nn.ModuleList([Qwen3PiDecoderLayer(qwen_config, i) for i in range(self.num_layers)])
527
+ self.rotary_emb = Qwen3RotaryEmbedding(qwen_config)
528
+ self.final_norm = AdaRMSNorm(int(qwen_config.hidden_size), eps=float(getattr(qwen_config, "rms_norm_eps", 1e-6)))
529
+
530
+ if init_from_qwen3_lm is not None:
531
+ self._init_from_qwen3_lm(init_from_qwen3_lm)
532
+
533
+ def _init_from_qwen3_lm(self, qwen3_lm: nn.Module):
534
+ """
535
+ Copy attention/MLP weights from a Qwen3ForCausalLM (or Qwen3Model) into this expert.
536
+ AdaRMSNorm modulation stays zero-init to match Pi05.
537
+ """
538
+ src_layers = None
539
+ if hasattr(qwen3_lm, "model") and hasattr(qwen3_lm.model, "layers"):
540
+ src_layers = qwen3_lm.model.layers
541
+ elif hasattr(qwen3_lm, "layers"):
542
+ src_layers = qwen3_lm.layers
543
+ if src_layers is None:
544
+ raise ValueError("Unsupported qwen3_lm: cannot locate `.model.layers`.")
545
+
546
+ if len(src_layers) != len(self.layers):
547
+ raise ValueError(f"Layer count mismatch: {len(src_layers)=} vs {len(self.layers)=}")
548
+
549
+ for dst, src in zip(self.layers, src_layers, strict=True):
550
+ dst.self_attn.q_proj.load_state_dict(src.self_attn.q_proj.state_dict())
551
+ dst.self_attn.k_proj.load_state_dict(src.self_attn.k_proj.state_dict())
552
+ dst.self_attn.v_proj.load_state_dict(src.self_attn.v_proj.state_dict())
553
+ dst.self_attn.o_proj.load_state_dict(src.self_attn.o_proj.state_dict())
554
+ # head norms
555
+ if hasattr(src.self_attn, "q_norm") and hasattr(dst.self_attn, "q_norm"):
556
+ dst.self_attn.q_norm.weight.data.copy_(src.self_attn.q_norm.weight.data)
557
+ if hasattr(src.self_attn, "k_norm") and hasattr(dst.self_attn, "k_norm"):
558
+ dst.self_attn.k_norm.weight.data.copy_(src.self_attn.k_norm.weight.data)
559
+ # mlp
560
+ dst.mlp.gate_proj.load_state_dict(src.mlp.gate_proj.state_dict())
561
+ dst.mlp.up_proj.load_state_dict(src.mlp.up_proj.state_dict())
562
+ dst.mlp.down_proj.load_state_dict(src.mlp.down_proj.state_dict())
563
+
564
+ def forward(
565
+ self,
566
+ action_tokens: torch.Tensor,
567
+ *,
568
+ prefix_kv_cache: list[tuple[torch.Tensor, torch.Tensor]],
569
+ prefix_key_mask: torch.Tensor,
570
+ position_ids: torch.Tensor,
571
+ adarms_cond: torch.Tensor,
572
+ suffix_key_mask: torch.Tensor | None = None,
573
+ ) -> torch.Tensor:
574
+ if action_tokens.ndim != 3:
575
+ raise ValueError(f"action_tokens must be (B,S,D), got {tuple(action_tokens.shape)}")
576
+ bsz, s_len, _ = action_tokens.shape
577
+ if prefix_key_mask.ndim != 2 or int(prefix_key_mask.shape[0]) != bsz:
578
+ raise ValueError(f"prefix_key_mask must be (B,P), got {tuple(prefix_key_mask.shape)}")
579
+ if position_ids.shape != (bsz, s_len):
580
+ raise ValueError(f"position_ids must be (B,S)={bsz,s_len}, got {tuple(position_ids.shape)}")
581
+ if len(prefix_kv_cache) == 0:
582
+ raise ValueError("prefix_kv_cache is empty.")
583
+
584
+ position_embeddings = self.rotary_emb(action_tokens, position_ids)
585
+
586
+ if suffix_key_mask is None:
587
+ suffix_key_mask = torch.ones((bsz, s_len), device=action_tokens.device, dtype=torch.bool)
588
+ if suffix_key_mask.shape != (bsz, s_len):
589
+ raise ValueError(f"suffix_key_mask must be (B,S), got {tuple(suffix_key_mask.shape)}")
590
+
591
+ prefix_part = (suffix_key_mask[:, None, :, None] & prefix_key_mask[:, None, None, :]) # (B,1,S,P)
592
+ suffix_part = (suffix_key_mask[:, None, :, None] & suffix_key_mask[:, None, None, :]) # (B,1,S,S)
593
+ allow = torch.cat([prefix_part, suffix_part], dim=-1) # (B,1,S,P+S)
594
+ attn_mask = torch.zeros(
595
+ (bsz, 1, s_len, int(prefix_key_mask.shape[1]) + s_len),
596
+ device=action_tokens.device,
597
+ dtype=action_tokens.dtype,
598
+ )
599
+ attn_mask = _masked_fill_min(attn_mask, allow)
600
+
601
+ x = action_tokens
602
+ for layer_idx, layer in enumerate(self.layers):
603
+ if layer_idx >= len(prefix_kv_cache):
604
+ raise ValueError(
605
+ "prefix_kv_cache has fewer layers than action expert. "
606
+ f"{len(prefix_kv_cache)=} < {len(self.layers)=}"
607
+ )
608
+ pk, pv = prefix_kv_cache[layer_idx]
609
+ x = layer(
610
+ x,
611
+ adarms_cond=adarms_cond,
612
+ position_embeddings=position_embeddings,
613
+ attention_mask=attn_mask,
614
+ prefix_k=pk,
615
+ prefix_v=pv,
616
+ )
617
+
618
+ x, _ = self.final_norm(x, adarms_cond)
619
+ return x
620
+
621
+
622
+ @dataclass
623
+ class EO1InternVLPiFlowMatchingOutput(ModelOutput):
624
+ loss: torch.FloatTensor | None = None
625
+ fm_loss: torch.FloatTensor | None = None
626
+ fm_loss_pos: torch.FloatTensor | None = None
627
+ fm_loss_rot: torch.FloatTensor | None = None
628
+ fm_loss_gripper: torch.FloatTensor | None = None
629
+ ar_loss: torch.FloatTensor | None = None
630
+ actions: torch.FloatTensor | None = None
631
+
632
+ logits: torch.FloatTensor | None = None
633
+ hidden_states: tuple[torch.FloatTensor] | None = None
634
+ attentions: tuple[torch.FloatTensor] | None = None
635
+
636
+ def count_params(module, trainable_only=False):
637
+ ps = module.parameters()
638
+ if trainable_only:
639
+ ps = [p for p in ps if p.requires_grad]
640
+ return sum(p.numel() for p in ps)
641
+
642
+ class EO1InternVLPiFlowMatchingModel(PreTrainedModel):
643
+ """EO1 action model with InternVL prefix expert + Pi05-style (Qwen2/Qwen3) action expert (AdaRMSNorm timestep)."""
644
+
645
+ config_class = EO1InternVLPiFlowMatchingConfig
646
+ supports_gradient_checkpointing = True
647
+
648
+ def __init__(
649
+ self,
650
+ config: EO1InternVLPiFlowMatchingConfig,
651
+ internvl_backbone: nn.Module,
652
+ action_expert: nn.Module | None = None,
653
+ ):
654
+ super().__init__(config)
655
+ self.internvl_backbone = internvl_backbone
656
+
657
+ # InternVL uses a HF causal LM as `.language_model` (e.g. Qwen2ForCausalLM).
658
+ if not hasattr(self.internvl_backbone, "language_model"):
659
+ raise ValueError("internvl_backbone must have `.language_model`.")
660
+ # Do NOT register an alias module (e.g. `self.prefix_lm = self.internvl_backbone.language_model`).
661
+ # Registering both creates shared tensors under different state_dict keys, which safetensors refuses
662
+ # to save unless they are declared as tied weights. Use the property `prefix_lm` instead.
663
+
664
+ # ------------------------- Build action expert config (Pi05-style: smaller expert) -------------------------
665
+ prefix_cfg = self.prefix_lm.config
666
+ cfg_name = prefix_cfg.__class__.__name__
667
+
668
+ expert_cfg = copy.deepcopy(prefix_cfg)
669
+ if getattr(config, "expert_hidden_size", None) is not None:
670
+ expert_cfg.hidden_size = int(config.expert_hidden_size)
671
+ if getattr(config, "expert_intermediate_size", None) is not None:
672
+ expert_cfg.intermediate_size = int(config.expert_intermediate_size)
673
+ if getattr(config, "expert_num_attention_heads", None) is not None:
674
+ expert_cfg.num_attention_heads = int(config.expert_num_attention_heads)
675
+ if getattr(config, "expert_num_hidden_layers", None) is not None:
676
+ expert_cfg.num_hidden_layers = int(config.expert_num_hidden_layers)
677
+ # Keep head geometry aligned with prefix kv-cache.
678
+ if int(getattr(expert_cfg, "num_key_value_heads", -1)) != int(getattr(prefix_cfg, "num_key_value_heads", -2)):
679
+ raise ValueError(
680
+ "To reuse prefix KV-cache, expert and prefix must share num_key_value_heads. "
681
+ f"{int(getattr(prefix_cfg, 'num_key_value_heads'))=} vs {int(getattr(expert_cfg, 'num_key_value_heads'))=}."
682
+ )
683
+ if int(getattr(expert_cfg, "head_dim", -1)) != int(getattr(prefix_cfg, "head_dim", -2)):
684
+ raise ValueError(
685
+ "To reuse prefix KV-cache, expert and prefix must share head_dim. "
686
+ f"{int(getattr(prefix_cfg, 'head_dim'))=} vs {int(getattr(expert_cfg, 'head_dim'))=}."
687
+ )
688
+ if int(expert_cfg.num_attention_heads) % int(expert_cfg.num_key_value_heads) != 0:
689
+ raise ValueError(
690
+ "expert_num_attention_heads must be divisible by num_key_value_heads. "
691
+ f"{int(expert_cfg.num_attention_heads)=} {int(expert_cfg.num_key_value_heads)=}."
692
+ )
693
+ # Keep layer_types length consistent (Qwen3Config defines it).
694
+ if hasattr(expert_cfg, "layer_types") and isinstance(getattr(expert_cfg, "layer_types"), list):
695
+ if len(expert_cfg.layer_types) != int(expert_cfg.num_hidden_layers):
696
+ expert_cfg.layer_types = ["full_attention"] * int(expert_cfg.num_hidden_layers)
697
+
698
+ self._expert_hidden_size = int(expert_cfg.hidden_size)
699
+ self._expert_num_layers = int(expert_cfg.num_hidden_layers)
700
+ self._prefix_num_layers = int(getattr(prefix_cfg, "num_hidden_layers", self._expert_num_layers))
701
+
702
+ if self._expert_num_layers > self._prefix_num_layers:
703
+ raise ValueError(
704
+ "expert_num_hidden_layers cannot exceed prefix LM layers when using prefix KV-cache. "
705
+ f"{self._expert_num_layers=} > {self._prefix_num_layers=}."
706
+ )
707
+
708
+ mapping = str(getattr(config, "expert_layer_mapping", "last")).strip().lower()
709
+ if mapping == "last":
710
+ start = self._prefix_num_layers - self._expert_num_layers
711
+ self._prefix_kv_layer_indices = list(range(start, self._prefix_num_layers))
712
+ elif mapping == "first":
713
+ self._prefix_kv_layer_indices = list(range(self._expert_num_layers))
714
+ else:
715
+ raise ValueError(f"Unsupported expert_layer_mapping={mapping!r} (expected 'last' or 'first').")
716
+
717
+ max_action_dim = int(config.max_action_dim)
718
+
719
+ # Pi05: action embeddings do NOT concatenate timestep.
720
+ self.action_in_proj = nn.Linear(max_action_dim, self._expert_hidden_size)
721
+ self.action_out_proj = nn.Linear(self._expert_hidden_size, max_action_dim)
722
+
723
+ # Pi05: timestep is injected via AdaRMSNorm in the action expert.
724
+ self.time_mlp_in = nn.Linear(self._expert_hidden_size, self._expert_hidden_size)
725
+ self.time_mlp_out = nn.Linear(self._expert_hidden_size, self._expert_hidden_size)
726
+
727
+ if action_expert is not None:
728
+ self.action_expert = action_expert
729
+ else:
730
+ # Default: build an action expert (Qwen2/Qwen3) with its own (possibly smaller) config.
731
+ init_from = self.prefix_lm if bool(getattr(self.config, "expert_init_from_backbone", False)) else None
732
+ try:
733
+ if cfg_name == "Qwen2Config":
734
+ self.action_expert = Qwen2PiActionExpert(expert_cfg, init_from_qwen2_lm=init_from)
735
+ elif cfg_name == "Qwen3Config":
736
+ self.action_expert = Qwen3PiActionExpert(expert_cfg, init_from_qwen3_lm=init_from)
737
+ else:
738
+ raise NotImplementedError(
739
+ "eo_pi_internvl currently supports only Qwen2/Qwen3 LMs for action expert. "
740
+ f"Got: {cfg_name}"
741
+ )
742
+ except Exception as e:
743
+ raise RuntimeError(
744
+ "Failed to build/initialize action expert. If you set `expert_init_from_backbone=True`, "
745
+ "make sure expert_* hyperparams exactly match the prefix LM shapes, or set it to False "
746
+ "for Pi05-style random init."
747
+ ) from e
748
+
749
+ n_all = count_params(self.action_expert)
750
+ n_train = count_params(self.action_expert, trainable_only=True)
751
+ print(f"action_expert params: {n_all/1e6:.2f}M (trainable {n_train/1e6:.2f}M)")
752
+ self.post_init()
753
+
754
+ @property
755
+ def prefix_lm(self) -> nn.Module:
756
+ # A convenience accessor for the InternVL backbone LM used as the prefix model.
757
+ return self.internvl_backbone.language_model
758
+
759
+ @staticmethod
760
+ def _action_group_indices(action_dim: int, *, action_dim_mask: torch.Tensor | None = None) -> dict[str, list[int]]:
761
+ """
762
+ Best-effort split of action dims into position/rotation/gripper groups.
763
+
764
+ Supports both common layouts:
765
+ 1) Compact (single-arm): [xyz(3), rotvec(3), gripper(1)] -> 7 dims (or bimanual 14 dims).
766
+ 2) EO unified action encoding (see `dataset/action_encoding.py`):
767
+ left: 0:3 pos, 3:6 rotvec (or 3:9 r6d), 9 gripper
768
+ right: 10:13 pos, 13:16 rotvec (or 13:19 r6d), 19 gripper
769
+
770
+ Rotation repr is controlled via env var `EO_ACTION_ROT_REPR` (default rotvec).
771
+ """
772
+ d = int(action_dim)
773
+ if d <= 0:
774
+ return {"pos": [], "rot": [], "gripper": [], "other": []}
775
+
776
+ rot_repr = os.environ.get("EO_ACTION_ROT_REPR", "rotvec").strip().lower()
777
+ rot_is_r6d = rot_repr in ("r6d", "rot6d", "6d")
778
+
779
+ m_any = None
780
+ if action_dim_mask is not None and torch.is_tensor(action_dim_mask):
781
+ m = action_dim_mask.detach()
782
+ if m.ndim == 1:
783
+ m_any = m.to(torch.bool)
784
+ elif m.ndim == 2:
785
+ m_any = m.to(torch.bool).any(dim=0)
786
+ elif m.ndim == 3 and int(m.shape[1]) == 1:
787
+ m_any = m[:, 0, :].to(torch.bool).any(dim=0)
788
+ else:
789
+ m_any = m.reshape(-1, m.shape[-1]).to(torch.bool).any(dim=0)
790
+ if int(m_any.numel()) != d:
791
+ if int(m_any.numel()) > d:
792
+ m_any = m_any[:d]
793
+ else:
794
+ pad = torch.zeros((d - int(m_any.numel()),), dtype=torch.bool, device=m_any.device)
795
+ m_any = torch.cat([m_any, pad], dim=0)
796
+
797
+ # Infer effective dim span from mask if available (common when original action dim < max_action_dim).
798
+ eff = d
799
+ if m_any is not None and bool(m_any.any().item()):
800
+ eff = int(torch.nonzero(m_any, as_tuple=False).max().item()) + 1
801
+
802
+ pos: list[int] = []
803
+ rot: list[int] = []
804
+ gripper: list[int] = []
805
+
806
+ # Compact layout: 7D single-arm / 14D bimanual.
807
+ if eff in (7, 14):
808
+ arm_offsets = [0] if eff == 7 else [0, 7]
809
+ for off in arm_offsets:
810
+ pos.extend([off + i for i in range(0, 3) if off + i < d])
811
+ rot.extend([off + i for i in range(3, 6) if off + i < d])
812
+ g = off + 6
813
+ if g < d:
814
+ gripper.append(g)
815
+ used = set(pos) | set(rot) | set(gripper)
816
+ other = [i for i in range(d) if i not in used]
817
+ return {"pos": pos, "rot": rot, "gripper": gripper, "other": other}
818
+
819
+ # EO unified layout (10 dims per arm slot, supports bimanual at offset 10).
820
+ right_active = (d >= 20)
821
+ if m_any is not None and int(m_any.numel()) >= 20:
822
+ right_active = bool(m_any[10:20].any().item())
823
+
824
+ # Left arm.
825
+ pos.extend([i for i in range(0, min(3, d))])
826
+ rot.extend([i for i in range(3, min(6, d))])
827
+ if rot_is_r6d:
828
+ rot.extend([i for i in range(6, min(9, d))])
829
+ if 9 < d:
830
+ gripper.append(9)
831
+
832
+ # Right arm (offset 10) when active.
833
+ if right_active and d >= 20:
834
+ pos.extend([i for i in range(10, min(13, d))])
835
+ rot.extend([i for i in range(13, min(16, d))])
836
+ if rot_is_r6d:
837
+ rot.extend([i for i in range(16, min(19, d))])
838
+ if 19 < d:
839
+ gripper.append(19)
840
+
841
+ used = set(pos) | set(rot) | set(gripper)
842
+ other = [i for i in range(d) if i not in used]
843
+ return {"pos": pos, "rot": rot, "gripper": gripper, "other": other}
844
+
845
+ # ------------------------- EO1 Flow Matching utils -------------------------
846
+ def sample_noise(self, shape, device):
847
+ return torch.normal(mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device)
848
+
849
+ def sample_time(self, bsize, device):
850
+ beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
851
+ time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
852
+ return time_beta * 0.999 + 0.001
853
+
854
+ def _embed_time_cond(self, timestep: torch.Tensor, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
855
+ hidden = int(getattr(self, "_expert_hidden_size", self.prefix_lm.config.hidden_size))
856
+ if timestep.ndim == 1:
857
+ t_emb = create_sinusoidal_pos_embedding(timestep, hidden, device=device).to(dtype=dtype)
858
+ t_emb = self.time_mlp_in(t_emb)
859
+ t_emb = F.silu(t_emb)
860
+ t_emb = self.time_mlp_out(t_emb)
861
+ t_emb = F.silu(t_emb)
862
+ return t_emb
863
+ if timestep.ndim == 2:
864
+ bsz, t_len = int(timestep.shape[0]), int(timestep.shape[1])
865
+ t_flat = timestep.reshape(-1)
866
+ t_emb = create_sinusoidal_pos_embedding(t_flat, hidden, device=device).to(dtype=dtype)
867
+ t_emb = self.time_mlp_in(t_emb)
868
+ t_emb = F.silu(t_emb)
869
+ t_emb = self.time_mlp_out(t_emb)
870
+ t_emb = F.silu(t_emb)
871
+ return t_emb.view(bsz, t_len, hidden)
872
+ raise ValueError(f"timestep must be (B,) or (B,T), got {tuple(timestep.shape)}")
873
+
874
+ def _select_prefix_kv_cache(self, prefix_kv_cache: list[tuple[torch.Tensor, torch.Tensor]]) -> list[tuple[torch.Tensor, torch.Tensor]]:
875
+ if not hasattr(self, "_prefix_kv_layer_indices"):
876
+ return prefix_kv_cache
877
+ idx = list(getattr(self, "_prefix_kv_layer_indices"))
878
+ if not idx:
879
+ return prefix_kv_cache
880
+ if max(idx) >= len(prefix_kv_cache):
881
+ raise ValueError(
882
+ "Prefix KV cache shorter than expected. "
883
+ f"{len(prefix_kv_cache)=} < {max(idx)+1=}."
884
+ )
885
+ return [prefix_kv_cache[i] for i in idx]
886
+
887
+ def _replace_img_context_embeddings(
888
+ self,
889
+ input_ids: torch.LongTensor,
890
+ inputs_embeds: torch.FloatTensor,
891
+ pixel_values: torch.FloatTensor,
892
+ image_flags: torch.LongTensor | None,
893
+ ) -> torch.FloatTensor:
894
+ img_context_token_id = self.config.img_context_token_id
895
+ if img_context_token_id is None:
896
+ raise ValueError("config.img_context_token_id is None (tokenizer/model not initialized).")
897
+
898
+ try:
899
+ vision_dtype = next(self.internvl_backbone.vision_model.parameters()).dtype
900
+ except Exception:
901
+ vision_dtype = inputs_embeds.dtype
902
+ pixel_values = pixel_values.to(device=inputs_embeds.device, dtype=vision_dtype)
903
+
904
+ vit_embeds = self.internvl_backbone.extract_feature(pixel_values) # (n_img, n_token, hidden)
905
+ if image_flags is not None:
906
+ image_flags = image_flags.squeeze(-1)
907
+ vit_embeds = vit_embeds[image_flags == 1]
908
+
909
+ bsz, _, hidden = inputs_embeds.shape
910
+ selected = input_ids == int(img_context_token_id) # (B,S)
911
+ n_ctx = int(selected.sum().item())
912
+ if n_ctx == 0:
913
+ return inputs_embeds
914
+
915
+ vit_flat = vit_embeds.reshape(-1, hidden)
916
+ if vit_flat.shape[0] < n_ctx:
917
+ raise ValueError(f"IMG_CONTEXT mismatch: need {n_ctx} embeddings, got {vit_flat.shape[0]}.")
918
+
919
+ mask3 = selected.unsqueeze(-1).expand_as(inputs_embeds)
920
+ src = vit_flat[:n_ctx].to(device=inputs_embeds.device, dtype=inputs_embeds.dtype).reshape(-1)
921
+ return inputs_embeds.masked_scatter(mask3, src)
922
+
923
+ @staticmethod
924
+ def _find_suffix_starts(action_mask_token: torch.Tensor, *, expected_horizon: int | None = None) -> torch.Tensor:
925
+ if action_mask_token.ndim != 2:
926
+ raise ValueError(f"action_mask_token must be (B,S), got {tuple(action_mask_token.shape)}")
927
+ bsz = int(action_mask_token.shape[0])
928
+ starts = torch.empty((bsz,), dtype=torch.long, device=action_mask_token.device)
929
+ for b in range(bsz):
930
+ pos = torch.nonzero(action_mask_token[b], as_tuple=False).squeeze(-1)
931
+ if int(pos.numel()) == 0:
932
+ raise ValueError(f"Expected at least 1 action token per sample, got 0 for batch {b}.")
933
+ if expected_horizon is not None and int(pos.numel()) not in (1, int(expected_horizon)):
934
+ raise ValueError(
935
+ f"Expected 1 or {int(expected_horizon)} action tokens per sample, got {int(pos.numel())} for batch {b}."
936
+ )
937
+ starts[b] = pos.min()
938
+ return starts
939
+
940
+ # ------------------------- Forward (train) -------------------------
941
+ def forward(
942
+ self,
943
+ input_ids: torch.LongTensor | None = None,
944
+ attention_mask: torch.Tensor | None = None,
945
+ position_ids: torch.LongTensor | None = None, # noqa: ARG002 - recomputed for pi05 mask
946
+ inputs_embeds: torch.FloatTensor | None = None, # noqa: ARG002 - use InternVL embeddings
947
+ labels: torch.LongTensor | None = None, # noqa: ARG002 - Pi05 does not train AR loss here
948
+ pixel_values: torch.FloatTensor | None = None,
949
+ image_flags: torch.LongTensor | None = None,
950
+ states: torch.Tensor | None = None, # noqa: ARG002 - Pi05: state should be discrete in text
951
+ actions: torch.Tensor | None = None,
952
+ action_is_pad: torch.Tensor | None = None,
953
+ action_dim_mask: torch.Tensor | None = None,
954
+ use_cache: bool | None = None, # noqa: ARG002
955
+ output_attentions: bool | None = None, # noqa: ARG002
956
+ output_hidden_states: bool | None = None, # noqa: ARG002
957
+ **kwargs,
958
+ ) -> EO1InternVLPiFlowMatchingOutput:
959
+ if input_ids is None:
960
+ raise ValueError("Pi model requires `input_ids`.")
961
+ if actions is None:
962
+ raise ValueError("Pi model forward requires `actions` (flow-matching).")
963
+ if attention_mask is None:
964
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
965
+
966
+ action_token_id = self.config.action_token_id
967
+ if action_token_id is None:
968
+ raise ValueError("config.action_token_id is None (tokenizer/model not initialized).")
969
+ action_pass_id = self.config.action_pass_id
970
+
971
+ noise_mask = input_ids == int(action_token_id)
972
+ pass_mask = (input_ids == int(action_pass_id)) if action_pass_id is not None else torch.zeros_like(noise_mask)
973
+ action_mask_token = noise_mask | pass_mask # (B, S)
974
+
975
+ bsz, horizon, act_dim = actions.shape
976
+
977
+ suffix_starts = self._find_suffix_starts(action_mask_token, expected_horizon=int(horizon)) # (B,)
978
+ prefix_len = int(suffix_starts.max().item())
979
+
980
+ # ---------------- Prefix expert (InternVL LM) ----------------
981
+ prefix_ids = input_ids[:, :prefix_len]
982
+ prefix_am = attention_mask[:, :prefix_len].to(dtype=torch.bool, device=input_ids.device)
983
+ ar = torch.arange(prefix_len, device=input_ids.device)
984
+ prefix_valid = prefix_am & (ar[None, :] < suffix_starts[:, None])
985
+
986
+ prefix_embeds = self.prefix_lm.get_input_embeddings()(prefix_ids).clone()
987
+ if pixel_values is not None:
988
+ prefix_embeds = self._replace_img_context_embeddings(
989
+ input_ids=prefix_ids,
990
+ inputs_embeds=prefix_embeds,
991
+ pixel_values=pixel_values,
992
+ image_flags=image_flags,
993
+ )
994
+
995
+ prefix_attn = prefix_valid.to(dtype=torch.long)
996
+ prefix_out = self.prefix_lm.model(
997
+ inputs_embeds=prefix_embeds,
998
+ attention_mask=prefix_attn,
999
+ use_cache=True,
1000
+ return_dict=True,
1001
+ )
1002
+ prefix_pkv = prefix_out.past_key_values
1003
+ prefix_kv_cache = [prefix_pkv[i] for i in range(len(prefix_pkv))]
1004
+ prefix_kv_cache = self._select_prefix_kv_cache(prefix_kv_cache)
1005
+
1006
+ # ---------------- Flow Matching ----------------
1007
+ actions_f32 = actions.to(dtype=torch.float32, device=input_ids.device)
1008
+ rtc_delay = int(getattr(self.config, "rtc_delay", 5))
1009
+ delay = max(0, min(rtc_delay, int(horizon)))
1010
+ prefix_mask = torch.zeros((int(bsz), int(horizon)), device=actions_f32.device, dtype=torch.bool)
1011
+ if delay > 0:
1012
+ prefix_mask = torch.arange(int(horizon), device=actions_f32.device)[None, :] < int(delay)
1013
+ prefix_mask = prefix_mask.expand(int(bsz), -1)
1014
+ if action_is_pad is not None:
1015
+ prefix_mask = prefix_mask & (~action_is_pad.to(device=actions_f32.device, dtype=torch.bool))
1016
+
1017
+ time = self.sample_time(int(bsz), input_ids.device) # (B,)
1018
+ noise = self.sample_noise(actions_f32.shape, input_ids.device)
1019
+ time_tokens = time[:, None].expand(int(bsz), int(horizon))
1020
+ if delay > 0:
1021
+ time_tokens = torch.where(prefix_mask, torch.zeros_like(time_tokens), time_tokens)
1022
+ time_expanded = time_tokens[:, :, None]
1023
+ x_t = time_expanded * noise + (1 - time_expanded) * actions_f32
1024
+ u_t = noise - actions_f32
1025
+
1026
+ # Action tokens: no timestep concatenation in Pi05.
1027
+ action_tokens = self.action_in_proj(x_t.to(dtype=self.action_in_proj.weight.dtype)) # (B,H,D)
1028
+
1029
+ # AdaRMSNorm conditioning vector (Pi05).
1030
+ adarms_cond = self._embed_time_cond(time_tokens, dtype=action_tokens.dtype, device=action_tokens.device)
1031
+
1032
+ # Suffix RoPE positions follow the *per-sample* prefix length (suffix_starts).
1033
+ pos_ids = suffix_starts[:, None] + torch.arange(horizon, device=input_ids.device)[None, :]
1034
+
1035
+ suffix_valid = torch.ones((int(bsz), int(horizon)), device=input_ids.device, dtype=torch.bool)
1036
+ if action_is_pad is not None:
1037
+ suffix_valid = suffix_valid & (~action_is_pad.to(device=input_ids.device, dtype=torch.bool))
1038
+
1039
+ expert_h = self.action_expert(
1040
+ action_tokens,
1041
+ prefix_kv_cache=prefix_kv_cache,
1042
+ prefix_key_mask=prefix_valid,
1043
+ position_ids=pos_ids,
1044
+ adarms_cond=adarms_cond,
1045
+ suffix_key_mask=suffix_valid,
1046
+ )
1047
+ v_t = self.action_out_proj(expert_h).to(dtype=torch.float32) # (B,H,A)
1048
+
1049
+ # Loss: average over *valid elements* (step mask + action_dim_mask).
1050
+ target = u_t.to(dtype=v_t.dtype)
1051
+ per_elem = (v_t - target) ** 2 # (B,H,A)
1052
+
1053
+ valid = suffix_valid[:, :, None] if suffix_valid is not None else torch.ones((int(bsz), int(horizon), 1), device=per_elem.device, dtype=torch.bool)
1054
+ if delay > 0:
1055
+ valid = valid & (~prefix_mask[:, :, None])
1056
+ adm_for_groups = None
1057
+ if action_dim_mask is not None:
1058
+ adm = action_dim_mask
1059
+ if not torch.is_tensor(adm):
1060
+ adm = torch.as_tensor(adm)
1061
+ adm = adm.to(device=per_elem.device, dtype=torch.bool)
1062
+ if adm.ndim == 1:
1063
+ adm = adm.view(1, -1).expand(int(bsz), -1)
1064
+ elif adm.ndim == 2:
1065
+ pass
1066
+ elif adm.ndim == 3 and int(adm.shape[1]) == 1:
1067
+ adm = adm[:, 0, :]
1068
+ else:
1069
+ adm = adm.reshape(int(bsz), -1)
1070
+
1071
+ if int(adm.shape[-1]) == int(per_elem.shape[-1]):
1072
+ valid = valid & adm[:, None, :]
1073
+ adm_for_groups = adm
1074
+ else:
1075
+ logger.warning_once(
1076
+ "Ignoring action_dim_mask due to shape mismatch in PI FM loss: "
1077
+ f"{tuple(adm.shape)=} vs expected (B,{int(per_elem.shape[-1])})."
1078
+ )
1079
+
1080
+ # Exclude padding/unused action dims ("other") from FM loss.
1081
+ # We only train on {pos, rot, gripper} dims so `fm_loss` matches the meaningful action space.
1082
+ pos_mask = rot_mask = grip_mask = None
1083
+ try:
1084
+ a_dim = int(per_elem.shape[-1])
1085
+ pos_mask = torch.zeros((int(bsz), a_dim), device=per_elem.device, dtype=torch.bool)
1086
+ rot_mask = torch.zeros_like(pos_mask)
1087
+ grip_mask = torch.zeros_like(pos_mask)
1088
+ for bi in range(int(bsz)):
1089
+ g = self._action_group_indices(a_dim, action_dim_mask=(adm_for_groups[bi] if adm_for_groups is not None else None))
1090
+ if g["pos"]:
1091
+ pos_mask[bi, g["pos"]] = True
1092
+ if g["rot"]:
1093
+ rot_mask[bi, g["rot"]] = True
1094
+ if g["gripper"]:
1095
+ grip_mask[bi, g["gripper"]] = True
1096
+ group_mask = pos_mask | rot_mask | grip_mask # (B,A)
1097
+ empty = ~group_mask.any(dim=1)
1098
+ if empty.any():
1099
+ fallback = adm_for_groups if adm_for_groups is not None else torch.ones_like(group_mask)
1100
+ group_mask[empty] = fallback[empty]
1101
+ valid = valid & group_mask[:, None, :]
1102
+ except Exception:
1103
+ pos_mask = rot_mask = grip_mask = None
1104
+
1105
+ weight = valid.to(dtype=per_elem.dtype)
1106
+ denom = weight.sum().clamp_min(1)
1107
+ fm_loss = (per_elem * weight).sum() / denom
1108
+
1109
+ fm_loss_pos = None
1110
+ fm_loss_rot = None
1111
+ fm_loss_gripper = None
1112
+ # Decompose FM loss by action groups (auxiliary logs only; never crash training for these).
1113
+ try:
1114
+ if pos_mask is not None:
1115
+ pos_w = (weight * pos_mask[:, None, :].to(dtype=weight.dtype)).sum()
1116
+ if bool((pos_w > 0).item()):
1117
+ fm_loss_pos = (per_elem * weight * pos_mask[:, None, :].to(dtype=weight.dtype)).sum() / pos_w.clamp_min(1)
1118
+ if rot_mask is not None:
1119
+ rot_w = (weight * rot_mask[:, None, :].to(dtype=weight.dtype)).sum()
1120
+ if bool((rot_w > 0).item()):
1121
+ fm_loss_rot = (per_elem * weight * rot_mask[:, None, :].to(dtype=weight.dtype)).sum() / rot_w.clamp_min(1)
1122
+ if grip_mask is not None:
1123
+ grip_w = (weight * grip_mask[:, None, :].to(dtype=weight.dtype)).sum()
1124
+ if bool((grip_w > 0).item()):
1125
+ fm_loss_gripper = (per_elem * weight * grip_mask[:, None, :].to(dtype=weight.dtype)).sum() / grip_w.clamp_min(1)
1126
+ except Exception:
1127
+ fm_loss_pos = fm_loss_rot = fm_loss_gripper = None
1128
+
1129
+ return EO1InternVLPiFlowMatchingOutput(
1130
+ loss=fm_loss,
1131
+ fm_loss=fm_loss,
1132
+ fm_loss_pos=fm_loss_pos,
1133
+ fm_loss_rot=fm_loss_rot,
1134
+ fm_loss_gripper=fm_loss_gripper,
1135
+ ar_loss=None,
1136
+ actions=v_t,
1137
+ logits=None,
1138
+ hidden_states=None,
1139
+ attentions=None,
1140
+ )
1141
+
1142
+ @torch.no_grad()
1143
+ def sample_actions(
1144
+ self,
1145
+ input_ids: torch.LongTensor | None = None,
1146
+ attention_mask: torch.Tensor | None = None,
1147
+ position_ids: torch.LongTensor | None = None, # noqa: ARG002
1148
+ pixel_values: torch.FloatTensor | None = None,
1149
+ image_flags: torch.LongTensor | None = None,
1150
+ num_steps: int | None = None,
1151
+ noise: torch.Tensor | None = None,
1152
+ action_prefix: torch.Tensor | None = None,
1153
+ delay: int | None = None,
1154
+ **kwargs,
1155
+ ) -> Tensor:
1156
+ if input_ids is None:
1157
+ raise ValueError("sample_actions requires input_ids.")
1158
+ if attention_mask is None:
1159
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
1160
+
1161
+ chunk_size = int(self.config.action_chunk_size)
1162
+ max_action_dim = int(self.config.max_action_dim)
1163
+ steps = int(num_steps) if num_steps is not None else int(self.config.num_denoise_steps)
1164
+ dt = torch.tensor(-1.0 / max(1, steps), device=input_ids.device, dtype=torch.float32)
1165
+
1166
+ action_token_id = self.config.action_token_id
1167
+ if action_token_id is None:
1168
+ raise ValueError("config.action_token_id is None (tokenizer/model not initialized).")
1169
+ action_pass_id = self.config.action_pass_id
1170
+
1171
+ noise_mask = input_ids == int(action_token_id)
1172
+ pass_mask = (input_ids == int(action_pass_id)) if action_pass_id is not None else torch.zeros_like(noise_mask)
1173
+ action_mask_token = noise_mask | pass_mask
1174
+
1175
+ bsz = int(input_ids.shape[0])
1176
+
1177
+ suffix_starts = self._find_suffix_starts(action_mask_token, expected_horizon=chunk_size)
1178
+ prefix_len = int(suffix_starts.max().item())
1179
+
1180
+ prefix_ids = input_ids[:, :prefix_len]
1181
+ prefix_am = attention_mask[:, :prefix_len].to(dtype=torch.bool, device=input_ids.device)
1182
+ ar = torch.arange(prefix_len, device=input_ids.device)
1183
+ prefix_valid = prefix_am & (ar[None, :] < suffix_starts[:, None])
1184
+
1185
+ prefix_embeds = self.prefix_lm.get_input_embeddings()(prefix_ids).clone()
1186
+ if pixel_values is not None:
1187
+ prefix_embeds = self._replace_img_context_embeddings(
1188
+ input_ids=prefix_ids,
1189
+ inputs_embeds=prefix_embeds,
1190
+ pixel_values=pixel_values,
1191
+ image_flags=image_flags,
1192
+ )
1193
+
1194
+ prefix_attn = prefix_valid.to(dtype=torch.long)
1195
+ prefix_out = self.prefix_lm.model(
1196
+ inputs_embeds=prefix_embeds,
1197
+ attention_mask=prefix_attn,
1198
+ use_cache=True,
1199
+ return_dict=True,
1200
+ )
1201
+ prefix_pkv = prefix_out.past_key_values
1202
+ prefix_kv_cache = [prefix_pkv[i] for i in range(len(prefix_pkv))]
1203
+ prefix_kv_cache = self._select_prefix_kv_cache(prefix_kv_cache)
1204
+
1205
+ device = input_ids.device
1206
+ if noise is None:
1207
+ x_t = self.sample_noise((bsz, chunk_size, max_action_dim), device=device).to(dtype=torch.float32)
1208
+ else:
1209
+ x_t = noise.to(device=device, dtype=torch.float32)
1210
+
1211
+ suffix_valid = torch.ones((bsz, chunk_size), device=device, dtype=torch.bool)
1212
+ pos_ids = suffix_starts[:, None] + torch.arange(chunk_size, device=device)[None, :]
1213
+
1214
+ use_prefix = action_prefix is not None
1215
+ rtc_delay = int(delay) if delay is not None else int(getattr(self.config, "rtc_delay", 0))
1216
+ if use_prefix and rtc_delay <= 0:
1217
+ try:
1218
+ rtc_delay = int(action_prefix.shape[1])
1219
+ except Exception:
1220
+ rtc_delay = 0
1221
+ rtc_delay = max(0, min(int(rtc_delay), int(chunk_size)))
1222
+
1223
+ prefix_mask = None
1224
+ if use_prefix and rtc_delay > 0:
1225
+ prefix = action_prefix
1226
+ if not torch.is_tensor(prefix):
1227
+ prefix = torch.as_tensor(prefix)
1228
+ prefix = prefix.to(device=device, dtype=torch.float32)
1229
+ if prefix.ndim != 3 or int(prefix.shape[0]) != bsz:
1230
+ raise ValueError(
1231
+ f"action_prefix must be (B, T, A) with B={bsz}, got {tuple(prefix.shape)}"
1232
+ )
1233
+
1234
+ # Clamp delay to available prefix length.
1235
+ rtc_delay = min(int(rtc_delay), int(prefix.shape[1]))
1236
+
1237
+ # Pad/trim time and action dims.
1238
+ if int(prefix.shape[-1]) != int(max_action_dim):
1239
+ if int(prefix.shape[-1]) > int(max_action_dim):
1240
+ prefix = prefix[..., : int(max_action_dim)]
1241
+ else:
1242
+ pad_dim = int(max_action_dim) - int(prefix.shape[-1])
1243
+ pad = torch.zeros((bsz, int(prefix.shape[1]), pad_dim), device=device, dtype=prefix.dtype)
1244
+ prefix = torch.cat([prefix, pad], dim=-1)
1245
+ if int(prefix.shape[1]) < int(chunk_size):
1246
+ pad_t = int(chunk_size) - int(prefix.shape[1])
1247
+ pad = torch.zeros((bsz, pad_t, int(max_action_dim)), device=device, dtype=prefix.dtype)
1248
+ prefix = torch.cat([prefix, pad], dim=1)
1249
+ elif int(prefix.shape[1]) > int(chunk_size):
1250
+ prefix = prefix[:, : int(chunk_size), :]
1251
+
1252
+ action_prefix = prefix
1253
+ prefix_mask = torch.arange(int(chunk_size), device=device)[None, :] < int(rtc_delay)
1254
+ prefix_mask = prefix_mask.expand(bsz, -1)
1255
+ else:
1256
+ use_prefix = False
1257
+
1258
+ for s in range(steps):
1259
+ t_scalar = 1.0 + float(s) * float(dt)
1260
+ if use_prefix:
1261
+ x_t = torch.where(prefix_mask[:, :, None], action_prefix, x_t)
1262
+ time_tokens = torch.full((bsz, chunk_size), t_scalar, device=device, dtype=torch.float32)
1263
+ time_tokens = torch.where(prefix_mask, torch.zeros_like(time_tokens), time_tokens)
1264
+ else:
1265
+ time_tokens = torch.full((bsz,), t_scalar, device=device, dtype=torch.float32)
1266
+
1267
+ action_tokens = self.action_in_proj(x_t.to(dtype=self.action_in_proj.weight.dtype))
1268
+ adarms_cond = self._embed_time_cond(time_tokens, dtype=action_tokens.dtype, device=action_tokens.device)
1269
+
1270
+ expert_h = self.action_expert(
1271
+ action_tokens,
1272
+ prefix_kv_cache=prefix_kv_cache,
1273
+ prefix_key_mask=prefix_valid,
1274
+ position_ids=pos_ids,
1275
+ adarms_cond=adarms_cond,
1276
+ suffix_key_mask=suffix_valid,
1277
+ )
1278
+ v_t = self.action_out_proj(expert_h).to(dtype=torch.float32)
1279
+ x_t = x_t + dt * v_t
1280
+
1281
+ return x_t
1282
+
1283
+
1284
+ EO1InternVLPiFlowMatchingModel.register_for_auto_class()
checkpoint-300000/preprocessor_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_eo1_internvl.EO1VisionProcessor"
4
+ },
5
+ "image_processor_type": "_InternVLImageProcessor",
6
+ "max_pixels": null,
7
+ "merge_size": 1,
8
+ "min_pixels": null,
9
+ "processor_class": "EO1VisionProcessor",
10
+ "temporal_patch_size": 1
11
+ }
checkpoint-300000/processing_eo1_internvl.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EO1Vision processor for `eo_pi_internvl`.
3
+
4
+ This is the InternVL-backbone EO1 processor with a Pi05-style action prompt:
5
+ - We keep a *single* `<|action_pad|>` as a placeholder suffix token in text prompts.
6
+ - The action expert consumes *continuous* action tokens (length=`action_chunk_size`) internally, so we do not need to
7
+ repeat `<|action_pad|>` by chunk size in the text (this also keeps AR loss extensible).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import inspect
13
+
14
+ import torch
15
+ from transformers.feature_extraction_utils import BatchFeature
16
+ from transformers.image_utils import ImageInput
17
+ from transformers.processing_utils import Unpack
18
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
19
+ from transformers.video_utils import VideoInput
20
+
21
+ from eo_internvl.model.processing_eo1_internvl import (
22
+ ACTION_END_TOKEN,
23
+ ACTION_START_TOKEN,
24
+ DEFAULT_ACTION_TOKEN,
25
+ EO1VisionProcessor as _BaseEO1VisionProcessor,
26
+ EO1VisionProcessorKwargs,
27
+ RobotInput,
28
+ )
29
+
30
+
31
+ class EO1VisionProcessor(_BaseEO1VisionProcessor):
32
+ def expand_action_prompt(self, chunk_size: int) -> str:
33
+ # Pi05-style: keep a single placeholder token in text; the model builds the full continuous action block.
34
+ return DEFAULT_ACTION_TOKEN
35
+
36
+ def __call__(
37
+ self,
38
+ images: ImageInput = None,
39
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
40
+ videos: VideoInput = None,
41
+ states: RobotInput = None,
42
+ actions: RobotInput = None,
43
+ **kwargs: Unpack[EO1VisionProcessorKwargs],
44
+ ) -> BatchFeature:
45
+ # Force action-token expansion length to 1 (no-op), regardless of robot_config / caller.
46
+ text_kwargs = kwargs.get("text_kwargs") or {}
47
+ text_kwargs = dict(text_kwargs)
48
+ text_kwargs["noise_token_num"] = 1
49
+ kwargs["text_kwargs"] = text_kwargs
50
+ return super().__call__(images=images, text=text, videos=videos, states=states, actions=actions, **kwargs)
51
+
52
+ @torch.no_grad()
53
+ def select_action(self, model, batch: dict, return_raw_actions: bool = False, **kwargs):
54
+ if not hasattr(model, "sample_actions"):
55
+ raise NotImplementedError("InternVL EO1 model does not implement sample_actions yet.")
56
+
57
+ action_prefix = batch.pop("action_prefix", None)
58
+ rtc_delay = batch.pop("rtc_delay", None)
59
+
60
+ batch_messages, batch_states, repo_ids = self._prepare_robot_inputs(batch)
61
+ chunk_size = int(getattr(getattr(model, "config", None), "action_chunk_size", 0) or self.robot_config.get("action_chunk_size") or 0)
62
+ noise_prompt = self.expand_action_prompt(chunk_size) if chunk_size > 0 else f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN}{ACTION_END_TOKEN}"
63
+
64
+ inputs = self.apply_chat_template(
65
+ batch_messages,
66
+ states=batch_states,
67
+ add_generation_prompt=True,
68
+ noise_prompt=noise_prompt,
69
+ tokenize=True,
70
+ padding=True,
71
+ truncation=True,
72
+ return_dict=True,
73
+ return_tensors="pt",
74
+ ).to(model.device)
75
+
76
+ sig = None
77
+ try:
78
+ sig = inspect.signature(model.sample_actions)
79
+ except Exception:
80
+ sig = None
81
+
82
+ if action_prefix is not None:
83
+ if isinstance(action_prefix, (list, tuple)):
84
+ elems = []
85
+ for v in action_prefix:
86
+ if not torch.is_tensor(v):
87
+ v = torch.as_tensor(v)
88
+ elems.append(v)
89
+ action_prefix = torch.stack(elems, dim=0)
90
+ elif not torch.is_tensor(action_prefix):
91
+ action_prefix = torch.as_tensor(action_prefix)
92
+ action_prefix = action_prefix.to(device=model.device, dtype=torch.float32)
93
+
94
+ if sig is not None and "action_prefix" in sig.parameters:
95
+ actions = model.sample_actions(**inputs, action_prefix=action_prefix, delay=rtc_delay).cpu()
96
+ else:
97
+ actions = model.sample_actions(**inputs).cpu()
98
+
99
+ if return_raw_actions:
100
+ return BatchFeature({"action": actions})
101
+
102
+ output_actions = self._process_robot_outputs(repo_ids, actions)
103
+ return BatchFeature({"action": output_actions})
104
+
105
+
106
+ EO1VisionProcessor.register_for_auto_class()
checkpoint-300000/processor_config.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-300000/rng_state_0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:388d0abae6e0753b56d1b79bd7c7b73fcfccb4957397a5ee57998394e88cece5
3
+ size 16389
checkpoint-300000/rng_state_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f0318f052e090c4c1091ee226138f0b0e38e16bf32f5460d93ae46cd2b62bf3
3
+ size 16325
checkpoint-300000/rng_state_10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ffefbc3679191f405a2eb0dde3e69c0f5edfde5f3946bff62ea54998861b001
3
+ size 16340
checkpoint-300000/rng_state_11.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f2011371e60491f875266fc2da3784324fd2cdf1a4a155dfbbd3d402b44c4a4
3
+ size 16340
checkpoint-300000/rng_state_12.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e46fe4a57d1ec96258b6f77a05ea7d6d30b382f3f4cdd376e27c8ddbcd720f91
3
+ size 16340
checkpoint-300000/rng_state_13.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88a377276842f2f552fab96dc7f8a7d50ae8d60d11dc3baa2e31776d3e1605b6
3
+ size 16340
checkpoint-300000/rng_state_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c56c295a6c7519f2c7035150ed0edf5a670584a8ec2ea0f2117f765fb08f1911
3
+ size 16404
checkpoint-300000/rng_state_15.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea52e84fcdec4bb49342e67d480b02e28e7d5c4b4200b2b4edadbe3944d181d5
3
+ size 16340
checkpoint-300000/rng_state_16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc91185faa8105acfacdf6156b53a344486e0eb204062fc0c8e28a4383c601a5
3
+ size 16404
checkpoint-300000/rng_state_17.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c28cbf4329acc1cfcefc2703b6d3fe7e479ac477825091ca4e8569a1210c79c5
3
+ size 16340
checkpoint-300000/rng_state_18.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8d8808f51dd1e25529d6269a0c2ea3fa7f9efa104651e99901f0e393584aa2c
3
+ size 16340
checkpoint-300000/rng_state_19.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:073af3a3597db492896416e5eb8b60c43124fffea5e242649aa54dd12f17239f
3
+ size 16340
checkpoint-300000/rng_state_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a147ca63f7fda57d6b3d773e040b21e76b5145c0945d7d06fbc98c38d8d7b6d
3
+ size 16325
checkpoint-300000/rng_state_20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:262b7889e695735cb78ade23e0cc5c6d599bc0e3e1f87f64ca0f9b4be4d32c7e
3
+ size 16340
checkpoint-300000/rng_state_21.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb40e63e1866c584a56668679389b8a5d5ee59588037e789f70c862d71506132
3
+ size 16340
checkpoint-300000/rng_state_22.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b38b632787005dd8e69799ae476cbfd6696b15c899a7fbe5e511495c493b199d
3
+ size 16340
checkpoint-300000/rng_state_23.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a80b0699cc457e8aeaa78b0a842bab60c71bd73a384f23318337bd45a335b27
3
+ size 16340
checkpoint-300000/rng_state_24.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87b1d6892e58225ce3ca8ec82f1afc68ed59bbba79aa6d7d894377a383a0df66
3
+ size 16404
checkpoint-300000/rng_state_25.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7431ca30048ce99f9f7bd77aee0928144d4ca0ca1cdae12d4872b22861d2b049
3
+ size 16340
checkpoint-300000/rng_state_26.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfe14873b9ee26f4451262c0f1fa43440abc293ca688d7b00ef1acb210a57cb4
3
+ size 16404
checkpoint-300000/rng_state_27.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8a2b002e3b43745936ceebe0b5eba0f2e48da322e2e997a0e93655a594ee325
3
+ size 16340
checkpoint-300000/rng_state_28.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12dd3c283e93897c36859ea8644b6e3207eaac15413764037c11ac541383bee6
3
+ size 16340
checkpoint-300000/rng_state_29.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50955531e56f1c20e6020a44816df8a3c65a477213af8944daf6622c3731f3eb
3
+ size 16340
checkpoint-300000/rng_state_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60af6b5c5fe98e0c922980e556be3ea50cf5896d83e2b90d54b03072d2dcdaa9
3
+ size 16325
checkpoint-300000/rng_state_30.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:852f8f8511c96525038c130713173d1a3ed5d4669cbb3da73f7f9f4f81cdf360
3
+ size 16340
checkpoint-300000/rng_state_31.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ce14b79b193abd0d3a1f6c0d7931425808444e031eac78eeec415838809904e
3
+ size 16340
checkpoint-300000/rng_state_32.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f34b482a8c9800fad7e6f42f2fd1ea479748130bcf9a14c924fce86c3410040
3
+ size 16404
checkpoint-300000/rng_state_33.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:397467ff82ca5291b554f77c9f32c28b60224da61bd3b29c26ff4e0396586809
3
+ size 16404
checkpoint-300000/rng_state_34.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a48d7eff57624f740c4501e1048eeafc4dff1f18126033f2ed9150221e813dc7
3
+ size 16340
checkpoint-300000/rng_state_35.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e34fb6ea0b4d87f7421a15ade140e5daa7f3c7d8542872eb98eed57f4d4a49d4
3
+ size 16340
checkpoint-300000/rng_state_36.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a972908767aa403094d4293403b051bde35700fc20cbbb17e9a0e6a9e705fed4
3
+ size 16404
checkpoint-300000/rng_state_37.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:496ec0fdfb71f611489a7a705e1e5f94812b11630b798003ea8ff1de5d784453
3
+ size 16340
checkpoint-300000/rng_state_38.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:260c419a4339617f2079ac5bf5c20722aacf6ea835aa411e55ad8d05b088e7cb
3
+ size 16404
checkpoint-300000/rng_state_39.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48dea9a0a2f262d6321230a6028dc948f8807812567a75251ff461be069bd372
3
+ size 16340
checkpoint-300000/rng_state_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:757a7eb6dac311c44331320a83496d342f31659775ea5c09f6f15f88880edb1e
3
+ size 16389
checkpoint-300000/rng_state_40.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd5bff9dc535da7868905861f72eabd55d4d77834cbbdfd7c06f103dfb11d9d0
3
+ size 16340
checkpoint-300000/rng_state_41.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:849cfd652c91de8b4c688efb3bfab0adffd3f0bc9ece272e36cdb844ca48c5d2
3
+ size 16340
checkpoint-300000/rng_state_42.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed11e4531c429edead3eb1107cf8c4b04e285edf9f4a11deae8da63c4a606dd9
3
+ size 16340
checkpoint-300000/rng_state_43.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae7f1de5c77824717c3180390598da0e6d0e97d93ac401a422238d9891ff120b
3
+ size 16340