| """
|
| This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model.
|
| """
|
|
|
|
|
| import os
|
| import json
|
| import logging |
|
|
| import sys |
| from collections import deque |
| from dataclasses import dataclass |
| from enum import Enum
|
| from pathlib import Path
|
| from typing import Optional, Union
|
| from PIL import Image
|
|
|
| import draccus
|
| import numpy as np
|
| from tqdm import tqdm
|
| import torch |
| import copy |
|
|
| import wandb |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.append(str(REPO_ROOT)) |
| from experiments.robot.openvla_utils import ( |
| get_action_head, |
| get_noisy_action_projector, |
| get_processor, |
| get_proprio_projector,
|
| resize_image_for_policy,
|
| )
|
| from experiments.robot.robot_utils import (
|
| DATE_TIME,
|
| get_action,
|
| get_image_resize_size,
|
| get_model,
|
| invert_gripper_action,
|
| normalize_gripper_action,
|
| set_seed_everywhere,
|
| )
|
| from experiments.robot.libero.run_libero_eval import check_unnorm_key
|
| from prismatic.vla.constants import NUM_ACTIONS_CHUNK
|
|
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO,
|
| format="%(asctime)s [%(levelname)s] %(message)s",
|
| handlers=[logging.StreamHandler()],
|
| )
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| @dataclass
|
| class GenerateConfig:
|
|
|
|
|
| |
| |
| |
| model_family: str = "openvla" |
| |
| pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" |
| |
| original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" |
| |
| vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth" |
| |
| initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla" |
| |
| pt_ckpt: Union[str, Path] = "checkpoints/openvla_base" |
|
|
| feature_vector_weight: float = 1
|
|
|
| use_l1_regression: bool = True
|
| use_diffusion: bool = False
|
| num_diffusion_steps_train: int = 50
|
| num_diffusion_steps_inference: int = 50
|
| use_film: bool = False
|
| num_images_in_input: int = 2
|
| use_proprio: bool = True
|
|
|
| center_crop: bool = True
|
| num_open_loop_steps: int = 8
|
|
|
| lora_rank: int = 32
|
|
|
| unnorm_key: Union[str, Path] = ""
|
|
|
| load_in_8bit: bool = False
|
| load_in_4bit: bool = False
|
|
|
|
|
|
|
|
|
| task_suite_name: str = "de"
|
| num_steps_wait: int = 10
|
| num_trials_per_task: int = 50
|
| initial_states_path: str = "DEFAULT"
|
| env_img_res: int = 256
|
|
|
|
|
|
|
|
|
| run_id_note: Optional[str] = None
|
| local_log_dir: str = "./experiments/logs"
|
|
|
| use_wandb: bool = False
|
| wandb_entity: str = "your-wandb-entity"
|
| wandb_project: str = "your-wandb-project"
|
|
|
| seed: int = 7
|
|
|
| def validate_config(cfg: GenerateConfig) -> None:
|
| """Validate configuration parameters."""
|
| assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
|
|
|
| if "image_aug" in str(cfg.pretrained_checkpoint):
|
| assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
|
|
|
| assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
|
|
|
|
|
| assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
|
|
|
| def initialize_model(cfg: GenerateConfig, only_pt: bool = False):
|
| """Initialize model and associated components."""
|
|
|
| model = get_model(cfg)
|
|
|
|
|
| proprio_projector = None
|
| if cfg.use_proprio:
|
| proprio_projector = get_proprio_projector(
|
| cfg,
|
| model.llm_dim,
|
| proprio_dim=8,
|
| )
|
|
|
|
|
| action_head = None
|
| if cfg.use_l1_regression or cfg.use_diffusion:
|
| action_head = get_action_head(cfg, model.llm_dim)
|
|
|
|
|
| noisy_action_projector = None
|
| if cfg.use_diffusion:
|
| noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
|
|
|
|
|
| processor = None
|
| if not only_pt:
|
| if cfg.model_family == "openvla":
|
| processor = get_processor(cfg)
|
| check_unnorm_key(cfg, model)
|
|
|
| return model, action_head, proprio_projector, noisy_action_projector, processor
|
|
|
|
|
| def generate_feature_vector(cfg: GenerateConfig):
|
| """Generate a feature vector (parameter differences) between two task-specific models."""
|
|
|
|
|
|
|
| set_seed_everywhere(cfg.seed)
|
|
|
|
|
| model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
|
|
|
| original_config = GenerateConfig(
|
| pretrained_checkpoint=cfg.original_pretrained_checkpoint,
|
| task_suite_name=cfg.task_suite_name,
|
| )
|
|
|
| original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config)
|
|
|
| assert len(model.state_dict()) == len(original_model.state_dict())
|
| feature_vector_dict = {}
|
| total = len(original_model.state_dict())
|
| for name, original_model_param in tqdm(original_model.named_parameters(), total=total):
|
| model_param = model.state_dict()[name]
|
| feature_vector_dict[name] = (model_param - original_model_param).detach().cpu()
|
|
|
| return feature_vector_dict
|
|
|
|
|
| def interpolate_feature_vector(cfg: GenerateConfig):
|
| """Interpolate feature vector."""
|
| feature_vector_dict = torch.load(cfg.vector_save_path)
|
|
|
| pt_vla_config = GenerateConfig(
|
| pretrained_checkpoint=cfg.pt_ckpt,
|
| original_pretrained_checkpoint=cfg.original_pretrained_checkpoint,
|
| vector_save_path=cfg.vector_save_path,
|
| initialized_pt_vla_path=cfg.initialized_pt_vla_path,
|
| feature_vector_weight=cfg.feature_vector_weight,
|
| pt_ckpt=cfg.pt_ckpt,
|
| task_suite_name=cfg.task_suite_name,
|
| use_proprio=False,
|
| use_l1_regression=False,
|
| use_diffusion=False
|
| )
|
|
|
| pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True)
|
|
|
|
|
| model_sd = pt_vla.state_dict()
|
| before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point}
|
|
|
| with torch.no_grad():
|
| pt_params = dict(pt_vla.named_parameters())
|
| for name, diff in feature_vector_dict.items():
|
| if name in pt_params:
|
| pt_param = pt_params[name]
|
| diff = diff.to(pt_param.device)
|
| pt_param.add_(diff, alpha=cfg.feature_vector_weight)
|
|
|
|
|
| diffs_after = []
|
| for name, before_tensor in before_interp_sd.items():
|
| after_tensor = model_sd[name]
|
| difference = (after_tensor - before_tensor).float().norm().item()
|
| diffs_after.append(difference)
|
|
|
| print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, "
|
| f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}")
|
|
|
|
|
| return pt_vla
|
|
|
| @draccus.wrap()
|
| def main(cfg: GenerateConfig):
|
| if not os.path.exists(cfg.vector_save_path):
|
| feature_vector_dict = generate_feature_vector(cfg)
|
| torch.save(feature_vector_dict, cfg.vector_save_path)
|
| else:
|
| print(f"Feature vector already exists at {cfg.vector_save_path}")
|
| initialized_pt_vla = interpolate_feature_vector(cfg)
|
| os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True)
|
| initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path)
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|