ModernProteinLM / run_pretrain.sh
GrimSqueaker's picture
Upload run_pretrain.sh with huggingface_hub
b388cd5 verified
#!/bin/bash
# ============================================================================
# Pre-training script for ModernProteinLM on private GPU cluster
#
# Usage:
# Single GPU: bash run_pretrain.sh
# Multi-GPU: torchrun --nproc_per_node=4 run_pretrain.sh
# SLURM: sbatch run_pretrain.sh (see SLURM section below)
# ============================================================================
set -e
# ----------------------------------------------------------------------------
# CONFIGURATION - ADJUST FOR YOUR CLUSTER
# ----------------------------------------------------------------------------
# Data
DATA_DIR="${DATA_DIR:-./data}"
UNIREF_PATH="${UNIREF_PATH:-$DATA_DIR/uniref50.fasta}"
# Alternative: use HuggingFace datasets streaming (no local download needed)
USE_STREAMING="${USE_STREAMING:-1}"
# Model architecture
HIDDEN_SIZE="${HIDDEN_SIZE:-576}"
NUM_LAYERS="${NUM_LAYERS:-28}"
NUM_HEADS="${NUM_HEADS:-9}"
INTERMEDIATE_SIZE="${INTERMEDIATE_SIZE:-2304}"
MAX_SEQ_LENGTH="${MAX_SEQ_LENGTH:-1024}"
# Generator (25% of discriminator)
GEN_HIDDEN_SIZE="${GEN_HIDDEN_SIZE:-320}"
GEN_NUM_LAYERS="${GEN_NUM_LAYERS:-8}"
GEN_NUM_HEADS="${GEN_NUM_HEADS:-8}"
GEN_INTERMEDIATE="${GEN_INTERMEDIATE:-1280}"
# Training hyperparameters
BATCH_SIZE="${BATCH_SIZE:-64}" # Per-device batch size
MAX_STEPS="${MAX_STEPS:-100000}"
WARMUP_STEPS="${WARMUP_STEPS:-10000}"
LR="${LR:-5e-4}"
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
GRAD_CLIP="${GRAD_CLIP:-1.0}"
GEN_WEIGHT="${GEN_WEIGHT:-1.0}"
DISC_WEIGHT="${DISC_WEIGHT:-50.0}"
# Masking curriculum
MASK_START="${MASK_START:-0.30}"
MASK_END="${MASK_END:-0.05}"
SPAN_LENGTH="${SPAN_LENGTH:-3}"
# System
OUTPUT_DIR="${OUTPUT_DIR:-./outputs/pretrain}"
NUM_WORKERS="${NUM_WORKERS:-8}"
LOG_INTERVAL="${LOG_INTERVAL:-100}"
EVAL_INTERVAL="${EVAL_INTERVAL:-5000}"
SAVE_INTERVAL="${SAVE_INTERVAL:-5000}"
NUM_GPUS="${NUM_GPUS:-1}"
MASTER_PORT="${MASTER_PORT:-29500}"
# Precision
USE_AMP="${USE_AMP:-1}" # Automatic Mixed Precision (bf16)
USE_FLASH_ATTN="${USE_FLASH_ATTN:-1}" # FlashAttention (pip install flash-attn)
# Checkpointing
RESUME_FROM="${RESUME_FROM:-}"
GRADIENT_CHECKPOINTING="${GRADIENT_CHECKPOINTING:-0}"
# Tracking
USE_TRACKIO="${USE_TRACKIO:-0}"
TRACKIO_PROJECT="${TRACKIO_PROJECT:-modern-protein-lm}"
TRACKIO_SPACE_ID="${TRACKIO_SPACE_ID:-}"
# ----------------------------------------------------------------------------
# DERIVED SETTINGS
# ----------------------------------------------------------------------------
TOTAL_BS=$(( BATCH_SIZE * NUM_GPUS ))
echo "=========================================="
echo "ModernProteinLM Pre-training Configuration"
echo "=========================================="
echo "GPUs: $NUM_GPUS"
echo "Per-device BS: $BATCH_SIZE"
echo "Total batch size: $TOTAL_BS"
echo "Max steps: $MAX_STEPS"
echo "Learning rate: $LR"
echo "Output dir: $OUTPUT_DIR"
echo "FlashAttention: $USE_FLASH_ATTN"
echo "AMP: $USE_AMP"
echo "=========================================="
mkdir -p "$OUTPUT_DIR"
# ----------------------------------------------------------------------------
# LAUNCH
# ----------------------------------------------------------------------------
PYTHON_ARGS=(
train_pretrain.py
--output_dir "$OUTPUT_DIR"
--hidden_size "$HIDDEN_SIZE"
--num_layers "$NUM_LAYERS"
--num_heads "$NUM_HEADS"
--intermediate_size "$INTERMEDIATE_SIZE"
--gen_hidden_size "$GEN_HIDDEN_SIZE"
--gen_num_layers "$GEN_NUM_LAYERS"
--gen_num_heads "$GEN_NUM_HEADS"
--gen_intermediate_size "$GEN_INTERMEDIATE"
--max_seq_length "$MAX_SEQ_LENGTH"
--batch_size "$BATCH_SIZE"
--max_steps "$MAX_STEPS"
--warmup_steps "$WARMUP_STEPS"
--lr "$LR"
--weight_decay "$WEIGHT_DECAY"
--grad_clip "$GRAD_CLIP"
--gen_weight "$GEN_WEIGHT"
--disc_weight "$DISC_WEIGHT"
--mask_start "$MASK_START"
--mask_end "$MASK_END"
--span_length "$SPAN_LENGTH"
--num_workers "$NUM_WORKERS"
--log_interval "$LOG_INTERVAL"
--eval_interval "$EVAL_INTERVAL"
--save_interval "$SAVE_INTERVAL"
)
if [[ "$USE_STREAMING" == "1" ]]; then
PYTHON_ARGS+=(--use_streaming)
fi
if [[ "$USE_AMP" == "1" ]]; then
PYTHON_ARGS+=(--use_amp)
fi
if [[ "$USE_FLASH_ATTN" == "1" ]]; then
PYTHON_ARGS+=(--use_flash_attn)
fi
if [[ -n "$RESUME_FROM" ]]; then
PYTHON_ARGS+=(--resume_from "$RESUME_FROM")
fi
if [[ "$GRADIENT_CHECKPOINTING" == "1" ]]; then
PYTHON_ARGS+=(--gradient_checkpointing)
fi
if [[ "$USE_TRACKIO" == "1" ]]; then
PYTHON_ARGS+=(--use_trackio --trackio_project "$TRACKIO_PROJECT")
if [[ -n "$TRACKIO_SPACE_ID" ]]; then
PYTHON_ARGS+=(--trackio_space_id "$TRACKIO_SPACE_ID")
fi
fi
# Detect torchrun / mpirun / srun
if command -v torchrun &> /dev/null && [[ "$NUM_GPUS" -gt 1 ]]; then
echo "Launching with torchrun (DDP) on $NUM_GPUS GPUs..."
torchrun \
--standalone \
--nnodes=1 \
--nproc_per_node="$NUM_GPUS" \
--master_port="$MASTER_PORT" \
"${PYTHON_ARGS[@]}"
else
echo "Launching single-process training..."
python "${PYTHON_ARGS[@]}"
fi
echo "Pre-training complete. Checkpoint saved to $OUTPUT_DIR"