File size: 5,244 Bytes
b388cd5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | #!/bin/bash
# ============================================================================
# Pre-training script for ModernProteinLM on private GPU cluster
#
# Usage:
# Single GPU: bash run_pretrain.sh
# Multi-GPU: torchrun --nproc_per_node=4 run_pretrain.sh
# SLURM: sbatch run_pretrain.sh (see SLURM section below)
# ============================================================================
set -e
# ----------------------------------------------------------------------------
# CONFIGURATION - ADJUST FOR YOUR CLUSTER
# ----------------------------------------------------------------------------
# Data
DATA_DIR="${DATA_DIR:-./data}"
UNIREF_PATH="${UNIREF_PATH:-$DATA_DIR/uniref50.fasta}"
# Alternative: use HuggingFace datasets streaming (no local download needed)
USE_STREAMING="${USE_STREAMING:-1}"
# Model architecture
HIDDEN_SIZE="${HIDDEN_SIZE:-576}"
NUM_LAYERS="${NUM_LAYERS:-28}"
NUM_HEADS="${NUM_HEADS:-9}"
INTERMEDIATE_SIZE="${INTERMEDIATE_SIZE:-2304}"
MAX_SEQ_LENGTH="${MAX_SEQ_LENGTH:-1024}"
# Generator (25% of discriminator)
GEN_HIDDEN_SIZE="${GEN_HIDDEN_SIZE:-320}"
GEN_NUM_LAYERS="${GEN_NUM_LAYERS:-8}"
GEN_NUM_HEADS="${GEN_NUM_HEADS:-8}"
GEN_INTERMEDIATE="${GEN_INTERMEDIATE:-1280}"
# Training hyperparameters
BATCH_SIZE="${BATCH_SIZE:-64}" # Per-device batch size
MAX_STEPS="${MAX_STEPS:-100000}"
WARMUP_STEPS="${WARMUP_STEPS:-10000}"
LR="${LR:-5e-4}"
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
GRAD_CLIP="${GRAD_CLIP:-1.0}"
GEN_WEIGHT="${GEN_WEIGHT:-1.0}"
DISC_WEIGHT="${DISC_WEIGHT:-50.0}"
# Masking curriculum
MASK_START="${MASK_START:-0.30}"
MASK_END="${MASK_END:-0.05}"
SPAN_LENGTH="${SPAN_LENGTH:-3}"
# System
OUTPUT_DIR="${OUTPUT_DIR:-./outputs/pretrain}"
NUM_WORKERS="${NUM_WORKERS:-8}"
LOG_INTERVAL="${LOG_INTERVAL:-100}"
EVAL_INTERVAL="${EVAL_INTERVAL:-5000}"
SAVE_INTERVAL="${SAVE_INTERVAL:-5000}"
NUM_GPUS="${NUM_GPUS:-1}"
MASTER_PORT="${MASTER_PORT:-29500}"
# Precision
USE_AMP="${USE_AMP:-1}" # Automatic Mixed Precision (bf16)
USE_FLASH_ATTN="${USE_FLASH_ATTN:-1}" # FlashAttention (pip install flash-attn)
# Checkpointing
RESUME_FROM="${RESUME_FROM:-}"
GRADIENT_CHECKPOINTING="${GRADIENT_CHECKPOINTING:-0}"
# Tracking
USE_TRACKIO="${USE_TRACKIO:-0}"
TRACKIO_PROJECT="${TRACKIO_PROJECT:-modern-protein-lm}"
TRACKIO_SPACE_ID="${TRACKIO_SPACE_ID:-}"
# ----------------------------------------------------------------------------
# DERIVED SETTINGS
# ----------------------------------------------------------------------------
TOTAL_BS=$(( BATCH_SIZE * NUM_GPUS ))
echo "=========================================="
echo "ModernProteinLM Pre-training Configuration"
echo "=========================================="
echo "GPUs: $NUM_GPUS"
echo "Per-device BS: $BATCH_SIZE"
echo "Total batch size: $TOTAL_BS"
echo "Max steps: $MAX_STEPS"
echo "Learning rate: $LR"
echo "Output dir: $OUTPUT_DIR"
echo "FlashAttention: $USE_FLASH_ATTN"
echo "AMP: $USE_AMP"
echo "=========================================="
mkdir -p "$OUTPUT_DIR"
# ----------------------------------------------------------------------------
# LAUNCH
# ----------------------------------------------------------------------------
PYTHON_ARGS=(
train_pretrain.py
--output_dir "$OUTPUT_DIR"
--hidden_size "$HIDDEN_SIZE"
--num_layers "$NUM_LAYERS"
--num_heads "$NUM_HEADS"
--intermediate_size "$INTERMEDIATE_SIZE"
--gen_hidden_size "$GEN_HIDDEN_SIZE"
--gen_num_layers "$GEN_NUM_LAYERS"
--gen_num_heads "$GEN_NUM_HEADS"
--gen_intermediate_size "$GEN_INTERMEDIATE"
--max_seq_length "$MAX_SEQ_LENGTH"
--batch_size "$BATCH_SIZE"
--max_steps "$MAX_STEPS"
--warmup_steps "$WARMUP_STEPS"
--lr "$LR"
--weight_decay "$WEIGHT_DECAY"
--grad_clip "$GRAD_CLIP"
--gen_weight "$GEN_WEIGHT"
--disc_weight "$DISC_WEIGHT"
--mask_start "$MASK_START"
--mask_end "$MASK_END"
--span_length "$SPAN_LENGTH"
--num_workers "$NUM_WORKERS"
--log_interval "$LOG_INTERVAL"
--eval_interval "$EVAL_INTERVAL"
--save_interval "$SAVE_INTERVAL"
)
if [[ "$USE_STREAMING" == "1" ]]; then
PYTHON_ARGS+=(--use_streaming)
fi
if [[ "$USE_AMP" == "1" ]]; then
PYTHON_ARGS+=(--use_amp)
fi
if [[ "$USE_FLASH_ATTN" == "1" ]]; then
PYTHON_ARGS+=(--use_flash_attn)
fi
if [[ -n "$RESUME_FROM" ]]; then
PYTHON_ARGS+=(--resume_from "$RESUME_FROM")
fi
if [[ "$GRADIENT_CHECKPOINTING" == "1" ]]; then
PYTHON_ARGS+=(--gradient_checkpointing)
fi
if [[ "$USE_TRACKIO" == "1" ]]; then
PYTHON_ARGS+=(--use_trackio --trackio_project "$TRACKIO_PROJECT")
if [[ -n "$TRACKIO_SPACE_ID" ]]; then
PYTHON_ARGS+=(--trackio_space_id "$TRACKIO_SPACE_ID")
fi
fi
# Detect torchrun / mpirun / srun
if command -v torchrun &> /dev/null && [[ "$NUM_GPUS" -gt 1 ]]; then
echo "Launching with torchrun (DDP) on $NUM_GPUS GPUs..."
torchrun \
--standalone \
--nnodes=1 \
--nproc_per_node="$NUM_GPUS" \
--master_port="$MASTER_PORT" \
"${PYTHON_ARGS[@]}"
else
echo "Launching single-process training..."
python "${PYTHON_ARGS[@]}"
fi
echo "Pre-training complete. Checkpoint saved to $OUTPUT_DIR"
|