| """
|
| finetune.py
|
|
|
| Fine-tunes OpenVLA via LoRA.
|
| """
|
|
|
| import os
|
| import time
|
| from collections import deque
|
| from dataclasses import dataclass
|
| from pathlib import Path
|
| from typing import Dict, Optional, Tuple, Type
|
|
|
| import draccus
|
| import torch
|
| import torch.distributed as dist
|
| import torch.nn as nn
|
| import tqdm
|
| from accelerate import PartialState
|
| from huggingface_hub import HfApi, snapshot_download
|
| from peft import LoraConfig, PeftModel, get_peft_model
|
| from torch.nn.parallel import DistributedDataParallel as DDP
|
| from torch.optim import AdamW
|
| from torch.optim.lr_scheduler import MultiStepLR
|
| from torch.utils.data import DataLoader
|
| from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
|
| from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
| import wandb
|
| os.environ["WANDB_MODE"]="offline"
|
|
|
| from experiments.robot.openvla_utils import (
|
| check_model_logic_mismatch,
|
| model_is_on_hf_hub,
|
| update_auto_map,
|
| )
|
|
|
| from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead
|
| from prismatic.models.backbones.llm.prompting import PurePromptBuilder
|
| from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone
|
| from prismatic.models.projectors import (
|
| NoisyActionProjector,
|
| ProprioProjector,
|
| )
|
| from prismatic.training.train_utils import (
|
| compute_actions_l1_loss,
|
| compute_token_accuracy,
|
| get_current_action_mask,
|
| get_next_actions_mask,
|
| )
|
| from prismatic.util.data_utils import PaddedCollatorForActionPrediction
|
| from prismatic.vla.action_tokenizer import ActionTokenizer
|
| from prismatic.vla.constants import (
|
| ACTION_DIM,
|
| ACTION_PROPRIO_NORMALIZATION_TYPE,
|
| NUM_ACTIONS_CHUNK,
|
| PROPRIO_DIM,
|
| )
|
| from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
|
| from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics
|
|
|
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
| import debugpy
|
| try:
|
| debugpy.listen(("localhost", 9501))
|
| print("Waiting for debugger attach")
|
| debugpy.wait_for_client()
|
| except Exception as e:
|
| pass
|
|
|
|
|
| @dataclass
|
| class FinetuneConfig:
|
|
|
| vla_path: str = "openvla/openvla-7b"
|
|
|
|
|
| data_root_dir: Path = Path("datasets/rlds")
|
| dataset_name: str = "aloha_scoop_x_into_bowl"
|
| run_root_dir: Path = Path("runs")
|
| shuffle_buffer_size: int = 100_000
|
|
|
|
|
| use_l1_regression: bool = True
|
| use_diffusion: bool = False
|
| num_diffusion_steps_train: int = 50
|
| use_film: bool = False
|
| num_images_in_input: int = 1
|
| use_proprio: bool = False
|
|
|
|
|
| batch_size: int = 8
|
| learning_rate: float = 5e-4
|
| lr_warmup_steps: int = 0
|
| num_steps_before_decay: int = 100_000
|
| grad_accumulation_steps: int = 1
|
| max_steps: int = 200_000
|
| use_val_set: bool = False
|
| val_freq: int = 10_000
|
| val_time_limit: int = 180
|
| save_freq: int = 10_000
|
| save_latest_checkpoint_only: bool = False
|
|
|
| resume: bool = False
|
| resume_step: Optional[int] = None
|
| image_aug: bool = True
|
| diffusion_sample_freq: int = 50
|
|
|
|
|
| use_lora: bool = True
|
| lora_rank: int = 32
|
| lora_dropout: float = 0.0
|
| merge_lora_during_training: bool = True
|
|
|
|
|
|
|
|
|
| wandb_entity: str = "your-wandb-entity"
|
| wandb_project: str = "your-wandb-project"
|
| run_id_note: Optional[str] = None
|
| run_id_override: Optional[str] = None
|
| wandb_log_freq: int = 10
|
|
|
|
|
|
|
|
|
| def remove_ddp_in_checkpoint(state_dict) -> dict:
|
| """
|
| Removes the 'module.' prefix from parameter names in a PyTorch model state dictionary that was saved using
|
| DistributedDataParallel (DDP).
|
|
|
| When a model is trained using PyTorch's DistributedDataParallel, the saved state dictionary contains parameters
|
| prefixed with 'module.'. This function removes these prefixes to make the state dictionary compatible when
|
| loading into models that are not yet wrapped in DDP.
|
|
|
| Args:
|
| state_dict (dict): PyTorch model state dictionary.
|
|
|
| Returns:
|
| dict: A new state dictionary with the same contents but with 'module.' prefixes removed from parameter names.
|
| Parameters without the 'module.' prefix remain unchanged.
|
| """
|
| new_state_dict = {}
|
| for k, v in state_dict.items():
|
| if k[:7] == "module.":
|
| new_state_dict[k[7:]] = v
|
| else:
|
| new_state_dict[k] = v
|
| return new_state_dict
|
|
|
|
|
| def get_run_id(cfg) -> str:
|
| """
|
| Generates or retrieves an identifier string for an experiment run.
|
|
|
| Args:
|
| cfg (FinetuneConfig): Training configuration.
|
|
|
| Returns:
|
| str: Experiment run ID.
|
| """
|
| if cfg.run_id_override is not None:
|
|
|
| run_id = cfg.run_id_override
|
| elif cfg.resume:
|
|
|
| run_id = cfg.vla_path.split("/")[-1]
|
|
|
| if "chkpt" in run_id.split("--")[-1]:
|
| run_id = "--".join(run_id.split("--")[:-1])
|
| else:
|
| run_id = (
|
| f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
|
| f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
|
| f"+lr-{cfg.learning_rate}"
|
| )
|
| if cfg.use_lora:
|
| run_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
|
| if cfg.image_aug:
|
| run_id += "--image_aug"
|
| if cfg.run_id_note is not None:
|
| run_id += f"--{cfg.run_id_note}"
|
| return run_id
|
|
|
|
|
| def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict:
|
| """
|
| Loads a checkpoint for a given module.
|
|
|
| Args:
|
| module_name (str): Name of model component to load checkpoint for.
|
| path (str): Path to checkpoint directory.
|
| step (int): Gradient step number of saved checkpoint.
|
| device (str): String specifying how to remap storage locations (default = "cpu").
|
|
|
| Returns:
|
| dict: PyTorch model state dictionary.
|
| """
|
| checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt")
|
| print(f"Loading checkpoint: {checkpoint_path}")
|
| state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
|
| return remove_ddp_in_checkpoint(state_dict)
|
|
|
|
|
| def wrap_ddp(module: nn.Module, device_id: int, find_unused: bool = False) -> DDP:
|
| """
|
| Wrap a module with DistributedDataParallel.
|
|
|
| Args:
|
| module (nn.Module): PyTorch module.
|
| device_id (str): Device ID.
|
| find_unused (bool): Whether to detect parameters without gradients in distributed training.
|
|
|
| Returns:
|
| DistributedDataParallel: PyTorch module wrapped with DDP.
|
| """
|
| return DDP(module, device_ids=[device_id], find_unused_parameters=find_unused, gradient_as_bucket_view=True)
|
|
|
|
|
| def count_parameters(module: nn.Module, name: str) -> None:
|
| """
|
| Counts and prints the number of trainable parameters in a module.
|
|
|
| Args:
|
| module (nn.Module): PyTorch module.
|
| module_name (str): Name of model component.
|
|
|
| Returns:
|
| None.
|
| """
|
| num_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| print(f"# trainable params in {name}: {num_params}")
|
|
|
|
|
| def init_module(
|
| module_class: Type[nn.Module],
|
| module_name: str,
|
| cfg: FinetuneConfig,
|
| device_id: int,
|
| module_args: dict,
|
| to_bf16: bool = False,
|
| find_unused_params: bool = False,
|
| ) -> DDP:
|
| """
|
| Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP.
|
|
|
| Args:
|
| module_class (Type[nn.Module]): Class of PyTorch module to initialize.
|
| module_name (str): Name of model component to load checkpoint for.
|
| cfg (FinetuneConfig): Training configuration.
|
| device_id (str): Device ID.
|
| module_args (dict): Args for initializing the module.
|
| to_bf16 (bool): Whether to convert to torch.bfloat16 data type.
|
| find_unused_params (bool): Whether to detect parameters without gradients in distributed training.
|
|
|
| Returns:
|
| DistributedDataParallel: PyTorch module wrapped with DDP.
|
| """
|
| module = module_class(**module_args)
|
| count_parameters(module, module_name)
|
|
|
| if cfg.resume:
|
| state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step)
|
| module.load_state_dict(state_dict)
|
|
|
| if to_bf16:
|
| module = module.to(torch.bfloat16)
|
| module = module.to(device_id)
|
|
|
| return wrap_ddp(module, device_id, find_unused_params)
|
|
|
|
|
| def run_forward_pass(
|
| vla,
|
| action_head,
|
| noisy_action_projector,
|
| proprio_projector,
|
| batch,
|
| action_tokenizer,
|
| device_id,
|
| use_l1_regression,
|
| use_diffusion,
|
| use_proprio,
|
| use_film,
|
| num_patches,
|
| compute_diffusion_l1=False,
|
| num_diffusion_steps_train=None,
|
| ) -> Tuple[torch.Tensor, Dict[str, float]]:
|
| """
|
| Compute model forward pass and metrics for both training and validation.
|
|
|
| Args:
|
| vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| action_head (nn.Module): Action head module.
|
| noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| proprio_projector (nn.Module): Proprioceptive state projector module.
|
| batch (dict): Input batch.
|
| action_tokenizer (ActionTokenizer): Action tokenizer.
|
| device_id (str): Device ID.
|
| use_l1_regression (bool): Whether to use L1 regression.
|
| use_diffusion (bool): Whether to use diffusion.
|
| use_proprio (bool): Whether to use proprioceptive state as input.
|
| use_film (bool): Whether to use FiLM for better language following.
|
| num_patches (int): Number of vision patches.
|
| compute_diffusion_l1 (bool): Whether to sample actions and compute L1 loss for diffusion (do this once every
|
| diffusion_sample_freq steps during training; do it every batch for validation)
|
| num_diffusion_steps_train (int): Number of diffusion steps for training (only used for diffusion).
|
|
|
| Returns:
|
| tuple: (loss, metrics_dict)
|
| loss: The loss tensor with gradient for backpropagation.
|
| metrics_dict: Dictionary of computed metrics (detached values for logging).
|
| """
|
| metrics = {}
|
|
|
|
|
| ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16)
|
|
|
|
|
| if use_diffusion:
|
| noisy_dict = action_head.module.sample_noisy_actions(ground_truth_actions)
|
| noise, noisy_actions, diffusion_timestep_embeddings = (
|
| noisy_dict["noise"],
|
| noisy_dict["noisy_actions"],
|
| noisy_dict["diffusion_timestep_embeddings"],
|
| )
|
| else:
|
| noise, noisy_actions, diffusion_timestep_embeddings = None, None, None
|
|
|
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16):
|
| output: CausalLMOutputWithPast = vla(
|
| input_ids=batch["input_ids"].to(device_id),
|
| attention_mask=batch["attention_mask"].to(device_id),
|
| pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
|
| labels=batch["labels"],
|
| output_hidden_states=True,
|
| proprio=batch["proprio"] if use_proprio else None,
|
| proprio_projector=proprio_projector if use_proprio else None,
|
| noisy_actions=noisy_actions if use_diffusion else None,
|
| noisy_action_projector=noisy_action_projector if use_diffusion else None,
|
| diffusion_timestep_embeddings=diffusion_timestep_embeddings if use_diffusion else None,
|
| use_film=use_film,
|
| )
|
|
|
|
|
| ground_truth_token_ids = batch["labels"][:, 1:].to(device_id)
|
| current_action_mask = get_current_action_mask(ground_truth_token_ids)
|
| next_actions_mask = get_next_actions_mask(ground_truth_token_ids)
|
|
|
|
|
| if not (use_l1_regression or use_diffusion):
|
| loss = output.loss
|
| predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2)
|
| curr_action_accuracy = compute_token_accuracy(
|
| predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
|
| )
|
| curr_action_l1_loss = compute_actions_l1_loss(
|
| action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
|
| )
|
| next_actions_accuracy = compute_token_accuracy(
|
| predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
|
| )
|
| next_actions_l1_loss = compute_actions_l1_loss(
|
| action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
|
| )
|
| metrics.update(
|
| {
|
| "loss_value": loss.item(),
|
| "curr_action_accuracy": curr_action_accuracy.item(),
|
| "curr_action_l1_loss": curr_action_l1_loss.item(),
|
| "next_actions_accuracy": next_actions_accuracy.item(),
|
| "next_actions_l1_loss": next_actions_l1_loss.item(),
|
| }
|
| )
|
|
|
| else:
|
|
|
| last_hidden_states = output.hidden_states[-1]
|
|
|
| text_hidden_states = last_hidden_states[:, num_patches:-1]
|
|
|
| batch_size = batch["input_ids"].shape[0]
|
| actions_hidden_states = (
|
| text_hidden_states[current_action_mask | next_actions_mask]
|
| .reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1)
|
| .to(torch.bfloat16)
|
| )
|
|
|
| if use_l1_regression:
|
|
|
| predicted_actions = action_head.module.predict_action(actions_hidden_states)
|
|
|
| loss = torch.nn.L1Loss()(ground_truth_actions, predicted_actions)
|
|
|
| if use_diffusion:
|
|
|
| noise_pred = action_head.module.predict_noise(actions_hidden_states)
|
|
|
| noise_pred = noise_pred.reshape(noise.shape)
|
| loss = nn.functional.mse_loss(noise_pred, noise, reduction="mean")
|
|
|
|
|
| if compute_diffusion_l1:
|
| with torch.no_grad():
|
| predicted_actions = run_diffusion_sampling(
|
| vla=vla,
|
| action_head=action_head,
|
| noisy_action_projector=noisy_action_projector,
|
| proprio_projector=proprio_projector,
|
| batch=batch,
|
| batch_size=batch_size,
|
| num_patches=num_patches,
|
| actions_shape=ground_truth_actions.shape,
|
| device_id=device_id,
|
| current_action_mask=current_action_mask,
|
| next_actions_mask=next_actions_mask,
|
| use_proprio=use_proprio,
|
| use_film=use_film,
|
| )
|
|
|
| metrics.update(
|
| {
|
| "loss_value": loss.item(),
|
| }
|
| )
|
|
|
|
|
| should_log_l1_loss = not use_diffusion or (use_diffusion and compute_diffusion_l1)
|
| if should_log_l1_loss:
|
| ground_truth_curr_action = ground_truth_actions[:, 0]
|
| predicted_curr_action = predicted_actions[:, 0]
|
| ground_truth_next_actions = ground_truth_actions[:, 1:]
|
| predicted_next_actions = predicted_actions[:, 1:]
|
| curr_action_l1_loss = torch.nn.L1Loss()(ground_truth_curr_action, predicted_curr_action)
|
| next_actions_l1_loss = torch.nn.L1Loss()(ground_truth_next_actions, predicted_next_actions)
|
| metrics.update(
|
| {
|
| "curr_action_l1_loss": curr_action_l1_loss.item(),
|
| "next_actions_l1_loss": next_actions_l1_loss.item(),
|
| }
|
| )
|
|
|
|
|
| return loss, metrics
|
|
|
|
|
| def run_diffusion_sampling(
|
| vla,
|
| action_head,
|
| noisy_action_projector,
|
| proprio_projector,
|
| batch,
|
| batch_size,
|
| num_patches,
|
| actions_shape,
|
| device_id,
|
| current_action_mask,
|
| next_actions_mask,
|
| use_proprio,
|
| use_film,
|
| ) -> torch.Tensor:
|
| """
|
| Run diffusion sampling (reverse diffusion) to generate actions.
|
|
|
| Args:
|
| vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| action_head (nn.Module): Action head module.
|
| noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| proprio_projector (nn.Module): Proprioceptive state projector module.
|
| batch (dict): Input batch.
|
| batch_size (int): Batch size.
|
| num_patches (int): Number of vision patches.
|
| actions_shape (tuple): Shape of ground-truth actions.
|
| device_id (str): Device ID.
|
| current_action_mask (torch.Tensor): Mask for current action.
|
| next_actions_mask (torch.Tensor): Mask for next actions.
|
| use_proprio (bool): Whether to use proprioceptive state as input.
|
| use_film (bool): Whether to use FiLM for better language following.
|
|
|
| Returns:
|
| torch.Tensor: Predicted actions.
|
| """
|
|
|
| noise = torch.randn(
|
| size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM),
|
| device=device_id,
|
| dtype=torch.bfloat16,
|
| )
|
|
|
|
|
| action_head.module.noise_scheduler.set_timesteps(action_head.module.num_diffusion_steps_train)
|
|
|
|
|
| curr_noisy_actions = noise
|
| for t in action_head.module.noise_scheduler.timesteps:
|
|
|
|
|
| timesteps = torch.Tensor([t]).repeat(batch_size).to(device_id)
|
| diffusion_timestep_embeddings = (
|
| action_head.module.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
|
| )
|
| diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1)
|
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16):
|
| output = vla(
|
| input_ids=batch["input_ids"].to(device_id),
|
| attention_mask=batch["attention_mask"].to(device_id),
|
| pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
|
| labels=batch["labels"],
|
| output_hidden_states=True,
|
| proprio=batch["proprio"] if use_proprio else None,
|
| proprio_projector=proprio_projector if use_proprio else None,
|
| noisy_actions=curr_noisy_actions,
|
| noisy_action_projector=noisy_action_projector,
|
| diffusion_timestep_embeddings=diffusion_timestep_embeddings,
|
| use_film=use_film,
|
| )
|
|
|
| last_hidden_states = output.hidden_states[-1]
|
|
|
| text_hidden_states = last_hidden_states[:, num_patches:-1]
|
|
|
| actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask].reshape(
|
| batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1
|
| )
|
| actions_hidden_states = actions_hidden_states.to(torch.bfloat16)
|
|
|
| noise_pred = action_head.module.predict_noise(actions_hidden_states)
|
|
|
|
|
| curr_noisy_actions = action_head.module.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
|
|
|
| return curr_noisy_actions.reshape(actions_shape)
|
|
|
|
|
| def compute_smoothened_metrics(metrics_deques) -> dict:
|
| """
|
| Compute smoothened metrics from recent deques.
|
|
|
| Args:
|
| metrics_deques (dict): Dictionary of deques containing recent metrics.
|
|
|
| Returns:
|
| dict: Dictionary of smoothened metrics.
|
| """
|
| smoothened_metrics = {}
|
| for name, deque in metrics_deques.items():
|
| if deque and len(deque) > 0:
|
| smoothened_metrics[name] = sum(deque) / len(deque)
|
| return smoothened_metrics
|
|
|
|
|
| def log_metrics_to_wandb(metrics, prefix, step, wandb_entity) -> None:
|
| """
|
| Log metrics to Weights & Biases.
|
|
|
| Args:
|
| metrics (dict): Dictionary of metrics to log
|
| prefix (str): Prefix for metric names
|
| step (int): Training step
|
| wandb_entity (str): W&B entity instance
|
|
|
| Returns:
|
| None.
|
| """
|
| log_dict = {}
|
| for name, value in metrics.items():
|
|
|
| if name == "loss_value":
|
| log_dict[f"{prefix}/Loss"] = value
|
|
|
| else:
|
| log_dict[f"{prefix}/{name.replace('_', ' ').title()}"] = value
|
| wandb_entity.log(log_dict, step=step)
|
|
|
|
|
| def save_training_checkpoint(
|
| cfg,
|
| run_dir,
|
| log_step,
|
| vla,
|
| processor,
|
| proprio_projector,
|
| noisy_action_projector,
|
| action_head,
|
| train_dataset,
|
| distributed_state,
|
| ) -> None:
|
| """
|
| Save all training checkpoints including model components, LoRA adapter, and dataset statistics.
|
|
|
| Args:
|
| cfg (FinetuneConfig): Training configuration.
|
| run_dir (Path): Experiment run directory path.
|
| log_step (int): Current logging step.
|
| vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| processor (PrismaticProcessor): OpenVLA inputs processor.
|
| proprio_projector (nn.Module): Proprioceptive state projector module.
|
| noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| action_head (nn.Module): Action head module.
|
| train_dataset (RLDSDataset): Training dataset.
|
| distributed_state (PartialState): Distributed training state.
|
|
|
| Returns:
|
| None.
|
| """
|
|
|
| if cfg.save_latest_checkpoint_only:
|
| checkpoint_dir = run_dir
|
| checkpoint_name_suffix = "latest_checkpoint.pt"
|
| else:
|
| checkpoint_dir = Path(str(run_dir) + f"--{log_step}_chkpt")
|
| checkpoint_name_suffix = f"{log_step}_checkpoint.pt"
|
|
|
| adapter_dir = checkpoint_dir / "lora_adapter"
|
|
|
|
|
| if distributed_state.is_main_process:
|
| os.makedirs(checkpoint_dir, exist_ok=True)
|
| os.makedirs(adapter_dir, exist_ok=True)
|
| save_dataset_statistics(train_dataset.dataset_statistics, checkpoint_dir)
|
| print(f"Saving Model Checkpoint for Step {log_step}")
|
|
|
|
|
| dist.barrier()
|
|
|
|
|
| if distributed_state.is_main_process:
|
|
|
| processor.save_pretrained(checkpoint_dir)
|
| vla.module.save_pretrained(adapter_dir)
|
|
|
|
|
| if cfg.use_proprio and proprio_projector is not None:
|
| torch.save(proprio_projector.state_dict(), checkpoint_dir / f"proprio_projector--{checkpoint_name_suffix}")
|
|
|
| if cfg.use_diffusion and noisy_action_projector is not None:
|
| torch.save(
|
| noisy_action_projector.state_dict(), checkpoint_dir / f"noisy_action_projector--{checkpoint_name_suffix}"
|
| )
|
|
|
| if (cfg.use_l1_regression or cfg.use_diffusion) and action_head is not None:
|
| torch.save(action_head.state_dict(), checkpoint_dir / f"action_head--{checkpoint_name_suffix}")
|
|
|
| if cfg.use_film:
|
|
|
| torch.save(
|
| vla.module.vision_backbone.state_dict(), checkpoint_dir / f"vision_backbone--{checkpoint_name_suffix}"
|
| )
|
|
|
|
|
| dist.barrier()
|
|
|
|
|
|
|
| if cfg.use_lora and cfg.merge_lora_during_training:
|
| base_vla = AutoModelForVision2Seq.from_pretrained(
|
| cfg.vla_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
|
| )
|
| merged_vla = PeftModel.from_pretrained(base_vla, adapter_dir)
|
| merged_vla = merged_vla.merge_and_unload()
|
|
|
| if distributed_state.is_main_process:
|
| merged_vla.save_pretrained(checkpoint_dir)
|
| print(f"Saved merged model for Step {log_step} at: {checkpoint_dir}")
|
|
|
|
|
| dist.barrier()
|
|
|
|
|
| def run_validation(
|
| vla,
|
| action_head,
|
| noisy_action_projector,
|
| proprio_projector,
|
| val_dataloader,
|
| action_tokenizer,
|
| device_id,
|
| cfg,
|
| num_patches,
|
| log_step,
|
| distributed_state,
|
| val_time_limit,
|
| ) -> None:
|
| """
|
| Compute validation set metrics for logging.
|
|
|
| Args:
|
| vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| action_head (nn.Module): Action head module.
|
| noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| proprio_projector (nn.Module): Proprioceptive state projector module.
|
| val_dataloader (DataLoader): Validation data loader.
|
| action_tokenizer (ActionTokenizer): Action tokenizer.
|
| device_id (str): Device ID.
|
| cfg (FinetuneConfig): Training configuration.
|
| num_patches (int): Number of vision patches.
|
| log_step (int): Current logging step.
|
| distributed_state (PartialState): Distributed training state.
|
| val_time_limit (int): Time limit for computing validation metrics.
|
|
|
| Returns:
|
| None.
|
| """
|
| val_start_time = time.time()
|
| vla.eval()
|
| val_batches_count = 0
|
|
|
|
|
| all_val_metrics = []
|
|
|
| with torch.no_grad():
|
| for batch in val_dataloader:
|
|
|
| _, metrics = run_forward_pass(
|
| vla=vla,
|
| action_head=action_head,
|
| noisy_action_projector=noisy_action_projector,
|
| proprio_projector=proprio_projector,
|
| batch=batch,
|
| action_tokenizer=action_tokenizer,
|
| device_id=device_id,
|
| use_l1_regression=cfg.use_l1_regression,
|
| use_diffusion=cfg.use_diffusion,
|
| use_proprio=cfg.use_proprio,
|
| use_film=cfg.use_film,
|
| num_patches=num_patches,
|
| compute_diffusion_l1=True,
|
| num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
|
| )
|
|
|
|
|
| metrics["loss"] = metrics["loss_value"]
|
| all_val_metrics.append(metrics)
|
| val_batches_count += 1
|
|
|
|
|
| if time.time() - val_start_time > val_time_limit:
|
| break
|
|
|
|
|
| avg_val_metrics = {}
|
| for metric_name in all_val_metrics[0].keys():
|
| values = [metrics[metric_name] for metrics in all_val_metrics if metric_name in metrics]
|
| if values:
|
| avg_val_metrics[metric_name] = sum(values) / len(values)
|
|
|
|
|
| avg_val_metrics["val_batches_count"] = val_batches_count
|
|
|
|
|
| if distributed_state.is_main_process:
|
| log_metrics_to_wandb(avg_val_metrics, "VLA Val", log_step, wandb)
|
|
|
|
|
| @draccus.wrap()
|
| def finetune(cfg: FinetuneConfig) -> None:
|
| """
|
| Fine-tunes base VLA on demonstration dataset via LoRA.
|
|
|
| Allows toggling different action representations (discrete vs. continuous), different learning objectives
|
| (next-token prediction vs. L1 regression vs. diffusion), FiLM. Also allows for additional model inputs,
|
| such as additional camera images and robot proprioceptive state. Assumes parallel action generation with
|
| action chunking.
|
|
|
| Args:
|
| cfg (FinetuneConfig): Training configuration.
|
|
|
| Returns:
|
| None.
|
| """
|
| assert cfg.use_lora, "Only LoRA fine-tuning is supported. Please set --use_lora=True!"
|
| assert not (cfg.use_l1_regression and cfg.use_diffusion), (
|
| "Cannot do both L1 regression and diffusion. Please pick one of them!"
|
| )
|
|
|
|
|
| cfg.vla_path = cfg.vla_path.rstrip("/")
|
| print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")
|
|
|
|
|
| run_id = get_run_id(cfg)
|
|
|
|
|
| run_dir = cfg.run_root_dir / run_id
|
| os.makedirs(run_dir, exist_ok=True)
|
|
|
|
|
| distributed_state = PartialState()
|
| device_id = distributed_state.local_process_index
|
| torch.cuda.set_device(device_id)
|
| torch.cuda.empty_cache()
|
|
|
|
|
| if distributed_state.is_main_process:
|
| wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=run_id)
|
|
|
|
|
| print(
|
| "Detected constants:\n"
|
| f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n"
|
| f"\tACTION_DIM: {ACTION_DIM}\n"
|
| f"\tPROPRIO_DIM: {PROPRIO_DIM}\n"
|
| f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}"
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if model_is_on_hf_hub(cfg.vla_path):
|
|
|
| vla_download_path = snapshot_download(repo_id=cfg.vla_path)
|
|
|
| cfg.vla_path = vla_download_path
|
| else:
|
|
|
| AutoConfig.register("openvla", OpenVLAConfig)
|
| AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
|
| AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
|
| AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
|
|
|
|
|
| if distributed_state.is_main_process:
|
| update_auto_map(cfg.vla_path)
|
| check_model_logic_mismatch(cfg.vla_path)
|
|
|
|
|
| dist.barrier()
|
|
|
|
|
| processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
|
| vla = AutoModelForVision2Seq.from_pretrained(
|
| cfg.vla_path,
|
| torch_dtype=torch.bfloat16,
|
| low_cpu_mem_usage=True,
|
| trust_remote_code=True,
|
| ).to(device_id)
|
|
|
|
|
| vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input)
|
|
|
|
|
| if cfg.use_lora:
|
| lora_config = LoraConfig(
|
| r=cfg.lora_rank,
|
| lora_alpha=min(cfg.lora_rank, 16),
|
| lora_dropout=cfg.lora_dropout,
|
| target_modules="all-linear",
|
| init_lora_weights="gaussian",
|
| )
|
| vla = get_peft_model(vla, lora_config)
|
| vla.print_trainable_parameters()
|
|
|
|
|
| if cfg.use_film:
|
| count_parameters(vla.vision_backbone, "vla.vision_backbone (original)")
|
|
|
|
|
|
|
|
|
| vla.model.vision_backbone = FiLMedPrismaticVisionBackbone(
|
| vision_backbone=vla.model.vision_backbone,
|
| llm_dim=vla.llm_dim,
|
| )
|
| count_parameters(vla.vision_backbone, "vla.vision_backbone (post-wrap)")
|
| if cfg.resume:
|
| state_dict = load_checkpoint("vision_backbone", cfg.vla_path, cfg.resume_step)
|
| vla.model.vision_backbone.load_state_dict(state_dict)
|
| vla.model.vision_backbone = vla.model.vision_backbone.to(device_id)
|
|
|
|
|
| vla = wrap_ddp(vla, device_id, find_unused=True)
|
|
|
|
|
| if cfg.use_proprio:
|
| proprio_projector = init_module(
|
| ProprioProjector,
|
| "proprio_projector",
|
| cfg,
|
| device_id,
|
| {"llm_dim": vla.module.llm_dim, "proprio_dim": PROPRIO_DIM},
|
| )
|
|
|
|
|
| if cfg.use_l1_regression:
|
| action_head = init_module(
|
| L1RegressionActionHead,
|
| "action_head",
|
| cfg,
|
| device_id,
|
| {"input_dim": vla.module.llm_dim, "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM},
|
| to_bf16=True,
|
| )
|
|
|
|
|
| if cfg.use_diffusion:
|
| action_head = init_module(
|
| DiffusionActionHead,
|
| "action_head",
|
| cfg,
|
| device_id,
|
| {
|
| "input_dim": vla.module.llm_dim,
|
| "hidden_dim": vla.module.llm_dim,
|
| "action_dim": ACTION_DIM,
|
| "num_diffusion_steps_train": cfg.num_diffusion_steps_train,
|
| },
|
| to_bf16=True,
|
| )
|
| noisy_action_projector = init_module(
|
| NoisyActionProjector, "noisy_action_projector", cfg, device_id, {"llm_dim": vla.module.llm_dim}
|
| )
|
|
|
|
|
| NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input()
|
|
|
| if cfg.use_proprio:
|
| NUM_PATCHES += 1
|
|
|
| if cfg.use_diffusion:
|
| NUM_PATCHES += 1
|
|
|
|
|
| trainable_params = [param for param in vla.parameters() if param.requires_grad]
|
| if cfg.use_l1_regression or cfg.use_diffusion:
|
| trainable_params += [param for param in action_head.parameters() if param.requires_grad]
|
| if cfg.use_diffusion:
|
| trainable_params += [param for param in noisy_action_projector.parameters() if param.requires_grad]
|
| if cfg.use_proprio:
|
| trainable_params += [param for param in proprio_projector.parameters() if param.requires_grad]
|
| print(f"# total trainable params: {sum(p.numel() for p in trainable_params)}")
|
| optimizer = AdamW(trainable_params, lr=cfg.learning_rate)
|
|
|
|
|
| original_lr = optimizer.param_groups[0]["lr"]
|
|
|
|
|
| scheduler = MultiStepLR(
|
| optimizer,
|
| milestones=[cfg.num_steps_before_decay],
|
| gamma=0.1,
|
| )
|
|
|
|
|
| action_tokenizer = ActionTokenizer(processor.tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| use_wrist_image = cfg.num_images_in_input > 1
|
|
|
|
|
| batch_transform = RLDSBatchTransform(
|
| action_tokenizer,
|
| processor.tokenizer,
|
| image_transform=processor.image_processor.apply_transform,
|
| prompt_builder_fn=PurePromptBuilder,
|
| use_wrist_image=use_wrist_image,
|
| use_proprio=cfg.use_proprio,
|
| )
|
| train_dataset = RLDSDataset(
|
| cfg.data_root_dir,
|
| cfg.dataset_name,
|
| batch_transform,
|
| resize_resolution=tuple(vla.module.config.image_sizes),
|
| shuffle_buffer_size=cfg.shuffle_buffer_size,
|
| image_aug=cfg.image_aug,
|
| )
|
| if cfg.use_val_set:
|
| val_dataset = RLDSDataset(
|
| cfg.data_root_dir,
|
| cfg.dataset_name,
|
| batch_transform,
|
| resize_resolution=tuple(vla.module.config.image_sizes),
|
| shuffle_buffer_size=cfg.shuffle_buffer_size // 10,
|
| image_aug=cfg.image_aug,
|
| train=False,
|
| )
|
|
|
|
|
| if distributed_state.is_main_process:
|
| save_dataset_statistics(train_dataset.dataset_statistics, run_dir)
|
|
|
|
|
| collator = PaddedCollatorForActionPrediction(
|
| processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
|
| )
|
| dataloader = DataLoader(
|
| train_dataset,
|
| batch_size=cfg.batch_size,
|
| sampler=None,
|
| collate_fn=collator,
|
| num_workers=0,
|
| )
|
| if cfg.use_val_set:
|
| val_batch_size = cfg.batch_size
|
| val_dataloader = DataLoader(
|
| val_dataset,
|
| batch_size=val_batch_size,
|
| sampler=None,
|
| collate_fn=collator,
|
| num_workers=0,
|
| )
|
|
|
|
|
| recent_metrics = {
|
| "loss_value": deque(maxlen=cfg.grad_accumulation_steps),
|
| "curr_action_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
|
| "curr_action_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
|
| "next_actions_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
|
| "next_actions_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
|
| }
|
|
|
|
|
| with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress:
|
| vla.train()
|
| optimizer.zero_grad()
|
| for batch_idx, batch in enumerate(dataloader):
|
|
|
| compute_diffusion_l1 = cfg.use_diffusion and batch_idx % cfg.diffusion_sample_freq == 0
|
| loss, metrics = run_forward_pass(
|
| vla=vla,
|
| action_head=action_head,
|
| noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| batch=batch,
|
| action_tokenizer=action_tokenizer,
|
| device_id=device_id,
|
| use_l1_regression=cfg.use_l1_regression,
|
| use_diffusion=cfg.use_diffusion,
|
| use_proprio=cfg.use_proprio,
|
| use_film=cfg.use_film,
|
| num_patches=NUM_PATCHES,
|
| compute_diffusion_l1=compute_diffusion_l1,
|
| num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
|
| )
|
|
|
|
|
| normalized_loss = loss / cfg.grad_accumulation_steps
|
|
|
|
|
| normalized_loss.backward()
|
|
|
|
|
| for metric_name, value in metrics.items():
|
| if metric_name in recent_metrics:
|
| recent_metrics[metric_name].append(value)
|
|
|
|
|
| gradient_step_idx = batch_idx // cfg.grad_accumulation_steps
|
|
|
|
|
| smoothened_metrics = compute_smoothened_metrics(recent_metrics)
|
|
|
|
|
| log_step = gradient_step_idx if not cfg.resume else cfg.resume_step + gradient_step_idx
|
| if distributed_state.is_main_process and log_step % cfg.wandb_log_freq == 0:
|
| log_metrics_to_wandb(smoothened_metrics, "VLA Train", log_step, wandb)
|
|
|
|
|
| if cfg.lr_warmup_steps > 0:
|
| lr_progress = min((gradient_step_idx + 1) / cfg.lr_warmup_steps, 1.0)
|
| current_lr = original_lr * (0.1 + 0.9 * lr_progress)
|
| for param_group in optimizer.param_groups:
|
| param_group["lr"] = current_lr
|
|
|
| if distributed_state.is_main_process and gradient_step_idx % cfg.wandb_log_freq == 0:
|
|
|
|
|
| wandb.log(
|
| {
|
| "VLA Train/Learning Rate": scheduler.get_last_lr()[0],
|
| },
|
| step=log_step,
|
| )
|
|
|
|
|
| if (batch_idx + 1) % cfg.grad_accumulation_steps == 0:
|
| optimizer.step()
|
| scheduler.step()
|
| optimizer.zero_grad()
|
| progress.update()
|
|
|
|
|
| if gradient_step_idx > 0 and log_step % cfg.save_freq == 0:
|
| save_training_checkpoint(
|
| cfg=cfg,
|
| run_dir=run_dir,
|
| log_step=log_step,
|
| vla=vla,
|
| processor=processor,
|
| proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None,
|
| train_dataset=train_dataset,
|
| distributed_state=distributed_state,
|
| )
|
|
|
|
|
| if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0:
|
| run_validation(
|
| vla=vla,
|
| action_head=action_head,
|
| noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| val_dataloader=val_dataloader,
|
| action_tokenizer=action_tokenizer,
|
| device_id=device_id,
|
| cfg=cfg,
|
| num_patches=NUM_PATCHES,
|
| log_step=log_step,
|
| distributed_state=distributed_state,
|
| val_time_limit=cfg.val_time_limit,
|
| )
|
|
|
| vla.train()
|
|
|
|
|
| if log_step == cfg.max_steps:
|
| print(f"Max step {cfg.max_steps} reached! Stopping training...")
|
| break
|
|
|
|
|
| if __name__ == "__main__":
|
| finetune()
|
|
|