File size: 6,958 Bytes
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
set -euo pipefail

# ── Flash-Attention 2 install (if missing) ────────────────────────────────────
# flash-attn requires (torch version, CUDA version, Python version) alignment.
# MAX_JOBS caps parallel compilation; prebuilt wheel installs in <30 s.
# In the prior run (grpo_20260425_151304), flash-attn was absent β†’ SDPA fallback
# β†’ iter times of 262-330 s once question-gen started (vs ~150 s with Flash).
if ! python -c "import flash_attn; assert int(flash_attn.__version__.split('.')[0]) >= 2" 2>/dev/null; then
    echo "[launch] flash-attn not found or < v2 β€” installing now …"
    MAX_JOBS=4 pip install flash-attn --no-build-isolation -q
    echo "[launch] flash-attn installed."
else
    FLASH_VER=$(python -c "import flash_attn; print(flash_attn.__version__)" 2>/dev/null)
    echo "[launch] flash-attn ${FLASH_VER} already installed β€” skipping install."
fi

# ── GPU / allocator ───────────────────────────────────────────────────────────
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
# expandable_segments: recovers 2-4 GB fragmented VRAM during long Flash+HF runs
export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}

# ── CPU / threading ───────────────────────────────────────────────────────────
export OMP_NUM_THREADS=${OMP_NUM_THREADS:-8}
export MKL_NUM_THREADS=${MKL_NUM_THREADS:-8}
export TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM:-false}

# ── Triton / Flash-Attn compilation cache ─────────────────────────────────────
# Persists JIT kernels across runs β€” avoids ~30 s recompile each launch.
export TRITON_CACHE_DIR=${TRITON_CACHE_DIR:-/tmp/triton_cache}
export FLASH_ATTENTION_SKIP_CUDA_BUILD=${FLASH_ATTENTION_SKIP_CUDA_BUILD:-FALSE}

# ── HuggingFace hub robustness ────────────────────────────────────────────────
export HF_HUB_DISABLE_XET=${HF_HUB_DISABLE_XET:-1}
export HF_HUB_ENABLE_HF_TRANSFER=${HF_HUB_ENABLE_HF_TRANSFER:-0}
export TRANSFORMERS_VERBOSITY=${TRANSFORMERS_VERBOSITY:-warning}

# ── Python path ───────────────────────────────────────────────────────────────
export PYTHONPATH="${PYTHONPATH:-}:$(pwd)"

# ── Pre-flight: GPU info ───────────────────────────────────────────────────────
if command -v nvidia-smi >/dev/null 2>&1; then
    echo "─── nvidia-smi ───────────────────────────────────────────────────"
    nvidia-smi --query-gpu=name,memory.total,memory.free,driver_version \
               --format=csv,noheader || true
    echo "──────────────────────────────────────────────────────────────────"
fi

# ── Confirm attention backend ─────────────────────────────────────────────────
python - <<'PYEOF'
import sys; sys.path.insert(0, '.')
from src.utils.attn_backend import select_attn_implementation
impl = select_attn_implementation()
tag = {
    "flash_attention_2": "FAST   β€” Flash-Attn 2 active (O(T) memory, ~1.5-2Γ— faster)",
    "sdpa":              "OK     β€” SDPA active (install flash-attn for ~2Γ— speedup)",
    "eager":             "SLOW   β€” Eager fallback (install flash-attn for best speed)",
}.get(impl, impl)
print(f"[launch] attn_backend = {tag}")
PYEOF

# ── Log tee ───────────────────────────────────────────────────────────────────
RUN_NAME="grpo_$(date +%Y%m%d_%H%M%S)"
LOG_DIR="logs/grpo"
mkdir -p "$LOG_DIR"
LOG_FILE="$LOG_DIR/${RUN_NAME}.log"

echo "[launch] run_name       = $RUN_NAME"
echo "[launch] base_model     = checkpoints/dual_task_v1"
echo "[launch] train_data     = data/sft/gsm8k_sft.jsonl + data/math/math_numeric.jsonl"
echo "[launch] eval_data      = data/sft/gsm8k_test.jsonl"
echo "[launch] log_file       = $LOG_FILE"
echo "[launch] architecture   = Two-phase self-play (K_q=2, K=10, N=20)"
echo "[launch] fixes_applied  = min-warmup↑12, selfplay-gt-thresh↑0.65, kl-coef↑0.06,"
echo "[launch]                  math-ramp-start↑18, group-size↑10, num-iters↑60"
echo "[launch] wall-time      β‰ˆ 3.3 h (Flash active) / 4.5 h (SDPA fallback)"

# ── Train ─────────────────────────────────────────────────────────────────────
python -u scripts/run_grpo_training.py \
    --base-model            checkpoints/dual_task_v1 \
    --output-dir            checkpoints/grpo \
    --gsm8k-data            data/sft/gsm8k_sft.jsonl \
    --eval-data-path        data/sft/gsm8k_test.jsonl \
    \
    --num-iterations        60 \
    --group-size            10 \
    --q-group-size          2 \
    --questions-per-iter    20 \
    \
    --learning-rate         5e-6 \
    --max-new-tokens        1000 \
    --temperature           0.8 \
    --max-grad-norm         0.5 \
    --clip-eps              0.2 \
    --kl-coef               0.06 \
    --warmup-iters          8 \
    --min-lr-ratio          0.1 \
    \
    --difficulty-alpha      3.5 \
    --self-play-ratio       0.70 \
    \
    --math-mix-ratio        0.30 \
    --math-mix-ratio-late   0.50 \
    --math-ramp-start       18 \
    --math-max-difficulty   3 \
    \
    --overlong-filter \
    --min-warmup            12 \
    --selfplay-gt-thresh    0.65 \
    --selfplay-grounded-thresh 0.65 \
    --selfplay-step-thresh  0.68 \
    --selfplay-ramp-iters   28 \
    --grounded-floor        0.55 \
    \
    --extractor-model       Qwen/Qwen2.5-0.5B-Instruct \
    --extraction-cache      data/extraction_cache.json \
    \
    --eval-every            5 \
    --eval-max-samples      150 \
    --eval-max-new-tokens   1000 \
    --eval-pass-at-k        0 \
    --save-every            5 \
    --keep-last             4 \
    \
    --use-prm \
    --prm-model             Qwen/Qwen2.5-Math-PRM-7B \
    --run-name              "$RUN_NAME" \
    "$@" 2>&1 | tee "$LOG_FILE"