| #!/bin/bash |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| set -e |
|
|
| |
| |
| |
|
|
| |
| DATA_DIR="${DATA_DIR:-./data}" |
| UNIREF_PATH="${UNIREF_PATH:-$DATA_DIR/uniref50.fasta}" |
| |
| USE_STREAMING="${USE_STREAMING:-1}" |
|
|
| |
| HIDDEN_SIZE="${HIDDEN_SIZE:-576}" |
| NUM_LAYERS="${NUM_LAYERS:-28}" |
| NUM_HEADS="${NUM_HEADS:-9}" |
| INTERMEDIATE_SIZE="${INTERMEDIATE_SIZE:-2304}" |
| MAX_SEQ_LENGTH="${MAX_SEQ_LENGTH:-1024}" |
|
|
| |
| GEN_HIDDEN_SIZE="${GEN_HIDDEN_SIZE:-320}" |
| GEN_NUM_LAYERS="${GEN_NUM_LAYERS:-8}" |
| GEN_NUM_HEADS="${GEN_NUM_HEADS:-8}" |
| GEN_INTERMEDIATE="${GEN_INTERMEDIATE:-1280}" |
|
|
| |
| BATCH_SIZE="${BATCH_SIZE:-64}" |
| MAX_STEPS="${MAX_STEPS:-100000}" |
| WARMUP_STEPS="${WARMUP_STEPS:-10000}" |
| LR="${LR:-5e-4}" |
| WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}" |
| GRAD_CLIP="${GRAD_CLIP:-1.0}" |
| GEN_WEIGHT="${GEN_WEIGHT:-1.0}" |
| DISC_WEIGHT="${DISC_WEIGHT:-50.0}" |
|
|
| |
| MASK_START="${MASK_START:-0.30}" |
| MASK_END="${MASK_END:-0.05}" |
| SPAN_LENGTH="${SPAN_LENGTH:-3}" |
|
|
| |
| OUTPUT_DIR="${OUTPUT_DIR:-./outputs/pretrain}" |
| NUM_WORKERS="${NUM_WORKERS:-8}" |
| LOG_INTERVAL="${LOG_INTERVAL:-100}" |
| EVAL_INTERVAL="${EVAL_INTERVAL:-5000}" |
| SAVE_INTERVAL="${SAVE_INTERVAL:-5000}" |
| NUM_GPUS="${NUM_GPUS:-1}" |
| MASTER_PORT="${MASTER_PORT:-29500}" |
|
|
| |
| USE_AMP="${USE_AMP:-1}" |
| USE_FLASH_ATTN="${USE_FLASH_ATTN:-1}" |
|
|
| |
| RESUME_FROM="${RESUME_FROM:-}" |
| GRADIENT_CHECKPOINTING="${GRADIENT_CHECKPOINTING:-0}" |
|
|
| |
| USE_TRACKIO="${USE_TRACKIO:-0}" |
| TRACKIO_PROJECT="${TRACKIO_PROJECT:-modern-protein-lm}" |
| TRACKIO_SPACE_ID="${TRACKIO_SPACE_ID:-}" |
|
|
| |
| |
| |
|
|
| TOTAL_BS=$(( BATCH_SIZE * NUM_GPUS )) |
| echo "==========================================" |
| echo "ModernProteinLM Pre-training Configuration" |
| echo "==========================================" |
| echo "GPUs: $NUM_GPUS" |
| echo "Per-device BS: $BATCH_SIZE" |
| echo "Total batch size: $TOTAL_BS" |
| echo "Max steps: $MAX_STEPS" |
| echo "Learning rate: $LR" |
| echo "Output dir: $OUTPUT_DIR" |
| echo "FlashAttention: $USE_FLASH_ATTN" |
| echo "AMP: $USE_AMP" |
| echo "==========================================" |
|
|
| mkdir -p "$OUTPUT_DIR" |
|
|
| |
| |
| |
|
|
| PYTHON_ARGS=( |
| train_pretrain.py |
| --output_dir "$OUTPUT_DIR" |
| --hidden_size "$HIDDEN_SIZE" |
| --num_layers "$NUM_LAYERS" |
| --num_heads "$NUM_HEADS" |
| --intermediate_size "$INTERMEDIATE_SIZE" |
| --gen_hidden_size "$GEN_HIDDEN_SIZE" |
| --gen_num_layers "$GEN_NUM_LAYERS" |
| --gen_num_heads "$GEN_NUM_HEADS" |
| --gen_intermediate_size "$GEN_INTERMEDIATE" |
| --max_seq_length "$MAX_SEQ_LENGTH" |
| --batch_size "$BATCH_SIZE" |
| --max_steps "$MAX_STEPS" |
| --warmup_steps "$WARMUP_STEPS" |
| --lr "$LR" |
| --weight_decay "$WEIGHT_DECAY" |
| --grad_clip "$GRAD_CLIP" |
| --gen_weight "$GEN_WEIGHT" |
| --disc_weight "$DISC_WEIGHT" |
| --mask_start "$MASK_START" |
| --mask_end "$MASK_END" |
| --span_length "$SPAN_LENGTH" |
| --num_workers "$NUM_WORKERS" |
| --log_interval "$LOG_INTERVAL" |
| --eval_interval "$EVAL_INTERVAL" |
| --save_interval "$SAVE_INTERVAL" |
| ) |
|
|
| if [[ "$USE_STREAMING" == "1" ]]; then |
| PYTHON_ARGS+=(--use_streaming) |
| fi |
|
|
| if [[ "$USE_AMP" == "1" ]]; then |
| PYTHON_ARGS+=(--use_amp) |
| fi |
|
|
| if [[ "$USE_FLASH_ATTN" == "1" ]]; then |
| PYTHON_ARGS+=(--use_flash_attn) |
| fi |
|
|
| if [[ -n "$RESUME_FROM" ]]; then |
| PYTHON_ARGS+=(--resume_from "$RESUME_FROM") |
| fi |
|
|
| if [[ "$GRADIENT_CHECKPOINTING" == "1" ]]; then |
| PYTHON_ARGS+=(--gradient_checkpointing) |
| fi |
|
|
| if [[ "$USE_TRACKIO" == "1" ]]; then |
| PYTHON_ARGS+=(--use_trackio --trackio_project "$TRACKIO_PROJECT") |
| if [[ -n "$TRACKIO_SPACE_ID" ]]; then |
| PYTHON_ARGS+=(--trackio_space_id "$TRACKIO_SPACE_ID") |
| fi |
| fi |
|
|
| |
| if command -v torchrun &> /dev/null && [[ "$NUM_GPUS" -gt 1 ]]; then |
| echo "Launching with torchrun (DDP) on $NUM_GPUS GPUs..." |
| torchrun \ |
| --standalone \ |
| --nnodes=1 \ |
| --nproc_per_node="$NUM_GPUS" \ |
| --master_port="$MASTER_PORT" \ |
| "${PYTHON_ARGS[@]}" |
| else |
| echo "Launching single-process training..." |
| python "${PYTHON_ARGS[@]}" |
| fi |
|
|
| echo "Pre-training complete. Checkpoint saved to $OUTPUT_DIR" |
|
|