| |
| """ |
| Load a JAX model and print all parameter keys, with optional conversion to PyTorch. |
| |
| This script loads a JAX model checkpoint using orbax and can either: |
| 1. Print out all the parameter keys in a hierarchical structure for inspection |
| 2. Convert the JAX model to PyTorch format using our PI0Pytorch model |
| |
| Usage: |
| # Just inspect keys: |
| python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only |
| python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only |
| |
| # Convert to PyTorch: |
| python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output |
| python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output |
| |
| Example: |
| # pi0_droid |
| python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch |
| |
| # pi0_aloha_sim |
| python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch |
| |
| # pi05_droid |
| python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch |
| """ |
|
|
| import json |
| import os |
| import pathlib |
| import shutil |
| from typing import Literal |
|
|
| from flax.nnx import traversals |
| import numpy as np |
| import orbax.checkpoint as ocp |
| import safetensors |
| import torch |
| import tyro |
|
|
| import openpi.models.gemma |
| import openpi.models.model |
| import openpi.models.pi0_config |
| import openpi.models_pytorch.pi0_pytorch |
| from openpi.training import utils |
| import openpi.training.config as _config |
|
|
|
|
| def slice_paligemma_state_dict(state_dict, config): |
| """Convert PaliGemma JAX parameters to PyTorch format.""" |
| suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" |
|
|
| |
| jax_key = f"img/embedding/kernel{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight" |
| state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1) |
|
|
| jax_key = f"img/embedding/bias{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias" |
| state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
| |
| jax_key = f"img/pos_embedding{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight" |
| state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size) |
|
|
| |
| encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") |
| encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") |
| encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") |
| encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") |
|
|
| encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") |
| encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") |
| encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") |
| encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") |
|
|
| encoderblock_attention_0_key_kernel = state_dict.pop( |
| f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}" |
| ) |
| encoderblock_attention_0_key_bias = state_dict.pop( |
| f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}" |
| ) |
| encoderblock_attention_0_value_kernel = state_dict.pop( |
| f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}" |
| ) |
| encoderblock_attention_0_value_bias = state_dict.pop( |
| f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}" |
| ) |
| encoderblock_attention_0_query_kernel = state_dict.pop( |
| f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}" |
| ) |
| encoderblock_attention_0_query_bias = state_dict.pop( |
| f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}" |
| ) |
| encoderblock_attention_0_out_kernel = state_dict.pop( |
| f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}" |
| ) |
| encoderblock_attention_0_out_bias = state_dict.pop( |
| f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}" |
| ) |
|
|
| for i in range(config.vision_config.num_hidden_layers): |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight" |
| ] = encoderblock_layernorm0_scale[i].transpose() |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias" |
| ] = encoderblock_layernorm0_bias[i] |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight" |
| ] = encoderblock_layernorm1_scale[i].transpose() |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias" |
| ] = encoderblock_layernorm1_bias[i] |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight" |
| ] = encoderblock_mlp_dense0_kernel[i].transpose() |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias" |
| ] = encoderblock_mlp_dense0_bias[i] |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight" |
| ] = encoderblock_mlp_dense1_kernel[i].transpose() |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias" |
| ] = encoderblock_mlp_dense1_bias[i] |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight" |
| ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias" |
| ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight" |
| ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias" |
| ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight" |
| ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias" |
| ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight" |
| ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias" |
| ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) |
|
|
| jax_key = f"img/Transformer/encoder_norm/scale{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight" |
| state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() |
|
|
| jax_key = f"img/Transformer/encoder_norm/bias{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias" |
| state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
| |
| jax_key = f"img/head/kernel{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight" |
| state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() |
|
|
| jax_key = f"img/head/bias{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias" |
| state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
| |
| jax_key = f"llm/embedder/input_embedding{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" |
| state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
| |
| llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") |
| llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") |
| llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") |
|
|
| llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") |
| llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") |
|
|
| llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") |
| llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") |
|
|
| for i in range(config.text_config.num_hidden_layers): |
| q_proj_weight_reshaped = ( |
| llm_attention_q_einsum[i] |
| .transpose(0, 2, 1) |
| .reshape( |
| config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size |
| ) |
| ) |
| state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = ( |
| q_proj_weight_reshaped |
| ) |
|
|
| k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() |
| state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = ( |
| k_proj_weight_reshaped |
| ) |
| v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() |
| state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = ( |
| v_proj_weight_reshaped |
| ) |
|
|
| o_proj_weight_reshaped = ( |
| llm_attention_attn_vec_einsum[i] |
| .transpose(2, 0, 1) |
| .reshape( |
| config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size |
| ) |
| ) |
| state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = ( |
| o_proj_weight_reshaped |
| ) |
|
|
| gate_proj_weight = llm_mlp_gating_einsum[i, 0] |
| state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = ( |
| gate_proj_weight.transpose() |
| ) |
| up_proj_weight = llm_mlp_gating_einsum[i, 1] |
| state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = ( |
| up_proj_weight.transpose() |
| ) |
| state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = ( |
| llm_mlp_linear[i].transpose() |
| ) |
| state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = ( |
| llm_input_layernorm[i] |
| ) |
| state_dict[ |
| f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight" |
| ] = llm_post_attention_layernorm[i] |
|
|
| jax_key = f"llm/final_norm/scale{suffix}" |
| pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight" |
| state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
| expert_dict = {} |
| final_state_dict = {} |
|
|
| |
| expert_keys = [ |
| f"llm/final_norm_1/scale{suffix}", |
| f"llm/final_norm_1/Dense_0/bias{suffix}", |
| f"llm/final_norm_1/Dense_0/kernel{suffix}", |
| f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", |
| f"llm/layers/attn/kv_einsum_1/w{suffix}", |
| f"llm/layers/attn/q_einsum_1/w{suffix}", |
| f"llm/layers/mlp_1/gating_einsum{suffix}", |
| f"llm/layers/mlp_1/linear{suffix}", |
| f"llm/layers/pre_attention_norm_1/scale{suffix}", |
| f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}", |
| f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}", |
| f"llm/layers/pre_ffw_norm_1/scale{suffix}", |
| f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}", |
| f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}", |
| ] |
|
|
| for key, value in state_dict.items(): |
| if key not in expert_keys: |
| final_state_dict[key] = torch.from_numpy(value) |
| else: |
| expert_dict[key] = value |
|
|
| return final_state_dict, expert_dict |
|
|
|
|
| def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05): |
| """Convert Gemma JAX parameters to PyTorch format.""" |
| |
| if not hasattr(config, "vocab_size"): |
| config.vocab_size = 257152 |
| if not hasattr(config, "hidden_size"): |
| config.hidden_size = config.width |
| if not hasattr(config, "num_hidden_layers"): |
| config.num_hidden_layers = config.depth |
| if not hasattr(config, "num_attention_heads"): |
| config.num_attention_heads = config.num_heads |
|
|
| suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" |
|
|
| llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") |
| llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") |
| llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") |
|
|
| llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") |
| llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") |
|
|
| |
| if "pi05" in checkpoint_dir: |
| |
| llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}") |
| llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}") |
| llm_input_layernorm_kernel = state_dict.pop( |
| f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}" |
| ) |
| llm_post_attention_layernorm_kernel = state_dict.pop( |
| f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}" |
| ) |
| else: |
| |
| llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") |
| llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") |
|
|
| for i in range(config.num_hidden_layers): |
| q_proj_weight_reshaped = ( |
| llm_attention_q_einsum[i] |
| .transpose(0, 2, 1) |
| .reshape(config.num_attention_heads * config.head_dim, config.hidden_size) |
| ) |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = ( |
| q_proj_weight_reshaped |
| ) |
|
|
| k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = ( |
| k_proj_weight_reshaped |
| ) |
| v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = ( |
| v_proj_weight_reshaped |
| ) |
|
|
| o_proj_weight_reshaped = ( |
| llm_attention_attn_vec_einsum[i] |
| .reshape(config.num_attention_heads * config.head_dim, config.hidden_size) |
| .transpose(1, 0) |
| ) |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = ( |
| o_proj_weight_reshaped |
| ) |
|
|
| gate_proj_weight = llm_mlp_gating_einsum[i, 0] |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = ( |
| gate_proj_weight.transpose() |
| ) |
| up_proj_weight = llm_mlp_gating_einsum[i, 1] |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = ( |
| up_proj_weight.transpose() |
| ) |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[ |
| i |
| ].transpose() |
|
|
| if "pi05" in checkpoint_dir: |
| |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = ( |
| llm_input_layernorm_bias[i] |
| ) |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = ( |
| llm_post_attention_layernorm_bias[i] |
| ) |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = ( |
| llm_input_layernorm_kernel[i].transpose() |
| ) |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = ( |
| llm_post_attention_layernorm_kernel[i].transpose() |
| ) |
| else: |
| |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = ( |
| llm_input_layernorm[i] |
| ) |
| state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = ( |
| llm_post_attention_layernorm[i] |
| ) |
|
|
| |
| if "pi05" in checkpoint_dir: |
| |
| final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}") |
| final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}") |
| state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias |
| state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose() |
| else: |
| |
| state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop( |
| f"llm/final_norm_{num_expert}/scale{suffix}" |
| ) |
|
|
| |
|
|
| final_state_dict = {} |
| for key, value in state_dict.items(): |
| if not isinstance(value, torch.Tensor): |
| final_state_dict[key] = torch.from_numpy(value) |
| else: |
| final_state_dict[key] = value |
|
|
| return final_state_dict |
|
|
|
|
| def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None): |
| """Load and process params by restoring via JAX model loader first. |
| This respects dtype conversions that occur during model restore. |
| """ |
| |
| params = openpi.models.model.restore_params( |
| f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision |
| ) |
|
|
| return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params} |
|
|
|
|
| def load_jax_model_and_print_keys(checkpoint_dir: str): |
| """ |
| Load JAX model from checkpoint and print all parameter keys. |
| |
| Args: |
| checkpoint_dir: Path to the checkpoint directory |
| """ |
| checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir |
| |
| checkpointer = ocp.PyTreeCheckpointer() |
| metadata = checkpointer.metadata(f"{checkpoint_dir}/params") |
| print(utils.array_tree_to_info(metadata)) |
|
|
|
|
| def convert_pi0_checkpoint( |
| checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config |
| ): |
| """ |
| Convert PI0 JAX checkpoint to PyTorch format. |
| |
| Args: |
| checkpoint_dir: Path to the JAX checkpoint |
| precision: Model precision (float32, bfloat16, float16) |
| output_path: Path to save the converted PyTorch model |
| model_config: Model config |
| """ |
| print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}") |
| print(f"Model config: {model_config}") |
|
|
| |
| initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32") |
|
|
| |
| if model_config.pi05: |
| keys = [ |
| "action_in_proj", |
| "action_out_proj", |
| "time_mlp_in", |
| "time_mlp_out", |
| ] |
| else: |
| keys = [ |
| "state_proj", |
| "action_in_proj", |
| "action_out_proj", |
| "action_time_mlp_in", |
| "action_time_mlp_out", |
| ] |
|
|
| projection_params = {} |
| for key in keys: |
| kernel_params = initial_params["projection_params"][key]["kernel"] |
| bias_params = initial_params["projection_params"][key]["bias"] |
| if isinstance(kernel_params, dict): |
| weight = kernel_params["value"] |
| bias = bias_params["value"] |
| else: |
| weight = kernel_params |
| bias = bias_params |
|
|
| pytorch_weight_key = f"{key}.weight" |
| pytorch_bias_key = f"{key}.bias" |
|
|
| projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T |
| projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias)) |
|
|
| |
| |
| class PaliGemmaConfig: |
| def __init__(self): |
| self.vision_config = type( |
| "obj", |
| (object,), |
| { |
| "hidden_size": 1152, |
| "num_hidden_layers": 27, |
| "num_attention_heads": 16, |
| "intermediate_size": 4304, |
| "patch_size": 14, |
| "projection_dim": 2048, |
| }, |
| )() |
| self.text_config = type( |
| "obj", |
| (object,), |
| { |
| "hidden_size": 2048, |
| "num_hidden_layers": 18, |
| "num_attention_heads": 8, |
| "head_dim": 256, |
| "intermediate_size": 16384, |
| }, |
| )() |
|
|
| paligemma_config = PaliGemmaConfig() |
| action_expert_config = openpi.models.gemma.get_config("gemma_300m") |
|
|
| |
| paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config) |
|
|
| |
| gemma_params = slice_gemma_state_dict( |
| expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05 |
| ) |
|
|
| |
| pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config) |
|
|
| |
| all_params = {**paligemma_params, **gemma_params, **projection_params} |
|
|
| |
| pi0_model.load_state_dict(all_params, strict=False) |
|
|
| if precision == "float32": |
| pi0_model = pi0_model.to(torch.float32) |
| elif precision == "bfloat16": |
| pi0_model = pi0_model.to(torch.bfloat16) |
| else: |
| raise ValueError(f"Invalid precision: {precision}") |
|
|
| |
| os.makedirs(output_path, exist_ok=True) |
|
|
| |
| safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors")) |
|
|
| |
| assets_source = pathlib.Path(checkpoint_dir).parent / "assets" |
| if assets_source.exists(): |
| assets_dest = pathlib.Path(output_path) / "assets" |
| if assets_dest.exists(): |
| shutil.rmtree(assets_dest) |
| shutil.copytree(assets_source, assets_dest) |
|
|
| |
| config_dict = { |
| "action_dim": model_config.action_dim, |
| "action_horizon": model_config.action_horizon, |
| "paligemma_variant": model_config.paligemma_variant, |
| "action_expert_variant": model_config.action_expert_variant, |
| "precision": precision, |
| } |
| with open(os.path.join(output_path, "config.json"), "w") as f: |
| json.dump(config_dict, f, indent=2) |
|
|
| print("Model conversion completed successfully!") |
| print(f"Model saved to {output_path}") |
|
|
|
|
| def main( |
| checkpoint_dir: str, |
| config_name: str, |
| output_path: str | None = None, |
| precision: Literal["float32", "bfloat16", "float16"] = "bfloat16", |
| *, |
| inspect_only: bool = False, |
| ): |
| """Load JAX model and optionally convert to PyTorch. |
| |
| Args: |
| checkpoint_dir: Path to the JAX checkpoint directory |
| output_path: Path to save converted PyTorch model (required for conversion) |
| precision: Precision for model conversion |
| inspect_only: Only inspect parameter keys, don't convert |
| """ |
| model_config = _config.get_config(config_name).model |
| if not isinstance(model_config, openpi.models.pi0_config.Pi0Config): |
| raise ValueError(f"Config {config_name} is not a Pi0Config") |
| if inspect_only: |
| load_jax_model_and_print_keys(checkpoint_dir) |
| else: |
| if not output_path: |
| print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.") |
| return |
| convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config) |
|
|
|
|
| if __name__ == "__main__": |
| tyro.cli(main) |
|
|