Spaces:
Sleeping
Sleeping
File size: 6,958 Bytes
ec4ae03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | set -euo pipefail
# ββ Flash-Attention 2 install (if missing) ββββββββββββββββββββββββββββββββββββ
# flash-attn requires (torch version, CUDA version, Python version) alignment.
# MAX_JOBS caps parallel compilation; prebuilt wheel installs in <30 s.
# In the prior run (grpo_20260425_151304), flash-attn was absent β SDPA fallback
# β iter times of 262-330 s once question-gen started (vs ~150 s with Flash).
if ! python -c "import flash_attn; assert int(flash_attn.__version__.split('.')[0]) >= 2" 2>/dev/null; then
echo "[launch] flash-attn not found or < v2 β installing now β¦"
MAX_JOBS=4 pip install flash-attn --no-build-isolation -q
echo "[launch] flash-attn installed."
else
FLASH_VER=$(python -c "import flash_attn; print(flash_attn.__version__)" 2>/dev/null)
echo "[launch] flash-attn ${FLASH_VER} already installed β skipping install."
fi
# ββ GPU / allocator βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
# expandable_segments: recovers 2-4 GB fragmented VRAM during long Flash+HF runs
export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}
# ββ CPU / threading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
export OMP_NUM_THREADS=${OMP_NUM_THREADS:-8}
export MKL_NUM_THREADS=${MKL_NUM_THREADS:-8}
export TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM:-false}
# ββ Triton / Flash-Attn compilation cache βββββββββββββββββββββββββββββββββββββ
# Persists JIT kernels across runs β avoids ~30 s recompile each launch.
export TRITON_CACHE_DIR=${TRITON_CACHE_DIR:-/tmp/triton_cache}
export FLASH_ATTENTION_SKIP_CUDA_BUILD=${FLASH_ATTENTION_SKIP_CUDA_BUILD:-FALSE}
# ββ HuggingFace hub robustness ββββββββββββββββββββββββββββββββββββββββββββββββ
export HF_HUB_DISABLE_XET=${HF_HUB_DISABLE_XET:-1}
export HF_HUB_ENABLE_HF_TRANSFER=${HF_HUB_ENABLE_HF_TRANSFER:-0}
export TRANSFORMERS_VERBOSITY=${TRANSFORMERS_VERBOSITY:-warning}
# ββ Python path βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
export PYTHONPATH="${PYTHONPATH:-}:$(pwd)"
# ββ Pre-flight: GPU info βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if command -v nvidia-smi >/dev/null 2>&1; then
echo "βββ nvidia-smi βββββββββββββββββββββββββββββββββββββββββββββββββββ"
nvidia-smi --query-gpu=name,memory.total,memory.free,driver_version \
--format=csv,noheader || true
echo "ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
fi
# ββ Confirm attention backend βββββββββββββββββββββββββββββββββββββββββββββββββ
python - <<'PYEOF'
import sys; sys.path.insert(0, '.')
from src.utils.attn_backend import select_attn_implementation
impl = select_attn_implementation()
tag = {
"flash_attention_2": "FAST β Flash-Attn 2 active (O(T) memory, ~1.5-2Γ faster)",
"sdpa": "OK β SDPA active (install flash-attn for ~2Γ speedup)",
"eager": "SLOW β Eager fallback (install flash-attn for best speed)",
}.get(impl, impl)
print(f"[launch] attn_backend = {tag}")
PYEOF
# ββ Log tee βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
RUN_NAME="grpo_$(date +%Y%m%d_%H%M%S)"
LOG_DIR="logs/grpo"
mkdir -p "$LOG_DIR"
LOG_FILE="$LOG_DIR/${RUN_NAME}.log"
echo "[launch] run_name = $RUN_NAME"
echo "[launch] base_model = checkpoints/dual_task_v1"
echo "[launch] train_data = data/sft/gsm8k_sft.jsonl + data/math/math_numeric.jsonl"
echo "[launch] eval_data = data/sft/gsm8k_test.jsonl"
echo "[launch] log_file = $LOG_FILE"
echo "[launch] architecture = Two-phase self-play (K_q=2, K=10, N=20)"
echo "[launch] fixes_applied = min-warmupβ12, selfplay-gt-threshβ0.65, kl-coefβ0.06,"
echo "[launch] math-ramp-startβ18, group-sizeβ10, num-itersβ60"
echo "[launch] wall-time β 3.3 h (Flash active) / 4.5 h (SDPA fallback)"
# ββ Train βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
python -u scripts/run_grpo_training.py \
--base-model checkpoints/dual_task_v1 \
--output-dir checkpoints/grpo \
--gsm8k-data data/sft/gsm8k_sft.jsonl \
--eval-data-path data/sft/gsm8k_test.jsonl \
\
--num-iterations 60 \
--group-size 10 \
--q-group-size 2 \
--questions-per-iter 20 \
\
--learning-rate 5e-6 \
--max-new-tokens 1000 \
--temperature 0.8 \
--max-grad-norm 0.5 \
--clip-eps 0.2 \
--kl-coef 0.06 \
--warmup-iters 8 \
--min-lr-ratio 0.1 \
\
--difficulty-alpha 3.5 \
--self-play-ratio 0.70 \
\
--math-mix-ratio 0.30 \
--math-mix-ratio-late 0.50 \
--math-ramp-start 18 \
--math-max-difficulty 3 \
\
--overlong-filter \
--min-warmup 12 \
--selfplay-gt-thresh 0.65 \
--selfplay-grounded-thresh 0.65 \
--selfplay-step-thresh 0.68 \
--selfplay-ramp-iters 28 \
--grounded-floor 0.55 \
\
--extractor-model Qwen/Qwen2.5-0.5B-Instruct \
--extraction-cache data/extraction_cache.json \
\
--eval-every 5 \
--eval-max-samples 150 \
--eval-max-new-tokens 1000 \
--eval-pass-at-k 0 \
--save-every 5 \
--keep-last 4 \
\
--use-prm \
--prm-model Qwen/Qwen2.5-Math-PRM-7B \
--run-name "$RUN_NAME" \
"$@" 2>&1 | tee "$LOG_FILE"
|