| #!/bin/bash |
| #SBATCH --account=vonneumann1 |
| #SBATCH --partition=vonneumann |
| #SBATCH --gpus=1 |
| #SBATCH --nodes=1 |
| #SBATCH --time=8:00:00 |
| #SBATCH --job-name=libero_train |
| #SBATCH --output=logs/train_%j.log |
| #SBATCH --error=logs/train_%j.err |
| # |
| # Usage: |
| # sbatch examples/LIBERO/train_files/sbatch_libero_train.sh |
| # |
| # Override GPU count: |
| # sbatch --gpus=4 examples/LIBERO/train_files/sbatch_libero_train.sh |
| # |
| set -e |
|
|
| # === Conda setup === |
| source /cm/shared/apps/Anaconda3/2023.09-0/etc/profile.d/conda.sh |
| conda activate starVLA |
|
|
| # === CUDA setup === |
| for cuda_path in /usr/local/cuda /usr/local/cuda-12 /usr/local/cuda-12.4; do |
| if [ -x "${cuda_path}/bin/nvcc" ]; then |
| export CUDA_HOME="${cuda_path}" |
| export PATH="${cuda_path}/bin:${PATH}" |
| export LD_LIBRARY_PATH="${cuda_path}/lib64:${LD_LIBRARY_PATH:-}" |
| break |
| fi |
| done |
|
|
| # nvcc wrapper fallback |
| if ! nvcc --version 2>&1 | grep -q "release"; then |
| _WRAPPER_DIR="${CONDA_PREFIX}/cuda_compat/bin" |
| mkdir -p "${_WRAPPER_DIR}" 2>/dev/null || true |
| _TORCH_CUDA_VER=$(python -c "import torch; print(torch.version.cuda)" 2>/dev/null || echo "12.4") |
| _MAJOR=$(echo "${_TORCH_CUDA_VER}" | cut -d. -f1) |
| _MINOR=$(echo "${_TORCH_CUDA_VER}" | cut -d. -f2) |
| cat > "${_WRAPPER_DIR}/nvcc" << NVCC_EOF |
| #!/bin/bash |
| echo "nvcc: NVIDIA (R) Cuda compiler driver" |
| echo "Cuda compilation tools, release ${_MAJOR}.${_MINOR}, V${_TORCH_CUDA_VER}" |
| NVCC_EOF |
| chmod +x "${_WRAPPER_DIR}/nvcc" |
| export PATH="${_WRAPPER_DIR}:${PATH}" |
| export CUDA_HOME="${CONDA_PREFIX}/cuda_compat" |
| echo "[INFO] Created nvcc wrapper: CUDA ${_TORCH_CUDA_VER}" |
| fi |
|
|
| echo "[INFO] CUDA_HOME=$CUDA_HOME" |
| nvcc --version 2>/dev/null || echo "[WARN] nvcc not found" |
|
|
| # === NCCL === |
| export NCCL_BLOCKING_WAIT=1 |
| export NCCL_ASYNC_ERROR_HANDLING=1 |
| export NCCL_TIMEOUT=10000 |
| export NCCL_SOCKET_TIMEOUT_MS=360000 |
|
|
| ########################################################################################### |
| # === Training config === |
| cd /home/jye624/Projcets/starVLA |
|
|
| Framework_name=CosmoPredict2GR00T |
| freeze_module_list='' |
| base_vlm=/home/jye624/Models/Pretrained_models/Qwen3-VL-4B-Instruct |
| config_yaml=./examples/LIBERO/train_files/starvla_cotrain_libero.yaml |
| libero_data_root=/home/jye624/Datasets/LIBERO |
| data_mix=libero_all |
| run_root_dir=./results/Checkpoints |
| run_id=0405_libero4in1_${Framework_name} |
| per_device_batch_size=8 |
| ########################################################################################### |
|
|
| export WANDB_API_KEY=${WANDB_API_KEY:-943ecb8d26fc2b3879cbc2d667414974906aebb9} |
|
|
| output_dir=${run_root_dir}/${run_id} |
| mkdir -p ${output_dir} logs/ |
| cp $0 ${output_dir}/ |
|
|
| # Auto-detect GPU count from SLURM allocation |
| num_processes=${SLURM_GPUS_ON_NODE:-$(nvidia-smi -L | wc -l)} |
| attn_implementation=sdpa |
| accelerate_config_file=starVLA/config/deepseeds/deepspeed_zero2.yaml |
| main_process_port=${MAIN_PROCESS_PORT:-29501} |
|
|
| echo "==============================" |
| echo "Job ID: ${SLURM_JOB_ID}" |
| echo "Node: ${SLURM_NODELIST}" |
| echo "GPUs: ${num_processes}" |
| echo "Batch/GPU: ${per_device_batch_size}" |
| echo "Framework: ${Framework_name}" |
| echo "Run ID: ${run_id}" |
| echo "==============================" |
|
|
| sg vonneumann1 -c " |
| source /cm/shared/apps/Anaconda3/2023.09-0/etc/profile.d/conda.sh && \ |
| conda activate starVLA && \ |
| accelerate launch \ |
| --config_file ${accelerate_config_file} \ |
| --num_processes ${num_processes} \ |
| --main_process_port ${main_process_port} \ |
| starVLA/training/train_starvla.py \ |
| --config_yaml ${config_yaml} \ |
| --framework.name ${Framework_name} \ |
| --framework.qwenvl.base_vlm ${base_vlm} \ |
| --framework.action_model.future_action_window_size 7 \ |
| --framework.action_model.past_action_window_size 0 \ |
| --datasets.vla_data.data_root_dir ${libero_data_root} \ |
| --datasets.vla_data.data_mix ${data_mix} \ |
| --datasets.vla_data.per_device_batch_size ${per_device_batch_size} \ |
| --trainer.vla_data.video_backend torchvision_av \ |
| --framework.qwenvl.attn_implementation ${attn_implementation} \ |
| --trainer.freeze_modules ${freeze_module_list} \ |
| --trainer.max_train_steps 80000 \ |
| --trainer.save_interval 10000 \ |
| --trainer.logging_frequency 100 \ |
| --trainer.eval_interval 100 \ |
| --run_root_dir ${run_root_dir} \ |
| --run_id ${run_id} \ |
| --wandb_project starVLA_Libero \ |
| --wandb_entity jinhuiye |
| " |
|
|