| #!/bin/bash |
| |
| |
| |
| |
|
|
| SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka" |
| WORKDIR="${SCRATCH}/training/BioRLHF" |
|
|
| echo "============================================================" |
| echo "BioGRPO Environment Setup" |
| echo "Working dir: $WORKDIR" |
| echo "============================================================" |
|
|
| cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; } |
|
|
| |
| source ~/.bashrc |
| conda activate biorlhf |
|
|
| |
| echo "" |
| echo "[1/6] Current package versions..." |
| python -c "import trl; print(f' TRL: {trl.__version__}')" |
| python -c "import peft; print(f' PEFT: {peft.__version__}')" |
| python -c "import transformers; print(f' Transformers: {transformers.__version__}')" |
| python -c "import torch; print(f' PyTorch: {torch.__version__}'); print(f' CUDA: {torch.cuda.is_available()}')" |
|
|
| |
| echo "" |
| echo "[2/6] Ensuring TRL >= 0.26.0..." |
| pip install "trl>=0.26.0" --upgrade --quiet |
|
|
| |
| echo "" |
| echo "[3/6] Verifying GRPO imports..." |
| python -c " |
| from trl import GRPOTrainer, GRPOConfig |
| print(' GRPOTrainer: OK') |
| print(' GRPOConfig: OK') |
| config = GRPOConfig(output_dir='/tmp/test', scale_rewards='group', loss_type='grpo') |
| print(f' scale_rewards={config.scale_rewards}, loss_type={config.loss_type}: OK') |
| " |
|
|
| |
| echo "" |
| echo "[4/6] Installing biorlhf package..." |
| pip install -e . --quiet 2>/dev/null || pip install -e . 2>&1 | tail -3 |
|
|
| |
| echo "" |
| echo "[5/6] Verifying biorlhf imports..." |
| python -c " |
| from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training |
| print(' BioGRPOConfig: OK') |
| from biorlhf.verifiers.composer import make_grpo_reward_function |
| print(' make_grpo_reward_function: OK') |
| from biorlhf.data.grpo_dataset import build_grpo_dataset |
| print(' build_grpo_dataset: OK') |
| from biorlhf.evaluation.calibration import compute_calibration_metrics |
| print(' compute_calibration_metrics: OK') |
| " |
|
|
| |
| echo "" |
| echo "[6/6] Running smoke test..." |
| python -c " |
| from biorlhf.verifiers.composer import make_grpo_reward_function |
| import json |
| reward_fn = make_grpo_reward_function(active_verifiers=['V1', 'V4']) |
| rewards = reward_fn( |
| completions=['Oxidative phosphorylation is upregulated. Confidence: high.'], |
| ground_truth=[json.dumps({ |
| 'pathway': 'HALLMARK_OXIDATIVE_PHOSPHORYLATION', |
| 'direction': 'UP', |
| 'expected_confidence': 'high', |
| })], |
| question_type=['direction'], |
| applicable_verifiers=[json.dumps(['V1', 'V4'])], |
| ) |
| print(f' Reward: {rewards[0]:.3f} (expected > 0.5)') |
| assert rewards[0] > 0.3, 'Reward too low' |
| print(' Smoke test: PASSED') |
| " |
|
|
| |
| mkdir -p logs configs results cache/transformers cache/huggingface wandb |
|
|
| |
| echo "" |
| echo "[6b/7] Setting up SFT checkpoint symlink..." |
| if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then |
| if [ -d "${SCRATCH}/training/biorlhf/kmp_sft_model_final" ]; then |
| ln -s "${SCRATCH}/training/biorlhf/kmp_sft_model_final" "${WORKDIR}/kmp_sft_model_final" |
| echo " Symlinked kmp_sft_model_final: OK" |
| else |
| echo " WARNING: kmp_sft_model_final not found at ${SCRATCH}/training/biorlhf/" |
| echo " You will need to provide the SFT checkpoint manually" |
| fi |
| else |
| echo " kmp_sft_model_final already exists: OK" |
| fi |
|
|
| |
| echo "" |
| echo "[7/7] Verifying data availability..." |
| export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark" |
| export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data" |
| export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm" |
| export BIOEVAL_ROOT="${SCRATCH}/data/BioEval" |
|
|
| for d in "$GENELAB_BASE" "$BIOEVAL_DATA" "$SPACEOMICS_DATA" "$BIOEVAL_ROOT"; do |
| if [ -d "$d" ]; then |
| echo " $d: OK" |
| else |
| echo " $d: MISSING" |
| fi |
| done |
|
|
| echo "" |
| echo "============================================================" |
| echo "BioGRPO setup complete!" |
| echo "" |
| echo "Next steps:" |
| echo " sbatch scripts/run_grpo_mve.sh" |
| echo " tail -f logs/grpo_mve_*.log" |
| echo "============================================================" |
|
|