Timsty commited on
Commit
464c0e0
·
verified ·
1 Parent(s): 631e5cc

Upload pi0moh_config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pi0moh_config.py +116 -0
pi0moh_config.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import TYPE_CHECKING
3
+
4
+ import flax.nnx as nnx
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from typing_extensions import override
8
+ from typing import List, Optional
9
+ from openpi.models import model as _model
10
+ import openpi.models.gemma as _gemma
11
+ from openpi.shared import array_typing as at
12
+ import openpi.shared.nnx_utils as nnx_utils
13
+
14
+ if TYPE_CHECKING:
15
+ from openpi.models.pi0_moh import Pi0Gated
16
+
17
+
18
+ @dataclasses.dataclass(frozen=True)
19
+ class Pi0GatedConfig(_model.BaseModelConfig):
20
+ dtype: str = "bfloat16"
21
+ paligemma_variant: _gemma.Variant = "gemma_2b"
22
+ action_expert_variant: _gemma.Variant = "gemma_300m"
23
+
24
+ # Set the model specific defaults.
25
+ action_dim: int = 32
26
+ action_horizon: int = 30
27
+ horizons: List[int] = dataclasses.field(default_factory=lambda: [5, 10, 15, 20, 25, 30])
28
+ max_token_len: int = None # type: ignore
29
+ # Pi05 has two differences from Pi0:
30
+ # - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix
31
+ # - the action expert uses adaRMSNorm to inject the flow matching timestep
32
+ pi05: bool = False
33
+ # This config option is not used directly by the model, but it is read by the ModelTransformFactory.
34
+ discrete_state_input: bool = None # type: ignore
35
+
36
+ # Loss weights from pi0_pytorch_moh.py
37
+ aux_weight: float = 1.0
38
+ balance_weight: float = 0.001
39
+
40
+ @property
41
+ def max_horizon(self) -> int:
42
+ return max(self.horizons)
43
+
44
+ def __post_init__(self):
45
+ if self.max_token_len is None:
46
+ object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48)
47
+ if self.discrete_state_input is None:
48
+ object.__setattr__(self, "discrete_state_input", self.pi05)
49
+
50
+ @property
51
+ @override
52
+ def model_type(self) -> _model.ModelType:
53
+ if self.pi05:
54
+ return _model.ModelType.PI05
55
+ return _model.ModelType.PI0
56
+
57
+ @override
58
+ def create(self, rng: at.KeyArrayLike) -> "Pi0Gated":
59
+ from openpi.models.pi0_moh import Pi0Gated
60
+ return Pi0Gated(self, rngs=nnx.Rngs(rng))
61
+
62
+ @override
63
+ def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
64
+ image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
65
+ image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
66
+
67
+ with at.disable_typechecking():
68
+ observation_spec = _model.Observation(
69
+ images={
70
+ "base_0_rgb": image_spec,
71
+ "left_wrist_0_rgb": image_spec,
72
+ "right_wrist_0_rgb": image_spec,
73
+ },
74
+ image_masks={
75
+ "base_0_rgb": image_mask_spec,
76
+ "left_wrist_0_rgb": image_mask_spec,
77
+ "right_wrist_0_rgb": image_mask_spec,
78
+ },
79
+ state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
80
+ tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
81
+ tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
82
+ )
83
+ action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
84
+
85
+ return observation_spec, action_spec
86
+
87
+ def get_freeze_filter(self) -> nnx.filterlib.Filter:
88
+ """Returns the freeze filter based on the model config."""
89
+ filters = []
90
+ has_lora = False
91
+ gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
92
+ action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
93
+ if "lora" in self.paligemma_variant:
94
+ filters.append(
95
+ gemma_params_filter,
96
+ )
97
+ if "lora" not in self.action_expert_variant:
98
+ # If only freeze gemma params, exclude action expert params.
99
+ filters.append(
100
+ nnx.Not(action_expert_params_filter),
101
+ )
102
+ has_lora = True
103
+ elif "lora" in self.action_expert_variant:
104
+ filters.append(
105
+ action_expert_params_filter,
106
+ )
107
+ has_lora = True
108
+
109
+ if has_lora:
110
+ # If any lora is used, exclude all lora params.
111
+ filters.append(
112
+ nnx.Not(nnx_utils.PathRegex(".*lora.*")),
113
+ )
114
+ if not filters:
115
+ return nnx.Nothing
116
+ return nnx.All(*filters)