#!/bin/bash # ============================================================================ # Pre-training script for ModernProteinLM on private GPU cluster # # Usage: # Single GPU: bash run_pretrain.sh # Multi-GPU: torchrun --nproc_per_node=4 run_pretrain.sh # SLURM: sbatch run_pretrain.sh (see SLURM section below) # ============================================================================ set -e # ---------------------------------------------------------------------------- # CONFIGURATION - ADJUST FOR YOUR CLUSTER # ---------------------------------------------------------------------------- # Data DATA_DIR="${DATA_DIR:-./data}" UNIREF_PATH="${UNIREF_PATH:-$DATA_DIR/uniref50.fasta}" # Alternative: use HuggingFace datasets streaming (no local download needed) USE_STREAMING="${USE_STREAMING:-1}" # Model architecture HIDDEN_SIZE="${HIDDEN_SIZE:-576}" NUM_LAYERS="${NUM_LAYERS:-28}" NUM_HEADS="${NUM_HEADS:-9}" INTERMEDIATE_SIZE="${INTERMEDIATE_SIZE:-2304}" MAX_SEQ_LENGTH="${MAX_SEQ_LENGTH:-1024}" # Generator (25% of discriminator) GEN_HIDDEN_SIZE="${GEN_HIDDEN_SIZE:-320}" GEN_NUM_LAYERS="${GEN_NUM_LAYERS:-8}" GEN_NUM_HEADS="${GEN_NUM_HEADS:-8}" GEN_INTERMEDIATE="${GEN_INTERMEDIATE:-1280}" # Training hyperparameters BATCH_SIZE="${BATCH_SIZE:-64}" # Per-device batch size MAX_STEPS="${MAX_STEPS:-100000}" WARMUP_STEPS="${WARMUP_STEPS:-10000}" LR="${LR:-5e-4}" WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}" GRAD_CLIP="${GRAD_CLIP:-1.0}" GEN_WEIGHT="${GEN_WEIGHT:-1.0}" DISC_WEIGHT="${DISC_WEIGHT:-50.0}" # Masking curriculum MASK_START="${MASK_START:-0.30}" MASK_END="${MASK_END:-0.05}" SPAN_LENGTH="${SPAN_LENGTH:-3}" # System OUTPUT_DIR="${OUTPUT_DIR:-./outputs/pretrain}" NUM_WORKERS="${NUM_WORKERS:-8}" LOG_INTERVAL="${LOG_INTERVAL:-100}" EVAL_INTERVAL="${EVAL_INTERVAL:-5000}" SAVE_INTERVAL="${SAVE_INTERVAL:-5000}" NUM_GPUS="${NUM_GPUS:-1}" MASTER_PORT="${MASTER_PORT:-29500}" # Precision USE_AMP="${USE_AMP:-1}" # Automatic Mixed Precision (bf16) USE_FLASH_ATTN="${USE_FLASH_ATTN:-1}" # FlashAttention (pip install flash-attn) # Checkpointing RESUME_FROM="${RESUME_FROM:-}" GRADIENT_CHECKPOINTING="${GRADIENT_CHECKPOINTING:-0}" # Tracking USE_TRACKIO="${USE_TRACKIO:-0}" TRACKIO_PROJECT="${TRACKIO_PROJECT:-modern-protein-lm}" TRACKIO_SPACE_ID="${TRACKIO_SPACE_ID:-}" # ---------------------------------------------------------------------------- # DERIVED SETTINGS # ---------------------------------------------------------------------------- TOTAL_BS=$(( BATCH_SIZE * NUM_GPUS )) echo "==========================================" echo "ModernProteinLM Pre-training Configuration" echo "==========================================" echo "GPUs: $NUM_GPUS" echo "Per-device BS: $BATCH_SIZE" echo "Total batch size: $TOTAL_BS" echo "Max steps: $MAX_STEPS" echo "Learning rate: $LR" echo "Output dir: $OUTPUT_DIR" echo "FlashAttention: $USE_FLASH_ATTN" echo "AMP: $USE_AMP" echo "==========================================" mkdir -p "$OUTPUT_DIR" # ---------------------------------------------------------------------------- # LAUNCH # ---------------------------------------------------------------------------- PYTHON_ARGS=( train_pretrain.py --output_dir "$OUTPUT_DIR" --hidden_size "$HIDDEN_SIZE" --num_layers "$NUM_LAYERS" --num_heads "$NUM_HEADS" --intermediate_size "$INTERMEDIATE_SIZE" --gen_hidden_size "$GEN_HIDDEN_SIZE" --gen_num_layers "$GEN_NUM_LAYERS" --gen_num_heads "$GEN_NUM_HEADS" --gen_intermediate_size "$GEN_INTERMEDIATE" --max_seq_length "$MAX_SEQ_LENGTH" --batch_size "$BATCH_SIZE" --max_steps "$MAX_STEPS" --warmup_steps "$WARMUP_STEPS" --lr "$LR" --weight_decay "$WEIGHT_DECAY" --grad_clip "$GRAD_CLIP" --gen_weight "$GEN_WEIGHT" --disc_weight "$DISC_WEIGHT" --mask_start "$MASK_START" --mask_end "$MASK_END" --span_length "$SPAN_LENGTH" --num_workers "$NUM_WORKERS" --log_interval "$LOG_INTERVAL" --eval_interval "$EVAL_INTERVAL" --save_interval "$SAVE_INTERVAL" ) if [[ "$USE_STREAMING" == "1" ]]; then PYTHON_ARGS+=(--use_streaming) fi if [[ "$USE_AMP" == "1" ]]; then PYTHON_ARGS+=(--use_amp) fi if [[ "$USE_FLASH_ATTN" == "1" ]]; then PYTHON_ARGS+=(--use_flash_attn) fi if [[ -n "$RESUME_FROM" ]]; then PYTHON_ARGS+=(--resume_from "$RESUME_FROM") fi if [[ "$GRADIENT_CHECKPOINTING" == "1" ]]; then PYTHON_ARGS+=(--gradient_checkpointing) fi if [[ "$USE_TRACKIO" == "1" ]]; then PYTHON_ARGS+=(--use_trackio --trackio_project "$TRACKIO_PROJECT") if [[ -n "$TRACKIO_SPACE_ID" ]]; then PYTHON_ARGS+=(--trackio_space_id "$TRACKIO_SPACE_ID") fi fi # Detect torchrun / mpirun / srun if command -v torchrun &> /dev/null && [[ "$NUM_GPUS" -gt 1 ]]; then echo "Launching with torchrun (DDP) on $NUM_GPUS GPUs..." torchrun \ --standalone \ --nnodes=1 \ --nproc_per_node="$NUM_GPUS" \ --master_port="$MASTER_PORT" \ "${PYTHON_ARGS[@]}" else echo "Launching single-process training..." python "${PYTHON_ARGS[@]}" fi echo "Pre-training complete. Checkpoint saved to $OUTPUT_DIR"