haofuly's picture
Add files using upload-large-folder tool
b23769d verified
raw
history blame
11.2 kB
"""
This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model.
"""
import os
import json
import logging
import sys
from collections import deque
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional, Union
from PIL import Image
import draccus
import numpy as np
from tqdm import tqdm
import torch
import copy
import wandb
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.append(str(REPO_ROOT))
from experiments.robot.openvla_utils import (
get_action_head,
get_noisy_action_projector,
get_processor,
get_proprio_projector,
resize_image_for_policy,
)
from experiments.robot.robot_utils import (
DATE_TIME,
get_action,
get_image_resize_size,
get_model,
invert_gripper_action,
normalize_gripper_action,
set_seed_everywhere,
)
from experiments.robot.libero.run_libero_eval import check_unnorm_key
from prismatic.vla.constants import NUM_ACTIONS_CHUNK
from prismatic.vla.constants import PROPRIO_DIM
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
@dataclass
class GenerateConfig:
# fmt: off
#################################################################################################################
# Model-specific parameters
#################################################################################################################
model_family: str = "openvla" # Model family
#the task-specific model after sf fine-tuning
pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path
#the task-specific model after oft fine-tuning
original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path
#feature vector is the difference between the two models, which represents the spatial features
vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth"
#the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero}
initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla"
#the original pretrained openvla model
pt_ckpt: Union[str, Path] = "checkpoints/openvla_base"
#the weight of the feature vector when initializing the pt vla model
feature_vector_weight: float = 1 # Weight of feature vector for interpolation
use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference
use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
num_images_in_input: int = 3 # Number of images in the VLA input (default: 1)
use_proprio: bool = True # Whether to include proprio state in input
center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy
lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
unnorm_key: Union[str, Path] = "" # Action un-normalization key
load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
#################################################################################################################
# LIBERO environment-specific parameters
#################################################################################################################
task_suite_name: str = "de" # Task suite
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
num_trials_per_task: int = 50 # Number of rollouts per task
initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file
env_img_res: int = 256 # Resolution for environment images (not policy input resolution)
#################################################################################################################
# Utils
#################################################################################################################
run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
local_log_dir: str = "./experiments/logs" # Local directory for eval logs
use_wandb: bool = False # Whether to also log results in Weights & Biases
wandb_entity: str = "your-wandb-entity" # Name of WandB entity
wandb_project: str = "your-wandb-project" # Name of WandB project
seed: int = 7 # Random Seed (for reproducibility)
def validate_config(cfg: GenerateConfig) -> None:
"""Validate configuration parameters."""
assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
if "image_aug" in str(cfg.pretrained_checkpoint):
assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
# Validate task suite
# assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately
"""Initialize model and associated components."""
# Load model
model = get_model(cfg)
# Load proprio projector if needed
proprio_projector = None
if cfg.use_proprio:
proprio_projector = get_proprio_projector(
cfg,
model.llm_dim,
proprio_dim=PROPRIO_DIM, #set the proprio_dim for different robots
)
# Load action head if needed
action_head = None
if cfg.use_l1_regression or cfg.use_diffusion:
action_head = get_action_head(cfg, model.llm_dim)
# Load noisy action projector if using diffusion
noisy_action_projector = None
if cfg.use_diffusion:
noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
# Get OpenVLA processor if needed
processor = None
if not only_pt:
if cfg.model_family == "openvla":
processor = get_processor(cfg)
# check_unnorm_key(cfg, model)
return model, action_head, proprio_projector, noisy_action_projector, processor
# @draccus.wrap()
def generate_feature_vector(cfg: GenerateConfig):
"""Generate a feature vector (parameter differences) between two task-specific models."""
# Validate configuration
# Set random seed
set_seed_everywhere(cfg.seed)
# Initialize model and components
model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
original_config = GenerateConfig(
pretrained_checkpoint=cfg.original_pretrained_checkpoint,
task_suite_name=cfg.task_suite_name,
)
original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config)
#for action_head and noisy_action_projector, these modules are not interpolated
assert len(model.state_dict()) == len(original_model.state_dict())
feature_vector_dict = {}
total = len(original_model.state_dict())
for name, original_model_param in tqdm(original_model.named_parameters(), total=total):
model_param = model.state_dict()[name]
feature_vector_dict[name] = (model_param - original_model_param).detach().cpu()
return feature_vector_dict
# @draccus.wrap()
def interpolate_feature_vector(cfg: GenerateConfig):
"""Interpolate feature vector."""
feature_vector_dict = torch.load(cfg.vector_save_path)
pt_vla_config = GenerateConfig(
pretrained_checkpoint=cfg.pt_ckpt,
original_pretrained_checkpoint=cfg.original_pretrained_checkpoint,
vector_save_path=cfg.vector_save_path,
initialized_pt_vla_path=cfg.initialized_pt_vla_path,
feature_vector_weight=cfg.feature_vector_weight,
pt_ckpt=cfg.pt_ckpt,
task_suite_name=cfg.task_suite_name,
use_proprio=False,
use_l1_regression=False,
use_diffusion=False
)
pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True)
#copy the SF parameters for checking the change before and after interpolation
model_sd = pt_vla.state_dict()
before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point}
with torch.no_grad():
pt_params = dict(pt_vla.named_parameters())
for name, diff in feature_vector_dict.items():
if name in pt_params:
pt_param = pt_params[name]
diff = diff.to(pt_param.device)
pt_param.add_(diff, alpha=cfg.feature_vector_weight)
#check after interpolation
diffs_after = []
for name, before_tensor in before_interp_sd.items():
after_tensor = model_sd[name]
difference = (after_tensor - before_tensor).float().norm().item()
diffs_after.append(difference)
print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, "
f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}")
#########################################################
return pt_vla
@draccus.wrap()
def main(cfg: GenerateConfig):
if not os.path.exists(cfg.vector_save_path):
feature_vector_dict = generate_feature_vector(cfg)
torch.save(feature_vector_dict, cfg.vector_save_path)
else:
print(f"Feature vector already exists at {cfg.vector_save_path}")
initialized_pt_vla = interpolate_feature_vector(cfg)
os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True)
initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path)
if __name__ == "__main__":
main()