""" This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model. """ import os import json import logging import sys from collections import deque from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Optional, Union from PIL import Image import draccus import numpy as np from tqdm import tqdm import torch import copy import wandb REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.append(str(REPO_ROOT)) from experiments.robot.openvla_utils import ( get_action_head, get_noisy_action_projector, get_processor, get_proprio_projector, resize_image_for_policy, ) from experiments.robot.robot_utils import ( DATE_TIME, get_action, get_image_resize_size, get_model, invert_gripper_action, normalize_gripper_action, set_seed_everywhere, ) from experiments.robot.libero.run_libero_eval import check_unnorm_key from prismatic.vla.constants import NUM_ACTIONS_CHUNK from prismatic.vla.constants import PROPRIO_DIM # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) @dataclass class GenerateConfig: # fmt: off ################################################################################################################# # Model-specific parameters ################################################################################################################# model_family: str = "openvla" # Model family #the task-specific model after sf fine-tuning pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path #the task-specific model after oft fine-tuning original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path #feature vector is the difference between the two models, which represents the spatial features vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth" #the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero} initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla" #the original pretrained openvla model pt_ckpt: Union[str, Path] = "checkpoints/openvla_base" #the weight of the feature vector when initializing the pt vla model feature_vector_weight: float = 1 # Weight of feature vector for interpolation use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM) num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features num_images_in_input: int = 3 # Number of images in the VLA input (default: 1) use_proprio: bool = True # Whether to include proprio state in input center_crop: bool = True # Center crop? (if trained w/ random crop image aug) num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!) unnorm_key: Union[str, Path] = "" # Action un-normalization key load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization ################################################################################################################# # LIBERO environment-specific parameters ################################################################################################################# task_suite_name: str = "de" # Task suite num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim num_trials_per_task: int = 50 # Number of rollouts per task initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file env_img_res: int = 256 # Resolution for environment images (not policy input resolution) ################################################################################################################# # Utils ################################################################################################################# run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging local_log_dir: str = "./experiments/logs" # Local directory for eval logs use_wandb: bool = False # Whether to also log results in Weights & Biases wandb_entity: str = "your-wandb-entity" # Name of WandB entity wandb_project: str = "your-wandb-project" # Name of WandB project seed: int = 7 # Random Seed (for reproducibility) def validate_config(cfg: GenerateConfig) -> None: """Validate configuration parameters.""" assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!" if "image_aug" in str(cfg.pretrained_checkpoint): assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!" assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!" # Validate task suite # assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}" def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately """Initialize model and associated components.""" # Load model model = get_model(cfg) # Load proprio projector if needed proprio_projector = None if cfg.use_proprio: proprio_projector = get_proprio_projector( cfg, model.llm_dim, proprio_dim=PROPRIO_DIM, #set the proprio_dim for different robots ) # Load action head if needed action_head = None if cfg.use_l1_regression or cfg.use_diffusion: action_head = get_action_head(cfg, model.llm_dim) # Load noisy action projector if using diffusion noisy_action_projector = None if cfg.use_diffusion: noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim) # Get OpenVLA processor if needed processor = None if not only_pt: if cfg.model_family == "openvla": processor = get_processor(cfg) # check_unnorm_key(cfg, model) return model, action_head, proprio_projector, noisy_action_projector, processor # @draccus.wrap() def generate_feature_vector(cfg: GenerateConfig): """Generate a feature vector (parameter differences) between two task-specific models.""" # Validate configuration # Set random seed set_seed_everywhere(cfg.seed) # Initialize model and components model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg) original_config = GenerateConfig( pretrained_checkpoint=cfg.original_pretrained_checkpoint, task_suite_name=cfg.task_suite_name, ) original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config) #for action_head and noisy_action_projector, these modules are not interpolated assert len(model.state_dict()) == len(original_model.state_dict()) feature_vector_dict = {} total = len(original_model.state_dict()) for name, original_model_param in tqdm(original_model.named_parameters(), total=total): model_param = model.state_dict()[name] feature_vector_dict[name] = (model_param - original_model_param).detach().cpu() return feature_vector_dict # @draccus.wrap() def interpolate_feature_vector(cfg: GenerateConfig): """Interpolate feature vector.""" feature_vector_dict = torch.load(cfg.vector_save_path) pt_vla_config = GenerateConfig( pretrained_checkpoint=cfg.pt_ckpt, original_pretrained_checkpoint=cfg.original_pretrained_checkpoint, vector_save_path=cfg.vector_save_path, initialized_pt_vla_path=cfg.initialized_pt_vla_path, feature_vector_weight=cfg.feature_vector_weight, pt_ckpt=cfg.pt_ckpt, task_suite_name=cfg.task_suite_name, use_proprio=False, use_l1_regression=False, use_diffusion=False ) pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True) #copy the SF parameters for checking the change before and after interpolation model_sd = pt_vla.state_dict() before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point} with torch.no_grad(): pt_params = dict(pt_vla.named_parameters()) for name, diff in feature_vector_dict.items(): if name in pt_params: pt_param = pt_params[name] diff = diff.to(pt_param.device) pt_param.add_(diff, alpha=cfg.feature_vector_weight) #check after interpolation diffs_after = [] for name, before_tensor in before_interp_sd.items(): after_tensor = model_sd[name] difference = (after_tensor - before_tensor).float().norm().item() diffs_after.append(difference) print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, " f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}") ######################################################### return pt_vla @draccus.wrap() def main(cfg: GenerateConfig): if not os.path.exists(cfg.vector_save_path): feature_vector_dict = generate_feature_vector(cfg) torch.save(feature_vector_dict, cfg.vector_save_path) else: print(f"Feature vector already exists at {cfg.vector_save_path}") initialized_pt_vla = interpolate_feature_vector(cfg) os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True) initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path) if __name__ == "__main__": main()