| import dataclasses |
| from typing import TYPE_CHECKING |
|
|
| import flax.nnx as nnx |
| import jax |
| import jax.numpy as jnp |
| from typing_extensions import override |
| from typing import List, Optional |
| from openpi.models import model as _model |
| import openpi.models.gemma as _gemma |
| from openpi.shared import array_typing as at |
| import openpi.shared.nnx_utils as nnx_utils |
|
|
| if TYPE_CHECKING: |
| from openpi.models.pi0_moh import Pi0Gated |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class Pi0GatedConfig(_model.BaseModelConfig): |
| dtype: str = "bfloat16" |
| paligemma_variant: _gemma.Variant = "gemma_2b" |
| action_expert_variant: _gemma.Variant = "gemma_300m" |
|
|
| |
| action_dim: int = 32 |
| action_horizon: int = 30 |
| horizons: List[int] = dataclasses.field(default_factory=lambda: [5, 10, 15, 20, 25, 30]) |
| max_token_len: int = None |
| |
| |
| |
| pi05: bool = False |
| |
| discrete_state_input: bool = None |
|
|
| |
| aux_weight: float = 1.0 |
| balance_weight: float = 0.001 |
|
|
| @property |
| def max_horizon(self) -> int: |
| return max(self.horizons) |
|
|
| def __post_init__(self): |
| if self.max_token_len is None: |
| object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48) |
| if self.discrete_state_input is None: |
| object.__setattr__(self, "discrete_state_input", self.pi05) |
|
|
| @property |
| @override |
| def model_type(self) -> _model.ModelType: |
| if self.pi05: |
| return _model.ModelType.PI05 |
| return _model.ModelType.PI0 |
|
|
| @override |
| def create(self, rng: at.KeyArrayLike) -> "Pi0Gated": |
| from openpi.models.pi0_moh import Pi0Gated |
| return Pi0Gated(self, rngs=nnx.Rngs(rng)) |
|
|
| @override |
| def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]: |
| image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32) |
| image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_) |
|
|
| with at.disable_typechecking(): |
| observation_spec = _model.Observation( |
| images={ |
| "base_0_rgb": image_spec, |
| "left_wrist_0_rgb": image_spec, |
| "right_wrist_0_rgb": image_spec, |
| }, |
| image_masks={ |
| "base_0_rgb": image_mask_spec, |
| "left_wrist_0_rgb": image_mask_spec, |
| "right_wrist_0_rgb": image_mask_spec, |
| }, |
| state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), |
| tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), |
| tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool), |
| ) |
| action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32) |
|
|
| return observation_spec, action_spec |
|
|
| def get_freeze_filter(self) -> nnx.filterlib.Filter: |
| """Returns the freeze filter based on the model config.""" |
| filters = [] |
| has_lora = False |
| gemma_params_filter = nnx_utils.PathRegex(".*llm.*") |
| action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*") |
| if "lora" in self.paligemma_variant: |
| filters.append( |
| gemma_params_filter, |
| ) |
| if "lora" not in self.action_expert_variant: |
| |
| filters.append( |
| nnx.Not(action_expert_params_filter), |
| ) |
| has_lora = True |
| elif "lora" in self.action_expert_variant: |
| filters.append( |
| action_expert_params_filter, |
| ) |
| has_lora = True |
|
|
| if has_lora: |
| |
| filters.append( |
| nnx.Not(nnx_utils.PathRegex(".*lora.*")), |
| ) |
| if not filters: |
| return nnx.Nothing |
| return nnx.All(*filters) |
|
|