jang1563 Claude Opus 4.6 commited on
Commit
bff2f94
·
1 Parent(s): c7ebaa1

Add BioGRPO training pipeline with composable biological verifiers

Browse files

Implements GRPO (Group Relative Policy Optimization) with four composable
biological verifiers (V1-V4) as reward functions:
- V1: Pathway direction verification against GeneLab fGSEA ground truth
- V2: Biological fact checking with keyword/entity validation
- V3: Cross-study consistency verification
- V4: Uncertainty calibration with ECE/Brier scoring

Key components:
- Verifier stack with weighted composition and per-question-type routing
- GRPO dataset builder merging GeneLab, BioEval, and SpaceOmicsBench sources
- GeneLab data loader with pathway enrichment score integration
- Calibration evaluation metrics (ECE, MCE, Brier, reliability diagrams)
- CLI entry point (biorlhf-grpo) with JSON config support
- SLURM scripts for Cayuga HPC (MVE and full experiment configs)
- Post-training evaluation script with SFT baseline comparison

Bug fixes applied during deployment:
- Tokenizer loading from base model (not adapter directory)
- LoRA adapter detection and SFT merge before GRPO training
- QLoRA 4-bit quantization support
- Lazy imports to avoid circular torch dependencies

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

configs/grpo_full.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "mistralai/Mistral-7B-v0.3",
3
+ "sft_model_path": "./kmp_sft_model_final",
4
+ "output_dir": "./biogrpo_full_model",
5
+
6
+ "num_generations": 8,
7
+ "beta": 0.04,
8
+ "num_iterations": 1,
9
+ "scale_rewards": "group",
10
+ "loss_type": "grpo",
11
+
12
+ "num_epochs": 2,
13
+ "batch_size": 1,
14
+ "gradient_accumulation_steps": 8,
15
+ "learning_rate": 5e-7,
16
+ "max_completion_length": 1024,
17
+ "max_prompt_length": 512,
18
+ "warmup_ratio": 0.1,
19
+
20
+ "lora_r": 32,
21
+ "lora_alpha": 64,
22
+ "lora_dropout": 0.05,
23
+
24
+ "active_verifiers": ["V1", "V2", "V3", "V4"],
25
+ "verifier_weights": {"V1": 0.35, "V2": 0.30, "V3": 0.15, "V4": 0.20},
26
+
27
+ "pathway_db": "hallmark",
28
+ "hold_out_tissues": ["eye", "thymus"],
29
+ "seed": 42,
30
+
31
+ "use_4bit": true,
32
+
33
+ "wandb_project": "biogrpo",
34
+ "wandb_run_name": "grpo_full_all_verifiers",
35
+ "use_wandb": true,
36
+ "logging_steps": 10,
37
+ "save_steps": 50,
38
+ "eval_steps": 50,
39
+ "save_total_limit": 3,
40
+ "log_completions": true,
41
+
42
+ "use_vllm": false,
43
+ "gradient_checkpointing": true,
44
+ "bf16": true
45
+ }
configs/grpo_mve.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "mistralai/Mistral-7B-v0.3",
3
+ "sft_model_path": "./kmp_sft_model_final",
4
+ "output_dir": "./biogrpo_mve_model",
5
+
6
+ "num_generations": 4,
7
+ "beta": 0.04,
8
+ "num_iterations": 1,
9
+ "scale_rewards": "group",
10
+ "loss_type": "grpo",
11
+
12
+ "num_epochs": 3,
13
+ "batch_size": 2,
14
+ "gradient_accumulation_steps": 4,
15
+ "learning_rate": 1e-6,
16
+ "max_completion_length": 512,
17
+ "max_prompt_length": 384,
18
+ "warmup_ratio": 0.1,
19
+
20
+ "lora_r": 32,
21
+ "lora_alpha": 64,
22
+ "lora_dropout": 0.05,
23
+
24
+ "active_verifiers": ["V1", "V4"],
25
+ "verifier_weights": {"V1": 0.6, "V4": 0.4},
26
+
27
+ "pathway_db": "hallmark",
28
+ "hold_out_tissues": ["eye"],
29
+ "seed": 42,
30
+
31
+ "use_4bit": true,
32
+
33
+ "wandb_project": "biogrpo",
34
+ "wandb_run_name": "grpo_mve_v1v4",
35
+ "use_wandb": true,
36
+ "logging_steps": 5,
37
+ "save_steps": 25,
38
+ "eval_steps": 25,
39
+ "save_total_limit": 3,
40
+ "log_completions": true,
41
+
42
+ "use_vllm": false,
43
+ "gradient_checkpointing": true,
44
+ "bf16": true
45
+ }
pyproject.toml CHANGED
@@ -43,7 +43,7 @@ dependencies = [
43
  "datasets>=2.14.0",
44
  "accelerate>=0.24.0",
45
  "peft>=0.6.0",
46
- "trl>=0.7.0",
47
  "bitsandbytes>=0.41.0",
48
  "wandb>=0.15.0",
49
  "pandas>=2.0.0",
@@ -76,6 +76,7 @@ Issues = "https://github.com/jang1563/BioRLHF/issues"
76
  [project.scripts]
77
  biorlhf-train = "biorlhf.cli:train"
78
  biorlhf-evaluate = "biorlhf.cli:evaluate"
 
79
 
80
  [tool.hatch.build.targets.sdist]
81
  include = [
 
43
  "datasets>=2.14.0",
44
  "accelerate>=0.24.0",
45
  "peft>=0.6.0",
46
+ "trl>=0.14.0",
47
  "bitsandbytes>=0.41.0",
48
  "wandb>=0.15.0",
49
  "pandas>=2.0.0",
 
76
  [project.scripts]
77
  biorlhf-train = "biorlhf.cli:train"
78
  biorlhf-evaluate = "biorlhf.cli:evaluate"
79
+ biorlhf-grpo = "biorlhf.cli:grpo_train"
80
 
81
  [tool.hatch.build.targets.sdist]
82
  include = [
scripts/HPC_TRAINING_GUIDE.md ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BioRLHF Training on Cayuga HPC (Interactive Session)
2
+
3
+ **Cluster:** Cornell Cayuga HPC
4
+ **Target:** GPU training with Mistral-7B + LoRA
5
+
6
+ ---
7
+
8
+ ## Quick Start (Copy-Paste Commands)
9
+
10
+ ```bash
11
+ # 1. Start interactive GPU session (A100 recommended, 80GB VRAM)
12
+ srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=4:00:00 --pty bash
13
+
14
+ # 2. Set up environment (first time only - see Step 2 below)
15
+
16
+ # 3. Run training
17
+ cd /athena/cayuga_XXXX/scratch/$USER/BioRLHF/biorlhf
18
+ ./scripts/train_ecosystem_improved.sh
19
+ ```
20
+
21
+ ---
22
+
23
+ ## Step 1: Transfer Files to HPC
24
+
25
+ From your local Mac:
26
+
27
+ ```bash
28
+ # Replace with your actual paths and CWID
29
+ rsync -avz --progress \
30
+ /Users/jak4013/Dropbox/Bioinformatics/Claude/BioRLHF \
31
+ YOUR_CWID@cayuga.cac.cornell.edu:/athena/cayuga_XXXX/scratch/$USER/
32
+ ```
33
+
34
+ Or use scp:
35
+ ```bash
36
+ scp -r /Users/jak4013/Dropbox/Bioinformatics/Claude/BioRLHF \
37
+ YOUR_CWID@cayuga.cac.cornell.edu:/athena/cayuga_XXXX/scratch/$USER/
38
+ ```
39
+
40
+ ---
41
+
42
+ ## Step 2: Set Up Conda Environment (First Time Only)
43
+
44
+ ### 2a. Start Interactive Session
45
+ ```bash
46
+ # SSH to Cayuga
47
+ ssh YOUR_CWID@cayuga.cac.cornell.edu
48
+
49
+ # Request interactive GPU session
50
+ srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=2:00:00 --pty bash
51
+ ```
52
+
53
+ ### 2b. Install Miniconda (if not already installed)
54
+ ```bash
55
+ # Create directory in scratch space
56
+ mkdir -p /athena/cayuga_XXXX/scratch/$USER/miniconda3
57
+
58
+ # Download and install
59
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
60
+ bash miniconda.sh -b -u -p /athena/cayuga_XXXX/scratch/$USER/miniconda3
61
+ rm miniconda.sh
62
+
63
+ # Initialize conda
64
+ source /athena/cayuga_XXXX/scratch/$USER/miniconda3/bin/activate
65
+ conda init bash
66
+ source ~/.bashrc
67
+ ```
68
+
69
+ ### 2c. Create BioRLHF Environment
70
+ ```bash
71
+ # Create environment with Python 3.10 (best compatibility)
72
+ conda create -n biorlhf python=3.10 -y
73
+ conda activate biorlhf
74
+
75
+ # Install PyTorch with CUDA support
76
+ conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia -y
77
+
78
+ # Install training dependencies
79
+ pip install transformers>=4.36.0
80
+ pip install peft>=0.7.0
81
+ pip install trl>=0.7.0
82
+ pip install bitsandbytes>=0.41.0
83
+ pip install accelerate>=0.25.0
84
+ pip install datasets>=2.14.0
85
+ pip install wandb
86
+ pip install scipy
87
+ pip install sentencepiece
88
+
89
+ # Verify GPU access
90
+ python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}'); print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')"
91
+ ```
92
+
93
+ ---
94
+
95
+ ## Step 3: Run Training (Interactive)
96
+
97
+ ### 3a. Start GPU Session
98
+ ```bash
99
+ # Request A100 GPU (80GB - best for Mistral-7B)
100
+ srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=4:00:00 --pty bash
101
+
102
+ # Or use A40 (48GB - also works with 4-bit quantization)
103
+ srun -p scu-gpu --gres=gpu:a40:1 --mem=48G -c 8 --time=4:00:00 --pty bash
104
+ ```
105
+
106
+ ### 3b. Activate Environment and Run
107
+ ```bash
108
+ # Activate conda
109
+ source /athena/cayuga_XXXX/scratch/$USER/miniconda3/bin/activate
110
+ conda activate biorlhf
111
+
112
+ # Navigate to BioRLHF
113
+ cd /athena/cayuga_XXXX/scratch/$USER/BioRLHF/biorlhf
114
+
115
+ # Check GPU is available
116
+ nvidia-smi
117
+
118
+ # Set HuggingFace cache (optional - saves space)
119
+ export HF_HOME=/athena/cayuga_XXXX/scratch/$USER/.cache/huggingface
120
+
121
+ # Run training
122
+ ./scripts/train_ecosystem_improved.sh
123
+ ```
124
+
125
+ ---
126
+
127
+ ## Step 4: Monitor Training
128
+
129
+ In a separate terminal (or use tmux/screen):
130
+
131
+ ```bash
132
+ # Watch GPU usage
133
+ watch -n 1 nvidia-smi
134
+
135
+ # Tail training logs
136
+ tail -f logs/biorlhf_ecosystem_*.out
137
+ ```
138
+
139
+ ### Using WandB (Optional)
140
+ ```bash
141
+ # Login to Weights & Biases
142
+ wandb login
143
+
144
+ # Training will automatically log to: https://wandb.ai/YOUR_USERNAME/biorlhf
145
+ ```
146
+
147
+ ---
148
+
149
+ ## GPU Options on Cayuga
150
+
151
+ | GPU Type | VRAM | Recommended For | Command |
152
+ |----------|------|-----------------|---------|
153
+ | A100 | 80GB | Full training, larger batches | `--gres=gpu:a100:1` |
154
+ | A40 | 48GB | Standard training with 4-bit | `--gres=gpu:a40:1` |
155
+ | H100 | 80GB | Fastest (if available) | `--gres=gpu:h100:1` |
156
+
157
+ ---
158
+
159
+ ## Troubleshooting
160
+
161
+ ### "CUDA out of memory"
162
+ Reduce batch size in training script:
163
+ ```bash
164
+ # Edit train_ecosystem_improved.sh
165
+ BATCH_SIZE=2 # Reduce from 4
166
+ GRAD_ACCUM=8 # Increase to maintain effective batch size
167
+ ```
168
+
169
+ ### "No GPU available"
170
+ ```bash
171
+ # Check GPU allocation
172
+ nvidia-smi
173
+
174
+ # Verify CUDA installation
175
+ python -c "import torch; print(torch.cuda.is_available())"
176
+
177
+ # Check if you're on a GPU node
178
+ squeue -u $USER
179
+ ```
180
+
181
+ ### "Module not found"
182
+ ```bash
183
+ # Ensure conda environment is activated
184
+ conda activate biorlhf
185
+
186
+ # Reinstall missing package
187
+ pip install <missing_package>
188
+ ```
189
+
190
+ ### Interactive session times out
191
+ Use `tmux` or `screen` to persist sessions:
192
+ ```bash
193
+ # Start tmux before srun
194
+ tmux new -s training
195
+
196
+ # Then request GPU
197
+ srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=4:00:00 --pty bash
198
+
199
+ # Detach: Ctrl+B, then D
200
+ # Reattach: tmux attach -t training
201
+ ```
202
+
203
+ ---
204
+
205
+ ## Expected Training Time
206
+
207
+ | Configuration | Dataset Size | Estimated Time |
208
+ |--------------|--------------|----------------|
209
+ | A100 + 4-bit | 378 examples, 10 epochs | ~45-60 min |
210
+ | A40 + 4-bit | 378 examples, 10 epochs | ~60-90 min |
211
+ | A100 (full) | 378 examples, 10 epochs | ~90-120 min |
212
+
213
+ ---
214
+
215
+ ## After Training
216
+
217
+ ### Copy model back to local machine:
218
+ ```bash
219
+ # From your Mac
220
+ scp -r YOUR_CWID@cayuga.cac.cornell.edu:/athena/cayuga_XXXX/scratch/$USER/BioRLHF/biorlhf/ecosystem_improved_model \
221
+ /Users/jak4013/Dropbox/Bioinformatics/Claude/BioRLHF/biorlhf/
222
+ ```
223
+
224
+ ### Run evaluation:
225
+ ```bash
226
+ python evaluate_model.py --model ecosystem_improved_model
227
+ ```
228
+
229
+ ---
230
+
231
+ ## Complete Interactive Session Example
232
+
233
+ ```bash
234
+ # SSH to Cayuga
235
+ ssh jk2042@cayuga.cac.cornell.edu
236
+
237
+ # Start tmux (optional but recommended)
238
+ tmux new -s biorlhf
239
+
240
+ # Request GPU
241
+ srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=4:00:00 --pty bash
242
+
243
+ # Set up environment
244
+ source ~/miniconda3/bin/activate
245
+ conda activate biorlhf
246
+
247
+ # Navigate and run
248
+ cd /athena/cayuga_XXXX/scratch/$USER/BioRLHF/biorlhf
249
+ ./scripts/train_ecosystem_improved.sh
250
+
251
+ # Watch progress (in another terminal or after Ctrl+B, c for new window)
252
+ watch -n 5 nvidia-smi
253
+ ```
scripts/deploy_to_cayuga.sh ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================
3
+ # Deploy BioRLHF code + data to Cayuga HPC
4
+ # Run from local Mac
5
+ # ============================================================
6
+
7
+ set -e
8
+
9
+ REMOTE="cayuga-login1"
10
+ SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
11
+ LOCAL_BASE="$HOME/Dropbox/Bioinformatics/Claude"
12
+
13
+ echo "============================================================"
14
+ echo "BioRLHF Cayuga Deployment"
15
+ echo "============================================================"
16
+
17
+ # Step 1: Create directories on Cayuga
18
+ echo ""
19
+ echo "[1/4] Creating directories on Cayuga..."
20
+ ssh ${REMOTE} "mkdir -p ${SCRATCH}/training/BioRLHF ${SCRATCH}/data/GeneLab_benchmark ${SCRATCH}/data/BioEval ${SCRATCH}/data/SpaceOmicsBench/v3/evaluation"
21
+
22
+ # Step 2: Transfer BioRLHF code (only essential files)
23
+ echo ""
24
+ echo "[2/4] Transferring BioRLHF code..."
25
+ LOCAL_BIORLHF="${LOCAL_BASE}/BioRLHF/biorlhf"
26
+ DEST="${REMOTE}:${SCRATCH}/training/BioRLHF"
27
+
28
+ # Transfer only the package structure needed for GRPO
29
+ rsync -avz --progress \
30
+ "${LOCAL_BIORLHF}/src/" \
31
+ ${DEST}/src/
32
+
33
+ rsync -avz --progress \
34
+ "${LOCAL_BIORLHF}/configs/" \
35
+ ${DEST}/configs/
36
+
37
+ rsync -avz --progress \
38
+ "${LOCAL_BIORLHF}/scripts/" \
39
+ ${DEST}/scripts/
40
+
41
+ rsync -avz --progress \
42
+ "${LOCAL_BIORLHF}/tests/" \
43
+ ${DEST}/tests/
44
+
45
+ rsync -avz --progress \
46
+ "${LOCAL_BIORLHF}/pyproject.toml" \
47
+ "${LOCAL_BIORLHF}/README.md" \
48
+ ${DEST}/
49
+
50
+ # Step 3: Transfer data (only what GRPO training needs)
51
+ echo ""
52
+ echo "[3/4] Transferring data..."
53
+
54
+ echo " GeneLab fgsea (pathway enrichment scores - required)..."
55
+ rsync -avz --progress \
56
+ "${LOCAL_BASE}/GeneLab_benchmark/processed/fgsea/" \
57
+ ${REMOTE}:${SCRATCH}/data/GeneLab_benchmark/processed/fgsea/
58
+
59
+ echo " GeneLab evaluation (NES conservation - for conservation questions)..."
60
+ rsync -avz --progress \
61
+ "${LOCAL_BASE}/GeneLab_benchmark/evaluation/" \
62
+ ${REMOTE}:${SCRATCH}/data/GeneLab_benchmark/evaluation/
63
+
64
+ echo " BioEval data..."
65
+ rsync -avz --progress \
66
+ "${LOCAL_BASE}/Evaluation_model/BioEval/data/" \
67
+ ${REMOTE}:${SCRATCH}/data/BioEval/data/
68
+
69
+ echo " BioEval scoring (for calibration imports)..."
70
+ rsync -avz --progress \
71
+ "${LOCAL_BASE}/Evaluation_model/BioEval/bioeval/" \
72
+ ${REMOTE}:${SCRATCH}/data/BioEval/bioeval/
73
+
74
+ echo " SpaceOmicsBench..."
75
+ rsync -avz --progress \
76
+ "${LOCAL_BASE}/SpaceOmicsBench/v3/evaluation/llm/" \
77
+ ${REMOTE}:${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm/
78
+
79
+ # Step 4: Verify
80
+ echo ""
81
+ echo "[4/4] Verifying deployment..."
82
+ ssh ${REMOTE} "
83
+ echo 'Directory structure:'
84
+ echo ' BioRLHF code:'
85
+ ls ${SCRATCH}/training/BioRLHF/pyproject.toml 2>/dev/null && echo ' pyproject.toml: OK' || echo ' pyproject.toml: MISSING'
86
+ ls ${SCRATCH}/training/BioRLHF/configs/grpo_mve.json 2>/dev/null && echo ' configs/grpo_mve.json: OK' || echo ' configs/grpo_mve.json: MISSING'
87
+ ls -d ${SCRATCH}/training/BioRLHF/src/biorlhf/ 2>/dev/null && echo ' src/biorlhf/: OK' || echo ' src/biorlhf/: MISSING'
88
+
89
+ echo ' SFT checkpoint:'
90
+ ls -d ${SCRATCH}/training/biorlhf/kmp_sft_model_final/ 2>/dev/null && echo ' kmp_sft_model_final: OK' || echo ' kmp_sft_model_final: MISSING'
91
+
92
+ echo ' Data:'
93
+ ls ${SCRATCH}/data/GeneLab_benchmark/processed/fgsea/ 2>/dev/null | head -3 && echo ' GeneLab fgsea: OK' || echo ' GeneLab fgsea: MISSING'
94
+ ls ${SCRATCH}/data/GeneLab_benchmark/evaluation/ 2>/dev/null | head -3 && echo ' GeneLab evaluation: OK' || echo ' GeneLab evaluation: MISSING'
95
+ ls ${SCRATCH}/data/BioEval/data/ 2>/dev/null | head -3 && echo ' BioEval: OK' || echo ' BioEval: MISSING'
96
+ ls ${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm/ 2>/dev/null | head -3 && echo ' SpaceOmicsBench: OK' || echo ' SpaceOmicsBench: MISSING'
97
+ "
98
+
99
+ echo ""
100
+ echo "============================================================"
101
+ echo "Deployment complete!"
102
+ echo ""
103
+ echo "Next steps on Cayuga:"
104
+ echo " ssh ${REMOTE}"
105
+ echo " cd ${SCRATCH}/training/BioRLHF"
106
+ echo " bash scripts/setup_cayuga_grpo.sh"
107
+ echo " sbatch scripts/run_grpo_mve.sh"
108
+ echo "============================================================"
scripts/evaluate_ecosystem_model.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluate Ecosystem-Improved Model on Failure Patterns
4
+
5
+ This script evaluates the fine-tuned model specifically on the patterns
6
+ it was trained to improve:
7
+ - Calibration (uncertainty expression)
8
+ - Adversarial resistance
9
+ - Protocol completeness
10
+ - Fact recall
11
+
12
+ Usage (on HPC with GPU):
13
+ python scripts/evaluate_ecosystem_model.py --model ./ecosystem_improved_model
14
+
15
+ Requirements:
16
+ - CUDA GPU
17
+ - transformers, peft, bitsandbytes, torch
18
+ """
19
+
20
+ import argparse
21
+ import json
22
+ import torch
23
+ from pathlib import Path
24
+ from datetime import datetime
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
26
+ from peft import PeftModel
27
+
28
+
29
+ def load_model(model_path: str, base_model: str = "mistralai/Mistral-7B-v0.3", use_4bit: bool = True):
30
+ """Load the fine-tuned model with LoRA adapters."""
31
+ print(f"Loading base model: {base_model}")
32
+
33
+ if use_4bit:
34
+ bnb_config = BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_quant_type="nf4",
37
+ bnb_4bit_compute_dtype=torch.float16,
38
+ bnb_4bit_use_double_quant=True,
39
+ )
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ base_model,
42
+ quantization_config=bnb_config,
43
+ device_map="auto",
44
+ trust_remote_code=True,
45
+ )
46
+ else:
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ base_model,
49
+ device_map="auto",
50
+ torch_dtype=torch.float16,
51
+ trust_remote_code=True,
52
+ )
53
+
54
+ print(f"Loading LoRA adapters from: {model_path}")
55
+ model = PeftModel.from_pretrained(model, model_path)
56
+
57
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
58
+ if tokenizer.pad_token is None:
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+
61
+ return model, tokenizer
62
+
63
+
64
+ def generate_response(model, tokenizer, prompt: str, max_new_tokens: int = 512) -> str:
65
+ """Generate a response from the model."""
66
+ formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
67
+
68
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
69
+
70
+ with torch.no_grad():
71
+ outputs = model.generate(
72
+ **inputs,
73
+ max_new_tokens=max_new_tokens,
74
+ temperature=0.1,
75
+ top_p=0.9,
76
+ do_sample=True,
77
+ pad_token_id=tokenizer.pad_token_id,
78
+ )
79
+
80
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+ # Extract just the response part
82
+ if "### Response:" in response:
83
+ response = response.split("### Response:")[-1].strip()
84
+
85
+ return response
86
+
87
+
88
+ def evaluate_calibration(model, tokenizer, test_cases: list) -> dict:
89
+ """Evaluate uncertainty calibration - does model admit when it doesn't know?"""
90
+ print("\n" + "="*60)
91
+ print("EVALUATING: Calibration (Uncertainty Expression)")
92
+ print("="*60)
93
+
94
+ results = []
95
+ calibration_phrases = [
96
+ "cannot determine",
97
+ "not provided",
98
+ "not included",
99
+ "would need",
100
+ "don't have",
101
+ "doesn't include",
102
+ "cannot say",
103
+ "uncertain",
104
+ "not specified",
105
+ "additional information",
106
+ "not in the data",
107
+ ]
108
+
109
+ for i, case in enumerate(test_cases):
110
+ question = case["conversations"][0]["content"]
111
+ expected = case["conversations"][1]["content"]
112
+
113
+ print(f"\n[{i+1}] Testing: {question[:60]}...")
114
+ response = generate_response(model, tokenizer, question)
115
+
116
+ # Check if response expresses appropriate uncertainty
117
+ expresses_uncertainty = any(phrase.lower() in response.lower() for phrase in calibration_phrases)
118
+
119
+ result = {
120
+ "id": case["id"],
121
+ "question": question,
122
+ "expected_behavior": "Express uncertainty about unavailable data",
123
+ "response": response[:500],
124
+ "expresses_uncertainty": expresses_uncertainty,
125
+ "passed": expresses_uncertainty
126
+ }
127
+ results.append(result)
128
+
129
+ status = "PASS" if expresses_uncertainty else "FAIL"
130
+ print(f" [{status}] Uncertainty expressed: {expresses_uncertainty}")
131
+
132
+ passed = sum(1 for r in results if r["passed"])
133
+ total = len(results)
134
+
135
+ return {
136
+ "category": "calibration",
137
+ "passed": passed,
138
+ "total": total,
139
+ "accuracy": passed / total if total > 0 else 0,
140
+ "details": results
141
+ }
142
+
143
+
144
+ def evaluate_adversarial(model, tokenizer, test_cases: list) -> dict:
145
+ """Evaluate adversarial resistance - does model reject false claims?"""
146
+ print("\n" + "="*60)
147
+ print("EVALUATING: Adversarial Resistance")
148
+ print("="*60)
149
+
150
+ results = []
151
+ resistance_phrases = [
152
+ "not supported",
153
+ "contradicted",
154
+ "incorrect",
155
+ "need to correct",
156
+ "not true",
157
+ "actually shows",
158
+ "data shows",
159
+ "opposite",
160
+ "challenge that premise",
161
+ "misrepresent",
162
+ ]
163
+
164
+ for i, case in enumerate(test_cases):
165
+ question = case["conversations"][0]["content"]
166
+ expected = case["conversations"][1]["content"]
167
+
168
+ print(f"\n[{i+1}] Testing adversarial: {question[:60]}...")
169
+ response = generate_response(model, tokenizer, question)
170
+
171
+ # Check if response resists the false claim
172
+ resists_claim = any(phrase.lower() in response.lower() for phrase in resistance_phrases)
173
+
174
+ result = {
175
+ "id": case["id"],
176
+ "question": question,
177
+ "expected_behavior": "Reject false premise with evidence",
178
+ "response": response[:500],
179
+ "resists_false_claim": resists_claim,
180
+ "passed": resists_claim
181
+ }
182
+ results.append(result)
183
+
184
+ status = "PASS" if resists_claim else "FAIL"
185
+ print(f" [{status}] Resisted false claim: {resists_claim}")
186
+
187
+ passed = sum(1 for r in results if r["passed"])
188
+ total = len(results)
189
+
190
+ return {
191
+ "category": "adversarial_resistance",
192
+ "passed": passed,
193
+ "total": total,
194
+ "accuracy": passed / total if total > 0 else 0,
195
+ "details": results
196
+ }
197
+
198
+
199
+ def evaluate_completeness(model, tokenizer, test_cases: list) -> dict:
200
+ """Evaluate protocol completeness - does model detect all missing steps?"""
201
+ print("\n" + "="*60)
202
+ print("EVALUATING: Protocol Completeness")
203
+ print("="*60)
204
+
205
+ results = []
206
+
207
+ # Key missing steps that should be detected
208
+ key_steps = {
209
+ "comp_001": ["dnase", "reverse transcription", "rt", "cdna"],
210
+ "comp_002": ["transfer", "blot", "membrane transfer"]
211
+ }
212
+
213
+ for i, case in enumerate(test_cases):
214
+ question = case["conversations"][0]["content"]
215
+ expected = case["conversations"][1]["content"]
216
+ case_id = case["id"]
217
+
218
+ print(f"\n[{i+1}] Testing completeness: {case_id}...")
219
+ response = generate_response(model, tokenizer, question, max_new_tokens=800)
220
+
221
+ # Check if key missing steps are detected
222
+ expected_steps = key_steps.get(case_id, [])
223
+ response_lower = response.lower()
224
+ detected = [step for step in expected_steps if step in response_lower]
225
+ detection_rate = len(detected) / len(expected_steps) if expected_steps else 0
226
+
227
+ result = {
228
+ "id": case_id,
229
+ "question": question[:100],
230
+ "expected_steps": expected_steps,
231
+ "detected_steps": detected,
232
+ "response": response[:600],
233
+ "detection_rate": detection_rate,
234
+ "passed": detection_rate >= 0.5 # Pass if at least half detected
235
+ }
236
+ results.append(result)
237
+
238
+ status = "PASS" if result["passed"] else "FAIL"
239
+ print(f" [{status}] Detected {len(detected)}/{len(expected_steps)} key steps")
240
+
241
+ passed = sum(1 for r in results if r["passed"])
242
+ total = len(results)
243
+
244
+ return {
245
+ "category": "protocol_completeness",
246
+ "passed": passed,
247
+ "total": total,
248
+ "accuracy": passed / total if total > 0 else 0,
249
+ "details": results
250
+ }
251
+
252
+
253
+ def evaluate_fact_recall(model, tokenizer, test_cases: list) -> dict:
254
+ """Evaluate fact recall - does model remember key trained facts?"""
255
+ print("\n" + "="*60)
256
+ print("EVALUATING: Fact Recall")
257
+ print("="*60)
258
+
259
+ results = []
260
+
261
+ # Key facts and values that should be recalled
262
+ key_facts = {
263
+ "fact_001": ["52%", "52 percent"],
264
+ "fact_002": ["52%", "52 percent"],
265
+ "fact_003": ["52%", "8%"],
266
+ "fact_004": ["-1.60", "-1.6", "suppressed", "suppression"],
267
+ "fact_005": ["liver", "-1.60", "-1.6", "opposite"]
268
+ }
269
+
270
+ for i, case in enumerate(test_cases):
271
+ question = case["conversations"][0]["content"]
272
+ expected = case["conversations"][1]["content"]
273
+ case_id = case["id"]
274
+
275
+ print(f"\n[{i+1}] Testing fact recall: {case_id}...")
276
+ response = generate_response(model, tokenizer, question)
277
+
278
+ # Check if key facts are present
279
+ expected_facts = key_facts.get(case_id, [])
280
+ response_lower = response.lower()
281
+ recalled = [fact for fact in expected_facts if fact.lower() in response_lower]
282
+ recall_rate = len(recalled) / len(expected_facts) if expected_facts else 0
283
+
284
+ result = {
285
+ "id": case_id,
286
+ "question": question,
287
+ "expected_facts": expected_facts,
288
+ "recalled_facts": recalled,
289
+ "response": response[:400],
290
+ "recall_rate": recall_rate,
291
+ "passed": recall_rate >= 0.5 # Pass if at least half recalled
292
+ }
293
+ results.append(result)
294
+
295
+ status = "PASS" if result["passed"] else "FAIL"
296
+ print(f" [{status}] Recalled {len(recalled)}/{len(expected_facts)} key facts")
297
+
298
+ passed = sum(1 for r in results if r["passed"])
299
+ total = len(results)
300
+
301
+ return {
302
+ "category": "fact_recall",
303
+ "passed": passed,
304
+ "total": total,
305
+ "accuracy": passed / total if total > 0 else 0,
306
+ "details": results
307
+ }
308
+
309
+
310
+ def main():
311
+ parser = argparse.ArgumentParser(description="Evaluate ecosystem-improved model")
312
+ parser.add_argument("--model", type=str, default="./ecosystem_improved_model",
313
+ help="Path to the fine-tuned model")
314
+ parser.add_argument("--base-model", type=str, default="mistralai/Mistral-7B-v0.3",
315
+ help="Base model name")
316
+ parser.add_argument("--test-data", type=str, default="data/ecosystem_failures_training.json",
317
+ help="Path to test data JSON")
318
+ parser.add_argument("--output", type=str, default=None,
319
+ help="Output path for results JSON")
320
+ parser.add_argument("--no-4bit", action="store_true",
321
+ help="Disable 4-bit quantization")
322
+
323
+ args = parser.parse_args()
324
+
325
+ print("="*60)
326
+ print("BioRLHF Ecosystem Model Evaluation")
327
+ print("="*60)
328
+ print(f"Model: {args.model}")
329
+ print(f"Base: {args.base_model}")
330
+ print(f"Test data: {args.test_data}")
331
+ print(f"Time: {datetime.now().isoformat()}")
332
+ print("="*60)
333
+
334
+ # Load test data
335
+ with open(args.test_data, 'r') as f:
336
+ test_data = json.load(f)
337
+
338
+ # Load model
339
+ model, tokenizer = load_model(args.model, args.base_model, use_4bit=not args.no_4bit)
340
+ print("\nModel loaded successfully!\n")
341
+
342
+ # Run evaluations
343
+ results = {}
344
+
345
+ # 1. Calibration
346
+ if test_data.get("calibration_examples"):
347
+ results["calibration"] = evaluate_calibration(
348
+ model, tokenizer, test_data["calibration_examples"]
349
+ )
350
+
351
+ # 2. Adversarial resistance
352
+ if test_data.get("adversarial_resistance_examples"):
353
+ results["adversarial"] = evaluate_adversarial(
354
+ model, tokenizer, test_data["adversarial_resistance_examples"]
355
+ )
356
+
357
+ # 3. Protocol completeness
358
+ if test_data.get("completeness_examples"):
359
+ results["completeness"] = evaluate_completeness(
360
+ model, tokenizer, test_data["completeness_examples"]
361
+ )
362
+
363
+ # 4. Fact recall
364
+ if test_data.get("fact_drilling_examples"):
365
+ results["fact_recall"] = evaluate_fact_recall(
366
+ model, tokenizer, test_data["fact_drilling_examples"]
367
+ )
368
+
369
+ # Summary
370
+ print("\n" + "="*60)
371
+ print("EVALUATION SUMMARY")
372
+ print("="*60)
373
+
374
+ total_passed = 0
375
+ total_tests = 0
376
+
377
+ for category, data in results.items():
378
+ print(f"\n{category.upper()}:")
379
+ print(f" Passed: {data['passed']}/{data['total']} ({data['accuracy']:.1%})")
380
+ total_passed += data['passed']
381
+ total_tests += data['total']
382
+
383
+ overall_accuracy = total_passed / total_tests if total_tests > 0 else 0
384
+
385
+ print("\n" + "-"*60)
386
+ print(f"OVERALL: {total_passed}/{total_tests} ({overall_accuracy:.1%})")
387
+ print("-"*60)
388
+
389
+ # Save results
390
+ output_path = args.output or f"ecosystem_eval_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
391
+
392
+ output_data = {
393
+ "model_path": args.model,
394
+ "base_model": args.base_model,
395
+ "evaluation_date": datetime.now().isoformat(),
396
+ "overall_accuracy": overall_accuracy,
397
+ "total_passed": total_passed,
398
+ "total_tests": total_tests,
399
+ "results": results
400
+ }
401
+
402
+ with open(output_path, 'w') as f:
403
+ json.dump(output_data, f, indent=2)
404
+
405
+ print(f"\nResults saved to: {output_path}")
406
+ print("\n" + "="*60)
407
+ print("Evaluation complete!")
408
+ print("="*60)
409
+
410
+
411
+ if __name__ == "__main__":
412
+ main()
scripts/evaluate_grpo.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BioGRPO Post-Training Evaluation Script
4
+
5
+ Evaluates a GRPO-trained model against:
6
+ 1. Held-out GeneLab questions (LOMO: Leave-One-Mission-Out)
7
+ 2. Calibration metrics (ECE, Brier, overconfidence rate)
8
+ 3. Per-verifier reward scores
9
+ 4. Baseline comparison (SFT, DPO)
10
+
11
+ Usage:
12
+ python scripts/evaluate_grpo.py \
13
+ --model ./biogrpo_mve_model \
14
+ --sft-baseline ./kmp_sft_model_final \
15
+ --hold-out-tissues eye \
16
+ --output results/grpo_mve_eval.json
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import torch
22
+ from pathlib import Path
23
+ from datetime import datetime
24
+ from typing import Dict, List, Optional
25
+
26
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
27
+ from peft import PeftModel
28
+ from tqdm import tqdm
29
+
30
+ from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
31
+ from biorlhf.verifiers.composer import VerifierComposer
32
+ from biorlhf.verifiers.uncertainty import _extract_confidence_simple
33
+ from biorlhf.evaluation.calibration import compute_calibration_metrics
34
+
35
+
36
+ def load_model(
37
+ model_path: str,
38
+ base_model: str = "mistralai/Mistral-7B-v0.3",
39
+ use_4bit: bool = True,
40
+ ):
41
+ """Load a fine-tuned model with LoRA adapters."""
42
+ print(f" Base model: {base_model}")
43
+ print(f" Adapter: {model_path}")
44
+
45
+ if use_4bit:
46
+ bnb_config = BitsAndBytesConfig(
47
+ load_in_4bit=True,
48
+ bnb_4bit_quant_type="nf4",
49
+ bnb_4bit_compute_dtype=torch.bfloat16,
50
+ bnb_4bit_use_double_quant=True,
51
+ )
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ base_model,
54
+ quantization_config=bnb_config,
55
+ device_map="auto",
56
+ torch_dtype=torch.bfloat16,
57
+ trust_remote_code=True,
58
+ )
59
+ else:
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ base_model,
62
+ device_map="auto",
63
+ torch_dtype=torch.bfloat16,
64
+ trust_remote_code=True,
65
+ )
66
+
67
+ model = PeftModel.from_pretrained(model, model_path)
68
+
69
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
70
+ if tokenizer.pad_token is None:
71
+ tokenizer.pad_token = tokenizer.eos_token
72
+
73
+ return model, tokenizer
74
+
75
+
76
+ def generate_response(
77
+ model,
78
+ tokenizer,
79
+ prompt: str,
80
+ max_new_tokens: int = 512,
81
+ temperature: float = 0.1,
82
+ ) -> str:
83
+ """Generate a response from the model."""
84
+ formatted = f"### Instruction:\n{prompt}\n\n### Response:\n"
85
+ inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
86
+
87
+ with torch.no_grad():
88
+ outputs = model.generate(
89
+ **inputs,
90
+ max_new_tokens=max_new_tokens,
91
+ temperature=temperature,
92
+ top_p=0.9,
93
+ do_sample=temperature > 0,
94
+ pad_token_id=tokenizer.pad_token_id,
95
+ )
96
+
97
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+ if "### Response:" in response:
99
+ response = response.split("### Response:")[-1].strip()
100
+ return response
101
+
102
+
103
+ def evaluate_with_verifiers(
104
+ model,
105
+ tokenizer,
106
+ eval_dataset,
107
+ composer: VerifierComposer,
108
+ max_samples: Optional[int] = None,
109
+ ) -> Dict:
110
+ """Evaluate model using the verifier stack.
111
+
112
+ Returns per-sample results and aggregated metrics.
113
+ """
114
+ results = []
115
+ n = len(eval_dataset)
116
+ if max_samples:
117
+ n = min(n, max_samples)
118
+
119
+ for i in tqdm(range(n), desc="Evaluating"):
120
+ sample = eval_dataset[i]
121
+ prompt = sample["prompt"]
122
+ gt = sample["ground_truth"]
123
+ qtype = sample["question_type"]
124
+ applicable = sample["applicable_verifiers"]
125
+
126
+ response = generate_response(model, tokenizer, prompt)
127
+
128
+ reward = composer.compute_reward(
129
+ prompt=prompt,
130
+ completion=response,
131
+ ground_truth=gt,
132
+ question_type=qtype,
133
+ applicable_verifiers=applicable,
134
+ )
135
+
136
+ # Extract confidence for calibration
137
+ conf = _extract_confidence_simple(response)
138
+
139
+ results.append({
140
+ "prompt": prompt[:100],
141
+ "response": response[:300],
142
+ "total_reward": reward.total_reward,
143
+ "verifier_scores": reward.verifier_scores,
144
+ "question_type": qtype,
145
+ "source": sample.get("source", "unknown"),
146
+ "tissue": sample.get("tissue", "unknown"),
147
+ "confidence": conf.numeric,
148
+ "confidence_stated": conf.stated,
149
+ })
150
+
151
+ # Aggregate metrics
152
+ total_rewards = [r["total_reward"] for r in results]
153
+ per_verifier: Dict[str, List[float]] = {}
154
+ for r in results:
155
+ for v, s in r["verifier_scores"].items():
156
+ per_verifier.setdefault(v, []).append(s)
157
+
158
+ verifier_means = {v: sum(s) / len(s) for v, s in per_verifier.items()}
159
+
160
+ # Per question type
161
+ by_type: Dict[str, List[float]] = {}
162
+ for r in results:
163
+ by_type.setdefault(r["question_type"], []).append(r["total_reward"])
164
+ type_means = {t: sum(s) / len(s) for t, s in by_type.items()}
165
+
166
+ return {
167
+ "n_samples": len(results),
168
+ "mean_reward": sum(total_rewards) / len(total_rewards) if total_rewards else 0,
169
+ "verifier_means": verifier_means,
170
+ "by_question_type": type_means,
171
+ "per_sample": results,
172
+ }
173
+
174
+
175
+ def evaluate_calibration(results: List[Dict]) -> Dict:
176
+ """Compute calibration metrics from evaluation results."""
177
+ confidences = [r["confidence"] for r in results]
178
+
179
+ # Correctness: reward > 0.5 considered "correct"
180
+ correctnesses = [r["total_reward"] > 0.5 for r in results]
181
+
182
+ metrics = compute_calibration_metrics(
183
+ confidences=confidences,
184
+ correctnesses=correctnesses,
185
+ )
186
+
187
+ return {
188
+ "ece": metrics.ece,
189
+ "mce": metrics.mce,
190
+ "brier_score": metrics.brier_score,
191
+ "overconfidence_rate": metrics.overconfidence_rate,
192
+ "underconfidence_rate": metrics.underconfidence_rate,
193
+ "mean_confidence": metrics.mean_confidence,
194
+ "mean_accuracy": metrics.mean_accuracy,
195
+ "n_samples": metrics.n_samples,
196
+ "reliability_bins": metrics.reliability_bins,
197
+ }
198
+
199
+
200
+ def main():
201
+ parser = argparse.ArgumentParser(
202
+ description="Evaluate a BioGRPO-trained model",
203
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
204
+ )
205
+ parser.add_argument(
206
+ "--model", type=str, required=True,
207
+ help="Path to the GRPO-trained model (LoRA adapter directory)",
208
+ )
209
+ parser.add_argument(
210
+ "--base-model", type=str, default="mistralai/Mistral-7B-v0.3",
211
+ help="Base model name",
212
+ )
213
+ parser.add_argument(
214
+ "--sft-baseline", type=str, default=None,
215
+ help="Path to SFT baseline model for comparison",
216
+ )
217
+ parser.add_argument(
218
+ "--hold-out-tissues", type=str, nargs="+", default=["eye"],
219
+ help="Tissues held out for evaluation",
220
+ )
221
+ parser.add_argument(
222
+ "--pathway-db", type=str, default="hallmark",
223
+ help="Pathway database",
224
+ )
225
+ parser.add_argument(
226
+ "--active-verifiers", type=str, nargs="+", default=None,
227
+ help="Active verifiers (default: all)",
228
+ )
229
+ parser.add_argument(
230
+ "--max-samples", type=int, default=None,
231
+ help="Max samples to evaluate (for quick testing)",
232
+ )
233
+ parser.add_argument(
234
+ "--output", type=str, default=None,
235
+ help="Output path for results JSON",
236
+ )
237
+ parser.add_argument(
238
+ "--no-4bit", action="store_true",
239
+ help="Disable 4-bit quantization",
240
+ )
241
+
242
+ args = parser.parse_args()
243
+
244
+ print("=" * 60)
245
+ print("BioGRPO Evaluation")
246
+ print("=" * 60)
247
+ print(f" Model: {args.model}")
248
+ print(f" Base: {args.base_model}")
249
+ print(f" Hold-out: {args.hold_out_tissues}")
250
+ print(f" SFT baseline: {args.sft_baseline or 'None'}")
251
+ print(f" Time: {datetime.now().isoformat()}")
252
+ print("=" * 60)
253
+
254
+ # Build eval dataset
255
+ print("\n[1/4] Building evaluation dataset...")
256
+ _, eval_dataset = build_grpo_dataset(
257
+ db=args.pathway_db,
258
+ hold_out_tissues=args.hold_out_tissues,
259
+ )
260
+ eval_stats = get_dataset_stats(eval_dataset)
261
+ print(f" Eval samples: {eval_stats['total']}")
262
+ print(f" By source: {eval_stats['by_source']}")
263
+ print(f" By type: {eval_stats['by_question_type']}")
264
+
265
+ # Create verifier composer
266
+ composer = VerifierComposer(active_verifiers=args.active_verifiers)
267
+
268
+ # Evaluate GRPO model
269
+ print(f"\n[2/4] Evaluating GRPO model: {args.model}")
270
+ model, tokenizer = load_model(
271
+ args.model, args.base_model, use_4bit=not args.no_4bit,
272
+ )
273
+ grpo_results = evaluate_with_verifiers(
274
+ model, tokenizer, eval_dataset, composer,
275
+ max_samples=args.max_samples,
276
+ )
277
+ grpo_calibration = evaluate_calibration(grpo_results["per_sample"])
278
+
279
+ # Free GPU memory
280
+ del model
281
+ torch.cuda.empty_cache()
282
+
283
+ # Evaluate baseline if provided
284
+ baseline_results = None
285
+ baseline_calibration = None
286
+ if args.sft_baseline:
287
+ print(f"\n[3/4] Evaluating SFT baseline: {args.sft_baseline}")
288
+ baseline_model, baseline_tokenizer = load_model(
289
+ args.sft_baseline, args.base_model, use_4bit=not args.no_4bit,
290
+ )
291
+ baseline_results = evaluate_with_verifiers(
292
+ baseline_model, baseline_tokenizer, eval_dataset, composer,
293
+ max_samples=args.max_samples,
294
+ )
295
+ baseline_calibration = evaluate_calibration(baseline_results["per_sample"])
296
+ del baseline_model
297
+ torch.cuda.empty_cache()
298
+ else:
299
+ print("\n[3/4] Skipping baseline (not provided)")
300
+
301
+ # Print summary
302
+ print("\n[4/4] Results Summary")
303
+ print("=" * 60)
304
+ print(f"GRPO Model: {args.model}")
305
+ print(f" Mean reward: {grpo_results['mean_reward']:.3f}")
306
+ print(f" Per verifier: {grpo_results['verifier_means']}")
307
+ print(f" ECE: {grpo_calibration['ece']:.3f}")
308
+ print(f" Brier: {grpo_calibration['brier_score']:.3f}")
309
+ print(f" Overconfidence: {grpo_calibration['overconfidence_rate']:.3f}")
310
+ print(f" By type: {grpo_results['by_question_type']}")
311
+
312
+ comparison = {}
313
+ if baseline_results:
314
+ print(f"\nSFT Baseline: {args.sft_baseline}")
315
+ print(f" Mean reward: {baseline_results['mean_reward']:.3f}")
316
+ print(f" ECE: {baseline_calibration['ece']:.3f}")
317
+ print(f" Brier: {baseline_calibration['brier_score']:.3f}")
318
+
319
+ delta_reward = grpo_results["mean_reward"] - baseline_results["mean_reward"]
320
+ delta_ece = grpo_calibration["ece"] - baseline_calibration["ece"]
321
+ print(f"\n Delta reward: {delta_reward:+.3f}")
322
+ print(f" Delta ECE: {delta_ece:+.3f} (negative = better)")
323
+
324
+ comparison = {
325
+ "sft_mean_reward": baseline_results["mean_reward"],
326
+ "sft_ece": baseline_calibration["ece"],
327
+ "delta_reward": delta_reward,
328
+ "delta_ece": delta_ece,
329
+ }
330
+
331
+ # Success criteria
332
+ criteria = {
333
+ "reward_above_05": grpo_results["mean_reward"] > 0.5,
334
+ "ece_below_015": grpo_calibration["ece"] < 0.15,
335
+ }
336
+ if baseline_results:
337
+ criteria["reward_above_baseline"] = delta_reward > 0
338
+ criteria["overall_pass"] = all(criteria.values())
339
+
340
+ print(f"\nSuccess criteria: {criteria}")
341
+
342
+ # Save results
343
+ output_path = args.output or f"results/grpo_eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
344
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
345
+
346
+ output_data = {
347
+ "model_path": args.model,
348
+ "base_model": args.base_model,
349
+ "evaluation_date": datetime.now().isoformat(),
350
+ "hold_out_tissues": args.hold_out_tissues,
351
+ "eval_dataset_stats": eval_stats,
352
+ "grpo": {
353
+ "mean_reward": grpo_results["mean_reward"],
354
+ "verifier_means": grpo_results["verifier_means"],
355
+ "by_question_type": grpo_results["by_question_type"],
356
+ "n_samples": grpo_results["n_samples"],
357
+ },
358
+ "calibration": grpo_calibration,
359
+ "baseline_comparison": comparison,
360
+ "success_criteria": criteria,
361
+ "per_sample": grpo_results["per_sample"],
362
+ }
363
+
364
+ with open(output_path, "w") as f:
365
+ json.dump(output_data, f, indent=2)
366
+
367
+ print(f"\nResults saved to: {output_path}")
368
+ print("=" * 60)
369
+
370
+
371
+ if __name__ == "__main__":
372
+ main()
scripts/merge_training_data.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Merge BioRLHF training data with ecosystem failure examples.
4
+
5
+ This script:
6
+ 1. Loads existing kmp_sft_final.json training data
7
+ 2. Loads ecosystem_failures_training.json (failure-based examples)
8
+ 3. Converts failure examples to the same format
9
+ 4. Outputs combined_training.json
10
+
11
+ Usage:
12
+ python scripts/merge_training_data.py
13
+ """
14
+
15
+ import json
16
+ from pathlib import Path
17
+ from datetime import datetime
18
+
19
+
20
+ def load_json(filepath: str) -> dict | list:
21
+ """Load JSON file."""
22
+ with open(filepath, 'r') as f:
23
+ return json.load(f)
24
+
25
+
26
+ def save_json(data: list, filepath: str):
27
+ """Save JSON file."""
28
+ with open(filepath, 'w') as f:
29
+ json.dump(data, f, indent=2)
30
+ print(f"Saved {len(data)} examples to {filepath}")
31
+
32
+
33
+ def convert_conversation_to_text(conversation: list) -> str:
34
+ """
35
+ Convert conversation format to text format.
36
+
37
+ Input: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
38
+ Output: "### Instruction:\n...\n\n### Response:\n..."
39
+ """
40
+ instruction = ""
41
+ response = ""
42
+
43
+ for turn in conversation:
44
+ if turn["role"] == "user":
45
+ instruction = turn["content"]
46
+ elif turn["role"] == "assistant":
47
+ response = turn["content"]
48
+
49
+ return f"### Instruction:\n{instruction}\n\n### Response:\n{response}"
50
+
51
+
52
+ def extract_examples_from_failures(failure_data: dict) -> list:
53
+ """
54
+ Extract and convert all examples from failure training data.
55
+ """
56
+ examples = []
57
+
58
+ # Process calibration examples
59
+ for ex in failure_data.get("calibration_examples", []):
60
+ text = convert_conversation_to_text(ex["conversations"])
61
+ examples.append({
62
+ "text": text,
63
+ "source": f"ecosystem_failures:{ex['type']}",
64
+ "id": ex["id"]
65
+ })
66
+
67
+ # Process adversarial resistance examples
68
+ for ex in failure_data.get("adversarial_resistance_examples", []):
69
+ text = convert_conversation_to_text(ex["conversations"])
70
+ examples.append({
71
+ "text": text,
72
+ "source": f"ecosystem_failures:{ex['type']}",
73
+ "id": ex["id"]
74
+ })
75
+
76
+ # Process completeness examples
77
+ for ex in failure_data.get("completeness_examples", []):
78
+ text = convert_conversation_to_text(ex["conversations"])
79
+ examples.append({
80
+ "text": text,
81
+ "source": f"ecosystem_failures:{ex['type']}",
82
+ "id": ex["id"]
83
+ })
84
+
85
+ # Process fact drilling examples
86
+ for ex in failure_data.get("fact_drilling_examples", []):
87
+ text = convert_conversation_to_text(ex["conversations"])
88
+ examples.append({
89
+ "text": text,
90
+ "source": f"ecosystem_failures:{ex['type']}",
91
+ "id": ex["id"]
92
+ })
93
+
94
+ return examples
95
+
96
+
97
+ def main():
98
+ # Paths
99
+ data_dir = Path(__file__).parent.parent / "data"
100
+ existing_path = data_dir / "kmp_sft_final.json"
101
+ failures_path = data_dir / "ecosystem_failures_training.json"
102
+ output_path = data_dir / "combined_training.json"
103
+
104
+ print("=" * 60)
105
+ print("BioRLHF Training Data Merger")
106
+ print("=" * 60)
107
+
108
+ # Load existing data
109
+ print(f"\n📂 Loading existing data from {existing_path}")
110
+ existing_data = load_json(existing_path)
111
+ print(f" Found {len(existing_data)} existing examples")
112
+
113
+ # Load failure-based examples
114
+ print(f"\n📂 Loading failure examples from {failures_path}")
115
+ failure_data = load_json(failures_path)
116
+
117
+ # Convert failure examples
118
+ print("\n🔄 Converting failure examples to training format...")
119
+ new_examples = extract_examples_from_failures(failure_data)
120
+ print(f" Converted {len(new_examples)} examples")
121
+
122
+ # Show breakdown
123
+ print("\n📊 New examples by type:")
124
+ type_counts = {}
125
+ for ex in new_examples:
126
+ source_type = ex["source"].split(":")[1] if ":" in ex["source"] else ex["source"]
127
+ type_counts[source_type] = type_counts.get(source_type, 0) + 1
128
+ for t, c in sorted(type_counts.items()):
129
+ print(f" - {t}: {c}")
130
+
131
+ # Combine data
132
+ print("\n🔀 Merging datasets...")
133
+
134
+ # Add source field to existing data if not present
135
+ for ex in existing_data:
136
+ if "source" not in ex:
137
+ ex["source"] = "kmp_sft_original"
138
+
139
+ # Combine
140
+ combined = existing_data + new_examples
141
+ print(f" Total examples: {len(combined)}")
142
+
143
+ # Save combined data
144
+ print(f"\n💾 Saving to {output_path}")
145
+ save_json(combined, output_path)
146
+
147
+ # Summary
148
+ print("\n" + "=" * 60)
149
+ print("✅ MERGE COMPLETE")
150
+ print("=" * 60)
151
+ print(f" Original examples: {len(existing_data)}")
152
+ print(f" New examples: {len(new_examples)}")
153
+ print(f" Total combined: {len(combined)}")
154
+ print(f"\n Output: {output_path}")
155
+ print("\nNext step: Run training with combined data:")
156
+ print(" python sft_train_v2.py --dataset data/combined_training.json")
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()
scripts/run_eval_grpo.sh ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=eval_grpo
3
+ #SBATCH --partition=scu-gpu
4
+ #SBATCH --account=cayuga_0003
5
+ #SBATCH --gres=gpu:1
6
+ #SBATCH --mem=64G
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --time=4:00:00
9
+ #SBATCH --output=logs/eval_grpo_%j.log
10
+ #SBATCH --error=logs/eval_grpo_%j.err
11
+
12
+ # ============================================================
13
+ # BioGRPO Post-Training Evaluation
14
+ # Evaluates GRPO model + SFT baseline comparison
15
+ # ============================================================
16
+
17
+ SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
18
+ WORKDIR="${SCRATCH}/training/BioRLHF"
19
+
20
+ echo "============================================================"
21
+ echo "BioGRPO Evaluation"
22
+ echo "Job ID: $SLURM_JOB_ID"
23
+ echo "Node: $SLURMD_NODENAME"
24
+ echo "Working dir: $WORKDIR"
25
+ echo "Start time: $(date)"
26
+ echo "============================================================"
27
+
28
+ cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; }
29
+ mkdir -p logs results
30
+
31
+ module purge
32
+ module load cuda/12.1
33
+
34
+ . /home/fs01/jak4013/miniconda3/miniconda3/etc/profile.d/conda.sh
35
+ conda activate biorlhf
36
+
37
+ echo ""
38
+ echo "GPU Information:"
39
+ nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
40
+ echo ""
41
+
42
+ export CUDA_VISIBLE_DEVICES=0
43
+ export TRANSFORMERS_CACHE="${WORKDIR}/cache/transformers"
44
+ export HF_HOME="${WORKDIR}/cache/huggingface"
45
+ export TOKENIZERS_PARALLELISM=false
46
+
47
+ # Data paths
48
+ export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark"
49
+ export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data"
50
+ export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm"
51
+ export BIOEVAL_ROOT="${SCRATCH}/data/BioEval"
52
+
53
+ # Model paths
54
+ GRPO_MODEL="./biogrpo_mve_model"
55
+ SFT_BASELINE="./kmp_sft_model_final"
56
+ OUTPUT="results/grpo_mve_eval_$(date +%Y%m%d_%H%M%S).json"
57
+
58
+ echo "GRPO model: $GRPO_MODEL"
59
+ echo "SFT baseline: $SFT_BASELINE"
60
+ echo "Output: $OUTPUT"
61
+ echo ""
62
+
63
+ # Check model exists
64
+ if [ ! -d "$GRPO_MODEL" ]; then
65
+ echo "ERROR: GRPO model not found at $GRPO_MODEL"
66
+ echo "Available directories:"
67
+ ls -d biogrpo_* 2>/dev/null || echo " No biogrpo_* dirs found"
68
+ exit 1
69
+ fi
70
+
71
+ echo "Starting BioGRPO evaluation..."
72
+ python scripts/evaluate_grpo.py \
73
+ --model "$GRPO_MODEL" \
74
+ --sft-baseline "$SFT_BASELINE" \
75
+ --hold-out-tissues eye \
76
+ --output "$OUTPUT"
77
+
78
+ if [ $? -eq 0 ]; then
79
+ echo ""
80
+ echo "============================================================"
81
+ echo "BioGRPO evaluation completed!"
82
+ echo "Results: $OUTPUT"
83
+ echo "End time: $(date)"
84
+ echo "============================================================"
85
+ else
86
+ echo ""
87
+ echo "============================================================"
88
+ echo "BioGRPO evaluation failed with exit code $?"
89
+ echo "Check logs/eval_grpo_${SLURM_JOB_ID}.err for details"
90
+ echo "============================================================"
91
+ exit 1
92
+ fi
scripts/run_evaluation.sh ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #
3
+ # BioRLHF Model Evaluation Script
4
+ # ================================
5
+ #
6
+ # Evaluates the ecosystem-improved model on:
7
+ # - Calibration (uncertainty expression)
8
+ # - Adversarial resistance
9
+ # - Protocol completeness
10
+ # - Fact recall
11
+ #
12
+ # Usage on HPC:
13
+ # srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=1:00:00 --pty bash
14
+ # conda activate biorlhf
15
+ # ./scripts/run_evaluation.sh
16
+ #
17
+
18
+ echo "============================================================"
19
+ echo "BioRLHF Ecosystem Model Evaluation"
20
+ echo "============================================================"
21
+ echo "Start time: $(date)"
22
+ echo "Host: $(hostname)"
23
+ echo ""
24
+
25
+ # Set working directory
26
+ cd "$(dirname "$0")/.." || exit 1
27
+ echo "Working directory: $(pwd)"
28
+
29
+ # Check GPU
30
+ echo ""
31
+ echo "GPU Information:"
32
+ nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv 2>/dev/null || echo "No GPU detected"
33
+ echo ""
34
+
35
+ # Configuration
36
+ MODEL_PATH="./ecosystem_improved_model"
37
+ TEST_DATA="data/ecosystem_failures_training.json"
38
+ OUTPUT="ecosystem_eval_results_$(date +%Y%m%d_%H%M%S).json"
39
+
40
+ echo "============================================================"
41
+ echo "Configuration:"
42
+ echo "============================================================"
43
+ echo "Model: $MODEL_PATH"
44
+ echo "Test data: $TEST_DATA"
45
+ echo "Output: $OUTPUT"
46
+ echo ""
47
+
48
+ # Check files exist
49
+ if [ ! -d "$MODEL_PATH" ]; then
50
+ echo "ERROR: Model not found at $MODEL_PATH"
51
+ exit 1
52
+ fi
53
+
54
+ if [ ! -f "$TEST_DATA" ]; then
55
+ echo "ERROR: Test data not found at $TEST_DATA"
56
+ exit 1
57
+ fi
58
+
59
+ # Run evaluation
60
+ echo "============================================================"
61
+ echo "Starting Evaluation..."
62
+ echo "============================================================"
63
+
64
+ python3 scripts/evaluate_ecosystem_model.py \
65
+ --model "$MODEL_PATH" \
66
+ --test-data "$TEST_DATA" \
67
+ --output "$OUTPUT"
68
+
69
+ # Check exit status
70
+ if [ $? -eq 0 ]; then
71
+ echo ""
72
+ echo "============================================================"
73
+ echo "Evaluation Complete!"
74
+ echo "============================================================"
75
+ echo "Results saved to: $OUTPUT"
76
+ echo "End time: $(date)"
77
+ else
78
+ echo ""
79
+ echo "============================================================"
80
+ echo "Evaluation Failed!"
81
+ echo "============================================================"
82
+ exit 1
83
+ fi
scripts/run_grpo_full.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=biogrpo_full
3
+ #SBATCH --partition=scu-gpu
4
+ #SBATCH --account=cayuga_0003
5
+ #SBATCH --gres=gpu:1
6
+ #SBATCH --mem=96G
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --time=24:00:00
9
+ #SBATCH --output=logs/grpo_full_%j.log
10
+ #SBATCH --error=logs/grpo_full_%j.err
11
+
12
+ # ============================================================
13
+ # BioGRPO Full Experiment
14
+ # All V1-V4 verifiers, G=8, from SFT checkpoint
15
+ # ============================================================
16
+
17
+ SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
18
+ WORKDIR="${SCRATCH}/training/BioRLHF"
19
+
20
+ echo "============================================================"
21
+ echo "BioGRPO Full Training"
22
+ echo "Job ID: $SLURM_JOB_ID"
23
+ echo "Node: $SLURMD_NODENAME"
24
+ echo "Working dir: $WORKDIR"
25
+ echo "Start time: $(date)"
26
+ echo "============================================================"
27
+
28
+ cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; }
29
+ mkdir -p logs
30
+
31
+ module purge
32
+ module load cuda/12.1
33
+
34
+ . /home/fs01/jak4013/miniconda3/miniconda3/etc/profile.d/conda.sh
35
+ conda activate biorlhf
36
+
37
+ echo ""
38
+ echo "GPU Information:"
39
+ nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
40
+ echo ""
41
+
42
+ export CUDA_VISIBLE_DEVICES=0
43
+ export TRANSFORMERS_CACHE="${WORKDIR}/cache/transformers"
44
+ export HF_HOME="${WORKDIR}/cache/huggingface"
45
+ export WANDB_DIR="${WORKDIR}/wandb"
46
+ export TOKENIZERS_PARALLELISM=false
47
+
48
+ # Data paths
49
+ export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark"
50
+ export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data"
51
+ export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm"
52
+ export BIOEVAL_ROOT="${SCRATCH}/data/BioEval"
53
+
54
+ mkdir -p $TRANSFORMERS_CACHE $HF_HOME $WANDB_DIR
55
+
56
+ # Symlink SFT checkpoint if not already present
57
+ if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then
58
+ ln -s "${SCRATCH}/training/biorlhf/kmp_sft_model_final" "${WORKDIR}/kmp_sft_model_final"
59
+ echo "Symlinked kmp_sft_model_final"
60
+ fi
61
+
62
+ echo "Starting BioGRPO Full training..."
63
+ biorlhf-grpo --config configs/grpo_full.json
64
+
65
+ if [ $? -eq 0 ]; then
66
+ echo ""
67
+ echo "============================================================"
68
+ echo "BioGRPO Full training completed!"
69
+ echo "Model saved to: ./biogrpo_full_model"
70
+ echo "End time: $(date)"
71
+ echo "============================================================"
72
+ else
73
+ echo ""
74
+ echo "============================================================"
75
+ echo "BioGRPO Full training failed with exit code $?"
76
+ echo "Check logs/grpo_full_${SLURM_JOB_ID}.err for details"
77
+ echo "============================================================"
78
+ exit 1
79
+ fi
scripts/run_grpo_mve.sh ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=biogrpo_mve
3
+ #SBATCH --partition=scu-gpu
4
+ #SBATCH --account=cayuga_0003
5
+ #SBATCH --gres=gpu:1
6
+ #SBATCH --mem=64G
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --time=48:00:00
9
+ #SBATCH --output=logs/grpo_mve_%j.log
10
+ #SBATCH --error=logs/grpo_mve_%j.err
11
+
12
+ # ============================================================
13
+ # BioGRPO Minimum Viable Experiment (MVE)
14
+ # V1+V4 verifiers, G=4, from SFT checkpoint
15
+ # ============================================================
16
+
17
+ SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
18
+ WORKDIR="${SCRATCH}/training/BioRLHF"
19
+
20
+ echo "============================================================"
21
+ echo "BioGRPO MVE Training"
22
+ echo "Job ID: $SLURM_JOB_ID"
23
+ echo "Node: $SLURMD_NODENAME"
24
+ echo "Working dir: $WORKDIR"
25
+ echo "Start time: $(date)"
26
+ echo "============================================================"
27
+
28
+ cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; }
29
+ mkdir -p logs
30
+
31
+ # Load modules
32
+ module purge
33
+ module load cuda/12.1
34
+
35
+ # Activate conda environment
36
+ . /home/fs01/jak4013/miniconda3/miniconda3/etc/profile.d/conda.sh
37
+ conda activate biorlhf
38
+
39
+ # Verify GPU
40
+ echo ""
41
+ echo "GPU Information:"
42
+ nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
43
+ echo ""
44
+
45
+ # Set environment variables
46
+ export CUDA_VISIBLE_DEVICES=0
47
+ export TRANSFORMERS_CACHE="${WORKDIR}/cache/transformers"
48
+ export HF_HOME="${WORKDIR}/cache/huggingface"
49
+ export WANDB_DIR="${WORKDIR}/wandb"
50
+ export TOKENIZERS_PARALLELISM=false
51
+
52
+ # Data paths
53
+ export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark"
54
+ export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data"
55
+ export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm"
56
+ export BIOEVAL_ROOT="${SCRATCH}/data/BioEval"
57
+
58
+ mkdir -p $TRANSFORMERS_CACHE $HF_HOME $WANDB_DIR
59
+
60
+ # Symlink SFT checkpoint if not already present
61
+ if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then
62
+ ln -s "${SCRATCH}/training/biorlhf/kmp_sft_model_final" "${WORKDIR}/kmp_sft_model_final"
63
+ echo "Symlinked kmp_sft_model_final"
64
+ fi
65
+
66
+ # Run GRPO MVE training
67
+ echo "Starting BioGRPO MVE training..."
68
+ biorlhf-grpo --config configs/grpo_mve.json
69
+
70
+ if [ $? -eq 0 ]; then
71
+ echo ""
72
+ echo "============================================================"
73
+ echo "BioGRPO MVE training completed!"
74
+ echo "Model saved to: ./biogrpo_mve_model"
75
+ echo "End time: $(date)"
76
+ echo "============================================================"
77
+ else
78
+ echo ""
79
+ echo "============================================================"
80
+ echo "BioGRPO MVE training failed with exit code $?"
81
+ echo "Check logs/grpo_mve_${SLURM_JOB_ID}.err for details"
82
+ echo "============================================================"
83
+ exit 1
84
+ fi
scripts/setup_cayuga_grpo.sh ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================
3
+ # BioGRPO Environment Setup for Cayuga HPC
4
+ # Run once to verify/upgrade GRPO dependencies
5
+ # ============================================================
6
+
7
+ SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
8
+ WORKDIR="${SCRATCH}/training/BioRLHF"
9
+
10
+ echo "============================================================"
11
+ echo "BioGRPO Environment Setup"
12
+ echo "Working dir: $WORKDIR"
13
+ echo "============================================================"
14
+
15
+ cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; }
16
+
17
+ # Activate environment
18
+ source ~/.bashrc
19
+ conda activate biorlhf
20
+
21
+ # Step 1: Check current versions
22
+ echo ""
23
+ echo "[1/6] Current package versions..."
24
+ python -c "import trl; print(f' TRL: {trl.__version__}')"
25
+ python -c "import peft; print(f' PEFT: {peft.__version__}')"
26
+ python -c "import transformers; print(f' Transformers: {transformers.__version__}')"
27
+ python -c "import torch; print(f' PyTorch: {torch.__version__}'); print(f' CUDA: {torch.cuda.is_available()}')"
28
+
29
+ # Step 2: Upgrade TRL if needed
30
+ echo ""
31
+ echo "[2/6] Ensuring TRL >= 0.26.0..."
32
+ pip install "trl>=0.26.0" --upgrade --quiet
33
+
34
+ # Step 3: Verify GRPO imports
35
+ echo ""
36
+ echo "[3/6] Verifying GRPO imports..."
37
+ python -c "
38
+ from trl import GRPOTrainer, GRPOConfig
39
+ print(' GRPOTrainer: OK')
40
+ print(' GRPOConfig: OK')
41
+ config = GRPOConfig(output_dir='/tmp/test', scale_rewards='group', loss_type='grpo')
42
+ print(f' scale_rewards={config.scale_rewards}, loss_type={config.loss_type}: OK')
43
+ "
44
+
45
+ # Step 4: Install biorlhf package
46
+ echo ""
47
+ echo "[4/6] Installing biorlhf package..."
48
+ pip install -e . --quiet 2>/dev/null || pip install -e . 2>&1 | tail -3
49
+
50
+ # Step 5: Verify biorlhf imports
51
+ echo ""
52
+ echo "[5/6] Verifying biorlhf imports..."
53
+ python -c "
54
+ from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training
55
+ print(' BioGRPOConfig: OK')
56
+ from biorlhf.verifiers.composer import make_grpo_reward_function
57
+ print(' make_grpo_reward_function: OK')
58
+ from biorlhf.data.grpo_dataset import build_grpo_dataset
59
+ print(' build_grpo_dataset: OK')
60
+ from biorlhf.evaluation.calibration import compute_calibration_metrics
61
+ print(' compute_calibration_metrics: OK')
62
+ "
63
+
64
+ # Step 6: Smoke test
65
+ echo ""
66
+ echo "[6/6] Running smoke test..."
67
+ python -c "
68
+ from biorlhf.verifiers.composer import make_grpo_reward_function
69
+ import json
70
+ reward_fn = make_grpo_reward_function(active_verifiers=['V1', 'V4'])
71
+ rewards = reward_fn(
72
+ completions=['Oxidative phosphorylation is upregulated. Confidence: high.'],
73
+ ground_truth=[json.dumps({
74
+ 'pathway': 'HALLMARK_OXIDATIVE_PHOSPHORYLATION',
75
+ 'direction': 'UP',
76
+ 'expected_confidence': 'high',
77
+ })],
78
+ question_type=['direction'],
79
+ applicable_verifiers=[json.dumps(['V1', 'V4'])],
80
+ )
81
+ print(f' Reward: {rewards[0]:.3f} (expected > 0.5)')
82
+ assert rewards[0] > 0.3, 'Reward too low'
83
+ print(' Smoke test: PASSED')
84
+ "
85
+
86
+ # Create directories
87
+ mkdir -p logs configs results cache/transformers cache/huggingface wandb
88
+
89
+ # Step 6b: Symlink SFT checkpoint
90
+ echo ""
91
+ echo "[6b/7] Setting up SFT checkpoint symlink..."
92
+ if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then
93
+ if [ -d "${SCRATCH}/training/biorlhf/kmp_sft_model_final" ]; then
94
+ ln -s "${SCRATCH}/training/biorlhf/kmp_sft_model_final" "${WORKDIR}/kmp_sft_model_final"
95
+ echo " Symlinked kmp_sft_model_final: OK"
96
+ else
97
+ echo " WARNING: kmp_sft_model_final not found at ${SCRATCH}/training/biorlhf/"
98
+ echo " You will need to provide the SFT checkpoint manually"
99
+ fi
100
+ else
101
+ echo " kmp_sft_model_final already exists: OK"
102
+ fi
103
+
104
+ # Step 7: Verify data paths
105
+ echo ""
106
+ echo "[7/7] Verifying data availability..."
107
+ export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark"
108
+ export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data"
109
+ export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm"
110
+ export BIOEVAL_ROOT="${SCRATCH}/data/BioEval"
111
+
112
+ for d in "$GENELAB_BASE" "$BIOEVAL_DATA" "$SPACEOMICS_DATA" "$BIOEVAL_ROOT"; do
113
+ if [ -d "$d" ]; then
114
+ echo " $d: OK"
115
+ else
116
+ echo " $d: MISSING"
117
+ fi
118
+ done
119
+
120
+ echo ""
121
+ echo "============================================================"
122
+ echo "BioGRPO setup complete!"
123
+ echo ""
124
+ echo "Next steps:"
125
+ echo " sbatch scripts/run_grpo_mve.sh"
126
+ echo " tail -f logs/grpo_mve_*.log"
127
+ echo "============================================================"
scripts/train_ecosystem_improved.sh ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #
3
+ # BioRLHF Training Script - Ecosystem Improved Model
4
+ # ====================================================
5
+ #
6
+ # This script trains a model on the combined dataset including:
7
+ # - Original KMP study data (363 examples)
8
+ # - Ecosystem failure-based examples (15 examples)
9
+ # - Calibration training
10
+ # - Adversarial resistance
11
+ # - Protocol completeness
12
+ # - Fact drilling
13
+ #
14
+ # Requirements:
15
+ # - CUDA-capable GPU (recommended: A100, V100, or 4090)
16
+ # - 24GB+ VRAM for Mistral-7B with 4-bit quantization
17
+ # - Python environment with: torch, transformers, peft, trl, bitsandbytes
18
+ #
19
+ # Usage:
20
+ # ./scripts/train_ecosystem_improved.sh
21
+ #
22
+ # Or on HPC with SLURM:
23
+ # sbatch scripts/train_ecosystem_improved.sh
24
+ #
25
+
26
+ # ==============================================================================
27
+ # SLURM Configuration (for HPC clusters - uncomment if using SLURM)
28
+ # ==============================================================================
29
+ #SBATCH --job-name=biorlhf_ecosystem
30
+ #SBATCH --output=logs/biorlhf_ecosystem_%j.out
31
+ #SBATCH --error=logs/biorlhf_ecosystem_%j.err
32
+ #SBATCH --time=4:00:00
33
+ #SBATCH --gres=gpu:1
34
+ #SBATCH --mem=48G
35
+ #SBATCH --cpus-per-task=8
36
+
37
+ # ==============================================================================
38
+ # Environment Setup
39
+ # ==============================================================================
40
+ echo "============================================================"
41
+ echo "BioRLHF Ecosystem Training"
42
+ echo "============================================================"
43
+ echo "Start time: $(date)"
44
+ echo "Host: $(hostname)"
45
+ echo ""
46
+
47
+ # Activate conda environment (adjust path as needed)
48
+ # source /path/to/conda/etc/profile.d/conda.sh
49
+ # conda activate biorlhf
50
+
51
+ # Set working directory
52
+ cd "$(dirname "$0")/.." || exit 1
53
+ echo "Working directory: $(pwd)"
54
+
55
+ # Check GPU
56
+ echo ""
57
+ echo "GPU Information:"
58
+ nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv 2>/dev/null || echo "No GPU detected"
59
+ echo ""
60
+
61
+ # ==============================================================================
62
+ # Training Configuration
63
+ # ==============================================================================
64
+
65
+ # Model settings
66
+ MODEL="mistralai/Mistral-7B-v0.3"
67
+ DATASET="data/combined_training.json"
68
+ OUTPUT_DIR="./ecosystem_improved_model"
69
+
70
+ # Training hyperparameters (optimized based on prior BioRLHF experiments)
71
+ EPOCHS=10 # More epochs for better fact memorization
72
+ BATCH_SIZE=4 # Adjust based on GPU memory
73
+ GRAD_ACCUM=4 # Effective batch size = 16
74
+ LEARNING_RATE=2e-4 # Standard for LoRA fine-tuning
75
+ MAX_LENGTH=1024 # Sufficient for most examples
76
+
77
+ # LoRA configuration (higher rank for domain knowledge)
78
+ LORA_R=64 # Higher rank for better capacity
79
+ LORA_ALPHA=128 # Alpha = 2 * r
80
+
81
+ # Logging
82
+ WANDB_PROJECT="biorlhf"
83
+ WANDB_RUN="ecosystem_improved_$(date +%Y%m%d_%H%M%S)"
84
+
85
+ # ==============================================================================
86
+ # Pre-training Checks
87
+ # ==============================================================================
88
+ echo "============================================================"
89
+ echo "Configuration:"
90
+ echo "============================================================"
91
+ echo "Model: $MODEL"
92
+ echo "Dataset: $DATASET"
93
+ echo "Output: $OUTPUT_DIR"
94
+ echo "Epochs: $EPOCHS"
95
+ echo "Batch size: $BATCH_SIZE (effective: $((BATCH_SIZE * GRAD_ACCUM)))"
96
+ echo "LoRA r/α: $LORA_R / $LORA_ALPHA"
97
+ echo "Max length: $MAX_LENGTH"
98
+ echo ""
99
+
100
+ # Check if dataset exists
101
+ if [ ! -f "$DATASET" ]; then
102
+ echo "ERROR: Dataset not found at $DATASET"
103
+ echo "Run: python scripts/merge_training_data.py"
104
+ exit 1
105
+ fi
106
+
107
+ # Count examples
108
+ EXAMPLE_COUNT=$(python3 -c "import json; print(len(json.load(open('$DATASET'))))")
109
+ echo "Dataset contains $EXAMPLE_COUNT examples"
110
+ echo ""
111
+
112
+ # ==============================================================================
113
+ # Run Training
114
+ # ==============================================================================
115
+ echo "============================================================"
116
+ echo "Starting Training..."
117
+ echo "============================================================"
118
+
119
+ python3 sft_train_v2.py \
120
+ --model "$MODEL" \
121
+ --dataset "$DATASET" \
122
+ --output_dir "$OUTPUT_DIR" \
123
+ --epochs $EPOCHS \
124
+ --batch_size $BATCH_SIZE \
125
+ --grad_accum $GRAD_ACCUM \
126
+ --lr $LEARNING_RATE \
127
+ --max_length $MAX_LENGTH \
128
+ --lora_r $LORA_R \
129
+ --lora_alpha $LORA_ALPHA \
130
+ --use_4bit \
131
+ --wandb_project "$WANDB_PROJECT" \
132
+ --wandb_run "$WANDB_RUN"
133
+
134
+ # Check exit status
135
+ if [ $? -eq 0 ]; then
136
+ echo ""
137
+ echo "============================================================"
138
+ echo "✅ Training Complete!"
139
+ echo "============================================================"
140
+ echo "Model saved to: $OUTPUT_DIR"
141
+ echo "End time: $(date)"
142
+ echo ""
143
+ echo "Next steps:"
144
+ echo "1. Evaluate on SpaceOmicsBench: python evaluate_model.py --model $OUTPUT_DIR"
145
+ echo "2. Evaluate on CAMELOT: python evaluate_model.py --model $OUTPUT_DIR --benchmark camelot"
146
+ echo "3. Compare with baseline: python compare_models.py"
147
+ else
148
+ echo ""
149
+ echo "============================================================"
150
+ echo "❌ Training Failed!"
151
+ echo "============================================================"
152
+ echo "Check the error messages above."
153
+ exit 1
154
+ fi
src/biorlhf/__init__.py CHANGED
@@ -9,10 +9,30 @@ __version__ = "0.1.0"
9
  __author__ = "JangKeun Kim"
10
  __email__ = "jangkeun.kim@med.cornell.edu"
11
 
12
- from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
13
- from biorlhf.training.dpo import DPOTrainingConfig, run_dpo_training
14
- from biorlhf.data.dataset import create_sft_dataset, load_dataset
15
- from biorlhf.evaluation.evaluate import evaluate_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  __all__ = [
18
  "__version__",
 
9
  __author__ = "JangKeun Kim"
10
  __email__ = "jangkeun.kim@med.cornell.edu"
11
 
12
+ def __getattr__(name):
13
+ """Lazy imports for torch-dependent modules."""
14
+ if name == "SFTTrainingConfig":
15
+ from biorlhf.training.sft import SFTTrainingConfig
16
+ return SFTTrainingConfig
17
+ elif name == "run_sft_training":
18
+ from biorlhf.training.sft import run_sft_training
19
+ return run_sft_training
20
+ elif name == "DPOTrainingConfig":
21
+ from biorlhf.training.dpo import DPOTrainingConfig
22
+ return DPOTrainingConfig
23
+ elif name == "run_dpo_training":
24
+ from biorlhf.training.dpo import run_dpo_training
25
+ return run_dpo_training
26
+ elif name == "create_sft_dataset":
27
+ from biorlhf.data.dataset import create_sft_dataset
28
+ return create_sft_dataset
29
+ elif name == "load_dataset":
30
+ from biorlhf.data.dataset import load_dataset
31
+ return load_dataset
32
+ elif name == "evaluate_model":
33
+ from biorlhf.evaluation.evaluate import evaluate_model
34
+ return evaluate_model
35
+ raise AttributeError(f"module 'biorlhf' has no attribute {name!r}")
36
 
37
  __all__ = [
38
  "__version__",
src/biorlhf/cli.py CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
11
 
12
  from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
13
  from biorlhf.evaluation.evaluate import evaluate_model as _evaluate_model
 
14
 
15
 
16
  def train():
@@ -264,5 +265,131 @@ def evaluate():
264
  sys.exit(1)
265
 
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  if __name__ == "__main__":
268
- print("Use 'biorlhf-train' or 'biorlhf-evaluate' commands after installation.")
 
11
 
12
  from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
13
  from biorlhf.evaluation.evaluate import evaluate_model as _evaluate_model
14
+ from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training
15
 
16
 
17
  def train():
 
265
  sys.exit(1)
266
 
267
 
268
+ def grpo_train():
269
+ """CLI entry point for GRPO training with biological verifiers."""
270
+ parser = argparse.ArgumentParser(
271
+ description="Train a BioGRPO model with composable biological verifiers",
272
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
273
+ )
274
+
275
+ parser.add_argument(
276
+ "--model",
277
+ type=str,
278
+ default="mistralai/Mistral-7B-v0.3",
279
+ help="Base model to fine-tune",
280
+ )
281
+ parser.add_argument(
282
+ "--sft-model",
283
+ type=str,
284
+ default=None,
285
+ help="Path to SFT checkpoint (recommended)",
286
+ )
287
+ parser.add_argument(
288
+ "--output",
289
+ type=str,
290
+ default="./biogrpo_model",
291
+ help="Output directory",
292
+ )
293
+ parser.add_argument(
294
+ "--num-generations",
295
+ type=int,
296
+ default=8,
297
+ help="G value: number of completions per prompt",
298
+ )
299
+ parser.add_argument(
300
+ "--beta",
301
+ type=float,
302
+ default=0.04,
303
+ help="KL penalty coefficient",
304
+ )
305
+ parser.add_argument(
306
+ "--learning-rate",
307
+ type=float,
308
+ default=1e-6,
309
+ help="Learning rate",
310
+ )
311
+ parser.add_argument(
312
+ "--lora-r",
313
+ type=int,
314
+ default=32,
315
+ help="LoRA rank",
316
+ )
317
+ parser.add_argument(
318
+ "--lora-alpha",
319
+ type=int,
320
+ default=64,
321
+ help="LoRA alpha",
322
+ )
323
+ parser.add_argument(
324
+ "--verifiers",
325
+ type=str,
326
+ nargs="+",
327
+ default=None,
328
+ help="Active verifiers (e.g., V1 V2 V3 V4). Default: all",
329
+ )
330
+ parser.add_argument(
331
+ "--pathway-db",
332
+ type=str,
333
+ default="hallmark",
334
+ choices=["hallmark", "kegg", "reactome", "mitocarta"],
335
+ help="Pathway database for GeneLab questions",
336
+ )
337
+ parser.add_argument(
338
+ "--no-wandb",
339
+ action="store_true",
340
+ help="Disable W&B logging",
341
+ )
342
+ parser.add_argument(
343
+ "--wandb-project",
344
+ type=str,
345
+ default="biogrpo",
346
+ help="W&B project name",
347
+ )
348
+ parser.add_argument(
349
+ "--wandb-run-name",
350
+ type=str,
351
+ default="grpo_v1",
352
+ help="W&B run name",
353
+ )
354
+ parser.add_argument(
355
+ "--config",
356
+ type=str,
357
+ default=None,
358
+ help="Path to JSON config file (overrides other args)",
359
+ )
360
+
361
+ args = parser.parse_args()
362
+
363
+ if args.config:
364
+ with open(args.config) as f:
365
+ config_dict = json.load(f)
366
+ config = BioGRPOConfig(**config_dict)
367
+ else:
368
+ config = BioGRPOConfig(
369
+ model_name=args.model,
370
+ sft_model_path=args.sft_model,
371
+ output_dir=args.output,
372
+ num_generations=args.num_generations,
373
+ beta=args.beta,
374
+ learning_rate=args.learning_rate,
375
+ lora_r=args.lora_r,
376
+ lora_alpha=args.lora_alpha,
377
+ active_verifiers=args.verifiers,
378
+ pathway_db=args.pathway_db,
379
+ use_wandb=not args.no_wandb,
380
+ wandb_project=args.wandb_project,
381
+ wandb_run_name=args.wandb_run_name,
382
+ )
383
+
384
+ try:
385
+ output_path = run_grpo_training(config)
386
+ print(f"\nModel saved to: {output_path}")
387
+ except Exception as e:
388
+ import traceback
389
+ traceback.print_exc()
390
+ print(f"Error during GRPO training: {e}", file=sys.stderr)
391
+ sys.exit(1)
392
+
393
+
394
  if __name__ == "__main__":
395
+ print("Use 'biorlhf-train', 'biorlhf-evaluate', or 'biorlhf-grpo' commands after installation.")
src/biorlhf/data/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  """Data processing and dataset creation modules for BioRLHF."""
2
 
3
- from biorlhf.data.dataset import create_sft_dataset, load_dataset
4
  from biorlhf.data.ground_truth import (
5
  STRESSOR_EFFECTS,
6
  KMP_EFFECTS,
@@ -18,3 +18,11 @@ __all__ = [
18
  "TISSUE_TYPES",
19
  "OXPHOS_PATTERNS",
20
  ]
 
 
 
 
 
 
 
 
 
1
  """Data processing and dataset creation modules for BioRLHF."""
2
 
3
+ # ground_truth has no heavy dependencies, safe to import eagerly
4
  from biorlhf.data.ground_truth import (
5
  STRESSOR_EFFECTS,
6
  KMP_EFFECTS,
 
18
  "TISSUE_TYPES",
19
  "OXPHOS_PATTERNS",
20
  ]
21
+
22
+
23
+ def __getattr__(name):
24
+ """Lazy imports for modules with heavy dependencies."""
25
+ if name in ("create_sft_dataset", "load_dataset"):
26
+ from biorlhf.data.dataset import create_sft_dataset, load_dataset
27
+ return {"create_sft_dataset": create_sft_dataset, "load_dataset": load_dataset}[name]
28
+ raise AttributeError(f"module 'biorlhf.data' has no attribute {name!r}")
src/biorlhf/data/genelabloader.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GeneLab fGSEA/GSVA data loading for BioGRPO.
3
+
4
+ Loads pathway enrichment results from the GeneLab_benchmark project's
5
+ processed fGSEA and GSVA files. Provides consensus pathway directions
6
+ across missions for use as verifiable ground truth.
7
+ """
8
+
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Dict, List
12
+ from dataclasses import dataclass, field
13
+ import json
14
+
15
+ import pandas as pd
16
+
17
+ # ── Paths (configurable via env vars for HPC) ─────────────────────────────
18
+ GENELAB_BASE = Path(os.environ.get(
19
+ "GENELAB_BASE",
20
+ "/Users/jak4013/Dropbox/Bioinformatics/Claude/GeneLab_benchmark",
21
+ ))
22
+ FGSEA_DIR = GENELAB_BASE / "processed" / "fgsea"
23
+ GSVA_DIR = GENELAB_BASE / "processed" / "pathway_scores"
24
+ TASKS_DIR = GENELAB_BASE / "tasks"
25
+ EVAL_DIR = GENELAB_BASE / "evaluation"
26
+
27
+ # ── Tissue → available missions (from actual files) ───────────────────────
28
+ TISSUE_MISSIONS: Dict[str, List[str]] = {
29
+ "liver": ["MHU-2", "RR-1", "RR-3", "RR-6", "RR-8", "RR-9"],
30
+ "gastrocnemius": ["RR-1", "RR-9"],
31
+ "kidney": ["RR-1", "RR-3", "RR-7"],
32
+ "thymus": ["MHU-2", "RR-6", "RR-9"],
33
+ "skin": ["MHU-2_dorsal", "MHU-2_femoral", "RR-6"],
34
+ "eye": ["RR-1", "RR-3"],
35
+ }
36
+
37
+ # Tissue → LOMO task ID
38
+ TISSUE_TASK_MAP: Dict[str, str] = {
39
+ "liver": "A1",
40
+ "gastrocnemius": "A2",
41
+ "kidney": "A3",
42
+ "thymus": "A4",
43
+ "skin": "A5",
44
+ "eye": "A6",
45
+ }
46
+
47
+ DBS = ["hallmark", "kegg", "reactome", "mitocarta"]
48
+
49
+
50
+ @dataclass
51
+ class PathwayResult:
52
+ """Single pathway enrichment result from fGSEA."""
53
+ pathway: str
54
+ nes: float
55
+ padj: float
56
+ direction: str # "UP", "DOWN", or "NS"
57
+ tissue: str
58
+ mission: str
59
+ db: str
60
+ leading_edge: List[str] = field(default_factory=list)
61
+
62
+
63
+ # ── Loading functions ──────────────────────────────────────────────────────
64
+
65
+ def load_fgsea(tissue: str, mission: str, db: str = "hallmark") -> pd.DataFrame:
66
+ """Load a single fGSEA result CSV.
67
+
68
+ Returns DataFrame with columns:
69
+ pathway, pval, padj, log2err, ES, NES, size, db,
70
+ leadingEdge_str, tissue, mission, glds
71
+ """
72
+ path = FGSEA_DIR / tissue / f"{mission}_fgsea_{db}.csv"
73
+ if not path.exists():
74
+ raise FileNotFoundError(f"fGSEA file not found: {path}")
75
+ return pd.read_csv(path)
76
+
77
+
78
+ def load_all_fgsea(tissue: str, db: str = "hallmark") -> pd.DataFrame:
79
+ """Load all fGSEA results for a tissue across all available missions."""
80
+ dfs = []
81
+ for mission in TISSUE_MISSIONS.get(tissue, []):
82
+ path = FGSEA_DIR / tissue / f"{mission}_fgsea_{db}.csv"
83
+ if path.exists():
84
+ dfs.append(pd.read_csv(path))
85
+ if not dfs:
86
+ return pd.DataFrame()
87
+ return pd.concat(dfs, ignore_index=True)
88
+
89
+
90
+ def get_pathway_directions(
91
+ tissue: str,
92
+ db: str = "hallmark",
93
+ padj_threshold: float = 0.05,
94
+ ) -> Dict[str, Dict[str, str]]:
95
+ """Return pathway directions per mission.
96
+
97
+ Returns:
98
+ {mission: {pathway: "UP"/"DOWN"/"NS"}}
99
+ Only pathways with padj < threshold get UP/DOWN; rest are NS.
100
+ """
101
+ df = load_all_fgsea(tissue, db)
102
+ if df.empty:
103
+ return {}
104
+
105
+ result: Dict[str, Dict[str, str]] = {}
106
+ for mission, mdf in df.groupby("mission"):
107
+ directions: Dict[str, str] = {}
108
+ for _, row in mdf.iterrows():
109
+ if pd.notna(row["padj"]) and row["padj"] < padj_threshold:
110
+ directions[row["pathway"]] = "UP" if row["NES"] > 0 else "DOWN"
111
+ else:
112
+ directions[row["pathway"]] = "NS"
113
+ result[str(mission)] = directions
114
+ return result
115
+
116
+
117
+ def get_consensus_directions(
118
+ tissue: str,
119
+ db: str = "hallmark",
120
+ min_missions: int = 2,
121
+ padj_threshold: float = 0.05,
122
+ ) -> Dict[str, Dict]:
123
+ """Return pathways with consensus direction across missions.
124
+
125
+ Only includes pathways where >= min_missions agree on direction
126
+ and the majority direction has more votes than the opposite.
127
+
128
+ Returns:
129
+ {pathway: {
130
+ direction: "UP"/"DOWN",
131
+ n_agree: int,
132
+ n_disagree: int,
133
+ n_ns: int,
134
+ missions_agree: List[str],
135
+ missions_disagree: List[str],
136
+ }}
137
+ """
138
+ all_dirs = get_pathway_directions(tissue, db, padj_threshold)
139
+ if not all_dirs:
140
+ return {}
141
+
142
+ # Collect per-pathway votes
143
+ pathway_votes: Dict[str, Dict[str, List[str]]] = {}
144
+ for mission, pmap in all_dirs.items():
145
+ for pathway, direction in pmap.items():
146
+ if pathway not in pathway_votes:
147
+ pathway_votes[pathway] = {"UP": [], "DOWN": [], "NS": []}
148
+ pathway_votes[pathway][direction].append(mission)
149
+
150
+ consensus: Dict[str, Dict] = {}
151
+ for pathway, votes in pathway_votes.items():
152
+ n_up = len(votes["UP"])
153
+ n_down = len(votes["DOWN"])
154
+ n_ns = len(votes["NS"])
155
+
156
+ if n_up >= min_missions and n_up > n_down:
157
+ consensus[pathway] = {
158
+ "direction": "UP",
159
+ "n_agree": n_up,
160
+ "n_disagree": n_down,
161
+ "n_ns": n_ns,
162
+ "missions_agree": votes["UP"],
163
+ "missions_disagree": votes["DOWN"],
164
+ }
165
+ elif n_down >= min_missions and n_down > n_up:
166
+ consensus[pathway] = {
167
+ "direction": "DOWN",
168
+ "n_agree": n_down,
169
+ "n_disagree": n_up,
170
+ "n_ns": n_ns,
171
+ "missions_agree": votes["DOWN"],
172
+ "missions_disagree": votes["UP"],
173
+ }
174
+ return consensus
175
+
176
+
177
+ def get_disagreeing_pathways(
178
+ tissue: str,
179
+ db: str = "hallmark",
180
+ padj_threshold: float = 0.05,
181
+ ) -> Dict[str, Dict]:
182
+ """Return pathways where missions disagree on direction.
183
+
184
+ These are ideal for uncertainty questions — the model should
185
+ express uncertainty about direction.
186
+
187
+ Returns:
188
+ {pathway: {
189
+ missions_up: List[str],
190
+ missions_down: List[str],
191
+ missions_ns: List[str],
192
+ }}
193
+ """
194
+ all_dirs = get_pathway_directions(tissue, db, padj_threshold)
195
+ if not all_dirs:
196
+ return {}
197
+
198
+ pathway_votes: Dict[str, Dict[str, List[str]]] = {}
199
+ for mission, pmap in all_dirs.items():
200
+ for pathway, direction in pmap.items():
201
+ if pathway not in pathway_votes:
202
+ pathway_votes[pathway] = {"UP": [], "DOWN": [], "NS": []}
203
+ pathway_votes[pathway][direction].append(mission)
204
+
205
+ disagreeing: Dict[str, Dict] = {}
206
+ for pathway, votes in pathway_votes.items():
207
+ if votes["UP"] and votes["DOWN"]:
208
+ disagreeing[pathway] = {
209
+ "missions_up": votes["UP"],
210
+ "missions_down": votes["DOWN"],
211
+ "missions_ns": votes["NS"],
212
+ }
213
+ return disagreeing
214
+
215
+
216
+ def load_gsva_scores(
217
+ tissue: str,
218
+ mission: str,
219
+ db: str = "hallmark",
220
+ ) -> pd.DataFrame:
221
+ """Load GSVA pathway scores (samples × pathways)."""
222
+ path = GSVA_DIR / tissue / f"{mission}_gsva_{db}.csv"
223
+ if not path.exists():
224
+ raise FileNotFoundError(f"GSVA file not found: {path}")
225
+ return pd.read_csv(path, index_col=0)
226
+
227
+
228
+ def load_lomo_splits(tissue: str) -> List[Dict]:
229
+ """Load LOMO fold definitions from task_info.json."""
230
+ task_id = TISSUE_TASK_MAP.get(tissue)
231
+ if not task_id:
232
+ return []
233
+ task_dir = TASKS_DIR / f"{task_id}_{tissue}_lomo"
234
+ info_path = task_dir / "task_info.json"
235
+ if not info_path.exists():
236
+ return []
237
+ with open(info_path) as f:
238
+ info = json.load(f)
239
+ return info.get("folds", [])
240
+
241
+
242
+ def load_nes_conservation(db: str = "hallmark") -> Dict:
243
+ """Load NES conservation analysis (cross-mission correlation data)."""
244
+ path = EVAL_DIR / f"NES_conservation_{db}.json"
245
+ if not path.exists():
246
+ return {}
247
+ with open(path) as f:
248
+ return json.load(f)
249
+
250
+
251
+ def get_all_pathways(tissue: str, db: str = "hallmark") -> List[str]:
252
+ """Get sorted list of all pathway names for a tissue/db combo."""
253
+ df = load_all_fgsea(tissue, db)
254
+ if df.empty:
255
+ return []
256
+ return sorted(df["pathway"].unique().tolist())
257
+
258
+
259
+ def get_pathway_nes_matrix(
260
+ tissue: str,
261
+ db: str = "hallmark",
262
+ ) -> pd.DataFrame:
263
+ """Return a mission × pathway NES matrix for a tissue.
264
+
265
+ Useful for visualizing pathway behavior across missions.
266
+ """
267
+ df = load_all_fgsea(tissue, db)
268
+ if df.empty:
269
+ return pd.DataFrame()
270
+ return df.pivot_table(
271
+ index="mission", columns="pathway", values="NES", aggfunc="first",
272
+ )
src/biorlhf/data/grpo_dataset.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified GRPO dataset builder for BioGRPO.
3
+
4
+ Merges pathway questions from GeneLab, calibration tasks from BioEval,
5
+ and domain questions from SpaceOmicsBench into a single TRL-compatible
6
+ dataset with multi-dimensional ground truth.
7
+ """
8
+
9
+ import os
10
+ from pathlib import Path
11
+ from typing import List, Dict, Optional, Tuple
12
+ import json
13
+
14
+ from datasets import Dataset as HFDataset
15
+
16
+ from biorlhf.data.question_generator import generate_all_questions
17
+
18
+ # ── External data paths (configurable via env vars for HPC) ───────────────
19
+ BIOEVAL_DATA = Path(os.environ.get(
20
+ "BIOEVAL_DATA",
21
+ "/Users/jak4013/Dropbox/Bioinformatics/Claude/Evaluation_model/BioEval/data",
22
+ ))
23
+ SPACEOMICS_DATA = Path(os.environ.get(
24
+ "SPACEOMICS_DATA",
25
+ "/Users/jak4013/Dropbox/Bioinformatics/Claude/SpaceOmicsBench/v3/evaluation/llm",
26
+ ))
27
+
28
+
29
+ def load_bioeval_for_grpo() -> List[Dict]:
30
+ """Load BioEval tasks that have verifiable ground truth.
31
+
32
+ Selects:
33
+ - calibration tasks (30) → V4 training
34
+ - bioambiguity tasks (45) → V3 training
35
+ - Other verifiable tasks → V2 training
36
+ """
37
+ samples: List[Dict] = []
38
+ base_path = BIOEVAL_DATA / "bioeval_v060_base.jsonl"
39
+ if not base_path.exists():
40
+ return samples
41
+
42
+ with open(base_path) as f:
43
+ for line in f:
44
+ task = json.loads(line)
45
+ component = task.get("component", "")
46
+ prompt = task.get("prompt", "")
47
+ gt = task.get("ground_truth", "{}")
48
+
49
+ # Ensure ground_truth is a JSON string
50
+ gt_str = json.dumps(gt) if isinstance(gt, dict) else gt
51
+
52
+ if component == "calibration":
53
+ samples.append({
54
+ "prompt": prompt,
55
+ "ground_truth": gt_str,
56
+ "question_type": "calibration",
57
+ "applicable_verifiers": json.dumps(["V4"]),
58
+ "source": "bioeval",
59
+ "tissue": "general",
60
+ "difficulty": "medium",
61
+ })
62
+ elif component == "bioambiguity":
63
+ samples.append({
64
+ "prompt": prompt,
65
+ "ground_truth": gt_str,
66
+ "question_type": "context_dependent",
67
+ "applicable_verifiers": json.dumps(["V3", "V4"]),
68
+ "source": "bioeval",
69
+ "tissue": "general",
70
+ "difficulty": "hard",
71
+ })
72
+ elif component in ("causalbio", "designcheck", "adversarial"):
73
+ samples.append({
74
+ "prompt": prompt,
75
+ "ground_truth": gt_str,
76
+ "question_type": component,
77
+ "applicable_verifiers": json.dumps(["V2"]),
78
+ "source": "bioeval",
79
+ "tissue": "general",
80
+ "difficulty": "hard" if component == "adversarial" else "medium",
81
+ })
82
+
83
+ return samples
84
+
85
+
86
+ def load_spaceomics_for_grpo() -> List[Dict]:
87
+ """Load SpaceOmicsBench v3 questions with ground truth."""
88
+ samples: List[Dict] = []
89
+ qbank_path = SPACEOMICS_DATA / "question_bank_v3.json"
90
+ if not qbank_path.exists():
91
+ return samples
92
+
93
+ with open(qbank_path) as f:
94
+ qbank = json.load(f)
95
+
96
+ questions = qbank.get("questions", [])
97
+ for q in questions:
98
+ gt = {
99
+ "key_facts": q.get("ground_truth_key_facts", []),
100
+ "expected_reasoning": q.get("expected_reasoning", []),
101
+ }
102
+
103
+ verifiers = ["V2"]
104
+ if q.get("requires_uncertainty_calibration", False):
105
+ verifiers.append("V4")
106
+ gt["expected_confidence"] = "medium"
107
+
108
+ samples.append({
109
+ "prompt": q["question"],
110
+ "ground_truth": json.dumps(gt),
111
+ "question_type": q.get("category", "factual"),
112
+ "applicable_verifiers": json.dumps(verifiers),
113
+ "source": "spaceomics",
114
+ "tissue": "general",
115
+ "difficulty": q.get("difficulty", "medium"),
116
+ })
117
+
118
+ return samples
119
+
120
+
121
+ def build_grpo_dataset(
122
+ db: str = "hallmark",
123
+ seed: int = 42,
124
+ hold_out_tissues: Optional[List[str]] = None,
125
+ ) -> Tuple[HFDataset, HFDataset]:
126
+ """Build the full GRPO training dataset with train/eval split.
127
+
128
+ Args:
129
+ db: Pathway database to use for GeneLab questions.
130
+ seed: Random seed for splitting.
131
+ hold_out_tissues: If set, questions from these tissues go to eval.
132
+ Otherwise uses random 10% split.
133
+
134
+ Returns:
135
+ (train_dataset, eval_dataset) as HuggingFace Datasets.
136
+
137
+ Dataset columns (TRL-compatible):
138
+ - prompt: str (required by GRPOTrainer)
139
+ - ground_truth: str (JSON, forwarded to reward function)
140
+ - question_type: str (forwarded to reward function)
141
+ - applicable_verifiers: str (JSON list, forwarded to reward function)
142
+ - source: str ("genelab", "bioeval", "spaceomics")
143
+ - tissue: str (for LOMO splitting)
144
+ - difficulty: str ("easy", "medium", "hard")
145
+ """
146
+ all_samples: List[Dict] = []
147
+
148
+ # 1. GeneLab pathway questions
149
+ genelab_qs = generate_all_questions(db)
150
+ for q in genelab_qs:
151
+ all_samples.append({
152
+ "prompt": q.prompt,
153
+ "ground_truth": json.dumps(q.ground_truth),
154
+ "question_type": q.question_type,
155
+ "applicable_verifiers": json.dumps(q.applicable_verifiers),
156
+ "source": "genelab",
157
+ "tissue": q.tissue,
158
+ "difficulty": q.difficulty,
159
+ })
160
+
161
+ # 2. BioEval tasks
162
+ all_samples.extend(load_bioeval_for_grpo())
163
+
164
+ # 3. SpaceOmicsBench questions
165
+ all_samples.extend(load_spaceomics_for_grpo())
166
+
167
+ if not all_samples:
168
+ raise ValueError("No training samples generated. Check data paths.")
169
+
170
+ # Convert to HF Dataset
171
+ full_dataset = HFDataset.from_list(all_samples)
172
+
173
+ # Split strategy
174
+ if hold_out_tissues:
175
+ train_indices = []
176
+ eval_indices = []
177
+ for i, sample in enumerate(all_samples):
178
+ if sample["tissue"] in hold_out_tissues:
179
+ eval_indices.append(i)
180
+ else:
181
+ train_indices.append(i)
182
+ if not eval_indices:
183
+ # Fallback: random split if no matching tissues
184
+ split = full_dataset.train_test_split(test_size=0.1, seed=seed)
185
+ return split["train"], split["test"]
186
+ train_dataset = full_dataset.select(train_indices)
187
+ eval_dataset = full_dataset.select(eval_indices)
188
+ else:
189
+ split = full_dataset.train_test_split(test_size=0.1, seed=seed)
190
+ train_dataset = split["train"]
191
+ eval_dataset = split["test"]
192
+
193
+ return train_dataset, eval_dataset
194
+
195
+
196
+ def get_dataset_stats(dataset: HFDataset) -> Dict:
197
+ """Return summary statistics for a GRPO dataset."""
198
+ sources = {}
199
+ types = {}
200
+ tissues = {}
201
+ difficulties = {}
202
+
203
+ for sample in dataset:
204
+ src = sample["source"]
205
+ sources[src] = sources.get(src, 0) + 1
206
+ qt = sample["question_type"]
207
+ types[qt] = types.get(qt, 0) + 1
208
+ t = sample["tissue"]
209
+ tissues[t] = tissues.get(t, 0) + 1
210
+ d = sample["difficulty"]
211
+ difficulties[d] = difficulties.get(d, 0) + 1
212
+
213
+ return {
214
+ "total": len(dataset),
215
+ "by_source": sources,
216
+ "by_question_type": types,
217
+ "by_tissue": tissues,
218
+ "by_difficulty": difficulties,
219
+ }
src/biorlhf/data/question_generator.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pathway reasoning question generator for BioGRPO.
3
+
4
+ Generates verifiable QA pairs from GeneLab fGSEA pathway data.
5
+ Each question has structured ground truth for scoring by the verifier stack.
6
+ """
7
+
8
+ from typing import List, Dict, Set
9
+ from dataclasses import dataclass, field
10
+
11
+ from biorlhf.data.genelabloader import (
12
+ get_consensus_directions,
13
+ get_disagreeing_pathways,
14
+ get_pathway_directions,
15
+ load_nes_conservation,
16
+ TISSUE_MISSIONS,
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class GRPOQuestion:
22
+ """A question with verifiable ground truth for GRPO training."""
23
+ prompt: str
24
+ ground_truth: Dict
25
+ tissue: str
26
+ db: str
27
+ question_type: str # "direction", "comparison", "consistency", "uncertainty"
28
+ applicable_verifiers: List[str]
29
+ difficulty: str # "easy", "medium", "hard"
30
+ metadata: Dict = field(default_factory=dict)
31
+
32
+
33
+ def _clean_pathway_name(pathway: str) -> str:
34
+ """HALLMARK_OXIDATIVE_PHOSPHORYLATION → Oxidative Phosphorylation"""
35
+ for prefix in ("HALLMARK_", "KEGG_", "REACTOME_", "MITOCARTA_"):
36
+ pathway = pathway.replace(prefix, "")
37
+ return pathway.replace("_", " ").title()
38
+
39
+
40
+ # ── Question generators ────────────────────────────────────────────────────
41
+
42
+ def generate_direction_questions(
43
+ tissue: str,
44
+ db: str = "hallmark",
45
+ padj_threshold: float = 0.05,
46
+ ) -> List[GRPOQuestion]:
47
+ """Generate V1-targetable questions about pathway direction."""
48
+ consensus = get_consensus_directions(tissue, db, min_missions=2, padj_threshold=padj_threshold)
49
+ questions: List[GRPOQuestion] = []
50
+
51
+ for pathway, info in consensus.items():
52
+ pw = _clean_pathway_name(pathway)
53
+ n_agree = info["n_agree"]
54
+
55
+ # Type 1: Direct direction question (easy/medium)
56
+ questions.append(GRPOQuestion(
57
+ prompt=(
58
+ f"In mouse {tissue} tissue during spaceflight, is the "
59
+ f"{pw} pathway upregulated or downregulated based on "
60
+ f"gene set enrichment analysis? "
61
+ f"Provide your confidence level."
62
+ ),
63
+ ground_truth={
64
+ "pathway": pathway,
65
+ "direction": info["direction"],
66
+ "n_supporting_missions": n_agree,
67
+ "expected_confidence": "high" if n_agree >= 3 else "medium",
68
+ },
69
+ tissue=tissue,
70
+ db=db,
71
+ question_type="direction",
72
+ applicable_verifiers=["V1", "V4"],
73
+ difficulty="easy" if n_agree >= 3 else "medium",
74
+ ))
75
+
76
+ # Type 2: Mechanistic reasoning question (medium/hard)
77
+ direction_word = "activation" if info["direction"] == "UP" else "suppression"
78
+ questions.append(GRPOQuestion(
79
+ prompt=(
80
+ f"Explain the biological significance of {pw} pathway "
81
+ f"{direction_word} in mouse {tissue} under spaceflight conditions. "
82
+ f"What mechanisms might drive this change? "
83
+ f"State your confidence in the direction and magnitude."
84
+ ),
85
+ ground_truth={
86
+ "pathway": pathway,
87
+ "direction": info["direction"],
88
+ "n_supporting_missions": n_agree,
89
+ "requires_mechanism": True,
90
+ "expected_confidence": "medium",
91
+ },
92
+ tissue=tissue,
93
+ db=db,
94
+ question_type="direction",
95
+ applicable_verifiers=["V1", "V2", "V4"],
96
+ difficulty="medium" if n_agree >= 3 else "hard",
97
+ ))
98
+
99
+ return questions
100
+
101
+
102
+ def generate_comparison_questions(
103
+ db: str = "hallmark",
104
+ padj_threshold: float = 0.05,
105
+ ) -> List[GRPOQuestion]:
106
+ """Generate cross-tissue comparison questions (V1 + V3 targetable)."""
107
+ questions: List[GRPOQuestion] = []
108
+
109
+ # Collect consensus directions across tissues (exclude skin subsites for cleaner Qs)
110
+ tissue_dirs: Dict[str, Dict[str, Dict]] = {}
111
+ comparison_tissues = ["liver", "gastrocnemius", "kidney", "thymus", "eye"]
112
+ for tissue in comparison_tissues:
113
+ consensus = get_consensus_directions(tissue, db, min_missions=2, padj_threshold=padj_threshold)
114
+ if consensus:
115
+ tissue_dirs[tissue] = consensus
116
+
117
+ if len(tissue_dirs) < 2:
118
+ return questions
119
+
120
+ # Find pathways in 2+ tissues
121
+ all_pathways: Set[str] = set()
122
+ for dirs in tissue_dirs.values():
123
+ all_pathways.update(dirs.keys())
124
+
125
+ for pathway in sorted(all_pathways):
126
+ tissues_with = {
127
+ t: d[pathway]
128
+ for t, d in tissue_dirs.items()
129
+ if pathway in d
130
+ }
131
+ if len(tissues_with) < 2:
132
+ continue
133
+
134
+ pw = _clean_pathway_name(pathway)
135
+ tissue_list = sorted(tissues_with.keys())
136
+ directions_set = {info["direction"] for info in tissues_with.values()}
137
+ is_consistent = len(directions_set) == 1
138
+
139
+ questions.append(GRPOQuestion(
140
+ prompt=(
141
+ f"Compare the response of the {pw} pathway to spaceflight "
142
+ f"across {', '.join(tissue_list)} tissues in mice. "
143
+ f"Is the direction of change consistent or tissue-specific? "
144
+ f"Explain the biological basis for any differences."
145
+ ),
146
+ ground_truth={
147
+ "pathway": pathway,
148
+ "tissue_directions": {
149
+ t: info["direction"] for t, info in tissues_with.items()
150
+ },
151
+ "is_consistent": is_consistent,
152
+ "n_tissues": len(tissues_with),
153
+ },
154
+ tissue="multi",
155
+ db=db,
156
+ question_type="comparison",
157
+ applicable_verifiers=["V1", "V3", "V4"],
158
+ difficulty="hard",
159
+ ))
160
+
161
+ return questions
162
+
163
+
164
+ def generate_uncertainty_questions(
165
+ tissue: str,
166
+ db: str = "hallmark",
167
+ padj_threshold: float = 0.05,
168
+ ) -> List[GRPOQuestion]:
169
+ """Generate questions where missions disagree → model should express uncertainty."""
170
+ disagreeing = get_disagreeing_pathways(tissue, db, padj_threshold)
171
+ questions: List[GRPOQuestion] = []
172
+
173
+ for pathway, info in disagreeing.items():
174
+ pw = _clean_pathway_name(pathway)
175
+ questions.append(GRPOQuestion(
176
+ prompt=(
177
+ f"Is the {pw} pathway consistently activated or suppressed "
178
+ f"in mouse {tissue} across different spaceflight missions? "
179
+ f"How confident are you in the direction of change?"
180
+ ),
181
+ ground_truth={
182
+ "pathway": pathway,
183
+ "missions_up": info["missions_up"],
184
+ "missions_down": info["missions_down"],
185
+ "missions_ns": info["missions_ns"],
186
+ "correct_behavior": "context_dependent",
187
+ "expected_confidence": "low",
188
+ },
189
+ tissue=tissue,
190
+ db=db,
191
+ question_type="uncertainty",
192
+ applicable_verifiers=["V1", "V4"],
193
+ difficulty="hard",
194
+ ))
195
+
196
+ return questions
197
+
198
+
199
+ def generate_conservation_questions(
200
+ db: str = "hallmark",
201
+ ) -> List[GRPOQuestion]:
202
+ """Generate questions about NES conservation across missions."""
203
+ conservation = load_nes_conservation(db)
204
+ if not conservation:
205
+ return []
206
+
207
+ questions: List[GRPOQuestion] = []
208
+ data = conservation.get("data", conservation)
209
+
210
+ for tissue, info in data.items():
211
+ if not isinstance(info, dict):
212
+ continue
213
+ mean_r = info.get("nes_mean_r")
214
+ if mean_r is None:
215
+ continue
216
+
217
+ if mean_r > 0.5:
218
+ conservation_level = "highly conserved"
219
+ expected_conf = "high"
220
+ elif mean_r > 0.2:
221
+ conservation_level = "moderately conserved"
222
+ expected_conf = "medium"
223
+ else:
224
+ conservation_level = "poorly conserved"
225
+ expected_conf = "medium"
226
+
227
+ questions.append(GRPOQuestion(
228
+ prompt=(
229
+ f"How conserved are pathway-level responses to spaceflight "
230
+ f"across different missions in mouse {tissue}? "
231
+ f"Are the enrichment patterns reproducible?"
232
+ ),
233
+ ground_truth={
234
+ "tissue": tissue,
235
+ "nes_mean_r": mean_r,
236
+ "conservation_level": conservation_level,
237
+ "expected_confidence": expected_conf,
238
+ "key_facts": [
239
+ f"Mean pairwise NES correlation across missions is {mean_r:.3f}",
240
+ f"Pathway responses in {tissue} are {conservation_level}",
241
+ ],
242
+ },
243
+ tissue=tissue,
244
+ db=db,
245
+ question_type="direction",
246
+ applicable_verifiers=["V2", "V4"],
247
+ difficulty="medium",
248
+ ))
249
+
250
+ return questions
251
+
252
+
253
+ def generate_all_questions(db: str = "hallmark") -> List[GRPOQuestion]:
254
+ """Generate the full question set from GeneLab data."""
255
+ all_q: List[GRPOQuestion] = []
256
+
257
+ for tissue in TISSUE_MISSIONS:
258
+ all_q.extend(generate_direction_questions(tissue, db))
259
+ all_q.extend(generate_uncertainty_questions(tissue, db))
260
+
261
+ all_q.extend(generate_comparison_questions(db))
262
+ all_q.extend(generate_conservation_questions(db))
263
+
264
+ return all_q
src/biorlhf/evaluation/__init__.py CHANGED
@@ -1,8 +1,14 @@
1
  """Evaluation modules for BioRLHF."""
2
 
3
- from biorlhf.evaluation.evaluate import evaluate_model, compute_metrics
4
-
5
  __all__ = [
6
  "evaluate_model",
7
  "compute_metrics",
8
  ]
 
 
 
 
 
 
 
 
 
1
  """Evaluation modules for BioRLHF."""
2
 
 
 
3
  __all__ = [
4
  "evaluate_model",
5
  "compute_metrics",
6
  ]
7
+
8
+
9
+ def __getattr__(name):
10
+ """Lazy imports for torch-dependent modules."""
11
+ if name in ("evaluate_model", "compute_metrics"):
12
+ from biorlhf.evaluation.evaluate import evaluate_model, compute_metrics
13
+ return {"evaluate_model": evaluate_model, "compute_metrics": compute_metrics}[name]
14
+ raise AttributeError(f"module 'biorlhf.evaluation' has no attribute {name!r}")
src/biorlhf/evaluation/calibration.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Calibration evaluation metrics for BioGRPO.
3
+
4
+ Implements Expected Calibration Error (ECE), Brier score, overconfidence
5
+ rate, and reliability diagram data generation.
6
+ """
7
+
8
+ from typing import Dict, List, Tuple
9
+ from dataclasses import dataclass
10
+
11
+
12
+ @dataclass
13
+ class CalibrationMetrics:
14
+ """Aggregated calibration metrics."""
15
+ ece: float # Expected Calibration Error
16
+ mce: float # Maximum Calibration Error
17
+ brier_score: float
18
+ overconfidence_rate: float # P(wrong | confidence > threshold)
19
+ underconfidence_rate: float # P(correct | confidence < threshold)
20
+ mean_confidence: float
21
+ mean_accuracy: float
22
+ n_samples: int
23
+ reliability_bins: List[Dict] # For plotting reliability diagrams
24
+
25
+
26
+ def compute_ece(
27
+ confidences: List[float],
28
+ correctnesses: List[bool],
29
+ n_bins: int = 10,
30
+ ) -> Tuple[float, float, List[Dict]]:
31
+ """Compute Expected and Maximum Calibration Error.
32
+
33
+ Uses equal-width binning.
34
+
35
+ Args:
36
+ confidences: Model's stated confidence for each prediction (0-1).
37
+ correctnesses: Whether each prediction was correct.
38
+ n_bins: Number of calibration bins.
39
+
40
+ Returns:
41
+ (ECE, MCE, bin_data) where bin_data is list of dicts for plotting.
42
+ """
43
+ if not confidences:
44
+ return 0.0, 0.0, []
45
+
46
+ bin_width = 1.0 / n_bins
47
+ bins: List[Dict] = []
48
+
49
+ for i in range(n_bins):
50
+ bin_lower = i * bin_width
51
+ bin_upper = (i + 1) * bin_width
52
+
53
+ # Find samples in this bin
54
+ indices = [
55
+ j for j, c in enumerate(confidences)
56
+ if bin_lower <= c < bin_upper or (i == n_bins - 1 and c == 1.0)
57
+ ]
58
+
59
+ if not indices:
60
+ bins.append({
61
+ "bin_lower": bin_lower,
62
+ "bin_upper": bin_upper,
63
+ "mean_confidence": (bin_lower + bin_upper) / 2,
64
+ "mean_accuracy": 0.0,
65
+ "count": 0,
66
+ "calibration_error": 0.0,
67
+ })
68
+ continue
69
+
70
+ bin_confs = [confidences[j] for j in indices]
71
+ bin_accs = [float(correctnesses[j]) for j in indices]
72
+ mean_conf = sum(bin_confs) / len(bin_confs)
73
+ mean_acc = sum(bin_accs) / len(bin_accs)
74
+
75
+ bins.append({
76
+ "bin_lower": bin_lower,
77
+ "bin_upper": bin_upper,
78
+ "mean_confidence": mean_conf,
79
+ "mean_accuracy": mean_acc,
80
+ "count": len(indices),
81
+ "calibration_error": abs(mean_acc - mean_conf),
82
+ })
83
+
84
+ # ECE: weighted average of calibration errors
85
+ total_samples = len(confidences)
86
+ ece = sum(
87
+ b["count"] / total_samples * b["calibration_error"]
88
+ for b in bins if b["count"] > 0
89
+ )
90
+
91
+ # MCE: maximum calibration error across non-empty bins
92
+ non_empty_errors = [b["calibration_error"] for b in bins if b["count"] > 0]
93
+ mce = max(non_empty_errors) if non_empty_errors else 0.0
94
+
95
+ return ece, mce, bins
96
+
97
+
98
+ def compute_brier_score(
99
+ confidences: List[float],
100
+ correctnesses: List[bool],
101
+ ) -> float:
102
+ """Compute Brier score: mean squared error between confidence and outcome.
103
+
104
+ Lower is better. Range [0, 1].
105
+ """
106
+ if not confidences:
107
+ return 0.0
108
+ n = len(confidences)
109
+ return sum(
110
+ (c - float(o)) ** 2 for c, o in zip(confidences, correctnesses)
111
+ ) / n
112
+
113
+
114
+ def compute_overconfidence_rate(
115
+ confidences: List[float],
116
+ correctnesses: List[bool],
117
+ threshold: float = 0.8,
118
+ ) -> float:
119
+ """P(wrong | confidence > threshold).
120
+
121
+ High overconfidence rate indicates the model is unreliably confident.
122
+ """
123
+ high_conf = [
124
+ (c, o) for c, o in zip(confidences, correctnesses) if c > threshold
125
+ ]
126
+ if not high_conf:
127
+ return 0.0
128
+ wrong = sum(1 for _, o in high_conf if not o)
129
+ return wrong / len(high_conf)
130
+
131
+
132
+ def compute_underconfidence_rate(
133
+ confidences: List[float],
134
+ correctnesses: List[bool],
135
+ threshold: float = 0.3,
136
+ ) -> float:
137
+ """P(correct | confidence < threshold).
138
+
139
+ High underconfidence rate means the model knows more than it admits.
140
+ """
141
+ low_conf = [
142
+ (c, o) for c, o in zip(confidences, correctnesses) if c < threshold
143
+ ]
144
+ if not low_conf:
145
+ return 0.0
146
+ correct = sum(1 for _, o in low_conf if o)
147
+ return correct / len(low_conf)
148
+
149
+
150
+ def compute_calibration_metrics(
151
+ confidences: List[float],
152
+ correctnesses: List[bool],
153
+ n_bins: int = 10,
154
+ overconf_threshold: float = 0.8,
155
+ underconf_threshold: float = 0.3,
156
+ ) -> CalibrationMetrics:
157
+ """Compute full calibration metrics suite."""
158
+ if not confidences:
159
+ return CalibrationMetrics(
160
+ ece=0.0, mce=0.0, brier_score=0.0,
161
+ overconfidence_rate=0.0, underconfidence_rate=0.0,
162
+ mean_confidence=0.0, mean_accuracy=0.0,
163
+ n_samples=0, reliability_bins=[],
164
+ )
165
+
166
+ ece, mce, bins = compute_ece(confidences, correctnesses, n_bins)
167
+ brier = compute_brier_score(confidences, correctnesses)
168
+ overconf = compute_overconfidence_rate(confidences, correctnesses, overconf_threshold)
169
+ underconf = compute_underconfidence_rate(confidences, correctnesses, underconf_threshold)
170
+
171
+ mean_conf = sum(confidences) / len(confidences)
172
+ mean_acc = sum(float(c) for c in correctnesses) / len(correctnesses)
173
+
174
+ return CalibrationMetrics(
175
+ ece=ece,
176
+ mce=mce,
177
+ brier_score=brier,
178
+ overconfidence_rate=overconf,
179
+ underconfidence_rate=underconf,
180
+ mean_confidence=mean_conf,
181
+ mean_accuracy=mean_acc,
182
+ n_samples=len(confidences),
183
+ reliability_bins=bins,
184
+ )
src/biorlhf/training/__init__.py CHANGED
@@ -1,11 +1,24 @@
1
  """Training modules for BioRLHF."""
2
 
3
- from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
4
- from biorlhf.training.dpo import DPOTrainingConfig, run_dpo_training
5
-
6
  __all__ = [
7
  "SFTTrainingConfig",
8
  "run_sft_training",
9
  "DPOTrainingConfig",
10
  "run_dpo_training",
 
 
11
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Training modules for BioRLHF."""
2
 
 
 
 
3
  __all__ = [
4
  "SFTTrainingConfig",
5
  "run_sft_training",
6
  "DPOTrainingConfig",
7
  "run_dpo_training",
8
+ "BioGRPOConfig",
9
+ "run_grpo_training",
10
  ]
11
+
12
+
13
+ def __getattr__(name):
14
+ """Lazy imports for torch-dependent modules."""
15
+ if name in ("SFTTrainingConfig", "run_sft_training"):
16
+ from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
17
+ return {"SFTTrainingConfig": SFTTrainingConfig, "run_sft_training": run_sft_training}[name]
18
+ elif name in ("DPOTrainingConfig", "run_dpo_training"):
19
+ from biorlhf.training.dpo import DPOTrainingConfig, run_dpo_training
20
+ return {"DPOTrainingConfig": DPOTrainingConfig, "run_dpo_training": run_dpo_training}[name]
21
+ elif name in ("BioGRPOConfig", "run_grpo_training"):
22
+ from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training
23
+ return {"BioGRPOConfig": BioGRPOConfig, "run_grpo_training": run_grpo_training}[name]
24
+ raise AttributeError(f"module 'biorlhf.training' has no attribute {name!r}")
src/biorlhf/training/grpo.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Group Relative Policy Optimization (GRPO) training for BioGRPO.
3
+
4
+ Uses TRL's GRPOTrainer with composable biological verifiers as reward functions.
5
+ Supports configurable G values, verifier weights, and LoRA parameters.
6
+ """
7
+
8
+ import os
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional, List, Dict
11
+
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
14
+ from peft import LoraConfig, PeftModel
15
+ from trl import GRPOTrainer, GRPOConfig
16
+
17
+ from biorlhf.verifiers.composer import make_grpo_reward_function
18
+ from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
19
+
20
+
21
+ @dataclass
22
+ class BioGRPOConfig:
23
+ """Configuration for BioGRPO training."""
24
+
25
+ # Model settings
26
+ model_name: str = "mistralai/Mistral-7B-v0.3"
27
+ sft_model_path: Optional[str] = None
28
+ output_dir: str = "./biogrpo_model"
29
+
30
+ # GRPO hyperparameters
31
+ num_generations: int = 8
32
+ beta: float = 0.04
33
+ num_iterations: int = 1
34
+ scale_rewards: str = "group"
35
+ loss_type: str = "grpo"
36
+
37
+ # Training hyperparameters
38
+ num_epochs: int = 1
39
+ batch_size: int = 2
40
+ gradient_accumulation_steps: int = 8
41
+ learning_rate: float = 1e-6
42
+ max_completion_length: int = 1024
43
+ max_prompt_length: int = 512
44
+ warmup_ratio: float = 0.1
45
+
46
+ # LoRA settings
47
+ lora_r: int = 32
48
+ lora_alpha: int = 64
49
+ lora_dropout: float = 0.05
50
+
51
+ # Verifier configuration
52
+ verifier_weights: Optional[Dict[str, float]] = None
53
+ active_verifiers: Optional[List[str]] = None
54
+
55
+ # Data
56
+ pathway_db: str = "hallmark"
57
+ hold_out_tissues: Optional[List[str]] = None
58
+ seed: int = 42
59
+
60
+ # Quantization
61
+ use_4bit: bool = True
62
+
63
+ # Logging
64
+ wandb_project: str = "biogrpo"
65
+ wandb_run_name: str = "grpo_v1"
66
+ use_wandb: bool = True
67
+ logging_steps: int = 5
68
+ save_steps: int = 50
69
+ eval_steps: int = 50
70
+ save_total_limit: int = 3
71
+ log_completions: bool = True
72
+
73
+ # Memory optimization
74
+ use_vllm: bool = False
75
+ gradient_checkpointing: bool = True
76
+ bf16: bool = True
77
+
78
+
79
+ def run_grpo_training(config: Optional[BioGRPOConfig] = None) -> str:
80
+ """Run BioGRPO training.
81
+
82
+ Pipeline:
83
+ 1. Build dataset from GeneLab + BioEval + SpaceOmicsBench
84
+ 2. Create composed reward function from verifier stack
85
+ 3. Load tokenizer and configure GRPOTrainer with LoRA
86
+ 4. Train and save model
87
+
88
+ Args:
89
+ config: Training configuration. Uses defaults if None.
90
+
91
+ Returns:
92
+ Path to the saved model directory.
93
+ """
94
+ if config is None:
95
+ config = BioGRPOConfig()
96
+
97
+ print("=" * 60)
98
+ print("BioGRPO Training")
99
+ print("=" * 60)
100
+ print(f" Model: {config.model_name}")
101
+ print(f" SFT checkpoint: {config.sft_model_path or 'None (from base)'}")
102
+ print(f" G (generations): {config.num_generations}")
103
+ print(f" Beta (KL): {config.beta}")
104
+ print(f" Loss type: {config.loss_type}")
105
+ print(f" Active verifiers:{config.active_verifiers or 'all (V1-V4)'}")
106
+ print(f" Verifier weights:{config.verifier_weights or 'default'}")
107
+ print(f" LoRA r/alpha: {config.lora_r}/{config.lora_alpha}")
108
+ print(f" Learning rate: {config.learning_rate}")
109
+ print(f" QLoRA 4-bit: {config.use_4bit}")
110
+ print(f" Output: {config.output_dir}")
111
+ print("=" * 60)
112
+
113
+ # Initialize wandb
114
+ if config.use_wandb:
115
+ try:
116
+ import wandb
117
+ wandb.init(
118
+ project=config.wandb_project,
119
+ name=config.wandb_run_name,
120
+ config={k: v for k, v in vars(config).items() if not k.startswith("_")},
121
+ )
122
+ except ImportError:
123
+ print("Warning: wandb not installed, disabling logging")
124
+ config.use_wandb = False
125
+
126
+ # 1. Build dataset
127
+ print("\n[1/5] Building GRPO dataset...")
128
+ train_dataset, eval_dataset = build_grpo_dataset(
129
+ db=config.pathway_db,
130
+ seed=config.seed,
131
+ hold_out_tissues=config.hold_out_tissues,
132
+ )
133
+ train_stats = get_dataset_stats(train_dataset)
134
+ eval_stats = get_dataset_stats(eval_dataset)
135
+ print(f" Train: {train_stats['total']} samples")
136
+ print(f" By source: {train_stats['by_source']}")
137
+ print(f" By type: {train_stats['by_question_type']}")
138
+ print(f" Eval: {eval_stats['total']} samples")
139
+
140
+ # 2. Create reward function
141
+ print("\n[2/5] Initializing verifier stack...")
142
+ reward_func = make_grpo_reward_function(
143
+ weights=config.verifier_weights,
144
+ active_verifiers=config.active_verifiers,
145
+ )
146
+ print(f" Active: {config.active_verifiers or ['V1', 'V2', 'V3', 'V4']}")
147
+
148
+ # 3. Load tokenizer (always from base model; adapter dirs lack config.json)
149
+ print("\n[3/5] Loading tokenizer...")
150
+ tokenizer_source = config.model_name
151
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, trust_remote_code=True)
152
+ tokenizer.pad_token = tokenizer.eos_token
153
+ tokenizer.padding_side = "left"
154
+ print(f" Tokenizer: {tokenizer.__class__.__name__}, vocab={tokenizer.vocab_size}")
155
+
156
+ # 4. Configure LoRA
157
+ peft_config = LoraConfig(
158
+ r=config.lora_r,
159
+ lora_alpha=config.lora_alpha,
160
+ target_modules=[
161
+ "q_proj", "k_proj", "v_proj", "o_proj",
162
+ "gate_proj", "up_proj", "down_proj",
163
+ ],
164
+ lora_dropout=config.lora_dropout,
165
+ bias="none",
166
+ task_type="CAUSAL_LM",
167
+ )
168
+
169
+ # 5. Load model (merge SFT adapter if present)
170
+ print("\n[4/5] Loading model...")
171
+
172
+ # QLoRA quantization config
173
+ bnb_config = None
174
+ if config.use_4bit:
175
+ bnb_config = BitsAndBytesConfig(
176
+ load_in_4bit=True,
177
+ bnb_4bit_quant_type="nf4",
178
+ bnb_4bit_compute_dtype=torch.bfloat16,
179
+ bnb_4bit_use_double_quant=True,
180
+ )
181
+
182
+ # Check if sft_model_path is a LoRA adapter or a full model
183
+ sft_is_adapter = (
184
+ config.sft_model_path
185
+ and os.path.isdir(config.sft_model_path)
186
+ and os.path.exists(os.path.join(config.sft_model_path, "adapter_config.json"))
187
+ )
188
+
189
+ if sft_is_adapter:
190
+ # Load base model, merge SFT adapter, then apply fresh LoRA for GRPO
191
+ print(f" Loading base model: {config.model_name}")
192
+ base_model = AutoModelForCausalLM.from_pretrained(
193
+ config.model_name,
194
+ quantization_config=bnb_config,
195
+ torch_dtype=torch.bfloat16,
196
+ trust_remote_code=True,
197
+ )
198
+ print(f" Loading SFT LoRA adapter: {config.sft_model_path}")
199
+ model = PeftModel.from_pretrained(base_model, config.sft_model_path)
200
+ print(" Merging SFT adapter into base model...")
201
+ model = model.merge_and_unload()
202
+ print(" SFT adapter merged successfully")
203
+ else:
204
+ # sft_model_path is a full model or use base model
205
+ model_path = config.sft_model_path or config.model_name
206
+ print(f" Loading model: {model_path}")
207
+ model = AutoModelForCausalLM.from_pretrained(
208
+ model_path,
209
+ quantization_config=bnb_config,
210
+ torch_dtype=torch.bfloat16,
211
+ trust_remote_code=True,
212
+ )
213
+
214
+ # 6. Configure GRPOTrainer
215
+ print("\n[5/6] Configuring GRPOTrainer...")
216
+
217
+ grpo_config = GRPOConfig(
218
+ output_dir=config.output_dir,
219
+ num_train_epochs=config.num_epochs,
220
+ per_device_train_batch_size=config.batch_size,
221
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
222
+ learning_rate=config.learning_rate,
223
+ warmup_ratio=config.warmup_ratio,
224
+ lr_scheduler_type="cosine",
225
+
226
+ # GRPO-specific
227
+ num_generations=config.num_generations,
228
+ beta=config.beta,
229
+ loss_type=config.loss_type,
230
+ max_completion_length=config.max_completion_length,
231
+ max_prompt_length=config.max_prompt_length,
232
+ num_iterations=config.num_iterations,
233
+ scale_rewards=config.scale_rewards,
234
+
235
+ # Memory/compute
236
+ gradient_checkpointing=config.gradient_checkpointing,
237
+ bf16=config.bf16,
238
+ use_vllm=config.use_vllm,
239
+
240
+ # Logging
241
+ logging_steps=config.logging_steps,
242
+ save_steps=config.save_steps,
243
+ save_total_limit=config.save_total_limit,
244
+ report_to="wandb" if config.use_wandb else "none",
245
+ run_name=config.wandb_run_name,
246
+
247
+ # Evaluation
248
+ eval_strategy="steps",
249
+ eval_steps=config.eval_steps,
250
+ log_completions=config.log_completions,
251
+ )
252
+
253
+ trainer = GRPOTrainer(
254
+ model=model,
255
+ args=grpo_config,
256
+ reward_funcs=reward_func,
257
+ train_dataset=train_dataset,
258
+ eval_dataset=eval_dataset,
259
+ peft_config=peft_config,
260
+ processing_class=tokenizer,
261
+ )
262
+
263
+ # Train
264
+ print("\n[6/6] Starting GRPO training...")
265
+ print("=" * 60)
266
+
267
+ trainer.train()
268
+
269
+ # Save
270
+ print(f"\nSaving model to {config.output_dir}")
271
+ trainer.save_model(config.output_dir)
272
+
273
+ if config.use_wandb:
274
+ try:
275
+ import wandb
276
+ wandb.finish()
277
+ except ImportError:
278
+ pass
279
+
280
+ print("\n" + "=" * 60)
281
+ print("BioGRPO Training complete!")
282
+ print("=" * 60)
283
+
284
+ return config.output_dir
src/biorlhf/verifiers/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Composable biological verifiers for BioGRPO."""
2
+
3
+ from biorlhf.verifiers.base import BaseVerifier, VerifierResult
4
+ from biorlhf.verifiers.pathway import PathwayDirectionVerifier
5
+ from biorlhf.verifiers.factual import BiologicalFactVerifier
6
+ from biorlhf.verifiers.consistency import CrossContextConsistencyVerifier
7
+ from biorlhf.verifiers.uncertainty import UncertaintyVerifier
8
+ from biorlhf.verifiers.composer import (
9
+ VerifierComposer,
10
+ make_grpo_reward_function,
11
+ make_single_verifier_reward,
12
+ )
13
+
14
+ __all__ = [
15
+ "BaseVerifier",
16
+ "VerifierResult",
17
+ "PathwayDirectionVerifier",
18
+ "BiologicalFactVerifier",
19
+ "CrossContextConsistencyVerifier",
20
+ "UncertaintyVerifier",
21
+ "VerifierComposer",
22
+ "make_grpo_reward_function",
23
+ "make_single_verifier_reward",
24
+ ]
src/biorlhf/verifiers/base.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract base class for biological verifiers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List
5
+ from dataclasses import dataclass, field
6
+
7
+
8
+ @dataclass
9
+ class VerifierResult:
10
+ """Result from a single verifier."""
11
+ score: float # 0.0 to 1.0
12
+ verifier_name: str
13
+ details: Dict = field(default_factory=dict)
14
+ applicable: bool = True # False if verifier doesn't apply
15
+
16
+
17
+ class BaseVerifier(ABC):
18
+ """Abstract base class for biological verifiers.
19
+
20
+ Each verifier scores a model completion against structured ground truth
21
+ on a specific dimension (pathway direction, factual accuracy, etc.).
22
+ """
23
+
24
+ name: str = "base"
25
+
26
+ @abstractmethod
27
+ def score(
28
+ self,
29
+ prompt: str,
30
+ completion: str,
31
+ ground_truth: Dict,
32
+ question_type: str,
33
+ ) -> VerifierResult:
34
+ """Score a single completion against ground truth.
35
+
36
+ Args:
37
+ prompt: The original question.
38
+ completion: The model's generated response.
39
+ ground_truth: Parsed ground truth dictionary.
40
+ question_type: Type of question for routing logic.
41
+
42
+ Returns:
43
+ VerifierResult with score in [0, 1].
44
+ """
45
+ raise NotImplementedError
46
+
47
+ def is_applicable(self, applicable_verifiers: List[str]) -> bool:
48
+ """Check if this verifier should score this question."""
49
+ return self.name in applicable_verifiers
src/biorlhf/verifiers/composer.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Verifier Composer: Weighted composition of V1-V4 into a TRL-compatible reward function.
3
+
4
+ This is THE critical integration point between the verifier stack and
5
+ TRL's GRPOTrainer. The reward function signature must match TRL's expected
6
+ interface exactly.
7
+ """
8
+
9
+ import json
10
+ from typing import Callable, Dict, List, Optional
11
+ from dataclasses import dataclass, field
12
+
13
+ from biorlhf.verifiers.base import BaseVerifier, VerifierResult
14
+ from biorlhf.verifiers.pathway import PathwayDirectionVerifier
15
+ from biorlhf.verifiers.factual import BiologicalFactVerifier
16
+ from biorlhf.verifiers.consistency import CrossContextConsistencyVerifier
17
+ from biorlhf.verifiers.uncertainty import UncertaintyVerifier
18
+
19
+
20
+ @dataclass
21
+ class ComposedReward:
22
+ """Result of composed reward computation."""
23
+ total_reward: float
24
+ verifier_scores: Dict[str, float]
25
+ verifier_details: Dict[str, Dict]
26
+ weights_used: Dict[str, float]
27
+
28
+
29
+ # Default weights — factual signals dominate
30
+ DEFAULT_WEIGHTS = {
31
+ "V1": 0.35, # Pathway direction (hard signal)
32
+ "V2": 0.30, # Biological facts (soft signal)
33
+ "V3": 0.15, # Cross-context consistency
34
+ "V4": 0.20, # Uncertainty appropriateness
35
+ }
36
+
37
+
38
+ class VerifierComposer:
39
+ """Composes V1-V4 verifiers into a unified reward signal."""
40
+
41
+ def __init__(
42
+ self,
43
+ weights: Optional[Dict[str, float]] = None,
44
+ active_verifiers: Optional[List[str]] = None,
45
+ ):
46
+ all_verifiers: Dict[str, BaseVerifier] = {
47
+ "V1": PathwayDirectionVerifier(),
48
+ "V2": BiologicalFactVerifier(),
49
+ "V3": CrossContextConsistencyVerifier(),
50
+ "V4": UncertaintyVerifier(),
51
+ }
52
+
53
+ self.weights = dict(weights or DEFAULT_WEIGHTS)
54
+
55
+ # Filter to active verifiers if specified
56
+ if active_verifiers:
57
+ self.verifiers = {
58
+ k: v for k, v in all_verifiers.items() if k in active_verifiers
59
+ }
60
+ # Renormalize weights
61
+ total_w = sum(self.weights.get(k, 0) for k in self.verifiers)
62
+ if total_w > 0:
63
+ self.weights = {
64
+ k: self.weights.get(k, 0) / total_w for k in self.verifiers
65
+ }
66
+ else:
67
+ self.verifiers = all_verifiers
68
+
69
+ def compute_reward(
70
+ self,
71
+ prompt: str,
72
+ completion: str,
73
+ ground_truth: str,
74
+ question_type: str,
75
+ applicable_verifiers: str,
76
+ ) -> ComposedReward:
77
+ """Compute composed reward from all applicable verifiers.
78
+
79
+ Args:
80
+ prompt: The question text.
81
+ completion: Model's generated response.
82
+ ground_truth: JSON string of ground truth.
83
+ question_type: Question type for routing.
84
+ applicable_verifiers: JSON list of verifier names.
85
+ """
86
+ gt = json.loads(ground_truth) if isinstance(ground_truth, str) else ground_truth
87
+ applicable = (
88
+ json.loads(applicable_verifiers)
89
+ if isinstance(applicable_verifiers, str)
90
+ else applicable_verifiers
91
+ )
92
+
93
+ scores: Dict[str, float] = {}
94
+ details: Dict[str, Dict] = {}
95
+ weights_used: Dict[str, float] = {}
96
+
97
+ for vname, verifier in self.verifiers.items():
98
+ if not verifier.is_applicable(applicable):
99
+ continue
100
+
101
+ result = verifier.score(prompt, completion, gt, question_type)
102
+
103
+ if not result.applicable:
104
+ continue
105
+
106
+ scores[vname] = result.score
107
+ details[vname] = result.details
108
+ weights_used[vname] = self.weights.get(vname, 0)
109
+
110
+ # Compute weighted sum with renormalization
111
+ if not weights_used:
112
+ return ComposedReward(
113
+ total_reward=0.0,
114
+ verifier_scores=scores,
115
+ verifier_details=details,
116
+ weights_used=weights_used,
117
+ )
118
+
119
+ w_total = sum(weights_used.values())
120
+ if w_total > 0:
121
+ normalized = {k: v / w_total for k, v in weights_used.items()}
122
+ else:
123
+ normalized = weights_used
124
+
125
+ total = sum(scores[k] * normalized.get(k, 0) for k in scores)
126
+
127
+ return ComposedReward(
128
+ total_reward=total,
129
+ verifier_scores=scores,
130
+ verifier_details=details,
131
+ weights_used=normalized,
132
+ )
133
+
134
+
135
+ def make_grpo_reward_function(
136
+ weights: Optional[Dict[str, float]] = None,
137
+ active_verifiers: Optional[List[str]] = None,
138
+ ) -> Callable:
139
+ """Create a TRL-compatible reward function from the verifier composer.
140
+
141
+ TRL's GRPOTrainer calls reward functions with signature:
142
+ reward_func(completions, **kwargs) -> list[float]
143
+
144
+ where kwargs include all dataset columns except "prompt".
145
+ The completions are list of list of dicts in chat format, or list of strings.
146
+
147
+ Note: TRL passes prompts separately. Dataset columns (ground_truth,
148
+ question_type, applicable_verifiers, etc.) are forwarded as kwargs.
149
+ """
150
+ composer = VerifierComposer(weights=weights, active_verifiers=active_verifiers)
151
+
152
+ def reward_func(
153
+ completions: List,
154
+ ground_truth: Optional[List[str]] = None,
155
+ question_type: Optional[List[str]] = None,
156
+ applicable_verifiers: Optional[List[str]] = None,
157
+ **kwargs,
158
+ ) -> List[float]:
159
+ """TRL-compatible reward function using composed biological verifiers.
160
+
161
+ Args:
162
+ completions: List of model completions (strings or chat messages).
163
+ ground_truth: List of JSON ground truth strings (from dataset).
164
+ question_type: List of question type strings (from dataset).
165
+ applicable_verifiers: List of JSON lists of verifier names.
166
+
167
+ Returns:
168
+ List of float rewards, one per completion.
169
+ """
170
+ rewards: List[float] = []
171
+ n = len(completions)
172
+
173
+ # Handle missing kwargs gracefully
174
+ if ground_truth is None:
175
+ ground_truth = ["{}"] * n
176
+ if question_type is None:
177
+ question_type = ["unknown"] * n
178
+ if applicable_verifiers is None:
179
+ applicable_verifiers = [json.dumps(["V1", "V2", "V3", "V4"])] * n
180
+
181
+ # Extract prompts if available in kwargs
182
+ prompts = kwargs.get("prompts", kwargs.get("prompt", [""] * n))
183
+ if isinstance(prompts, str):
184
+ prompts = [prompts] * n
185
+
186
+ for i in range(n):
187
+ # Extract completion text
188
+ completion_text = _extract_text(completions[i])
189
+ prompt_text = _extract_text(prompts[i]) if i < len(prompts) else ""
190
+
191
+ result = composer.compute_reward(
192
+ prompt=prompt_text,
193
+ completion=completion_text,
194
+ ground_truth=ground_truth[i],
195
+ question_type=question_type[i],
196
+ applicable_verifiers=applicable_verifiers[i],
197
+ )
198
+ rewards.append(result.total_reward)
199
+
200
+ return rewards
201
+
202
+ return reward_func
203
+
204
+
205
+ def make_single_verifier_reward(verifier_name: str) -> Callable:
206
+ """Create a reward function using only one verifier (for ablation)."""
207
+ return make_grpo_reward_function(active_verifiers=[verifier_name])
208
+
209
+
210
+ def _extract_text(item) -> str:
211
+ """Extract plain text from various completion formats.
212
+
213
+ TRL may pass completions as:
214
+ - str: plain text
215
+ - list[dict]: chat messages [{"role": "assistant", "content": "..."}]
216
+ """
217
+ if isinstance(item, str):
218
+ return item
219
+ elif isinstance(item, list):
220
+ # Chat format
221
+ texts = []
222
+ for msg in item:
223
+ if isinstance(msg, dict) and "content" in msg:
224
+ texts.append(msg["content"])
225
+ return " ".join(texts)
226
+ else:
227
+ return str(item)
src/biorlhf/verifiers/consistency.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V3: Cross-Context Consistency Verifier.
3
+
4
+ Scores whether the model appropriately distinguishes or generalizes
5
+ across biological contexts (tissues, species, doses, timepoints).
6
+
7
+ For comparison questions: checks tissue coverage + consistency assessment.
8
+ For context-dependent questions: checks nuance and hedging.
9
+ For BioAmbiguity tasks: checks context awareness using BioEval scoring logic.
10
+ """
11
+
12
+ import json
13
+ import re
14
+ from typing import Dict, List
15
+
16
+ from biorlhf.verifiers.base import BaseVerifier, VerifierResult
17
+
18
+
19
+ # ── Indicator patterns ─────────────────────────────────────────────────────
20
+ CONSISTENCY_TERMS = [
21
+ "consistent", "conserved", "similar across", "same direction",
22
+ "reproducible", "concordant", "shared", "common response",
23
+ "universal", "uniform",
24
+ ]
25
+
26
+ SPECIFICITY_TERMS = [
27
+ "tissue-specific", "differs", "different", "opposite", "varies",
28
+ "divergent", "heterogeneous", "discordant", "unique to",
29
+ "distinct", "context-dependent", "tissue-dependent",
30
+ ]
31
+
32
+ NUANCE_INDICATORS = [
33
+ "depends", "context", "varies", "mission-specific",
34
+ "not consistent", "differs", "some missions", "mixed",
35
+ "heterogeneous", "variable", "inconsistent",
36
+ ]
37
+
38
+ HEDGING_INDICATORS = [
39
+ "uncertain", "unclear", "difficult to generalize",
40
+ "not enough evidence", "conflicting", "limited data",
41
+ "preliminary", "tentative", "cannot be determined",
42
+ "more research", "caution",
43
+ ]
44
+
45
+
46
+ class CrossContextConsistencyVerifier(BaseVerifier):
47
+ """V3: Scores context-appropriate reasoning."""
48
+
49
+ name = "V3"
50
+
51
+ def score(
52
+ self,
53
+ prompt: str,
54
+ completion: str,
55
+ ground_truth: Dict,
56
+ question_type: str,
57
+ ) -> VerifierResult:
58
+ gt = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth)
59
+
60
+ if question_type == "comparison":
61
+ return self._score_comparison(completion, gt)
62
+ elif question_type in ("context_dependent", "uncertainty"):
63
+ return self._score_context_dependent(completion, gt)
64
+ elif "contexts" in gt:
65
+ # BioEval BioAmbiguity format
66
+ return self._score_bioambiguity(completion, gt)
67
+ elif "tissue_directions" in gt:
68
+ return self._score_comparison(completion, gt)
69
+ else:
70
+ return VerifierResult(
71
+ score=0.5,
72
+ verifier_name=self.name,
73
+ details={"reason": "not_applicable"},
74
+ applicable=False,
75
+ )
76
+
77
+ def _score_comparison(self, completion: str, gt: Dict) -> VerifierResult:
78
+ """Score cross-tissue comparison questions."""
79
+ tissue_directions = gt.get("tissue_directions", {})
80
+ is_consistent = gt.get("is_consistent", False)
81
+ comp_lower = completion.lower()
82
+
83
+ # Check tissue coverage
84
+ tissues_mentioned = sum(
85
+ 1 for tissue in tissue_directions if tissue.lower() in comp_lower
86
+ )
87
+ n_tissues = len(tissue_directions) if tissue_directions else 1
88
+ tissue_coverage = tissues_mentioned / n_tissues
89
+
90
+ # Check consistency/specificity assessment
91
+ claims_consistent = any(t in comp_lower for t in CONSISTENCY_TERMS)
92
+ claims_specific = any(t in comp_lower for t in SPECIFICITY_TERMS)
93
+
94
+ consistency_correct = False
95
+ if is_consistent:
96
+ consistency_correct = claims_consistent
97
+ else:
98
+ consistency_correct = claims_specific
99
+
100
+ score = 0.5 * tissue_coverage + 0.5 * (1.0 if consistency_correct else 0.0)
101
+
102
+ return VerifierResult(
103
+ score=score,
104
+ verifier_name=self.name,
105
+ details={
106
+ "tissues_mentioned": tissues_mentioned,
107
+ "total_tissues": n_tissues,
108
+ "is_consistent_gt": is_consistent,
109
+ "claims_consistent": claims_consistent,
110
+ "claims_specific": claims_specific,
111
+ "consistency_correct": consistency_correct,
112
+ },
113
+ )
114
+
115
+ def _score_context_dependent(self, completion: str, gt: Dict) -> VerifierResult:
116
+ """Score questions where answer should acknowledge context-dependence."""
117
+ comp_lower = completion.lower()
118
+
119
+ nuance_hits = sum(1 for t in NUANCE_INDICATORS if t in comp_lower)
120
+ hedging_hits = sum(1 for t in HEDGING_INDICATORS if t in comp_lower)
121
+
122
+ # Scale: having 2-3 indicators is ideal
123
+ nuance_score = min(nuance_hits / 2.0, 1.0)
124
+ hedging_score = min(hedging_hits / 2.0, 1.0)
125
+
126
+ score = 0.6 * nuance_score + 0.4 * hedging_score
127
+
128
+ return VerifierResult(
129
+ score=min(score, 1.0),
130
+ verifier_name=self.name,
131
+ details={
132
+ "nuance_hits": nuance_hits,
133
+ "hedging_hits": hedging_hits,
134
+ "nuance_score": nuance_score,
135
+ "hedging_score": hedging_score,
136
+ },
137
+ )
138
+
139
+ def _score_bioambiguity(self, completion: str, gt: Dict) -> VerifierResult:
140
+ """Score BioEval BioAmbiguity tasks.
141
+
142
+ GT format:
143
+ {"contexts": {context_name: {"key_terms": [...], "role": "..."}},
144
+ "distinction_key": "..."}
145
+ """
146
+ contexts = gt.get("contexts", {})
147
+ distinction_key = gt.get("distinction_key", "")
148
+ comp_lower = completion.lower()
149
+
150
+ if not contexts:
151
+ return VerifierResult(
152
+ score=0.5, verifier_name=self.name,
153
+ details={"reason": "no_contexts"}, applicable=False,
154
+ )
155
+
156
+ # Context awareness: % of key terms found across all contexts
157
+ total_terms = 0
158
+ found_terms = 0
159
+ context_scores = {}
160
+
161
+ for ctx_name, ctx_info in contexts.items():
162
+ key_terms = ctx_info.get("key_terms", [])
163
+ if not key_terms:
164
+ continue
165
+ hits = sum(1 for t in key_terms if t.lower() in comp_lower)
166
+ total_terms += len(key_terms)
167
+ found_terms += hits
168
+ context_scores[ctx_name] = hits / len(key_terms) if key_terms else 0
169
+
170
+ context_awareness = found_terms / total_terms if total_terms > 0 else 0
171
+
172
+ # Distinction quality: does response contain distinction key words?
173
+ if distinction_key:
174
+ dist_terms = _extract_key_terms(distinction_key)
175
+ dist_hits = sum(1 for t in dist_terms if t.lower() in comp_lower)
176
+ distinction_quality = dist_hits / len(dist_terms) if dist_terms else 0
177
+ else:
178
+ distinction_quality = 0
179
+
180
+ # Evidence support: does response mention roles?
181
+ role_hits = 0
182
+ role_total = 0
183
+ for ctx_info in contexts.values():
184
+ role = ctx_info.get("role", "")
185
+ if role:
186
+ role_total += 1
187
+ role_terms = _extract_key_terms(role)
188
+ if any(t.lower() in comp_lower for t in role_terms):
189
+ role_hits += 1
190
+ evidence_support = role_hits / role_total if role_total > 0 else 0
191
+
192
+ # Composite: 40% context + 35% distinction + 25% evidence
193
+ score = (
194
+ 0.40 * context_awareness
195
+ + 0.35 * distinction_quality
196
+ + 0.25 * evidence_support
197
+ )
198
+
199
+ return VerifierResult(
200
+ score=score,
201
+ verifier_name=self.name,
202
+ details={
203
+ "context_awareness": context_awareness,
204
+ "distinction_quality": distinction_quality,
205
+ "evidence_support": evidence_support,
206
+ "context_scores": context_scores,
207
+ "terms_found": found_terms,
208
+ "terms_total": total_terms,
209
+ },
210
+ )
211
+
212
+
213
+ def _extract_key_terms(text: str, min_length: int = 4) -> List[str]:
214
+ """Extract key terms from text for matching."""
215
+ stopwords = {
216
+ "the", "and", "for", "that", "this", "with", "from", "are",
217
+ "was", "were", "been", "have", "has", "had", "will", "would",
218
+ "could", "should", "may", "might", "can", "does", "between",
219
+ }
220
+ words = re.findall(r"\b[a-zA-Z0-9-]+\b", text)
221
+ return [w for w in words if len(w) >= min_length and w.lower() not in stopwords]
src/biorlhf/verifiers/factual.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V2: Biological Fact Verifier.
3
+
4
+ Scores model responses based on overlap with known correct facts
5
+ from curated knowledge bases (SpaceOmicsBench, BioEval, GeneTuring).
6
+
7
+ Scoring: proportion of ground truth key facts found in the response.
8
+ """
9
+
10
+ import re
11
+ import json
12
+ from typing import Dict, List
13
+
14
+ from biorlhf.verifiers.base import BaseVerifier, VerifierResult
15
+
16
+
17
+ def _extract_key_terms(text: str, min_length: int = 4, max_terms: int = 10) -> List[str]:
18
+ """Extract important terms from a text string."""
19
+ # Remove common stopwords and short words
20
+ stopwords = {
21
+ "the", "and", "for", "that", "this", "with", "from", "are", "was",
22
+ "were", "been", "have", "has", "had", "will", "would", "could",
23
+ "should", "may", "might", "can", "does", "did", "but", "not",
24
+ "its", "also", "into", "than", "then", "when", "which", "what",
25
+ "where", "who", "how", "all", "each", "every", "both", "more",
26
+ "most", "other", "some", "such", "only", "same", "very", "just",
27
+ }
28
+ words = re.findall(r"\b[a-zA-Z0-9-]+\b", text)
29
+ terms = [
30
+ w for w in words
31
+ if len(w) >= min_length and w.lower() not in stopwords
32
+ ]
33
+ return terms[:max_terms]
34
+
35
+
36
+ def _phrase_match(phrase: str, text: str) -> bool:
37
+ """Check if a phrase (or its key terms) appears in text."""
38
+ text_lower = text.lower()
39
+ phrase_lower = phrase.lower()
40
+
41
+ # Direct substring match
42
+ if phrase_lower in text_lower:
43
+ return True
44
+
45
+ # For multi-word phrases, check if key terms co-occur
46
+ terms = _extract_key_terms(phrase, min_length=4, max_terms=5)
47
+ if not terms:
48
+ return phrase_lower in text_lower
49
+
50
+ matches = sum(1 for t in terms if t.lower() in text_lower)
51
+ # Require majority of key terms to match
52
+ return matches >= max(1, len(terms) // 2)
53
+
54
+
55
+ class BiologicalFactVerifier(BaseVerifier):
56
+ """V2: Verifies biological factual claims against curated knowledge."""
57
+
58
+ name = "V2"
59
+
60
+ def score(
61
+ self,
62
+ prompt: str,
63
+ completion: str,
64
+ ground_truth: Dict,
65
+ question_type: str,
66
+ ) -> VerifierResult:
67
+ """Score based on overlap with ground truth key facts.
68
+
69
+ Handles multiple GT formats:
70
+ - {"key_facts": ["fact1", "fact2", ...]}
71
+ - {"ground_truth_key_facts": [...]}
72
+ - {"expected_answer": "text"}
73
+ - {"expected_reasoning": [...]}
74
+ - {"correct_steps": [...]}
75
+ """
76
+ gt = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth)
77
+
78
+ # Extract key facts from various GT formats
79
+ key_facts = self._extract_facts(gt)
80
+
81
+ if not key_facts:
82
+ return VerifierResult(
83
+ score=0.5,
84
+ verifier_name=self.name,
85
+ details={"reason": "no_key_facts_in_gt"},
86
+ applicable=False,
87
+ )
88
+
89
+ # Score: proportion of key facts found in completion
90
+ matched_facts: List[str] = []
91
+ for fact in key_facts:
92
+ if isinstance(fact, str) and _phrase_match(fact, completion):
93
+ matched_facts.append(fact)
94
+
95
+ total = len(key_facts)
96
+ matched = len(matched_facts)
97
+ score = matched / total if total > 0 else 0.0
98
+
99
+ return VerifierResult(
100
+ score=score,
101
+ verifier_name=self.name,
102
+ details={
103
+ "matched_facts": matched_facts,
104
+ "total_facts": total,
105
+ "matched_count": matched,
106
+ "unmatched": [f for f in key_facts if f not in matched_facts],
107
+ },
108
+ )
109
+
110
+ def _extract_facts(self, gt: Dict) -> List[str]:
111
+ """Extract verifiable facts from ground truth dictionary."""
112
+ facts: List[str] = []
113
+
114
+ # Direct key facts lists
115
+ for key in ("key_facts", "ground_truth_key_facts"):
116
+ if key in gt and isinstance(gt[key], list):
117
+ facts.extend(str(f) for f in gt[key] if f)
118
+
119
+ # Expected reasoning points
120
+ if "expected_reasoning" in gt and isinstance(gt["expected_reasoning"], list):
121
+ facts.extend(str(f) for f in gt["expected_reasoning"] if f)
122
+
123
+ # Single expected answer
124
+ if "expected_answer" in gt and isinstance(gt["expected_answer"], str):
125
+ facts.append(gt["expected_answer"])
126
+
127
+ # Protocol steps (BioEval protoreason)
128
+ if "correct_steps" in gt and isinstance(gt["correct_steps"], list):
129
+ facts.extend(str(s) for s in gt["correct_steps"] if s)
130
+
131
+ # NES conservation facts
132
+ if "conservation_level" in gt:
133
+ facts.append(gt["conservation_level"])
134
+
135
+ # Deduplicate while preserving order
136
+ seen = set()
137
+ unique_facts = []
138
+ for f in facts:
139
+ if f not in seen:
140
+ seen.add(f)
141
+ unique_facts.append(f)
142
+
143
+ return unique_facts
src/biorlhf/verifiers/pathway.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V1: Pathway Direction Verifier.
3
+
4
+ Extracts directional claims about biological pathways from model responses
5
+ and compares them against fGSEA NES direction ground truth.
6
+
7
+ Scoring:
8
+ 1.0 — correct direction claimed
9
+ 0.5 — mixed/contradictory claims
10
+ 0.3 — no directional claim extracted
11
+ 0.0 — wrong direction claimed
12
+ """
13
+
14
+ import re
15
+ from typing import Dict, List, Tuple
16
+
17
+ from biorlhf.verifiers.base import BaseVerifier, VerifierResult
18
+
19
+ # ── Direction indicator patterns ──────────────────────────────────────────
20
+ UP_INDICATORS = [
21
+ r"\bupregulat\w*\b",
22
+ r"\bactivat\w*\b",
23
+ r"\bincreas\w*\b",
24
+ r"\belevat\w*\b",
25
+ r"\benhanc\w*\b",
26
+ r"\binduced?\b",
27
+ r"\bhigher\b",
28
+ r"\boverexpress\w*\b",
29
+ r"\benrich\w*\b",
30
+ r"\bpositive\s+NES\b",
31
+ r"\bNES\s*[>=]\s*0\b",
32
+ r"\bupstream\s+activat\w*\b",
33
+ r"\bstimulat\w*\b",
34
+ r"\bpromot\w*\b",
35
+ ]
36
+
37
+ DOWN_INDICATORS = [
38
+ r"\bdownregulat\w*\b",
39
+ r"\bsuppress\w*\b",
40
+ r"\bdecreas\w*\b",
41
+ r"\breduced?\b",
42
+ r"\binhibit\w*\b",
43
+ r"\brepress\w*\b",
44
+ r"\blower\w*\b",
45
+ r"\bunderexpress\w*\b",
46
+ r"\bdepress\w*\b",
47
+ r"\bnegative\s+NES\b",
48
+ r"\bNES\s*<\s*0\b",
49
+ r"\bdiminish\w*\b",
50
+ r"\battenuati\w*\b",
51
+ r"\bimpair\w*\b",
52
+ ]
53
+
54
+ # Negation patterns that flip direction
55
+ NEGATION_PATTERNS = [
56
+ r"\bnot\s+",
57
+ r"\bno\s+",
58
+ r"\bneither\b",
59
+ r"\bwithout\s+",
60
+ r"\bfail\w*\s+to\b",
61
+ r"\bdoes\s+not\b",
62
+ r"\bdid\s+not\b",
63
+ r"\bisn'?t\b",
64
+ r"\bwasn'?t\b",
65
+ r"\baren'?t\b",
66
+ ]
67
+
68
+ # ── Pathway name abbreviations ────────────────────────────────────────────
69
+ PATHWAY_ABBREVIATIONS: Dict[str, List[str]] = {
70
+ "oxidative phosphorylation": ["oxphos", "oxidative phosphorylation", "ox phos"],
71
+ "tnfa signaling via nfkb": ["tnf-alpha", "nfkb", "nf-kb", "nf-κb", "tnfα"],
72
+ "mtorc1 signaling": ["mtor", "mtorc1"],
73
+ "pi3k akt mtor signaling": ["pi3k", "akt", "mtor", "pi3k/akt"],
74
+ "interferon gamma response": ["ifn-gamma", "ifn-γ", "interferon gamma", "ifnγ"],
75
+ "interferon alpha response": ["ifn-alpha", "ifn-α", "interferon alpha", "ifnα"],
76
+ "adipogenesis": ["adipogenesis", "adipogenic"],
77
+ "myogenesis": ["myogenesis", "myogenic"],
78
+ "epithelial mesenchymal transition": ["emt", "epithelial-mesenchymal"],
79
+ "unfolded protein response": ["upr", "unfolded protein"],
80
+ "reactive oxygen species pathway": ["ros", "reactive oxygen"],
81
+ "fatty acid metabolism": ["fatty acid", "fat metabolism", "lipid metabolism"],
82
+ "glycolysis": ["glycolysis", "glycolytic"],
83
+ "dna repair": ["dna repair", "dna damage response"],
84
+ "apoptosis": ["apoptosis", "apoptotic", "programmed cell death"],
85
+ "inflammatory response": ["inflammatory", "inflammation"],
86
+ "hypoxia": ["hypoxia", "hypoxic"],
87
+ "angiogenesis": ["angiogenesis", "angiogenic"],
88
+ "p53 pathway": ["p53", "tp53"],
89
+ "wnt beta catenin signaling": ["wnt", "beta-catenin", "β-catenin"],
90
+ }
91
+
92
+
93
+ def _generate_pathway_variants(pathway_name: str) -> List[str]:
94
+ """Generate matching variants for a pathway name.
95
+
96
+ E.g. HALLMARK_OXIDATIVE_PHOSPHORYLATION →
97
+ ["HALLMARK_OXIDATIVE_PHOSPHORYLATION",
98
+ "oxidative phosphorylation",
99
+ "oxidative phosphorylation pathway",
100
+ "oxphos"]
101
+ """
102
+ variants = [pathway_name]
103
+
104
+ clean = pathway_name
105
+ for prefix in ("HALLMARK_", "KEGG_", "REACTOME_", "MITOCARTA_"):
106
+ clean = clean.replace(prefix, "")
107
+ human = clean.replace("_", " ").lower()
108
+
109
+ variants.append(human)
110
+ variants.append(human + " pathway")
111
+
112
+ # Add known abbreviations
113
+ for key, abbrevs in PATHWAY_ABBREVIATIONS.items():
114
+ if key in human:
115
+ variants.extend(abbrevs)
116
+
117
+ return variants
118
+
119
+
120
+ def _extract_sentences_with_term(text: str, term: str) -> List[str]:
121
+ """Extract sentences containing a term."""
122
+ sentences = re.split(r"[.!?\n]+", text)
123
+ return [
124
+ s.strip()
125
+ for s in sentences
126
+ if term.lower() in s.lower() and len(s.strip()) > 10
127
+ ]
128
+
129
+
130
+ def _has_negation_before(text: str, match_start: int, window: int = 12) -> bool:
131
+ """Check if a negation word appears shortly before a match position.
132
+
133
+ Window of ~12 chars catches "not " + up to ~8 chars of whitespace/adverbs,
134
+ without reaching across clause boundaries like "not X but rather Y".
135
+ """
136
+ start = max(0, match_start - window)
137
+ preceding = text[start:match_start].lower()
138
+ return any(re.search(p, preceding) for p in NEGATION_PATTERNS)
139
+
140
+
141
+ def extract_direction_claims(
142
+ text: str,
143
+ pathway_name: str,
144
+ ) -> List[Tuple[str, str]]:
145
+ """Extract directional claims about a specific pathway from text.
146
+
147
+ Returns list of (pathway_variant, direction) tuples.
148
+ Direction is "UP", "DOWN", or "AMBIGUOUS".
149
+ """
150
+ text_lower = text.lower()
151
+ pathway_variants = _generate_pathway_variants(pathway_name)
152
+
153
+ claims: List[Tuple[str, str]] = []
154
+ for variant in pathway_variants:
155
+ if variant.lower() not in text_lower:
156
+ continue
157
+
158
+ sentences = _extract_sentences_with_term(text, variant)
159
+ for sentence in sentences:
160
+ sent_lower = sentence.lower()
161
+ up_count = 0
162
+ down_count = 0
163
+
164
+ for pattern in UP_INDICATORS:
165
+ for match in re.finditer(pattern, sent_lower):
166
+ if _has_negation_before(sent_lower, match.start()):
167
+ down_count += 1 # Negated up = down
168
+ else:
169
+ up_count += 1
170
+
171
+ for pattern in DOWN_INDICATORS:
172
+ for match in re.finditer(pattern, sent_lower):
173
+ if _has_negation_before(sent_lower, match.start()):
174
+ up_count += 1 # Negated down = up
175
+ else:
176
+ down_count += 1
177
+
178
+ if up_count > down_count:
179
+ claims.append((variant, "UP"))
180
+ elif down_count > up_count:
181
+ claims.append((variant, "DOWN"))
182
+ elif up_count > 0:
183
+ claims.append((variant, "AMBIGUOUS"))
184
+
185
+ return claims
186
+
187
+
188
+ class PathwayDirectionVerifier(BaseVerifier):
189
+ """V1: Verifies pathway direction claims against fGSEA NES data."""
190
+
191
+ name = "V1"
192
+
193
+ def score(
194
+ self,
195
+ prompt: str,
196
+ completion: str,
197
+ ground_truth: Dict,
198
+ question_type: str,
199
+ ) -> VerifierResult:
200
+ if "pathway" not in ground_truth or "direction" not in ground_truth:
201
+ return VerifierResult(
202
+ score=0.5,
203
+ verifier_name=self.name,
204
+ details={"reason": "no_pathway_in_gt"},
205
+ applicable=False,
206
+ )
207
+
208
+ expected_dir = ground_truth["direction"]
209
+ pathway = ground_truth["pathway"]
210
+
211
+ # For comparison questions, check all tissue directions
212
+ if "tissue_directions" in ground_truth and question_type == "comparison":
213
+ return self._score_comparison(completion, ground_truth)
214
+
215
+ claims = extract_direction_claims(completion, pathway)
216
+
217
+ if not claims:
218
+ return VerifierResult(
219
+ score=0.3,
220
+ verifier_name=self.name,
221
+ details={
222
+ "reason": "no_claim_extracted",
223
+ "pathway": pathway,
224
+ "expected": expected_dir,
225
+ },
226
+ )
227
+
228
+ matching = [c for c in claims if c[1] == expected_dir]
229
+ contradicting = [
230
+ c for c in claims if c[1] != expected_dir and c[1] != "AMBIGUOUS"
231
+ ]
232
+
233
+ if matching and not contradicting:
234
+ score = 1.0
235
+ elif matching and contradicting:
236
+ score = 0.5
237
+ elif contradicting:
238
+ score = 0.0
239
+ else:
240
+ score = 0.3 # Only ambiguous claims
241
+
242
+ return VerifierResult(
243
+ score=score,
244
+ verifier_name=self.name,
245
+ details={
246
+ "pathway": pathway,
247
+ "expected": expected_dir,
248
+ "claims": [(v, d) for v, d in claims],
249
+ "n_matching": len(matching),
250
+ "n_contradicting": len(contradicting),
251
+ },
252
+ )
253
+
254
+ def _score_comparison(
255
+ self, completion: str, ground_truth: Dict
256
+ ) -> VerifierResult:
257
+ """Score cross-tissue comparison: check direction per tissue."""
258
+ tissue_dirs = ground_truth.get("tissue_directions", {})
259
+ pathway = ground_truth.get("pathway", "")
260
+
261
+ if not tissue_dirs:
262
+ return VerifierResult(
263
+ score=0.5, verifier_name=self.name,
264
+ details={"reason": "no_tissue_directions"}, applicable=False,
265
+ )
266
+
267
+ correct = 0
268
+ checked = 0
269
+ details_per_tissue = {}
270
+
271
+ for tissue, expected_dir in tissue_dirs.items():
272
+ # Look for tissue-specific claims in the response
273
+ tissue_sentences = _extract_sentences_with_term(completion, tissue)
274
+ if not tissue_sentences:
275
+ continue
276
+
277
+ tissue_text = " ".join(tissue_sentences)
278
+ claims = extract_direction_claims(tissue_text, pathway)
279
+ checked += 1
280
+
281
+ if any(c[1] == expected_dir for c in claims):
282
+ correct += 1
283
+ details_per_tissue[tissue] = "correct"
284
+ elif claims:
285
+ details_per_tissue[tissue] = "wrong"
286
+ else:
287
+ details_per_tissue[tissue] = "no_claim"
288
+
289
+ score = correct / checked if checked > 0 else 0.3
290
+
291
+ return VerifierResult(
292
+ score=score,
293
+ verifier_name=self.name,
294
+ details={
295
+ "pathway": pathway,
296
+ "tissues_checked": checked,
297
+ "tissues_correct": correct,
298
+ "per_tissue": details_per_tissue,
299
+ },
300
+ )
src/biorlhf/verifiers/uncertainty.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V4: Uncertainty Appropriateness Verifier.
3
+
4
+ Scores whether a model's stated confidence aligns with the ground-truth
5
+ expected confidence level. Integrates with BioEval's calibration scoring
6
+ when available, with a built-in fallback.
7
+
8
+ Scoring dimensions:
9
+ - Confidence level alignment (stated vs. expected)
10
+ - Calibration task behavior (acknowledge_unknown, overconfidence_trap, etc.)
11
+ - Default: penalizes extreme overconfidence
12
+ """
13
+
14
+ import re
15
+ import json
16
+ from typing import Dict, List
17
+ from dataclasses import dataclass
18
+
19
+ from biorlhf.verifiers.base import BaseVerifier, VerifierResult
20
+
21
+ # ── Try importing BioEval calibration infrastructure ──────────────────────
22
+ try:
23
+ import os
24
+ import sys
25
+ _bioeval_root = os.environ.get(
26
+ "BIOEVAL_ROOT",
27
+ "/Users/jak4013/Dropbox/Bioinformatics/Claude/Evaluation_model/BioEval",
28
+ )
29
+ sys.path.insert(0, _bioeval_root)
30
+ from bioeval.scoring.calibration import extract_confidence, ConfidenceExtraction
31
+ HAS_BIOEVAL = True
32
+ except ImportError:
33
+ HAS_BIOEVAL = False
34
+
35
+ # ── Built-in confidence extraction (fallback) ─────────────────────────────
36
+ HIGH_CONFIDENCE_PATTERNS = [
37
+ r"\bhigh\s*confidence\b", r"\bvery\s+confident\b", r"\bconfident\s+that\b",
38
+ r"\bcertainly\b", r"\bclearly\b", r"\bdefinitely\b", r"\bwithout\s+doubt\b",
39
+ r"\bconfidence:\s*high\b", r"\bstrongly\s+suggest\b",
40
+ ]
41
+
42
+ MEDIUM_CONFIDENCE_PATTERNS = [
43
+ r"\bmoderate\s*confidence\b", r"\breasonably\s+confident\b",
44
+ r"\blikely\b", r"\bprobably\b", r"\bsuggest\w*\b",
45
+ r"\bconfidence:\s*medium\b", r"\bconfidence:\s*moderate\b",
46
+ ]
47
+
48
+ LOW_CONFIDENCE_PATTERNS = [
49
+ r"\blow\s*confidence\b", r"\bnot\s+confident\b", r"\buncertain\b",
50
+ r"\bunclear\b", r"\bnot\s+sure\b", r"\bdon'?t\s+know\b",
51
+ r"\bcannot\s+determine\b", r"\binsufficient\s+\w*\s*(?:data|evidence)\b",
52
+ r"\blimited\s+evidence\b", r"\bspeculat\w*\b",
53
+ r"\bconfidence:\s*low\b",
54
+ ]
55
+
56
+ # Explicit numeric confidence
57
+ NUMERIC_CONFIDENCE_RE = re.compile(
58
+ r"(?:confidence|certainty|probability)[:\s]*(\d{1,3})%",
59
+ re.IGNORECASE,
60
+ )
61
+
62
+ # Expected confidence ranges
63
+ CONFIDENCE_RANGES = {
64
+ "high": (0.70, 1.00),
65
+ "medium": (0.35, 0.75),
66
+ "low": (0.00, 0.40),
67
+ }
68
+
69
+ # Expected confidence for calibration task behaviors
70
+ BEHAVIOR_EXPECTED_CONFIDENCE = {
71
+ "acknowledge_unknown": 0.15,
72
+ "high_confidence_correct": 0.90,
73
+ "partial_knowledge": 0.50,
74
+ "context_dependent": 0.50,
75
+ "moderate_confidence": 0.50,
76
+ "overconfidence_trap": 0.30,
77
+ }
78
+
79
+
80
+ @dataclass
81
+ class SimpleConfidence:
82
+ """Fallback confidence extraction result."""
83
+ stated: str # "high", "medium", "low"
84
+ numeric: float # 0.0 to 1.0
85
+ source: str # "explicit", "pattern", "language"
86
+
87
+
88
+ def _extract_confidence_simple(text: str) -> SimpleConfidence:
89
+ """Simple confidence extraction without BioEval."""
90
+ text_lower = text.lower()
91
+
92
+ # Check for explicit numeric confidence
93
+ num_match = NUMERIC_CONFIDENCE_RE.search(text)
94
+ if num_match:
95
+ pct = int(num_match.group(1))
96
+ numeric = pct / 100.0
97
+ if numeric >= 0.70:
98
+ stated = "high"
99
+ elif numeric >= 0.40:
100
+ stated = "medium"
101
+ else:
102
+ stated = "low"
103
+ return SimpleConfidence(stated=stated, numeric=numeric, source="explicit")
104
+
105
+ # Count pattern matches
106
+ high_count = sum(1 for p in HIGH_CONFIDENCE_PATTERNS if re.search(p, text_lower))
107
+ med_count = sum(1 for p in MEDIUM_CONFIDENCE_PATTERNS if re.search(p, text_lower))
108
+ low_count = sum(1 for p in LOW_CONFIDENCE_PATTERNS if re.search(p, text_lower))
109
+
110
+ if low_count > high_count and low_count > med_count:
111
+ return SimpleConfidence(stated="low", numeric=0.25, source="pattern")
112
+ elif high_count > low_count and high_count > med_count:
113
+ return SimpleConfidence(stated="high", numeric=0.85, source="pattern")
114
+ elif med_count > 0:
115
+ return SimpleConfidence(stated="medium", numeric=0.55, source="pattern")
116
+ else:
117
+ # Default: assume moderate confidence
118
+ return SimpleConfidence(stated="medium", numeric=0.50, source="language")
119
+
120
+
121
+ class UncertaintyVerifier(BaseVerifier):
122
+ """V4: Verifies that model's confidence is appropriate for the question."""
123
+
124
+ name = "V4"
125
+
126
+ def score(
127
+ self,
128
+ prompt: str,
129
+ completion: str,
130
+ ground_truth: Dict,
131
+ question_type: str,
132
+ ) -> VerifierResult:
133
+ gt = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth)
134
+
135
+ expected_confidence = gt.get("expected_confidence")
136
+ correct_behavior = gt.get("correct_behavior")
137
+
138
+ # Extract confidence from completion
139
+ if HAS_BIOEVAL:
140
+ conf_extraction = extract_confidence(completion)
141
+ conf_score = conf_extraction.confidence_score
142
+ stated = conf_extraction.stated_confidence or "medium"
143
+ else:
144
+ simple = _extract_confidence_simple(completion)
145
+ conf_score = simple.numeric
146
+ stated = simple.stated
147
+
148
+ # Route to appropriate scoring
149
+ if correct_behavior:
150
+ return self._score_calibration_behavior(
151
+ completion, gt, conf_score, stated, correct_behavior,
152
+ )
153
+ elif expected_confidence:
154
+ return self._score_confidence_alignment(
155
+ conf_score, stated, expected_confidence,
156
+ )
157
+ else:
158
+ return self._score_default(conf_score, stated)
159
+
160
+ def _score_calibration_behavior(
161
+ self,
162
+ completion: str,
163
+ gt: Dict,
164
+ conf_score: float,
165
+ stated: str,
166
+ correct_behavior: str,
167
+ ) -> VerifierResult:
168
+ """Score calibration tasks with specific expected behaviors."""
169
+ expected_conf = BEHAVIOR_EXPECTED_CONFIDENCE.get(correct_behavior, 0.5)
170
+ conf_error = abs(conf_score - expected_conf)
171
+
172
+ # Behavior-specific checks
173
+ behavior_correct = False
174
+
175
+ if correct_behavior == "acknowledge_unknown":
176
+ # Model should say it doesn't know
177
+ unknown_markers = [
178
+ "don't know", "cannot determine", "not enough",
179
+ "unknown", "hypothetical", "no information",
180
+ "unable to", "cannot assess", "fictional",
181
+ ]
182
+ behavior_correct = any(m in completion.lower() for m in unknown_markers)
183
+
184
+ elif correct_behavior == "high_confidence_correct":
185
+ behavior_correct = conf_score >= 0.65
186
+
187
+ elif correct_behavior == "overconfidence_trap":
188
+ # Should NOT be overconfident
189
+ behavior_correct = conf_score < 0.70
190
+
191
+ elif correct_behavior in ("partial_knowledge", "context_dependent", "moderate_confidence"):
192
+ behavior_correct = 0.30 <= conf_score <= 0.75
193
+
194
+ # Combined score: behavior compliance + confidence alignment
195
+ behavior_score = 1.0 if behavior_correct else 0.0
196
+ alignment_score = max(0, 1.0 - 2.0 * conf_error)
197
+ score = 0.6 * behavior_score + 0.4 * alignment_score
198
+
199
+ return VerifierResult(
200
+ score=score,
201
+ verifier_name=self.name,
202
+ details={
203
+ "correct_behavior": correct_behavior,
204
+ "expected_confidence": expected_conf,
205
+ "actual_confidence": conf_score,
206
+ "stated_confidence": stated,
207
+ "confidence_error": conf_error,
208
+ "behavior_correct": behavior_correct,
209
+ "using_bioeval": HAS_BIOEVAL,
210
+ },
211
+ )
212
+
213
+ def _score_confidence_alignment(
214
+ self,
215
+ conf_score: float,
216
+ stated: str,
217
+ expected: str,
218
+ ) -> VerifierResult:
219
+ """Score how well stated confidence aligns with expected level."""
220
+ if expected not in CONFIDENCE_RANGES:
221
+ return VerifierResult(
222
+ score=0.5, verifier_name=self.name,
223
+ details={"reason": "unknown_expected_level"},
224
+ )
225
+
226
+ low, high = CONFIDENCE_RANGES[expected]
227
+ in_range = low <= conf_score <= high
228
+
229
+ if in_range:
230
+ score = 1.0
231
+ else:
232
+ distance = min(abs(conf_score - low), abs(conf_score - high))
233
+ score = max(0, 1.0 - 2.5 * distance)
234
+
235
+ return VerifierResult(
236
+ score=score,
237
+ verifier_name=self.name,
238
+ details={
239
+ "expected_level": expected,
240
+ "expected_range": (low, high),
241
+ "actual_confidence": conf_score,
242
+ "stated_confidence": stated,
243
+ "in_range": in_range,
244
+ "using_bioeval": HAS_BIOEVAL,
245
+ },
246
+ )
247
+
248
+ def _score_default(
249
+ self,
250
+ conf_score: float,
251
+ stated: str,
252
+ ) -> VerifierResult:
253
+ """Default scoring: penalize extreme overconfidence."""
254
+ if conf_score > 0.90:
255
+ score = 0.4 # Overconfidence penalty
256
+ elif conf_score < 0.10:
257
+ score = 0.3 # Extreme underconfidence penalty
258
+ else:
259
+ score = 0.7 # Moderate confidence is good default
260
+
261
+ return VerifierResult(
262
+ score=score,
263
+ verifier_name=self.name,
264
+ details={
265
+ "actual_confidence": conf_score,
266
+ "stated_confidence": stated,
267
+ "mode": "default",
268
+ "using_bioeval": HAS_BIOEVAL,
269
+ },
270
+ )