Jinhuiye's picture
Add files using upload-large-folder tool
c8173fb verified
#!/bin/bash
#SBATCH --account=vonneumann1
#SBATCH --partition=vonneumann
#SBATCH --gpus=1
#SBATCH --nodes=1
#SBATCH --time=8:00:00
#SBATCH --job-name=libero_train
#SBATCH --output=logs/train_%j.log
#SBATCH --error=logs/train_%j.err
#
# Usage:
# sbatch examples/LIBERO/train_files/sbatch_libero_train.sh
#
# Override GPU count:
# sbatch --gpus=4 examples/LIBERO/train_files/sbatch_libero_train.sh
#
set -e
# === Conda setup ===
source /cm/shared/apps/Anaconda3/2023.09-0/etc/profile.d/conda.sh
conda activate starVLA
# === CUDA setup ===
for cuda_path in /usr/local/cuda /usr/local/cuda-12 /usr/local/cuda-12.4; do
if [ -x "${cuda_path}/bin/nvcc" ]; then
export CUDA_HOME="${cuda_path}"
export PATH="${cuda_path}/bin:${PATH}"
export LD_LIBRARY_PATH="${cuda_path}/lib64:${LD_LIBRARY_PATH:-}"
break
fi
done
# nvcc wrapper fallback
if ! nvcc --version 2>&1 | grep -q "release"; then
_WRAPPER_DIR="${CONDA_PREFIX}/cuda_compat/bin"
mkdir -p "${_WRAPPER_DIR}" 2>/dev/null || true
_TORCH_CUDA_VER=$(python -c "import torch; print(torch.version.cuda)" 2>/dev/null || echo "12.4")
_MAJOR=$(echo "${_TORCH_CUDA_VER}" | cut -d. -f1)
_MINOR=$(echo "${_TORCH_CUDA_VER}" | cut -d. -f2)
cat > "${_WRAPPER_DIR}/nvcc" << NVCC_EOF
#!/bin/bash
echo "nvcc: NVIDIA (R) Cuda compiler driver"
echo "Cuda compilation tools, release ${_MAJOR}.${_MINOR}, V${_TORCH_CUDA_VER}"
NVCC_EOF
chmod +x "${_WRAPPER_DIR}/nvcc"
export PATH="${_WRAPPER_DIR}:${PATH}"
export CUDA_HOME="${CONDA_PREFIX}/cuda_compat"
echo "[INFO] Created nvcc wrapper: CUDA ${_TORCH_CUDA_VER}"
fi
echo "[INFO] CUDA_HOME=$CUDA_HOME"
nvcc --version 2>/dev/null || echo "[WARN] nvcc not found"
# === NCCL ===
export NCCL_BLOCKING_WAIT=1
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_TIMEOUT=10000
export NCCL_SOCKET_TIMEOUT_MS=360000
###########################################################################################
# === Training config ===
cd /home/jye624/Projcets/starVLA
Framework_name=CosmoPredict2GR00T
freeze_module_list=''
base_vlm=/home/jye624/Models/Pretrained_models/Qwen3-VL-4B-Instruct
config_yaml=./examples/LIBERO/train_files/starvla_cotrain_libero.yaml
libero_data_root=/home/jye624/Datasets/LIBERO
data_mix=libero_all
run_root_dir=./results/Checkpoints
run_id=0405_libero4in1_${Framework_name}
per_device_batch_size=8
###########################################################################################
export WANDB_API_KEY=${WANDB_API_KEY:-943ecb8d26fc2b3879cbc2d667414974906aebb9}
output_dir=${run_root_dir}/${run_id}
mkdir -p ${output_dir} logs/
cp $0 ${output_dir}/
# Auto-detect GPU count from SLURM allocation
num_processes=${SLURM_GPUS_ON_NODE:-$(nvidia-smi -L | wc -l)}
attn_implementation=sdpa
accelerate_config_file=starVLA/config/deepseeds/deepspeed_zero2.yaml
main_process_port=${MAIN_PROCESS_PORT:-29501}
echo "=============================="
echo "Job ID: ${SLURM_JOB_ID}"
echo "Node: ${SLURM_NODELIST}"
echo "GPUs: ${num_processes}"
echo "Batch/GPU: ${per_device_batch_size}"
echo "Framework: ${Framework_name}"
echo "Run ID: ${run_id}"
echo "=============================="
sg vonneumann1 -c "
source /cm/shared/apps/Anaconda3/2023.09-0/etc/profile.d/conda.sh && \
conda activate starVLA && \
accelerate launch \
--config_file ${accelerate_config_file} \
--num_processes ${num_processes} \
--main_process_port ${main_process_port} \
starVLA/training/train_starvla.py \
--config_yaml ${config_yaml} \
--framework.name ${Framework_name} \
--framework.qwenvl.base_vlm ${base_vlm} \
--framework.action_model.future_action_window_size 7 \
--framework.action_model.past_action_window_size 0 \
--datasets.vla_data.data_root_dir ${libero_data_root} \
--datasets.vla_data.data_mix ${data_mix} \
--datasets.vla_data.per_device_batch_size ${per_device_batch_size} \
--trainer.vla_data.video_backend torchvision_av \
--framework.qwenvl.attn_implementation ${attn_implementation} \
--trainer.freeze_modules ${freeze_module_list} \
--trainer.max_train_steps 80000 \
--trainer.save_interval 10000 \
--trainer.logging_frequency 100 \
--trainer.eval_interval 100 \
--run_root_dir ${run_root_dir} \
--run_id ${run_id} \
--wandb_project starVLA_Libero \
--wandb_entity jinhuiye
"