#!/bin/bash #SBATCH --account=vonneumann1 #SBATCH --partition=vonneumann #SBATCH --gpus=1 #SBATCH --nodes=1 #SBATCH --time=8:00:00 #SBATCH --job-name=libero_train #SBATCH --output=logs/train_%j.log #SBATCH --error=logs/train_%j.err # # Usage: # sbatch examples/LIBERO/train_files/sbatch_libero_train.sh # # Override GPU count: # sbatch --gpus=4 examples/LIBERO/train_files/sbatch_libero_train.sh # set -e # === Conda setup === source /cm/shared/apps/Anaconda3/2023.09-0/etc/profile.d/conda.sh conda activate starVLA # === CUDA setup === for cuda_path in /usr/local/cuda /usr/local/cuda-12 /usr/local/cuda-12.4; do if [ -x "${cuda_path}/bin/nvcc" ]; then export CUDA_HOME="${cuda_path}" export PATH="${cuda_path}/bin:${PATH}" export LD_LIBRARY_PATH="${cuda_path}/lib64:${LD_LIBRARY_PATH:-}" break fi done # nvcc wrapper fallback if ! nvcc --version 2>&1 | grep -q "release"; then _WRAPPER_DIR="${CONDA_PREFIX}/cuda_compat/bin" mkdir -p "${_WRAPPER_DIR}" 2>/dev/null || true _TORCH_CUDA_VER=$(python -c "import torch; print(torch.version.cuda)" 2>/dev/null || echo "12.4") _MAJOR=$(echo "${_TORCH_CUDA_VER}" | cut -d. -f1) _MINOR=$(echo "${_TORCH_CUDA_VER}" | cut -d. -f2) cat > "${_WRAPPER_DIR}/nvcc" << NVCC_EOF #!/bin/bash echo "nvcc: NVIDIA (R) Cuda compiler driver" echo "Cuda compilation tools, release ${_MAJOR}.${_MINOR}, V${_TORCH_CUDA_VER}" NVCC_EOF chmod +x "${_WRAPPER_DIR}/nvcc" export PATH="${_WRAPPER_DIR}:${PATH}" export CUDA_HOME="${CONDA_PREFIX}/cuda_compat" echo "[INFO] Created nvcc wrapper: CUDA ${_TORCH_CUDA_VER}" fi echo "[INFO] CUDA_HOME=$CUDA_HOME" nvcc --version 2>/dev/null || echo "[WARN] nvcc not found" # === NCCL === export NCCL_BLOCKING_WAIT=1 export NCCL_ASYNC_ERROR_HANDLING=1 export NCCL_TIMEOUT=10000 export NCCL_SOCKET_TIMEOUT_MS=360000 ########################################################################################### # === Training config === cd /home/jye624/Projcets/starVLA Framework_name=CosmoPredict2GR00T freeze_module_list='' base_vlm=/home/jye624/Models/Pretrained_models/Qwen3-VL-4B-Instruct config_yaml=./examples/LIBERO/train_files/starvla_cotrain_libero.yaml libero_data_root=/home/jye624/Datasets/LIBERO data_mix=libero_all run_root_dir=./results/Checkpoints run_id=0405_libero4in1_${Framework_name} per_device_batch_size=8 ########################################################################################### export WANDB_API_KEY=${WANDB_API_KEY:-943ecb8d26fc2b3879cbc2d667414974906aebb9} output_dir=${run_root_dir}/${run_id} mkdir -p ${output_dir} logs/ cp $0 ${output_dir}/ # Auto-detect GPU count from SLURM allocation num_processes=${SLURM_GPUS_ON_NODE:-$(nvidia-smi -L | wc -l)} attn_implementation=sdpa accelerate_config_file=starVLA/config/deepseeds/deepspeed_zero2.yaml main_process_port=${MAIN_PROCESS_PORT:-29501} echo "==============================" echo "Job ID: ${SLURM_JOB_ID}" echo "Node: ${SLURM_NODELIST}" echo "GPUs: ${num_processes}" echo "Batch/GPU: ${per_device_batch_size}" echo "Framework: ${Framework_name}" echo "Run ID: ${run_id}" echo "==============================" sg vonneumann1 -c " source /cm/shared/apps/Anaconda3/2023.09-0/etc/profile.d/conda.sh && \ conda activate starVLA && \ accelerate launch \ --config_file ${accelerate_config_file} \ --num_processes ${num_processes} \ --main_process_port ${main_process_port} \ starVLA/training/train_starvla.py \ --config_yaml ${config_yaml} \ --framework.name ${Framework_name} \ --framework.qwenvl.base_vlm ${base_vlm} \ --framework.action_model.future_action_window_size 7 \ --framework.action_model.past_action_window_size 0 \ --datasets.vla_data.data_root_dir ${libero_data_root} \ --datasets.vla_data.data_mix ${data_mix} \ --datasets.vla_data.per_device_batch_size ${per_device_batch_size} \ --trainer.vla_data.video_backend torchvision_av \ --framework.qwenvl.attn_implementation ${attn_implementation} \ --trainer.freeze_modules ${freeze_module_list} \ --trainer.max_train_steps 80000 \ --trainer.save_interval 10000 \ --trainer.logging_frequency 100 \ --trainer.eval_interval 100 \ --run_root_dir ${run_root_dir} \ --run_id ${run_id} \ --wandb_project starVLA_Libero \ --wandb_entity jinhuiye "