GrimSqueaker commited on
Commit
b388cd5
·
verified ·
1 Parent(s): 3714d46

Upload run_pretrain.sh with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_pretrain.sh +167 -0
run_pretrain.sh ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================================
3
+ # Pre-training script for ModernProteinLM on private GPU cluster
4
+ #
5
+ # Usage:
6
+ # Single GPU: bash run_pretrain.sh
7
+ # Multi-GPU: torchrun --nproc_per_node=4 run_pretrain.sh
8
+ # SLURM: sbatch run_pretrain.sh (see SLURM section below)
9
+ # ============================================================================
10
+
11
+ set -e
12
+
13
+ # ----------------------------------------------------------------------------
14
+ # CONFIGURATION - ADJUST FOR YOUR CLUSTER
15
+ # ----------------------------------------------------------------------------
16
+
17
+ # Data
18
+ DATA_DIR="${DATA_DIR:-./data}"
19
+ UNIREF_PATH="${UNIREF_PATH:-$DATA_DIR/uniref50.fasta}"
20
+ # Alternative: use HuggingFace datasets streaming (no local download needed)
21
+ USE_STREAMING="${USE_STREAMING:-1}"
22
+
23
+ # Model architecture
24
+ HIDDEN_SIZE="${HIDDEN_SIZE:-576}"
25
+ NUM_LAYERS="${NUM_LAYERS:-28}"
26
+ NUM_HEADS="${NUM_HEADS:-9}"
27
+ INTERMEDIATE_SIZE="${INTERMEDIATE_SIZE:-2304}"
28
+ MAX_SEQ_LENGTH="${MAX_SEQ_LENGTH:-1024}"
29
+
30
+ # Generator (25% of discriminator)
31
+ GEN_HIDDEN_SIZE="${GEN_HIDDEN_SIZE:-320}"
32
+ GEN_NUM_LAYERS="${GEN_NUM_LAYERS:-8}"
33
+ GEN_NUM_HEADS="${GEN_NUM_HEADS:-8}"
34
+ GEN_INTERMEDIATE="${GEN_INTERMEDIATE:-1280}"
35
+
36
+ # Training hyperparameters
37
+ BATCH_SIZE="${BATCH_SIZE:-64}" # Per-device batch size
38
+ MAX_STEPS="${MAX_STEPS:-100000}"
39
+ WARMUP_STEPS="${WARMUP_STEPS:-10000}"
40
+ LR="${LR:-5e-4}"
41
+ WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
42
+ GRAD_CLIP="${GRAD_CLIP:-1.0}"
43
+ GEN_WEIGHT="${GEN_WEIGHT:-1.0}"
44
+ DISC_WEIGHT="${DISC_WEIGHT:-50.0}"
45
+
46
+ # Masking curriculum
47
+ MASK_START="${MASK_START:-0.30}"
48
+ MASK_END="${MASK_END:-0.05}"
49
+ SPAN_LENGTH="${SPAN_LENGTH:-3}"
50
+
51
+ # System
52
+ OUTPUT_DIR="${OUTPUT_DIR:-./outputs/pretrain}"
53
+ NUM_WORKERS="${NUM_WORKERS:-8}"
54
+ LOG_INTERVAL="${LOG_INTERVAL:-100}"
55
+ EVAL_INTERVAL="${EVAL_INTERVAL:-5000}"
56
+ SAVE_INTERVAL="${SAVE_INTERVAL:-5000}"
57
+ NUM_GPUS="${NUM_GPUS:-1}"
58
+ MASTER_PORT="${MASTER_PORT:-29500}"
59
+
60
+ # Precision
61
+ USE_AMP="${USE_AMP:-1}" # Automatic Mixed Precision (bf16)
62
+ USE_FLASH_ATTN="${USE_FLASH_ATTN:-1}" # FlashAttention (pip install flash-attn)
63
+
64
+ # Checkpointing
65
+ RESUME_FROM="${RESUME_FROM:-}"
66
+ GRADIENT_CHECKPOINTING="${GRADIENT_CHECKPOINTING:-0}"
67
+
68
+ # Tracking
69
+ USE_TRACKIO="${USE_TRACKIO:-0}"
70
+ TRACKIO_PROJECT="${TRACKIO_PROJECT:-modern-protein-lm}"
71
+ TRACKIO_SPACE_ID="${TRACKIO_SPACE_ID:-}"
72
+
73
+ # ----------------------------------------------------------------------------
74
+ # DERIVED SETTINGS
75
+ # ----------------------------------------------------------------------------
76
+
77
+ TOTAL_BS=$(( BATCH_SIZE * NUM_GPUS ))
78
+ echo "=========================================="
79
+ echo "ModernProteinLM Pre-training Configuration"
80
+ echo "=========================================="
81
+ echo "GPUs: $NUM_GPUS"
82
+ echo "Per-device BS: $BATCH_SIZE"
83
+ echo "Total batch size: $TOTAL_BS"
84
+ echo "Max steps: $MAX_STEPS"
85
+ echo "Learning rate: $LR"
86
+ echo "Output dir: $OUTPUT_DIR"
87
+ echo "FlashAttention: $USE_FLASH_ATTN"
88
+ echo "AMP: $USE_AMP"
89
+ echo "=========================================="
90
+
91
+ mkdir -p "$OUTPUT_DIR"
92
+
93
+ # ----------------------------------------------------------------------------
94
+ # LAUNCH
95
+ # ----------------------------------------------------------------------------
96
+
97
+ PYTHON_ARGS=(
98
+ train_pretrain.py
99
+ --output_dir "$OUTPUT_DIR"
100
+ --hidden_size "$HIDDEN_SIZE"
101
+ --num_layers "$NUM_LAYERS"
102
+ --num_heads "$NUM_HEADS"
103
+ --intermediate_size "$INTERMEDIATE_SIZE"
104
+ --gen_hidden_size "$GEN_HIDDEN_SIZE"
105
+ --gen_num_layers "$GEN_NUM_LAYERS"
106
+ --gen_num_heads "$GEN_NUM_HEADS"
107
+ --gen_intermediate_size "$GEN_INTERMEDIATE"
108
+ --max_seq_length "$MAX_SEQ_LENGTH"
109
+ --batch_size "$BATCH_SIZE"
110
+ --max_steps "$MAX_STEPS"
111
+ --warmup_steps "$WARMUP_STEPS"
112
+ --lr "$LR"
113
+ --weight_decay "$WEIGHT_DECAY"
114
+ --grad_clip "$GRAD_CLIP"
115
+ --gen_weight "$GEN_WEIGHT"
116
+ --disc_weight "$DISC_WEIGHT"
117
+ --mask_start "$MASK_START"
118
+ --mask_end "$MASK_END"
119
+ --span_length "$SPAN_LENGTH"
120
+ --num_workers "$NUM_WORKERS"
121
+ --log_interval "$LOG_INTERVAL"
122
+ --eval_interval "$EVAL_INTERVAL"
123
+ --save_interval "$SAVE_INTERVAL"
124
+ )
125
+
126
+ if [[ "$USE_STREAMING" == "1" ]]; then
127
+ PYTHON_ARGS+=(--use_streaming)
128
+ fi
129
+
130
+ if [[ "$USE_AMP" == "1" ]]; then
131
+ PYTHON_ARGS+=(--use_amp)
132
+ fi
133
+
134
+ if [[ "$USE_FLASH_ATTN" == "1" ]]; then
135
+ PYTHON_ARGS+=(--use_flash_attn)
136
+ fi
137
+
138
+ if [[ -n "$RESUME_FROM" ]]; then
139
+ PYTHON_ARGS+=(--resume_from "$RESUME_FROM")
140
+ fi
141
+
142
+ if [[ "$GRADIENT_CHECKPOINTING" == "1" ]]; then
143
+ PYTHON_ARGS+=(--gradient_checkpointing)
144
+ fi
145
+
146
+ if [[ "$USE_TRACKIO" == "1" ]]; then
147
+ PYTHON_ARGS+=(--use_trackio --trackio_project "$TRACKIO_PROJECT")
148
+ if [[ -n "$TRACKIO_SPACE_ID" ]]; then
149
+ PYTHON_ARGS+=(--trackio_space_id "$TRACKIO_SPACE_ID")
150
+ fi
151
+ fi
152
+
153
+ # Detect torchrun / mpirun / srun
154
+ if command -v torchrun &> /dev/null && [[ "$NUM_GPUS" -gt 1 ]]; then
155
+ echo "Launching with torchrun (DDP) on $NUM_GPUS GPUs..."
156
+ torchrun \
157
+ --standalone \
158
+ --nnodes=1 \
159
+ --nproc_per_node="$NUM_GPUS" \
160
+ --master_port="$MASTER_PORT" \
161
+ "${PYTHON_ARGS[@]}"
162
+ else
163
+ echo "Launching single-process training..."
164
+ python "${PYTHON_ARGS[@]}"
165
+ fi
166
+
167
+ echo "Pre-training complete. Checkpoint saved to $OUTPUT_DIR"