File size: 11,171 Bytes
b23769d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""

This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model.

"""


import os
import json
import logging

import sys
from collections import deque
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional, Union
from PIL import Image

import draccus
import numpy as np
from tqdm import tqdm
import torch
import copy

import wandb

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))
from experiments.robot.openvla_utils import (
    get_action_head,
    get_noisy_action_projector,
    get_processor,
    get_proprio_projector,
    resize_image_for_policy,
)
from experiments.robot.robot_utils import (
    DATE_TIME,
    get_action,
    get_image_resize_size,
    get_model,
    invert_gripper_action,
    normalize_gripper_action,
    set_seed_everywhere,
)
from experiments.robot.libero.run_libero_eval import check_unnorm_key
from prismatic.vla.constants import NUM_ACTIONS_CHUNK
from prismatic.vla.constants import PROPRIO_DIM

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)


@dataclass
class GenerateConfig:
    # fmt: off

    #################################################################################################################
    # Model-specific parameters
    #################################################################################################################
    model_family: str = "openvla"                    # Model family
    #the task-specific model after sf fine-tuning
    pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model"     # Task-specific checkpoint path
    #the task-specific model after oft fine-tuning
    original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model"     # Reference checkpoint path
    #feature vector is the difference between the two models, which represents the spatial features
    vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth"
    #the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero}
    initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla"
    #the original pretrained openvla model
    pt_ckpt: Union[str, Path] = "checkpoints/openvla_base"
    #the weight of the feature vector when initializing the pt vla model
    feature_vector_weight: float = 1              # Weight of feature vector for interpolation

    use_l1_regression: bool = True                   # If True, uses continuous action head with L1 regression objective
    use_diffusion: bool = False                      # If True, uses continuous action head with diffusion modeling objective (DDIM)
    num_diffusion_steps_train: int = 50              # (When `diffusion==True`) Number of diffusion steps used for training
    num_diffusion_steps_inference: int = 50          # (When `diffusion==True`) Number of diffusion steps used for inference
    use_film: bool = False                           # If True, uses FiLM to infuse language inputs into visual features
    num_images_in_input: int = 3                    # Number of images in the VLA input (default: 1)
    use_proprio: bool = True                         # Whether to include proprio state in input

    center_crop: bool = True                         # Center crop? (if trained w/ random crop image aug)
    num_open_loop_steps: int = 8                     # Number of actions to execute open-loop before requerying policy

    lora_rank: int = 32                              # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)

    unnorm_key: Union[str, Path] = ""                # Action un-normalization key

    load_in_8bit: bool = False                       # (For OpenVLA only) Load with 8-bit quantization
    load_in_4bit: bool = False                       # (For OpenVLA only) Load with 4-bit quantization

    #################################################################################################################
    # LIBERO environment-specific parameters
    #################################################################################################################
    task_suite_name: str = "de"  # Task suite
    num_steps_wait: int = 10                         # Number of steps to wait for objects to stabilize in sim
    num_trials_per_task: int = 50                    # Number of rollouts per task
    initial_states_path: str = "DEFAULT"             # "DEFAULT", or path to initial states JSON file
    env_img_res: int = 256                           # Resolution for environment images (not policy input resolution)

    #################################################################################################################
    # Utils
    #################################################################################################################
    run_id_note: Optional[str] = None                # Extra note to add to end of run ID for logging
    local_log_dir: str = "./experiments/logs"        # Local directory for eval logs

    use_wandb: bool = False                          # Whether to also log results in Weights & Biases
    wandb_entity: str = "your-wandb-entity"          # Name of WandB entity
    wandb_project: str = "your-wandb-project"        # Name of WandB project

    seed: int = 7                                    # Random Seed (for reproducibility)

def validate_config(cfg: GenerateConfig) -> None:
    """Validate configuration parameters."""
    assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"

    if "image_aug" in str(cfg.pretrained_checkpoint):
        assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"

    assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"

    # Validate task suite
    # assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"

def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately
    """Initialize model and associated components."""
    # Load model
    model = get_model(cfg)

    # Load proprio projector if needed
    proprio_projector = None
    if cfg.use_proprio:
        proprio_projector = get_proprio_projector(
            cfg,
            model.llm_dim,
            proprio_dim=PROPRIO_DIM, #set the proprio_dim for different robots
        )

    # Load action head if needed
    action_head = None
    if cfg.use_l1_regression or cfg.use_diffusion:
        action_head = get_action_head(cfg, model.llm_dim)

    # Load noisy action projector if using diffusion
    noisy_action_projector = None
    if cfg.use_diffusion:
        noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)

    # Get OpenVLA processor if needed
    processor = None
    if not only_pt:
        if cfg.model_family == "openvla":
            processor = get_processor(cfg)
            # check_unnorm_key(cfg, model)

    return model, action_head, proprio_projector, noisy_action_projector, processor

# @draccus.wrap()
def generate_feature_vector(cfg: GenerateConfig):
    """Generate a feature vector (parameter differences) between two task-specific models."""
    # Validate configuration

    # Set random seed
    set_seed_everywhere(cfg.seed)

    # Initialize model and components
    model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)

    original_config = GenerateConfig(
        pretrained_checkpoint=cfg.original_pretrained_checkpoint,
        task_suite_name=cfg.task_suite_name,
        )

    original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config)
    #for action_head and noisy_action_projector, these modules are not interpolated
    assert len(model.state_dict()) == len(original_model.state_dict())
    feature_vector_dict = {}
    total = len(original_model.state_dict())
    for name, original_model_param in tqdm(original_model.named_parameters(), total=total):
        model_param = model.state_dict()[name]
        feature_vector_dict[name] = (model_param - original_model_param).detach().cpu()

    return feature_vector_dict

# @draccus.wrap()
def interpolate_feature_vector(cfg: GenerateConfig):
    """Interpolate feature vector."""
    feature_vector_dict = torch.load(cfg.vector_save_path)

    pt_vla_config = GenerateConfig(
        pretrained_checkpoint=cfg.pt_ckpt,
        original_pretrained_checkpoint=cfg.original_pretrained_checkpoint,
        vector_save_path=cfg.vector_save_path,
        initialized_pt_vla_path=cfg.initialized_pt_vla_path,
        feature_vector_weight=cfg.feature_vector_weight,
        pt_ckpt=cfg.pt_ckpt,
        task_suite_name=cfg.task_suite_name,
        use_proprio=False,
        use_l1_regression=False,
        use_diffusion=False
    )

    pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True)

    #copy the SF parameters for checking the change before and after interpolation
    model_sd = pt_vla.state_dict()
    before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point}

    with torch.no_grad():
        pt_params = dict(pt_vla.named_parameters())
        for name, diff in feature_vector_dict.items():
            if name in pt_params:
                pt_param = pt_params[name]
                diff = diff.to(pt_param.device)
                pt_param.add_(diff, alpha=cfg.feature_vector_weight)

    #check after interpolation
    diffs_after = []
    for name, before_tensor in before_interp_sd.items():
        after_tensor = model_sd[name]
        difference = (after_tensor - before_tensor).float().norm().item()
        diffs_after.append(difference)

    print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, "
          f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}")

    #########################################################
    return pt_vla

@draccus.wrap()
def main(cfg: GenerateConfig):
    if not os.path.exists(cfg.vector_save_path):
        feature_vector_dict = generate_feature_vector(cfg)
        torch.save(feature_vector_dict, cfg.vector_save_path)
    else:
        print(f"Feature vector already exists at {cfg.vector_save_path}")
    initialized_pt_vla = interpolate_feature_vector(cfg)
    os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True)
    initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path)

if __name__ == "__main__":
    main()