File size: 11,105 Bytes
b23769d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | """
This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model.
"""
import os
import json
import logging
import sys
from collections import deque
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional, Union
from PIL import Image
import draccus
import numpy as np
from tqdm import tqdm
import torch
import copy
import wandb
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.append(str(REPO_ROOT))
from experiments.robot.openvla_utils import (
get_action_head,
get_noisy_action_projector,
get_processor,
get_proprio_projector,
resize_image_for_policy,
)
from experiments.robot.robot_utils import (
DATE_TIME,
get_action,
get_image_resize_size,
get_model,
invert_gripper_action,
normalize_gripper_action,
set_seed_everywhere,
)
from experiments.robot.libero.run_libero_eval import check_unnorm_key
from prismatic.vla.constants import NUM_ACTIONS_CHUNK
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
@dataclass
class GenerateConfig:
# fmt: off
#################################################################################################################
# Model-specific parameters
#################################################################################################################
model_family: str = "openvla" # Model family
#the task-specific model after sf fine-tuning
pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path
#the task-specific model after oft fine-tuning
original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path
#feature vector is the difference between the two models, which represents the spatial features
vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth"
#the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero}
initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla"
#the original pretrained openvla model
pt_ckpt: Union[str, Path] = "checkpoints/openvla_base"
#the weight of the feature vector when initializing the pt vla model
feature_vector_weight: float = 1 # Weight of feature vector for interpolation
use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference
use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
num_images_in_input: int = 2 # Number of images in the VLA input (default: 1)
use_proprio: bool = True # Whether to include proprio state in input
center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy
lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
unnorm_key: Union[str, Path] = "" # Action un-normalization key
load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
#################################################################################################################
# LIBERO environment-specific parameters
#################################################################################################################
task_suite_name: str = "de" # Task suite
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
num_trials_per_task: int = 50 # Number of rollouts per task
initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file
env_img_res: int = 256 # Resolution for environment images (not policy input resolution)
#################################################################################################################
# Utils
#################################################################################################################
run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
local_log_dir: str = "./experiments/logs" # Local directory for eval logs
use_wandb: bool = False # Whether to also log results in Weights & Biases
wandb_entity: str = "your-wandb-entity" # Name of WandB entity
wandb_project: str = "your-wandb-project" # Name of WandB project
seed: int = 7 # Random Seed (for reproducibility)
def validate_config(cfg: GenerateConfig) -> None:
"""Validate configuration parameters."""
assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
if "image_aug" in str(cfg.pretrained_checkpoint):
assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
# Validate task suite
assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately
"""Initialize model and associated components."""
# Load model
model = get_model(cfg)
# Load proprio projector if needed
proprio_projector = None
if cfg.use_proprio:
proprio_projector = get_proprio_projector(
cfg,
model.llm_dim,
proprio_dim=8, # 8-dimensional proprio for LIBERO
)
# Load action head if needed
action_head = None
if cfg.use_l1_regression or cfg.use_diffusion:
action_head = get_action_head(cfg, model.llm_dim)
# Load noisy action projector if using diffusion
noisy_action_projector = None
if cfg.use_diffusion:
noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
# Get OpenVLA processor if needed
processor = None
if not only_pt:
if cfg.model_family == "openvla":
processor = get_processor(cfg)
check_unnorm_key(cfg, model)
return model, action_head, proprio_projector, noisy_action_projector, processor
# @draccus.wrap()
def generate_feature_vector(cfg: GenerateConfig):
"""Generate a feature vector (parameter differences) between two task-specific models."""
# Validate configuration
# Set random seed
set_seed_everywhere(cfg.seed)
# Initialize model and components
model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
original_config = GenerateConfig(
pretrained_checkpoint=cfg.original_pretrained_checkpoint,
task_suite_name=cfg.task_suite_name,
)
original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config)
#for action_head and noisy_action_projector, these modules are not interpolated
assert len(model.state_dict()) == len(original_model.state_dict())
feature_vector_dict = {}
total = len(original_model.state_dict())
for name, original_model_param in tqdm(original_model.named_parameters(), total=total):
model_param = model.state_dict()[name]
feature_vector_dict[name] = (model_param - original_model_param).detach().cpu()
return feature_vector_dict
# @draccus.wrap()
def interpolate_feature_vector(cfg: GenerateConfig):
"""Interpolate feature vector."""
feature_vector_dict = torch.load(cfg.vector_save_path)
pt_vla_config = GenerateConfig(
pretrained_checkpoint=cfg.pt_ckpt,
original_pretrained_checkpoint=cfg.original_pretrained_checkpoint,
vector_save_path=cfg.vector_save_path,
initialized_pt_vla_path=cfg.initialized_pt_vla_path,
feature_vector_weight=cfg.feature_vector_weight,
pt_ckpt=cfg.pt_ckpt,
task_suite_name=cfg.task_suite_name,
use_proprio=False,
use_l1_regression=False,
use_diffusion=False
)
pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True)
#copy the SF parameters for checking the change before and after interpolation
model_sd = pt_vla.state_dict()
before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point}
with torch.no_grad():
pt_params = dict(pt_vla.named_parameters())
for name, diff in feature_vector_dict.items():
if name in pt_params:
pt_param = pt_params[name]
diff = diff.to(pt_param.device)
pt_param.add_(diff, alpha=cfg.feature_vector_weight)
#check after interpolation
diffs_after = []
for name, before_tensor in before_interp_sd.items():
after_tensor = model_sd[name]
difference = (after_tensor - before_tensor).float().norm().item()
diffs_after.append(difference)
print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, "
f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}")
#########################################################
return pt_vla
@draccus.wrap()
def main(cfg: GenerateConfig):
if not os.path.exists(cfg.vector_save_path):
feature_vector_dict = generate_feature_vector(cfg)
torch.save(feature_vector_dict, cfg.vector_save_path)
else:
print(f"Feature vector already exists at {cfg.vector_save_path}")
initialized_pt_vla = interpolate_feature_vector(cfg)
os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True)
initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path)
if __name__ == "__main__":
main()
|