Add BioGRPO training pipeline with composable biological verifiers
Browse filesImplements GRPO (Group Relative Policy Optimization) with four composable
biological verifiers (V1-V4) as reward functions:
- V1: Pathway direction verification against GeneLab fGSEA ground truth
- V2: Biological fact checking with keyword/entity validation
- V3: Cross-study consistency verification
- V4: Uncertainty calibration with ECE/Brier scoring
Key components:
- Verifier stack with weighted composition and per-question-type routing
- GRPO dataset builder merging GeneLab, BioEval, and SpaceOmicsBench sources
- GeneLab data loader with pathway enrichment score integration
- Calibration evaluation metrics (ECE, MCE, Brier, reliability diagrams)
- CLI entry point (biorlhf-grpo) with JSON config support
- SLURM scripts for Cayuga HPC (MVE and full experiment configs)
- Post-training evaluation script with SFT baseline comparison
Bug fixes applied during deployment:
- Tokenizer loading from base model (not adapter directory)
- LoRA adapter detection and SFT merge before GRPO training
- QLoRA 4-bit quantization support
- Lazy imports to avoid circular torch dependencies
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- configs/grpo_full.json +45 -0
- configs/grpo_mve.json +45 -0
- pyproject.toml +2 -1
- scripts/HPC_TRAINING_GUIDE.md +253 -0
- scripts/deploy_to_cayuga.sh +108 -0
- scripts/evaluate_ecosystem_model.py +412 -0
- scripts/evaluate_grpo.py +372 -0
- scripts/merge_training_data.py +160 -0
- scripts/run_eval_grpo.sh +92 -0
- scripts/run_evaluation.sh +83 -0
- scripts/run_grpo_full.sh +79 -0
- scripts/run_grpo_mve.sh +84 -0
- scripts/setup_cayuga_grpo.sh +127 -0
- scripts/train_ecosystem_improved.sh +154 -0
- src/biorlhf/__init__.py +24 -4
- src/biorlhf/cli.py +128 -1
- src/biorlhf/data/__init__.py +9 -1
- src/biorlhf/data/genelabloader.py +272 -0
- src/biorlhf/data/grpo_dataset.py +219 -0
- src/biorlhf/data/question_generator.py +264 -0
- src/biorlhf/evaluation/__init__.py +8 -2
- src/biorlhf/evaluation/calibration.py +184 -0
- src/biorlhf/training/__init__.py +16 -3
- src/biorlhf/training/grpo.py +284 -0
- src/biorlhf/verifiers/__init__.py +24 -0
- src/biorlhf/verifiers/base.py +49 -0
- src/biorlhf/verifiers/composer.py +227 -0
- src/biorlhf/verifiers/consistency.py +221 -0
- src/biorlhf/verifiers/factual.py +143 -0
- src/biorlhf/verifiers/pathway.py +300 -0
- src/biorlhf/verifiers/uncertainty.py +270 -0
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "mistralai/Mistral-7B-v0.3",
|
| 3 |
+
"sft_model_path": "./kmp_sft_model_final",
|
| 4 |
+
"output_dir": "./biogrpo_full_model",
|
| 5 |
+
|
| 6 |
+
"num_generations": 8,
|
| 7 |
+
"beta": 0.04,
|
| 8 |
+
"num_iterations": 1,
|
| 9 |
+
"scale_rewards": "group",
|
| 10 |
+
"loss_type": "grpo",
|
| 11 |
+
|
| 12 |
+
"num_epochs": 2,
|
| 13 |
+
"batch_size": 1,
|
| 14 |
+
"gradient_accumulation_steps": 8,
|
| 15 |
+
"learning_rate": 5e-7,
|
| 16 |
+
"max_completion_length": 1024,
|
| 17 |
+
"max_prompt_length": 512,
|
| 18 |
+
"warmup_ratio": 0.1,
|
| 19 |
+
|
| 20 |
+
"lora_r": 32,
|
| 21 |
+
"lora_alpha": 64,
|
| 22 |
+
"lora_dropout": 0.05,
|
| 23 |
+
|
| 24 |
+
"active_verifiers": ["V1", "V2", "V3", "V4"],
|
| 25 |
+
"verifier_weights": {"V1": 0.35, "V2": 0.30, "V3": 0.15, "V4": 0.20},
|
| 26 |
+
|
| 27 |
+
"pathway_db": "hallmark",
|
| 28 |
+
"hold_out_tissues": ["eye", "thymus"],
|
| 29 |
+
"seed": 42,
|
| 30 |
+
|
| 31 |
+
"use_4bit": true,
|
| 32 |
+
|
| 33 |
+
"wandb_project": "biogrpo",
|
| 34 |
+
"wandb_run_name": "grpo_full_all_verifiers",
|
| 35 |
+
"use_wandb": true,
|
| 36 |
+
"logging_steps": 10,
|
| 37 |
+
"save_steps": 50,
|
| 38 |
+
"eval_steps": 50,
|
| 39 |
+
"save_total_limit": 3,
|
| 40 |
+
"log_completions": true,
|
| 41 |
+
|
| 42 |
+
"use_vllm": false,
|
| 43 |
+
"gradient_checkpointing": true,
|
| 44 |
+
"bf16": true
|
| 45 |
+
}
|
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "mistralai/Mistral-7B-v0.3",
|
| 3 |
+
"sft_model_path": "./kmp_sft_model_final",
|
| 4 |
+
"output_dir": "./biogrpo_mve_model",
|
| 5 |
+
|
| 6 |
+
"num_generations": 4,
|
| 7 |
+
"beta": 0.04,
|
| 8 |
+
"num_iterations": 1,
|
| 9 |
+
"scale_rewards": "group",
|
| 10 |
+
"loss_type": "grpo",
|
| 11 |
+
|
| 12 |
+
"num_epochs": 3,
|
| 13 |
+
"batch_size": 2,
|
| 14 |
+
"gradient_accumulation_steps": 4,
|
| 15 |
+
"learning_rate": 1e-6,
|
| 16 |
+
"max_completion_length": 512,
|
| 17 |
+
"max_prompt_length": 384,
|
| 18 |
+
"warmup_ratio": 0.1,
|
| 19 |
+
|
| 20 |
+
"lora_r": 32,
|
| 21 |
+
"lora_alpha": 64,
|
| 22 |
+
"lora_dropout": 0.05,
|
| 23 |
+
|
| 24 |
+
"active_verifiers": ["V1", "V4"],
|
| 25 |
+
"verifier_weights": {"V1": 0.6, "V4": 0.4},
|
| 26 |
+
|
| 27 |
+
"pathway_db": "hallmark",
|
| 28 |
+
"hold_out_tissues": ["eye"],
|
| 29 |
+
"seed": 42,
|
| 30 |
+
|
| 31 |
+
"use_4bit": true,
|
| 32 |
+
|
| 33 |
+
"wandb_project": "biogrpo",
|
| 34 |
+
"wandb_run_name": "grpo_mve_v1v4",
|
| 35 |
+
"use_wandb": true,
|
| 36 |
+
"logging_steps": 5,
|
| 37 |
+
"save_steps": 25,
|
| 38 |
+
"eval_steps": 25,
|
| 39 |
+
"save_total_limit": 3,
|
| 40 |
+
"log_completions": true,
|
| 41 |
+
|
| 42 |
+
"use_vllm": false,
|
| 43 |
+
"gradient_checkpointing": true,
|
| 44 |
+
"bf16": true
|
| 45 |
+
}
|
|
@@ -43,7 +43,7 @@ dependencies = [
|
|
| 43 |
"datasets>=2.14.0",
|
| 44 |
"accelerate>=0.24.0",
|
| 45 |
"peft>=0.6.0",
|
| 46 |
-
"trl>=0.
|
| 47 |
"bitsandbytes>=0.41.0",
|
| 48 |
"wandb>=0.15.0",
|
| 49 |
"pandas>=2.0.0",
|
|
@@ -76,6 +76,7 @@ Issues = "https://github.com/jang1563/BioRLHF/issues"
|
|
| 76 |
[project.scripts]
|
| 77 |
biorlhf-train = "biorlhf.cli:train"
|
| 78 |
biorlhf-evaluate = "biorlhf.cli:evaluate"
|
|
|
|
| 79 |
|
| 80 |
[tool.hatch.build.targets.sdist]
|
| 81 |
include = [
|
|
|
|
| 43 |
"datasets>=2.14.0",
|
| 44 |
"accelerate>=0.24.0",
|
| 45 |
"peft>=0.6.0",
|
| 46 |
+
"trl>=0.14.0",
|
| 47 |
"bitsandbytes>=0.41.0",
|
| 48 |
"wandb>=0.15.0",
|
| 49 |
"pandas>=2.0.0",
|
|
|
|
| 76 |
[project.scripts]
|
| 77 |
biorlhf-train = "biorlhf.cli:train"
|
| 78 |
biorlhf-evaluate = "biorlhf.cli:evaluate"
|
| 79 |
+
biorlhf-grpo = "biorlhf.cli:grpo_train"
|
| 80 |
|
| 81 |
[tool.hatch.build.targets.sdist]
|
| 82 |
include = [
|
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BioRLHF Training on Cayuga HPC (Interactive Session)
|
| 2 |
+
|
| 3 |
+
**Cluster:** Cornell Cayuga HPC
|
| 4 |
+
**Target:** GPU training with Mistral-7B + LoRA
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Quick Start (Copy-Paste Commands)
|
| 9 |
+
|
| 10 |
+
```bash
|
| 11 |
+
# 1. Start interactive GPU session (A100 recommended, 80GB VRAM)
|
| 12 |
+
srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=4:00:00 --pty bash
|
| 13 |
+
|
| 14 |
+
# 2. Set up environment (first time only - see Step 2 below)
|
| 15 |
+
|
| 16 |
+
# 3. Run training
|
| 17 |
+
cd /athena/cayuga_XXXX/scratch/$USER/BioRLHF/biorlhf
|
| 18 |
+
./scripts/train_ecosystem_improved.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## Step 1: Transfer Files to HPC
|
| 24 |
+
|
| 25 |
+
From your local Mac:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# Replace with your actual paths and CWID
|
| 29 |
+
rsync -avz --progress \
|
| 30 |
+
/Users/jak4013/Dropbox/Bioinformatics/Claude/BioRLHF \
|
| 31 |
+
YOUR_CWID@cayuga.cac.cornell.edu:/athena/cayuga_XXXX/scratch/$USER/
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Or use scp:
|
| 35 |
+
```bash
|
| 36 |
+
scp -r /Users/jak4013/Dropbox/Bioinformatics/Claude/BioRLHF \
|
| 37 |
+
YOUR_CWID@cayuga.cac.cornell.edu:/athena/cayuga_XXXX/scratch/$USER/
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## Step 2: Set Up Conda Environment (First Time Only)
|
| 43 |
+
|
| 44 |
+
### 2a. Start Interactive Session
|
| 45 |
+
```bash
|
| 46 |
+
# SSH to Cayuga
|
| 47 |
+
ssh YOUR_CWID@cayuga.cac.cornell.edu
|
| 48 |
+
|
| 49 |
+
# Request interactive GPU session
|
| 50 |
+
srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=2:00:00 --pty bash
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### 2b. Install Miniconda (if not already installed)
|
| 54 |
+
```bash
|
| 55 |
+
# Create directory in scratch space
|
| 56 |
+
mkdir -p /athena/cayuga_XXXX/scratch/$USER/miniconda3
|
| 57 |
+
|
| 58 |
+
# Download and install
|
| 59 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
|
| 60 |
+
bash miniconda.sh -b -u -p /athena/cayuga_XXXX/scratch/$USER/miniconda3
|
| 61 |
+
rm miniconda.sh
|
| 62 |
+
|
| 63 |
+
# Initialize conda
|
| 64 |
+
source /athena/cayuga_XXXX/scratch/$USER/miniconda3/bin/activate
|
| 65 |
+
conda init bash
|
| 66 |
+
source ~/.bashrc
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### 2c. Create BioRLHF Environment
|
| 70 |
+
```bash
|
| 71 |
+
# Create environment with Python 3.10 (best compatibility)
|
| 72 |
+
conda create -n biorlhf python=3.10 -y
|
| 73 |
+
conda activate biorlhf
|
| 74 |
+
|
| 75 |
+
# Install PyTorch with CUDA support
|
| 76 |
+
conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia -y
|
| 77 |
+
|
| 78 |
+
# Install training dependencies
|
| 79 |
+
pip install transformers>=4.36.0
|
| 80 |
+
pip install peft>=0.7.0
|
| 81 |
+
pip install trl>=0.7.0
|
| 82 |
+
pip install bitsandbytes>=0.41.0
|
| 83 |
+
pip install accelerate>=0.25.0
|
| 84 |
+
pip install datasets>=2.14.0
|
| 85 |
+
pip install wandb
|
| 86 |
+
pip install scipy
|
| 87 |
+
pip install sentencepiece
|
| 88 |
+
|
| 89 |
+
# Verify GPU access
|
| 90 |
+
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}'); print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')"
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Step 3: Run Training (Interactive)
|
| 96 |
+
|
| 97 |
+
### 3a. Start GPU Session
|
| 98 |
+
```bash
|
| 99 |
+
# Request A100 GPU (80GB - best for Mistral-7B)
|
| 100 |
+
srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=4:00:00 --pty bash
|
| 101 |
+
|
| 102 |
+
# Or use A40 (48GB - also works with 4-bit quantization)
|
| 103 |
+
srun -p scu-gpu --gres=gpu:a40:1 --mem=48G -c 8 --time=4:00:00 --pty bash
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
### 3b. Activate Environment and Run
|
| 107 |
+
```bash
|
| 108 |
+
# Activate conda
|
| 109 |
+
source /athena/cayuga_XXXX/scratch/$USER/miniconda3/bin/activate
|
| 110 |
+
conda activate biorlhf
|
| 111 |
+
|
| 112 |
+
# Navigate to BioRLHF
|
| 113 |
+
cd /athena/cayuga_XXXX/scratch/$USER/BioRLHF/biorlhf
|
| 114 |
+
|
| 115 |
+
# Check GPU is available
|
| 116 |
+
nvidia-smi
|
| 117 |
+
|
| 118 |
+
# Set HuggingFace cache (optional - saves space)
|
| 119 |
+
export HF_HOME=/athena/cayuga_XXXX/scratch/$USER/.cache/huggingface
|
| 120 |
+
|
| 121 |
+
# Run training
|
| 122 |
+
./scripts/train_ecosystem_improved.sh
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
---
|
| 126 |
+
|
| 127 |
+
## Step 4: Monitor Training
|
| 128 |
+
|
| 129 |
+
In a separate terminal (or use tmux/screen):
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
# Watch GPU usage
|
| 133 |
+
watch -n 1 nvidia-smi
|
| 134 |
+
|
| 135 |
+
# Tail training logs
|
| 136 |
+
tail -f logs/biorlhf_ecosystem_*.out
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### Using WandB (Optional)
|
| 140 |
+
```bash
|
| 141 |
+
# Login to Weights & Biases
|
| 142 |
+
wandb login
|
| 143 |
+
|
| 144 |
+
# Training will automatically log to: https://wandb.ai/YOUR_USERNAME/biorlhf
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
---
|
| 148 |
+
|
| 149 |
+
## GPU Options on Cayuga
|
| 150 |
+
|
| 151 |
+
| GPU Type | VRAM | Recommended For | Command |
|
| 152 |
+
|----------|------|-----------------|---------|
|
| 153 |
+
| A100 | 80GB | Full training, larger batches | `--gres=gpu:a100:1` |
|
| 154 |
+
| A40 | 48GB | Standard training with 4-bit | `--gres=gpu:a40:1` |
|
| 155 |
+
| H100 | 80GB | Fastest (if available) | `--gres=gpu:h100:1` |
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## Troubleshooting
|
| 160 |
+
|
| 161 |
+
### "CUDA out of memory"
|
| 162 |
+
Reduce batch size in training script:
|
| 163 |
+
```bash
|
| 164 |
+
# Edit train_ecosystem_improved.sh
|
| 165 |
+
BATCH_SIZE=2 # Reduce from 4
|
| 166 |
+
GRAD_ACCUM=8 # Increase to maintain effective batch size
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
### "No GPU available"
|
| 170 |
+
```bash
|
| 171 |
+
# Check GPU allocation
|
| 172 |
+
nvidia-smi
|
| 173 |
+
|
| 174 |
+
# Verify CUDA installation
|
| 175 |
+
python -c "import torch; print(torch.cuda.is_available())"
|
| 176 |
+
|
| 177 |
+
# Check if you're on a GPU node
|
| 178 |
+
squeue -u $USER
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
### "Module not found"
|
| 182 |
+
```bash
|
| 183 |
+
# Ensure conda environment is activated
|
| 184 |
+
conda activate biorlhf
|
| 185 |
+
|
| 186 |
+
# Reinstall missing package
|
| 187 |
+
pip install <missing_package>
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
### Interactive session times out
|
| 191 |
+
Use `tmux` or `screen` to persist sessions:
|
| 192 |
+
```bash
|
| 193 |
+
# Start tmux before srun
|
| 194 |
+
tmux new -s training
|
| 195 |
+
|
| 196 |
+
# Then request GPU
|
| 197 |
+
srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=4:00:00 --pty bash
|
| 198 |
+
|
| 199 |
+
# Detach: Ctrl+B, then D
|
| 200 |
+
# Reattach: tmux attach -t training
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
---
|
| 204 |
+
|
| 205 |
+
## Expected Training Time
|
| 206 |
+
|
| 207 |
+
| Configuration | Dataset Size | Estimated Time |
|
| 208 |
+
|--------------|--------------|----------------|
|
| 209 |
+
| A100 + 4-bit | 378 examples, 10 epochs | ~45-60 min |
|
| 210 |
+
| A40 + 4-bit | 378 examples, 10 epochs | ~60-90 min |
|
| 211 |
+
| A100 (full) | 378 examples, 10 epochs | ~90-120 min |
|
| 212 |
+
|
| 213 |
+
---
|
| 214 |
+
|
| 215 |
+
## After Training
|
| 216 |
+
|
| 217 |
+
### Copy model back to local machine:
|
| 218 |
+
```bash
|
| 219 |
+
# From your Mac
|
| 220 |
+
scp -r YOUR_CWID@cayuga.cac.cornell.edu:/athena/cayuga_XXXX/scratch/$USER/BioRLHF/biorlhf/ecosystem_improved_model \
|
| 221 |
+
/Users/jak4013/Dropbox/Bioinformatics/Claude/BioRLHF/biorlhf/
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
### Run evaluation:
|
| 225 |
+
```bash
|
| 226 |
+
python evaluate_model.py --model ecosystem_improved_model
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
---
|
| 230 |
+
|
| 231 |
+
## Complete Interactive Session Example
|
| 232 |
+
|
| 233 |
+
```bash
|
| 234 |
+
# SSH to Cayuga
|
| 235 |
+
ssh jk2042@cayuga.cac.cornell.edu
|
| 236 |
+
|
| 237 |
+
# Start tmux (optional but recommended)
|
| 238 |
+
tmux new -s biorlhf
|
| 239 |
+
|
| 240 |
+
# Request GPU
|
| 241 |
+
srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=4:00:00 --pty bash
|
| 242 |
+
|
| 243 |
+
# Set up environment
|
| 244 |
+
source ~/miniconda3/bin/activate
|
| 245 |
+
conda activate biorlhf
|
| 246 |
+
|
| 247 |
+
# Navigate and run
|
| 248 |
+
cd /athena/cayuga_XXXX/scratch/$USER/BioRLHF/biorlhf
|
| 249 |
+
./scripts/train_ecosystem_improved.sh
|
| 250 |
+
|
| 251 |
+
# Watch progress (in another terminal or after Ctrl+B, c for new window)
|
| 252 |
+
watch -n 5 nvidia-smi
|
| 253 |
+
```
|
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# ============================================================
|
| 3 |
+
# Deploy BioRLHF code + data to Cayuga HPC
|
| 4 |
+
# Run from local Mac
|
| 5 |
+
# ============================================================
|
| 6 |
+
|
| 7 |
+
set -e
|
| 8 |
+
|
| 9 |
+
REMOTE="cayuga-login1"
|
| 10 |
+
SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
|
| 11 |
+
LOCAL_BASE="$HOME/Dropbox/Bioinformatics/Claude"
|
| 12 |
+
|
| 13 |
+
echo "============================================================"
|
| 14 |
+
echo "BioRLHF Cayuga Deployment"
|
| 15 |
+
echo "============================================================"
|
| 16 |
+
|
| 17 |
+
# Step 1: Create directories on Cayuga
|
| 18 |
+
echo ""
|
| 19 |
+
echo "[1/4] Creating directories on Cayuga..."
|
| 20 |
+
ssh ${REMOTE} "mkdir -p ${SCRATCH}/training/BioRLHF ${SCRATCH}/data/GeneLab_benchmark ${SCRATCH}/data/BioEval ${SCRATCH}/data/SpaceOmicsBench/v3/evaluation"
|
| 21 |
+
|
| 22 |
+
# Step 2: Transfer BioRLHF code (only essential files)
|
| 23 |
+
echo ""
|
| 24 |
+
echo "[2/4] Transferring BioRLHF code..."
|
| 25 |
+
LOCAL_BIORLHF="${LOCAL_BASE}/BioRLHF/biorlhf"
|
| 26 |
+
DEST="${REMOTE}:${SCRATCH}/training/BioRLHF"
|
| 27 |
+
|
| 28 |
+
# Transfer only the package structure needed for GRPO
|
| 29 |
+
rsync -avz --progress \
|
| 30 |
+
"${LOCAL_BIORLHF}/src/" \
|
| 31 |
+
${DEST}/src/
|
| 32 |
+
|
| 33 |
+
rsync -avz --progress \
|
| 34 |
+
"${LOCAL_BIORLHF}/configs/" \
|
| 35 |
+
${DEST}/configs/
|
| 36 |
+
|
| 37 |
+
rsync -avz --progress \
|
| 38 |
+
"${LOCAL_BIORLHF}/scripts/" \
|
| 39 |
+
${DEST}/scripts/
|
| 40 |
+
|
| 41 |
+
rsync -avz --progress \
|
| 42 |
+
"${LOCAL_BIORLHF}/tests/" \
|
| 43 |
+
${DEST}/tests/
|
| 44 |
+
|
| 45 |
+
rsync -avz --progress \
|
| 46 |
+
"${LOCAL_BIORLHF}/pyproject.toml" \
|
| 47 |
+
"${LOCAL_BIORLHF}/README.md" \
|
| 48 |
+
${DEST}/
|
| 49 |
+
|
| 50 |
+
# Step 3: Transfer data (only what GRPO training needs)
|
| 51 |
+
echo ""
|
| 52 |
+
echo "[3/4] Transferring data..."
|
| 53 |
+
|
| 54 |
+
echo " GeneLab fgsea (pathway enrichment scores - required)..."
|
| 55 |
+
rsync -avz --progress \
|
| 56 |
+
"${LOCAL_BASE}/GeneLab_benchmark/processed/fgsea/" \
|
| 57 |
+
${REMOTE}:${SCRATCH}/data/GeneLab_benchmark/processed/fgsea/
|
| 58 |
+
|
| 59 |
+
echo " GeneLab evaluation (NES conservation - for conservation questions)..."
|
| 60 |
+
rsync -avz --progress \
|
| 61 |
+
"${LOCAL_BASE}/GeneLab_benchmark/evaluation/" \
|
| 62 |
+
${REMOTE}:${SCRATCH}/data/GeneLab_benchmark/evaluation/
|
| 63 |
+
|
| 64 |
+
echo " BioEval data..."
|
| 65 |
+
rsync -avz --progress \
|
| 66 |
+
"${LOCAL_BASE}/Evaluation_model/BioEval/data/" \
|
| 67 |
+
${REMOTE}:${SCRATCH}/data/BioEval/data/
|
| 68 |
+
|
| 69 |
+
echo " BioEval scoring (for calibration imports)..."
|
| 70 |
+
rsync -avz --progress \
|
| 71 |
+
"${LOCAL_BASE}/Evaluation_model/BioEval/bioeval/" \
|
| 72 |
+
${REMOTE}:${SCRATCH}/data/BioEval/bioeval/
|
| 73 |
+
|
| 74 |
+
echo " SpaceOmicsBench..."
|
| 75 |
+
rsync -avz --progress \
|
| 76 |
+
"${LOCAL_BASE}/SpaceOmicsBench/v3/evaluation/llm/" \
|
| 77 |
+
${REMOTE}:${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm/
|
| 78 |
+
|
| 79 |
+
# Step 4: Verify
|
| 80 |
+
echo ""
|
| 81 |
+
echo "[4/4] Verifying deployment..."
|
| 82 |
+
ssh ${REMOTE} "
|
| 83 |
+
echo 'Directory structure:'
|
| 84 |
+
echo ' BioRLHF code:'
|
| 85 |
+
ls ${SCRATCH}/training/BioRLHF/pyproject.toml 2>/dev/null && echo ' pyproject.toml: OK' || echo ' pyproject.toml: MISSING'
|
| 86 |
+
ls ${SCRATCH}/training/BioRLHF/configs/grpo_mve.json 2>/dev/null && echo ' configs/grpo_mve.json: OK' || echo ' configs/grpo_mve.json: MISSING'
|
| 87 |
+
ls -d ${SCRATCH}/training/BioRLHF/src/biorlhf/ 2>/dev/null && echo ' src/biorlhf/: OK' || echo ' src/biorlhf/: MISSING'
|
| 88 |
+
|
| 89 |
+
echo ' SFT checkpoint:'
|
| 90 |
+
ls -d ${SCRATCH}/training/biorlhf/kmp_sft_model_final/ 2>/dev/null && echo ' kmp_sft_model_final: OK' || echo ' kmp_sft_model_final: MISSING'
|
| 91 |
+
|
| 92 |
+
echo ' Data:'
|
| 93 |
+
ls ${SCRATCH}/data/GeneLab_benchmark/processed/fgsea/ 2>/dev/null | head -3 && echo ' GeneLab fgsea: OK' || echo ' GeneLab fgsea: MISSING'
|
| 94 |
+
ls ${SCRATCH}/data/GeneLab_benchmark/evaluation/ 2>/dev/null | head -3 && echo ' GeneLab evaluation: OK' || echo ' GeneLab evaluation: MISSING'
|
| 95 |
+
ls ${SCRATCH}/data/BioEval/data/ 2>/dev/null | head -3 && echo ' BioEval: OK' || echo ' BioEval: MISSING'
|
| 96 |
+
ls ${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm/ 2>/dev/null | head -3 && echo ' SpaceOmicsBench: OK' || echo ' SpaceOmicsBench: MISSING'
|
| 97 |
+
"
|
| 98 |
+
|
| 99 |
+
echo ""
|
| 100 |
+
echo "============================================================"
|
| 101 |
+
echo "Deployment complete!"
|
| 102 |
+
echo ""
|
| 103 |
+
echo "Next steps on Cayuga:"
|
| 104 |
+
echo " ssh ${REMOTE}"
|
| 105 |
+
echo " cd ${SCRATCH}/training/BioRLHF"
|
| 106 |
+
echo " bash scripts/setup_cayuga_grpo.sh"
|
| 107 |
+
echo " sbatch scripts/run_grpo_mve.sh"
|
| 108 |
+
echo "============================================================"
|
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Evaluate Ecosystem-Improved Model on Failure Patterns
|
| 4 |
+
|
| 5 |
+
This script evaluates the fine-tuned model specifically on the patterns
|
| 6 |
+
it was trained to improve:
|
| 7 |
+
- Calibration (uncertainty expression)
|
| 8 |
+
- Adversarial resistance
|
| 9 |
+
- Protocol completeness
|
| 10 |
+
- Fact recall
|
| 11 |
+
|
| 12 |
+
Usage (on HPC with GPU):
|
| 13 |
+
python scripts/evaluate_ecosystem_model.py --model ./ecosystem_improved_model
|
| 14 |
+
|
| 15 |
+
Requirements:
|
| 16 |
+
- CUDA GPU
|
| 17 |
+
- transformers, peft, bitsandbytes, torch
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import json
|
| 22 |
+
import torch
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 26 |
+
from peft import PeftModel
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_model(model_path: str, base_model: str = "mistralai/Mistral-7B-v0.3", use_4bit: bool = True):
|
| 30 |
+
"""Load the fine-tuned model with LoRA adapters."""
|
| 31 |
+
print(f"Loading base model: {base_model}")
|
| 32 |
+
|
| 33 |
+
if use_4bit:
|
| 34 |
+
bnb_config = BitsAndBytesConfig(
|
| 35 |
+
load_in_4bit=True,
|
| 36 |
+
bnb_4bit_quant_type="nf4",
|
| 37 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 38 |
+
bnb_4bit_use_double_quant=True,
|
| 39 |
+
)
|
| 40 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 41 |
+
base_model,
|
| 42 |
+
quantization_config=bnb_config,
|
| 43 |
+
device_map="auto",
|
| 44 |
+
trust_remote_code=True,
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 48 |
+
base_model,
|
| 49 |
+
device_map="auto",
|
| 50 |
+
torch_dtype=torch.float16,
|
| 51 |
+
trust_remote_code=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
print(f"Loading LoRA adapters from: {model_path}")
|
| 55 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 56 |
+
|
| 57 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 58 |
+
if tokenizer.pad_token is None:
|
| 59 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 60 |
+
|
| 61 |
+
return model, tokenizer
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def generate_response(model, tokenizer, prompt: str, max_new_tokens: int = 512) -> str:
|
| 65 |
+
"""Generate a response from the model."""
|
| 66 |
+
formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
|
| 67 |
+
|
| 68 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
outputs = model.generate(
|
| 72 |
+
**inputs,
|
| 73 |
+
max_new_tokens=max_new_tokens,
|
| 74 |
+
temperature=0.1,
|
| 75 |
+
top_p=0.9,
|
| 76 |
+
do_sample=True,
|
| 77 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 81 |
+
# Extract just the response part
|
| 82 |
+
if "### Response:" in response:
|
| 83 |
+
response = response.split("### Response:")[-1].strip()
|
| 84 |
+
|
| 85 |
+
return response
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def evaluate_calibration(model, tokenizer, test_cases: list) -> dict:
|
| 89 |
+
"""Evaluate uncertainty calibration - does model admit when it doesn't know?"""
|
| 90 |
+
print("\n" + "="*60)
|
| 91 |
+
print("EVALUATING: Calibration (Uncertainty Expression)")
|
| 92 |
+
print("="*60)
|
| 93 |
+
|
| 94 |
+
results = []
|
| 95 |
+
calibration_phrases = [
|
| 96 |
+
"cannot determine",
|
| 97 |
+
"not provided",
|
| 98 |
+
"not included",
|
| 99 |
+
"would need",
|
| 100 |
+
"don't have",
|
| 101 |
+
"doesn't include",
|
| 102 |
+
"cannot say",
|
| 103 |
+
"uncertain",
|
| 104 |
+
"not specified",
|
| 105 |
+
"additional information",
|
| 106 |
+
"not in the data",
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
for i, case in enumerate(test_cases):
|
| 110 |
+
question = case["conversations"][0]["content"]
|
| 111 |
+
expected = case["conversations"][1]["content"]
|
| 112 |
+
|
| 113 |
+
print(f"\n[{i+1}] Testing: {question[:60]}...")
|
| 114 |
+
response = generate_response(model, tokenizer, question)
|
| 115 |
+
|
| 116 |
+
# Check if response expresses appropriate uncertainty
|
| 117 |
+
expresses_uncertainty = any(phrase.lower() in response.lower() for phrase in calibration_phrases)
|
| 118 |
+
|
| 119 |
+
result = {
|
| 120 |
+
"id": case["id"],
|
| 121 |
+
"question": question,
|
| 122 |
+
"expected_behavior": "Express uncertainty about unavailable data",
|
| 123 |
+
"response": response[:500],
|
| 124 |
+
"expresses_uncertainty": expresses_uncertainty,
|
| 125 |
+
"passed": expresses_uncertainty
|
| 126 |
+
}
|
| 127 |
+
results.append(result)
|
| 128 |
+
|
| 129 |
+
status = "PASS" if expresses_uncertainty else "FAIL"
|
| 130 |
+
print(f" [{status}] Uncertainty expressed: {expresses_uncertainty}")
|
| 131 |
+
|
| 132 |
+
passed = sum(1 for r in results if r["passed"])
|
| 133 |
+
total = len(results)
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
"category": "calibration",
|
| 137 |
+
"passed": passed,
|
| 138 |
+
"total": total,
|
| 139 |
+
"accuracy": passed / total if total > 0 else 0,
|
| 140 |
+
"details": results
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def evaluate_adversarial(model, tokenizer, test_cases: list) -> dict:
|
| 145 |
+
"""Evaluate adversarial resistance - does model reject false claims?"""
|
| 146 |
+
print("\n" + "="*60)
|
| 147 |
+
print("EVALUATING: Adversarial Resistance")
|
| 148 |
+
print("="*60)
|
| 149 |
+
|
| 150 |
+
results = []
|
| 151 |
+
resistance_phrases = [
|
| 152 |
+
"not supported",
|
| 153 |
+
"contradicted",
|
| 154 |
+
"incorrect",
|
| 155 |
+
"need to correct",
|
| 156 |
+
"not true",
|
| 157 |
+
"actually shows",
|
| 158 |
+
"data shows",
|
| 159 |
+
"opposite",
|
| 160 |
+
"challenge that premise",
|
| 161 |
+
"misrepresent",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
for i, case in enumerate(test_cases):
|
| 165 |
+
question = case["conversations"][0]["content"]
|
| 166 |
+
expected = case["conversations"][1]["content"]
|
| 167 |
+
|
| 168 |
+
print(f"\n[{i+1}] Testing adversarial: {question[:60]}...")
|
| 169 |
+
response = generate_response(model, tokenizer, question)
|
| 170 |
+
|
| 171 |
+
# Check if response resists the false claim
|
| 172 |
+
resists_claim = any(phrase.lower() in response.lower() for phrase in resistance_phrases)
|
| 173 |
+
|
| 174 |
+
result = {
|
| 175 |
+
"id": case["id"],
|
| 176 |
+
"question": question,
|
| 177 |
+
"expected_behavior": "Reject false premise with evidence",
|
| 178 |
+
"response": response[:500],
|
| 179 |
+
"resists_false_claim": resists_claim,
|
| 180 |
+
"passed": resists_claim
|
| 181 |
+
}
|
| 182 |
+
results.append(result)
|
| 183 |
+
|
| 184 |
+
status = "PASS" if resists_claim else "FAIL"
|
| 185 |
+
print(f" [{status}] Resisted false claim: {resists_claim}")
|
| 186 |
+
|
| 187 |
+
passed = sum(1 for r in results if r["passed"])
|
| 188 |
+
total = len(results)
|
| 189 |
+
|
| 190 |
+
return {
|
| 191 |
+
"category": "adversarial_resistance",
|
| 192 |
+
"passed": passed,
|
| 193 |
+
"total": total,
|
| 194 |
+
"accuracy": passed / total if total > 0 else 0,
|
| 195 |
+
"details": results
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def evaluate_completeness(model, tokenizer, test_cases: list) -> dict:
|
| 200 |
+
"""Evaluate protocol completeness - does model detect all missing steps?"""
|
| 201 |
+
print("\n" + "="*60)
|
| 202 |
+
print("EVALUATING: Protocol Completeness")
|
| 203 |
+
print("="*60)
|
| 204 |
+
|
| 205 |
+
results = []
|
| 206 |
+
|
| 207 |
+
# Key missing steps that should be detected
|
| 208 |
+
key_steps = {
|
| 209 |
+
"comp_001": ["dnase", "reverse transcription", "rt", "cdna"],
|
| 210 |
+
"comp_002": ["transfer", "blot", "membrane transfer"]
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
for i, case in enumerate(test_cases):
|
| 214 |
+
question = case["conversations"][0]["content"]
|
| 215 |
+
expected = case["conversations"][1]["content"]
|
| 216 |
+
case_id = case["id"]
|
| 217 |
+
|
| 218 |
+
print(f"\n[{i+1}] Testing completeness: {case_id}...")
|
| 219 |
+
response = generate_response(model, tokenizer, question, max_new_tokens=800)
|
| 220 |
+
|
| 221 |
+
# Check if key missing steps are detected
|
| 222 |
+
expected_steps = key_steps.get(case_id, [])
|
| 223 |
+
response_lower = response.lower()
|
| 224 |
+
detected = [step for step in expected_steps if step in response_lower]
|
| 225 |
+
detection_rate = len(detected) / len(expected_steps) if expected_steps else 0
|
| 226 |
+
|
| 227 |
+
result = {
|
| 228 |
+
"id": case_id,
|
| 229 |
+
"question": question[:100],
|
| 230 |
+
"expected_steps": expected_steps,
|
| 231 |
+
"detected_steps": detected,
|
| 232 |
+
"response": response[:600],
|
| 233 |
+
"detection_rate": detection_rate,
|
| 234 |
+
"passed": detection_rate >= 0.5 # Pass if at least half detected
|
| 235 |
+
}
|
| 236 |
+
results.append(result)
|
| 237 |
+
|
| 238 |
+
status = "PASS" if result["passed"] else "FAIL"
|
| 239 |
+
print(f" [{status}] Detected {len(detected)}/{len(expected_steps)} key steps")
|
| 240 |
+
|
| 241 |
+
passed = sum(1 for r in results if r["passed"])
|
| 242 |
+
total = len(results)
|
| 243 |
+
|
| 244 |
+
return {
|
| 245 |
+
"category": "protocol_completeness",
|
| 246 |
+
"passed": passed,
|
| 247 |
+
"total": total,
|
| 248 |
+
"accuracy": passed / total if total > 0 else 0,
|
| 249 |
+
"details": results
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def evaluate_fact_recall(model, tokenizer, test_cases: list) -> dict:
|
| 254 |
+
"""Evaluate fact recall - does model remember key trained facts?"""
|
| 255 |
+
print("\n" + "="*60)
|
| 256 |
+
print("EVALUATING: Fact Recall")
|
| 257 |
+
print("="*60)
|
| 258 |
+
|
| 259 |
+
results = []
|
| 260 |
+
|
| 261 |
+
# Key facts and values that should be recalled
|
| 262 |
+
key_facts = {
|
| 263 |
+
"fact_001": ["52%", "52 percent"],
|
| 264 |
+
"fact_002": ["52%", "52 percent"],
|
| 265 |
+
"fact_003": ["52%", "8%"],
|
| 266 |
+
"fact_004": ["-1.60", "-1.6", "suppressed", "suppression"],
|
| 267 |
+
"fact_005": ["liver", "-1.60", "-1.6", "opposite"]
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
for i, case in enumerate(test_cases):
|
| 271 |
+
question = case["conversations"][0]["content"]
|
| 272 |
+
expected = case["conversations"][1]["content"]
|
| 273 |
+
case_id = case["id"]
|
| 274 |
+
|
| 275 |
+
print(f"\n[{i+1}] Testing fact recall: {case_id}...")
|
| 276 |
+
response = generate_response(model, tokenizer, question)
|
| 277 |
+
|
| 278 |
+
# Check if key facts are present
|
| 279 |
+
expected_facts = key_facts.get(case_id, [])
|
| 280 |
+
response_lower = response.lower()
|
| 281 |
+
recalled = [fact for fact in expected_facts if fact.lower() in response_lower]
|
| 282 |
+
recall_rate = len(recalled) / len(expected_facts) if expected_facts else 0
|
| 283 |
+
|
| 284 |
+
result = {
|
| 285 |
+
"id": case_id,
|
| 286 |
+
"question": question,
|
| 287 |
+
"expected_facts": expected_facts,
|
| 288 |
+
"recalled_facts": recalled,
|
| 289 |
+
"response": response[:400],
|
| 290 |
+
"recall_rate": recall_rate,
|
| 291 |
+
"passed": recall_rate >= 0.5 # Pass if at least half recalled
|
| 292 |
+
}
|
| 293 |
+
results.append(result)
|
| 294 |
+
|
| 295 |
+
status = "PASS" if result["passed"] else "FAIL"
|
| 296 |
+
print(f" [{status}] Recalled {len(recalled)}/{len(expected_facts)} key facts")
|
| 297 |
+
|
| 298 |
+
passed = sum(1 for r in results if r["passed"])
|
| 299 |
+
total = len(results)
|
| 300 |
+
|
| 301 |
+
return {
|
| 302 |
+
"category": "fact_recall",
|
| 303 |
+
"passed": passed,
|
| 304 |
+
"total": total,
|
| 305 |
+
"accuracy": passed / total if total > 0 else 0,
|
| 306 |
+
"details": results
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def main():
|
| 311 |
+
parser = argparse.ArgumentParser(description="Evaluate ecosystem-improved model")
|
| 312 |
+
parser.add_argument("--model", type=str, default="./ecosystem_improved_model",
|
| 313 |
+
help="Path to the fine-tuned model")
|
| 314 |
+
parser.add_argument("--base-model", type=str, default="mistralai/Mistral-7B-v0.3",
|
| 315 |
+
help="Base model name")
|
| 316 |
+
parser.add_argument("--test-data", type=str, default="data/ecosystem_failures_training.json",
|
| 317 |
+
help="Path to test data JSON")
|
| 318 |
+
parser.add_argument("--output", type=str, default=None,
|
| 319 |
+
help="Output path for results JSON")
|
| 320 |
+
parser.add_argument("--no-4bit", action="store_true",
|
| 321 |
+
help="Disable 4-bit quantization")
|
| 322 |
+
|
| 323 |
+
args = parser.parse_args()
|
| 324 |
+
|
| 325 |
+
print("="*60)
|
| 326 |
+
print("BioRLHF Ecosystem Model Evaluation")
|
| 327 |
+
print("="*60)
|
| 328 |
+
print(f"Model: {args.model}")
|
| 329 |
+
print(f"Base: {args.base_model}")
|
| 330 |
+
print(f"Test data: {args.test_data}")
|
| 331 |
+
print(f"Time: {datetime.now().isoformat()}")
|
| 332 |
+
print("="*60)
|
| 333 |
+
|
| 334 |
+
# Load test data
|
| 335 |
+
with open(args.test_data, 'r') as f:
|
| 336 |
+
test_data = json.load(f)
|
| 337 |
+
|
| 338 |
+
# Load model
|
| 339 |
+
model, tokenizer = load_model(args.model, args.base_model, use_4bit=not args.no_4bit)
|
| 340 |
+
print("\nModel loaded successfully!\n")
|
| 341 |
+
|
| 342 |
+
# Run evaluations
|
| 343 |
+
results = {}
|
| 344 |
+
|
| 345 |
+
# 1. Calibration
|
| 346 |
+
if test_data.get("calibration_examples"):
|
| 347 |
+
results["calibration"] = evaluate_calibration(
|
| 348 |
+
model, tokenizer, test_data["calibration_examples"]
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# 2. Adversarial resistance
|
| 352 |
+
if test_data.get("adversarial_resistance_examples"):
|
| 353 |
+
results["adversarial"] = evaluate_adversarial(
|
| 354 |
+
model, tokenizer, test_data["adversarial_resistance_examples"]
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# 3. Protocol completeness
|
| 358 |
+
if test_data.get("completeness_examples"):
|
| 359 |
+
results["completeness"] = evaluate_completeness(
|
| 360 |
+
model, tokenizer, test_data["completeness_examples"]
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# 4. Fact recall
|
| 364 |
+
if test_data.get("fact_drilling_examples"):
|
| 365 |
+
results["fact_recall"] = evaluate_fact_recall(
|
| 366 |
+
model, tokenizer, test_data["fact_drilling_examples"]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Summary
|
| 370 |
+
print("\n" + "="*60)
|
| 371 |
+
print("EVALUATION SUMMARY")
|
| 372 |
+
print("="*60)
|
| 373 |
+
|
| 374 |
+
total_passed = 0
|
| 375 |
+
total_tests = 0
|
| 376 |
+
|
| 377 |
+
for category, data in results.items():
|
| 378 |
+
print(f"\n{category.upper()}:")
|
| 379 |
+
print(f" Passed: {data['passed']}/{data['total']} ({data['accuracy']:.1%})")
|
| 380 |
+
total_passed += data['passed']
|
| 381 |
+
total_tests += data['total']
|
| 382 |
+
|
| 383 |
+
overall_accuracy = total_passed / total_tests if total_tests > 0 else 0
|
| 384 |
+
|
| 385 |
+
print("\n" + "-"*60)
|
| 386 |
+
print(f"OVERALL: {total_passed}/{total_tests} ({overall_accuracy:.1%})")
|
| 387 |
+
print("-"*60)
|
| 388 |
+
|
| 389 |
+
# Save results
|
| 390 |
+
output_path = args.output or f"ecosystem_eval_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 391 |
+
|
| 392 |
+
output_data = {
|
| 393 |
+
"model_path": args.model,
|
| 394 |
+
"base_model": args.base_model,
|
| 395 |
+
"evaluation_date": datetime.now().isoformat(),
|
| 396 |
+
"overall_accuracy": overall_accuracy,
|
| 397 |
+
"total_passed": total_passed,
|
| 398 |
+
"total_tests": total_tests,
|
| 399 |
+
"results": results
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
with open(output_path, 'w') as f:
|
| 403 |
+
json.dump(output_data, f, indent=2)
|
| 404 |
+
|
| 405 |
+
print(f"\nResults saved to: {output_path}")
|
| 406 |
+
print("\n" + "="*60)
|
| 407 |
+
print("Evaluation complete!")
|
| 408 |
+
print("="*60)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if __name__ == "__main__":
|
| 412 |
+
main()
|
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
BioGRPO Post-Training Evaluation Script
|
| 4 |
+
|
| 5 |
+
Evaluates a GRPO-trained model against:
|
| 6 |
+
1. Held-out GeneLab questions (LOMO: Leave-One-Mission-Out)
|
| 7 |
+
2. Calibration metrics (ECE, Brier, overconfidence rate)
|
| 8 |
+
3. Per-verifier reward scores
|
| 9 |
+
4. Baseline comparison (SFT, DPO)
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python scripts/evaluate_grpo.py \
|
| 13 |
+
--model ./biogrpo_mve_model \
|
| 14 |
+
--sft-baseline ./kmp_sft_model_final \
|
| 15 |
+
--hold-out-tissues eye \
|
| 16 |
+
--output results/grpo_mve_eval.json
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import torch
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
from typing import Dict, List, Optional
|
| 25 |
+
|
| 26 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 27 |
+
from peft import PeftModel
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
|
| 30 |
+
from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
|
| 31 |
+
from biorlhf.verifiers.composer import VerifierComposer
|
| 32 |
+
from biorlhf.verifiers.uncertainty import _extract_confidence_simple
|
| 33 |
+
from biorlhf.evaluation.calibration import compute_calibration_metrics
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_model(
|
| 37 |
+
model_path: str,
|
| 38 |
+
base_model: str = "mistralai/Mistral-7B-v0.3",
|
| 39 |
+
use_4bit: bool = True,
|
| 40 |
+
):
|
| 41 |
+
"""Load a fine-tuned model with LoRA adapters."""
|
| 42 |
+
print(f" Base model: {base_model}")
|
| 43 |
+
print(f" Adapter: {model_path}")
|
| 44 |
+
|
| 45 |
+
if use_4bit:
|
| 46 |
+
bnb_config = BitsAndBytesConfig(
|
| 47 |
+
load_in_4bit=True,
|
| 48 |
+
bnb_4bit_quant_type="nf4",
|
| 49 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 50 |
+
bnb_4bit_use_double_quant=True,
|
| 51 |
+
)
|
| 52 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 53 |
+
base_model,
|
| 54 |
+
quantization_config=bnb_config,
|
| 55 |
+
device_map="auto",
|
| 56 |
+
torch_dtype=torch.bfloat16,
|
| 57 |
+
trust_remote_code=True,
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 61 |
+
base_model,
|
| 62 |
+
device_map="auto",
|
| 63 |
+
torch_dtype=torch.bfloat16,
|
| 64 |
+
trust_remote_code=True,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 68 |
+
|
| 69 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 70 |
+
if tokenizer.pad_token is None:
|
| 71 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 72 |
+
|
| 73 |
+
return model, tokenizer
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def generate_response(
|
| 77 |
+
model,
|
| 78 |
+
tokenizer,
|
| 79 |
+
prompt: str,
|
| 80 |
+
max_new_tokens: int = 512,
|
| 81 |
+
temperature: float = 0.1,
|
| 82 |
+
) -> str:
|
| 83 |
+
"""Generate a response from the model."""
|
| 84 |
+
formatted = f"### Instruction:\n{prompt}\n\n### Response:\n"
|
| 85 |
+
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
outputs = model.generate(
|
| 89 |
+
**inputs,
|
| 90 |
+
max_new_tokens=max_new_tokens,
|
| 91 |
+
temperature=temperature,
|
| 92 |
+
top_p=0.9,
|
| 93 |
+
do_sample=temperature > 0,
|
| 94 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 98 |
+
if "### Response:" in response:
|
| 99 |
+
response = response.split("### Response:")[-1].strip()
|
| 100 |
+
return response
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def evaluate_with_verifiers(
|
| 104 |
+
model,
|
| 105 |
+
tokenizer,
|
| 106 |
+
eval_dataset,
|
| 107 |
+
composer: VerifierComposer,
|
| 108 |
+
max_samples: Optional[int] = None,
|
| 109 |
+
) -> Dict:
|
| 110 |
+
"""Evaluate model using the verifier stack.
|
| 111 |
+
|
| 112 |
+
Returns per-sample results and aggregated metrics.
|
| 113 |
+
"""
|
| 114 |
+
results = []
|
| 115 |
+
n = len(eval_dataset)
|
| 116 |
+
if max_samples:
|
| 117 |
+
n = min(n, max_samples)
|
| 118 |
+
|
| 119 |
+
for i in tqdm(range(n), desc="Evaluating"):
|
| 120 |
+
sample = eval_dataset[i]
|
| 121 |
+
prompt = sample["prompt"]
|
| 122 |
+
gt = sample["ground_truth"]
|
| 123 |
+
qtype = sample["question_type"]
|
| 124 |
+
applicable = sample["applicable_verifiers"]
|
| 125 |
+
|
| 126 |
+
response = generate_response(model, tokenizer, prompt)
|
| 127 |
+
|
| 128 |
+
reward = composer.compute_reward(
|
| 129 |
+
prompt=prompt,
|
| 130 |
+
completion=response,
|
| 131 |
+
ground_truth=gt,
|
| 132 |
+
question_type=qtype,
|
| 133 |
+
applicable_verifiers=applicable,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Extract confidence for calibration
|
| 137 |
+
conf = _extract_confidence_simple(response)
|
| 138 |
+
|
| 139 |
+
results.append({
|
| 140 |
+
"prompt": prompt[:100],
|
| 141 |
+
"response": response[:300],
|
| 142 |
+
"total_reward": reward.total_reward,
|
| 143 |
+
"verifier_scores": reward.verifier_scores,
|
| 144 |
+
"question_type": qtype,
|
| 145 |
+
"source": sample.get("source", "unknown"),
|
| 146 |
+
"tissue": sample.get("tissue", "unknown"),
|
| 147 |
+
"confidence": conf.numeric,
|
| 148 |
+
"confidence_stated": conf.stated,
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
# Aggregate metrics
|
| 152 |
+
total_rewards = [r["total_reward"] for r in results]
|
| 153 |
+
per_verifier: Dict[str, List[float]] = {}
|
| 154 |
+
for r in results:
|
| 155 |
+
for v, s in r["verifier_scores"].items():
|
| 156 |
+
per_verifier.setdefault(v, []).append(s)
|
| 157 |
+
|
| 158 |
+
verifier_means = {v: sum(s) / len(s) for v, s in per_verifier.items()}
|
| 159 |
+
|
| 160 |
+
# Per question type
|
| 161 |
+
by_type: Dict[str, List[float]] = {}
|
| 162 |
+
for r in results:
|
| 163 |
+
by_type.setdefault(r["question_type"], []).append(r["total_reward"])
|
| 164 |
+
type_means = {t: sum(s) / len(s) for t, s in by_type.items()}
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"n_samples": len(results),
|
| 168 |
+
"mean_reward": sum(total_rewards) / len(total_rewards) if total_rewards else 0,
|
| 169 |
+
"verifier_means": verifier_means,
|
| 170 |
+
"by_question_type": type_means,
|
| 171 |
+
"per_sample": results,
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def evaluate_calibration(results: List[Dict]) -> Dict:
|
| 176 |
+
"""Compute calibration metrics from evaluation results."""
|
| 177 |
+
confidences = [r["confidence"] for r in results]
|
| 178 |
+
|
| 179 |
+
# Correctness: reward > 0.5 considered "correct"
|
| 180 |
+
correctnesses = [r["total_reward"] > 0.5 for r in results]
|
| 181 |
+
|
| 182 |
+
metrics = compute_calibration_metrics(
|
| 183 |
+
confidences=confidences,
|
| 184 |
+
correctnesses=correctnesses,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"ece": metrics.ece,
|
| 189 |
+
"mce": metrics.mce,
|
| 190 |
+
"brier_score": metrics.brier_score,
|
| 191 |
+
"overconfidence_rate": metrics.overconfidence_rate,
|
| 192 |
+
"underconfidence_rate": metrics.underconfidence_rate,
|
| 193 |
+
"mean_confidence": metrics.mean_confidence,
|
| 194 |
+
"mean_accuracy": metrics.mean_accuracy,
|
| 195 |
+
"n_samples": metrics.n_samples,
|
| 196 |
+
"reliability_bins": metrics.reliability_bins,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def main():
|
| 201 |
+
parser = argparse.ArgumentParser(
|
| 202 |
+
description="Evaluate a BioGRPO-trained model",
|
| 203 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--model", type=str, required=True,
|
| 207 |
+
help="Path to the GRPO-trained model (LoRA adapter directory)",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--base-model", type=str, default="mistralai/Mistral-7B-v0.3",
|
| 211 |
+
help="Base model name",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--sft-baseline", type=str, default=None,
|
| 215 |
+
help="Path to SFT baseline model for comparison",
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--hold-out-tissues", type=str, nargs="+", default=["eye"],
|
| 219 |
+
help="Tissues held out for evaluation",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--pathway-db", type=str, default="hallmark",
|
| 223 |
+
help="Pathway database",
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--active-verifiers", type=str, nargs="+", default=None,
|
| 227 |
+
help="Active verifiers (default: all)",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--max-samples", type=int, default=None,
|
| 231 |
+
help="Max samples to evaluate (for quick testing)",
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--output", type=str, default=None,
|
| 235 |
+
help="Output path for results JSON",
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--no-4bit", action="store_true",
|
| 239 |
+
help="Disable 4-bit quantization",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
args = parser.parse_args()
|
| 243 |
+
|
| 244 |
+
print("=" * 60)
|
| 245 |
+
print("BioGRPO Evaluation")
|
| 246 |
+
print("=" * 60)
|
| 247 |
+
print(f" Model: {args.model}")
|
| 248 |
+
print(f" Base: {args.base_model}")
|
| 249 |
+
print(f" Hold-out: {args.hold_out_tissues}")
|
| 250 |
+
print(f" SFT baseline: {args.sft_baseline or 'None'}")
|
| 251 |
+
print(f" Time: {datetime.now().isoformat()}")
|
| 252 |
+
print("=" * 60)
|
| 253 |
+
|
| 254 |
+
# Build eval dataset
|
| 255 |
+
print("\n[1/4] Building evaluation dataset...")
|
| 256 |
+
_, eval_dataset = build_grpo_dataset(
|
| 257 |
+
db=args.pathway_db,
|
| 258 |
+
hold_out_tissues=args.hold_out_tissues,
|
| 259 |
+
)
|
| 260 |
+
eval_stats = get_dataset_stats(eval_dataset)
|
| 261 |
+
print(f" Eval samples: {eval_stats['total']}")
|
| 262 |
+
print(f" By source: {eval_stats['by_source']}")
|
| 263 |
+
print(f" By type: {eval_stats['by_question_type']}")
|
| 264 |
+
|
| 265 |
+
# Create verifier composer
|
| 266 |
+
composer = VerifierComposer(active_verifiers=args.active_verifiers)
|
| 267 |
+
|
| 268 |
+
# Evaluate GRPO model
|
| 269 |
+
print(f"\n[2/4] Evaluating GRPO model: {args.model}")
|
| 270 |
+
model, tokenizer = load_model(
|
| 271 |
+
args.model, args.base_model, use_4bit=not args.no_4bit,
|
| 272 |
+
)
|
| 273 |
+
grpo_results = evaluate_with_verifiers(
|
| 274 |
+
model, tokenizer, eval_dataset, composer,
|
| 275 |
+
max_samples=args.max_samples,
|
| 276 |
+
)
|
| 277 |
+
grpo_calibration = evaluate_calibration(grpo_results["per_sample"])
|
| 278 |
+
|
| 279 |
+
# Free GPU memory
|
| 280 |
+
del model
|
| 281 |
+
torch.cuda.empty_cache()
|
| 282 |
+
|
| 283 |
+
# Evaluate baseline if provided
|
| 284 |
+
baseline_results = None
|
| 285 |
+
baseline_calibration = None
|
| 286 |
+
if args.sft_baseline:
|
| 287 |
+
print(f"\n[3/4] Evaluating SFT baseline: {args.sft_baseline}")
|
| 288 |
+
baseline_model, baseline_tokenizer = load_model(
|
| 289 |
+
args.sft_baseline, args.base_model, use_4bit=not args.no_4bit,
|
| 290 |
+
)
|
| 291 |
+
baseline_results = evaluate_with_verifiers(
|
| 292 |
+
baseline_model, baseline_tokenizer, eval_dataset, composer,
|
| 293 |
+
max_samples=args.max_samples,
|
| 294 |
+
)
|
| 295 |
+
baseline_calibration = evaluate_calibration(baseline_results["per_sample"])
|
| 296 |
+
del baseline_model
|
| 297 |
+
torch.cuda.empty_cache()
|
| 298 |
+
else:
|
| 299 |
+
print("\n[3/4] Skipping baseline (not provided)")
|
| 300 |
+
|
| 301 |
+
# Print summary
|
| 302 |
+
print("\n[4/4] Results Summary")
|
| 303 |
+
print("=" * 60)
|
| 304 |
+
print(f"GRPO Model: {args.model}")
|
| 305 |
+
print(f" Mean reward: {grpo_results['mean_reward']:.3f}")
|
| 306 |
+
print(f" Per verifier: {grpo_results['verifier_means']}")
|
| 307 |
+
print(f" ECE: {grpo_calibration['ece']:.3f}")
|
| 308 |
+
print(f" Brier: {grpo_calibration['brier_score']:.3f}")
|
| 309 |
+
print(f" Overconfidence: {grpo_calibration['overconfidence_rate']:.3f}")
|
| 310 |
+
print(f" By type: {grpo_results['by_question_type']}")
|
| 311 |
+
|
| 312 |
+
comparison = {}
|
| 313 |
+
if baseline_results:
|
| 314 |
+
print(f"\nSFT Baseline: {args.sft_baseline}")
|
| 315 |
+
print(f" Mean reward: {baseline_results['mean_reward']:.3f}")
|
| 316 |
+
print(f" ECE: {baseline_calibration['ece']:.3f}")
|
| 317 |
+
print(f" Brier: {baseline_calibration['brier_score']:.3f}")
|
| 318 |
+
|
| 319 |
+
delta_reward = grpo_results["mean_reward"] - baseline_results["mean_reward"]
|
| 320 |
+
delta_ece = grpo_calibration["ece"] - baseline_calibration["ece"]
|
| 321 |
+
print(f"\n Delta reward: {delta_reward:+.3f}")
|
| 322 |
+
print(f" Delta ECE: {delta_ece:+.3f} (negative = better)")
|
| 323 |
+
|
| 324 |
+
comparison = {
|
| 325 |
+
"sft_mean_reward": baseline_results["mean_reward"],
|
| 326 |
+
"sft_ece": baseline_calibration["ece"],
|
| 327 |
+
"delta_reward": delta_reward,
|
| 328 |
+
"delta_ece": delta_ece,
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
# Success criteria
|
| 332 |
+
criteria = {
|
| 333 |
+
"reward_above_05": grpo_results["mean_reward"] > 0.5,
|
| 334 |
+
"ece_below_015": grpo_calibration["ece"] < 0.15,
|
| 335 |
+
}
|
| 336 |
+
if baseline_results:
|
| 337 |
+
criteria["reward_above_baseline"] = delta_reward > 0
|
| 338 |
+
criteria["overall_pass"] = all(criteria.values())
|
| 339 |
+
|
| 340 |
+
print(f"\nSuccess criteria: {criteria}")
|
| 341 |
+
|
| 342 |
+
# Save results
|
| 343 |
+
output_path = args.output or f"results/grpo_eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 344 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 345 |
+
|
| 346 |
+
output_data = {
|
| 347 |
+
"model_path": args.model,
|
| 348 |
+
"base_model": args.base_model,
|
| 349 |
+
"evaluation_date": datetime.now().isoformat(),
|
| 350 |
+
"hold_out_tissues": args.hold_out_tissues,
|
| 351 |
+
"eval_dataset_stats": eval_stats,
|
| 352 |
+
"grpo": {
|
| 353 |
+
"mean_reward": grpo_results["mean_reward"],
|
| 354 |
+
"verifier_means": grpo_results["verifier_means"],
|
| 355 |
+
"by_question_type": grpo_results["by_question_type"],
|
| 356 |
+
"n_samples": grpo_results["n_samples"],
|
| 357 |
+
},
|
| 358 |
+
"calibration": grpo_calibration,
|
| 359 |
+
"baseline_comparison": comparison,
|
| 360 |
+
"success_criteria": criteria,
|
| 361 |
+
"per_sample": grpo_results["per_sample"],
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
with open(output_path, "w") as f:
|
| 365 |
+
json.dump(output_data, f, indent=2)
|
| 366 |
+
|
| 367 |
+
print(f"\nResults saved to: {output_path}")
|
| 368 |
+
print("=" * 60)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
main()
|
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Merge BioRLHF training data with ecosystem failure examples.
|
| 4 |
+
|
| 5 |
+
This script:
|
| 6 |
+
1. Loads existing kmp_sft_final.json training data
|
| 7 |
+
2. Loads ecosystem_failures_training.json (failure-based examples)
|
| 8 |
+
3. Converts failure examples to the same format
|
| 9 |
+
4. Outputs combined_training.json
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python scripts/merge_training_data.py
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_json(filepath: str) -> dict | list:
|
| 21 |
+
"""Load JSON file."""
|
| 22 |
+
with open(filepath, 'r') as f:
|
| 23 |
+
return json.load(f)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def save_json(data: list, filepath: str):
|
| 27 |
+
"""Save JSON file."""
|
| 28 |
+
with open(filepath, 'w') as f:
|
| 29 |
+
json.dump(data, f, indent=2)
|
| 30 |
+
print(f"Saved {len(data)} examples to {filepath}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def convert_conversation_to_text(conversation: list) -> str:
|
| 34 |
+
"""
|
| 35 |
+
Convert conversation format to text format.
|
| 36 |
+
|
| 37 |
+
Input: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
| 38 |
+
Output: "### Instruction:\n...\n\n### Response:\n..."
|
| 39 |
+
"""
|
| 40 |
+
instruction = ""
|
| 41 |
+
response = ""
|
| 42 |
+
|
| 43 |
+
for turn in conversation:
|
| 44 |
+
if turn["role"] == "user":
|
| 45 |
+
instruction = turn["content"]
|
| 46 |
+
elif turn["role"] == "assistant":
|
| 47 |
+
response = turn["content"]
|
| 48 |
+
|
| 49 |
+
return f"### Instruction:\n{instruction}\n\n### Response:\n{response}"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def extract_examples_from_failures(failure_data: dict) -> list:
|
| 53 |
+
"""
|
| 54 |
+
Extract and convert all examples from failure training data.
|
| 55 |
+
"""
|
| 56 |
+
examples = []
|
| 57 |
+
|
| 58 |
+
# Process calibration examples
|
| 59 |
+
for ex in failure_data.get("calibration_examples", []):
|
| 60 |
+
text = convert_conversation_to_text(ex["conversations"])
|
| 61 |
+
examples.append({
|
| 62 |
+
"text": text,
|
| 63 |
+
"source": f"ecosystem_failures:{ex['type']}",
|
| 64 |
+
"id": ex["id"]
|
| 65 |
+
})
|
| 66 |
+
|
| 67 |
+
# Process adversarial resistance examples
|
| 68 |
+
for ex in failure_data.get("adversarial_resistance_examples", []):
|
| 69 |
+
text = convert_conversation_to_text(ex["conversations"])
|
| 70 |
+
examples.append({
|
| 71 |
+
"text": text,
|
| 72 |
+
"source": f"ecosystem_failures:{ex['type']}",
|
| 73 |
+
"id": ex["id"]
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
# Process completeness examples
|
| 77 |
+
for ex in failure_data.get("completeness_examples", []):
|
| 78 |
+
text = convert_conversation_to_text(ex["conversations"])
|
| 79 |
+
examples.append({
|
| 80 |
+
"text": text,
|
| 81 |
+
"source": f"ecosystem_failures:{ex['type']}",
|
| 82 |
+
"id": ex["id"]
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
# Process fact drilling examples
|
| 86 |
+
for ex in failure_data.get("fact_drilling_examples", []):
|
| 87 |
+
text = convert_conversation_to_text(ex["conversations"])
|
| 88 |
+
examples.append({
|
| 89 |
+
"text": text,
|
| 90 |
+
"source": f"ecosystem_failures:{ex['type']}",
|
| 91 |
+
"id": ex["id"]
|
| 92 |
+
})
|
| 93 |
+
|
| 94 |
+
return examples
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
# Paths
|
| 99 |
+
data_dir = Path(__file__).parent.parent / "data"
|
| 100 |
+
existing_path = data_dir / "kmp_sft_final.json"
|
| 101 |
+
failures_path = data_dir / "ecosystem_failures_training.json"
|
| 102 |
+
output_path = data_dir / "combined_training.json"
|
| 103 |
+
|
| 104 |
+
print("=" * 60)
|
| 105 |
+
print("BioRLHF Training Data Merger")
|
| 106 |
+
print("=" * 60)
|
| 107 |
+
|
| 108 |
+
# Load existing data
|
| 109 |
+
print(f"\n📂 Loading existing data from {existing_path}")
|
| 110 |
+
existing_data = load_json(existing_path)
|
| 111 |
+
print(f" Found {len(existing_data)} existing examples")
|
| 112 |
+
|
| 113 |
+
# Load failure-based examples
|
| 114 |
+
print(f"\n📂 Loading failure examples from {failures_path}")
|
| 115 |
+
failure_data = load_json(failures_path)
|
| 116 |
+
|
| 117 |
+
# Convert failure examples
|
| 118 |
+
print("\n🔄 Converting failure examples to training format...")
|
| 119 |
+
new_examples = extract_examples_from_failures(failure_data)
|
| 120 |
+
print(f" Converted {len(new_examples)} examples")
|
| 121 |
+
|
| 122 |
+
# Show breakdown
|
| 123 |
+
print("\n📊 New examples by type:")
|
| 124 |
+
type_counts = {}
|
| 125 |
+
for ex in new_examples:
|
| 126 |
+
source_type = ex["source"].split(":")[1] if ":" in ex["source"] else ex["source"]
|
| 127 |
+
type_counts[source_type] = type_counts.get(source_type, 0) + 1
|
| 128 |
+
for t, c in sorted(type_counts.items()):
|
| 129 |
+
print(f" - {t}: {c}")
|
| 130 |
+
|
| 131 |
+
# Combine data
|
| 132 |
+
print("\n🔀 Merging datasets...")
|
| 133 |
+
|
| 134 |
+
# Add source field to existing data if not present
|
| 135 |
+
for ex in existing_data:
|
| 136 |
+
if "source" not in ex:
|
| 137 |
+
ex["source"] = "kmp_sft_original"
|
| 138 |
+
|
| 139 |
+
# Combine
|
| 140 |
+
combined = existing_data + new_examples
|
| 141 |
+
print(f" Total examples: {len(combined)}")
|
| 142 |
+
|
| 143 |
+
# Save combined data
|
| 144 |
+
print(f"\n💾 Saving to {output_path}")
|
| 145 |
+
save_json(combined, output_path)
|
| 146 |
+
|
| 147 |
+
# Summary
|
| 148 |
+
print("\n" + "=" * 60)
|
| 149 |
+
print("✅ MERGE COMPLETE")
|
| 150 |
+
print("=" * 60)
|
| 151 |
+
print(f" Original examples: {len(existing_data)}")
|
| 152 |
+
print(f" New examples: {len(new_examples)}")
|
| 153 |
+
print(f" Total combined: {len(combined)}")
|
| 154 |
+
print(f"\n Output: {output_path}")
|
| 155 |
+
print("\nNext step: Run training with combined data:")
|
| 156 |
+
print(" python sft_train_v2.py --dataset data/combined_training.json")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=eval_grpo
|
| 3 |
+
#SBATCH --partition=scu-gpu
|
| 4 |
+
#SBATCH --account=cayuga_0003
|
| 5 |
+
#SBATCH --gres=gpu:1
|
| 6 |
+
#SBATCH --mem=64G
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --time=4:00:00
|
| 9 |
+
#SBATCH --output=logs/eval_grpo_%j.log
|
| 10 |
+
#SBATCH --error=logs/eval_grpo_%j.err
|
| 11 |
+
|
| 12 |
+
# ============================================================
|
| 13 |
+
# BioGRPO Post-Training Evaluation
|
| 14 |
+
# Evaluates GRPO model + SFT baseline comparison
|
| 15 |
+
# ============================================================
|
| 16 |
+
|
| 17 |
+
SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
|
| 18 |
+
WORKDIR="${SCRATCH}/training/BioRLHF"
|
| 19 |
+
|
| 20 |
+
echo "============================================================"
|
| 21 |
+
echo "BioGRPO Evaluation"
|
| 22 |
+
echo "Job ID: $SLURM_JOB_ID"
|
| 23 |
+
echo "Node: $SLURMD_NODENAME"
|
| 24 |
+
echo "Working dir: $WORKDIR"
|
| 25 |
+
echo "Start time: $(date)"
|
| 26 |
+
echo "============================================================"
|
| 27 |
+
|
| 28 |
+
cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; }
|
| 29 |
+
mkdir -p logs results
|
| 30 |
+
|
| 31 |
+
module purge
|
| 32 |
+
module load cuda/12.1
|
| 33 |
+
|
| 34 |
+
. /home/fs01/jak4013/miniconda3/miniconda3/etc/profile.d/conda.sh
|
| 35 |
+
conda activate biorlhf
|
| 36 |
+
|
| 37 |
+
echo ""
|
| 38 |
+
echo "GPU Information:"
|
| 39 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
|
| 40 |
+
echo ""
|
| 41 |
+
|
| 42 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 43 |
+
export TRANSFORMERS_CACHE="${WORKDIR}/cache/transformers"
|
| 44 |
+
export HF_HOME="${WORKDIR}/cache/huggingface"
|
| 45 |
+
export TOKENIZERS_PARALLELISM=false
|
| 46 |
+
|
| 47 |
+
# Data paths
|
| 48 |
+
export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark"
|
| 49 |
+
export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data"
|
| 50 |
+
export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm"
|
| 51 |
+
export BIOEVAL_ROOT="${SCRATCH}/data/BioEval"
|
| 52 |
+
|
| 53 |
+
# Model paths
|
| 54 |
+
GRPO_MODEL="./biogrpo_mve_model"
|
| 55 |
+
SFT_BASELINE="./kmp_sft_model_final"
|
| 56 |
+
OUTPUT="results/grpo_mve_eval_$(date +%Y%m%d_%H%M%S).json"
|
| 57 |
+
|
| 58 |
+
echo "GRPO model: $GRPO_MODEL"
|
| 59 |
+
echo "SFT baseline: $SFT_BASELINE"
|
| 60 |
+
echo "Output: $OUTPUT"
|
| 61 |
+
echo ""
|
| 62 |
+
|
| 63 |
+
# Check model exists
|
| 64 |
+
if [ ! -d "$GRPO_MODEL" ]; then
|
| 65 |
+
echo "ERROR: GRPO model not found at $GRPO_MODEL"
|
| 66 |
+
echo "Available directories:"
|
| 67 |
+
ls -d biogrpo_* 2>/dev/null || echo " No biogrpo_* dirs found"
|
| 68 |
+
exit 1
|
| 69 |
+
fi
|
| 70 |
+
|
| 71 |
+
echo "Starting BioGRPO evaluation..."
|
| 72 |
+
python scripts/evaluate_grpo.py \
|
| 73 |
+
--model "$GRPO_MODEL" \
|
| 74 |
+
--sft-baseline "$SFT_BASELINE" \
|
| 75 |
+
--hold-out-tissues eye \
|
| 76 |
+
--output "$OUTPUT"
|
| 77 |
+
|
| 78 |
+
if [ $? -eq 0 ]; then
|
| 79 |
+
echo ""
|
| 80 |
+
echo "============================================================"
|
| 81 |
+
echo "BioGRPO evaluation completed!"
|
| 82 |
+
echo "Results: $OUTPUT"
|
| 83 |
+
echo "End time: $(date)"
|
| 84 |
+
echo "============================================================"
|
| 85 |
+
else
|
| 86 |
+
echo ""
|
| 87 |
+
echo "============================================================"
|
| 88 |
+
echo "BioGRPO evaluation failed with exit code $?"
|
| 89 |
+
echo "Check logs/eval_grpo_${SLURM_JOB_ID}.err for details"
|
| 90 |
+
echo "============================================================"
|
| 91 |
+
exit 1
|
| 92 |
+
fi
|
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#
|
| 3 |
+
# BioRLHF Model Evaluation Script
|
| 4 |
+
# ================================
|
| 5 |
+
#
|
| 6 |
+
# Evaluates the ecosystem-improved model on:
|
| 7 |
+
# - Calibration (uncertainty expression)
|
| 8 |
+
# - Adversarial resistance
|
| 9 |
+
# - Protocol completeness
|
| 10 |
+
# - Fact recall
|
| 11 |
+
#
|
| 12 |
+
# Usage on HPC:
|
| 13 |
+
# srun -p scu-gpu --gres=gpu:a100:1 --mem=48G -c 8 --time=1:00:00 --pty bash
|
| 14 |
+
# conda activate biorlhf
|
| 15 |
+
# ./scripts/run_evaluation.sh
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
echo "============================================================"
|
| 19 |
+
echo "BioRLHF Ecosystem Model Evaluation"
|
| 20 |
+
echo "============================================================"
|
| 21 |
+
echo "Start time: $(date)"
|
| 22 |
+
echo "Host: $(hostname)"
|
| 23 |
+
echo ""
|
| 24 |
+
|
| 25 |
+
# Set working directory
|
| 26 |
+
cd "$(dirname "$0")/.." || exit 1
|
| 27 |
+
echo "Working directory: $(pwd)"
|
| 28 |
+
|
| 29 |
+
# Check GPU
|
| 30 |
+
echo ""
|
| 31 |
+
echo "GPU Information:"
|
| 32 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv 2>/dev/null || echo "No GPU detected"
|
| 33 |
+
echo ""
|
| 34 |
+
|
| 35 |
+
# Configuration
|
| 36 |
+
MODEL_PATH="./ecosystem_improved_model"
|
| 37 |
+
TEST_DATA="data/ecosystem_failures_training.json"
|
| 38 |
+
OUTPUT="ecosystem_eval_results_$(date +%Y%m%d_%H%M%S).json"
|
| 39 |
+
|
| 40 |
+
echo "============================================================"
|
| 41 |
+
echo "Configuration:"
|
| 42 |
+
echo "============================================================"
|
| 43 |
+
echo "Model: $MODEL_PATH"
|
| 44 |
+
echo "Test data: $TEST_DATA"
|
| 45 |
+
echo "Output: $OUTPUT"
|
| 46 |
+
echo ""
|
| 47 |
+
|
| 48 |
+
# Check files exist
|
| 49 |
+
if [ ! -d "$MODEL_PATH" ]; then
|
| 50 |
+
echo "ERROR: Model not found at $MODEL_PATH"
|
| 51 |
+
exit 1
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
if [ ! -f "$TEST_DATA" ]; then
|
| 55 |
+
echo "ERROR: Test data not found at $TEST_DATA"
|
| 56 |
+
exit 1
|
| 57 |
+
fi
|
| 58 |
+
|
| 59 |
+
# Run evaluation
|
| 60 |
+
echo "============================================================"
|
| 61 |
+
echo "Starting Evaluation..."
|
| 62 |
+
echo "============================================================"
|
| 63 |
+
|
| 64 |
+
python3 scripts/evaluate_ecosystem_model.py \
|
| 65 |
+
--model "$MODEL_PATH" \
|
| 66 |
+
--test-data "$TEST_DATA" \
|
| 67 |
+
--output "$OUTPUT"
|
| 68 |
+
|
| 69 |
+
# Check exit status
|
| 70 |
+
if [ $? -eq 0 ]; then
|
| 71 |
+
echo ""
|
| 72 |
+
echo "============================================================"
|
| 73 |
+
echo "Evaluation Complete!"
|
| 74 |
+
echo "============================================================"
|
| 75 |
+
echo "Results saved to: $OUTPUT"
|
| 76 |
+
echo "End time: $(date)"
|
| 77 |
+
else
|
| 78 |
+
echo ""
|
| 79 |
+
echo "============================================================"
|
| 80 |
+
echo "Evaluation Failed!"
|
| 81 |
+
echo "============================================================"
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=biogrpo_full
|
| 3 |
+
#SBATCH --partition=scu-gpu
|
| 4 |
+
#SBATCH --account=cayuga_0003
|
| 5 |
+
#SBATCH --gres=gpu:1
|
| 6 |
+
#SBATCH --mem=96G
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --time=24:00:00
|
| 9 |
+
#SBATCH --output=logs/grpo_full_%j.log
|
| 10 |
+
#SBATCH --error=logs/grpo_full_%j.err
|
| 11 |
+
|
| 12 |
+
# ============================================================
|
| 13 |
+
# BioGRPO Full Experiment
|
| 14 |
+
# All V1-V4 verifiers, G=8, from SFT checkpoint
|
| 15 |
+
# ============================================================
|
| 16 |
+
|
| 17 |
+
SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
|
| 18 |
+
WORKDIR="${SCRATCH}/training/BioRLHF"
|
| 19 |
+
|
| 20 |
+
echo "============================================================"
|
| 21 |
+
echo "BioGRPO Full Training"
|
| 22 |
+
echo "Job ID: $SLURM_JOB_ID"
|
| 23 |
+
echo "Node: $SLURMD_NODENAME"
|
| 24 |
+
echo "Working dir: $WORKDIR"
|
| 25 |
+
echo "Start time: $(date)"
|
| 26 |
+
echo "============================================================"
|
| 27 |
+
|
| 28 |
+
cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; }
|
| 29 |
+
mkdir -p logs
|
| 30 |
+
|
| 31 |
+
module purge
|
| 32 |
+
module load cuda/12.1
|
| 33 |
+
|
| 34 |
+
. /home/fs01/jak4013/miniconda3/miniconda3/etc/profile.d/conda.sh
|
| 35 |
+
conda activate biorlhf
|
| 36 |
+
|
| 37 |
+
echo ""
|
| 38 |
+
echo "GPU Information:"
|
| 39 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
|
| 40 |
+
echo ""
|
| 41 |
+
|
| 42 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 43 |
+
export TRANSFORMERS_CACHE="${WORKDIR}/cache/transformers"
|
| 44 |
+
export HF_HOME="${WORKDIR}/cache/huggingface"
|
| 45 |
+
export WANDB_DIR="${WORKDIR}/wandb"
|
| 46 |
+
export TOKENIZERS_PARALLELISM=false
|
| 47 |
+
|
| 48 |
+
# Data paths
|
| 49 |
+
export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark"
|
| 50 |
+
export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data"
|
| 51 |
+
export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm"
|
| 52 |
+
export BIOEVAL_ROOT="${SCRATCH}/data/BioEval"
|
| 53 |
+
|
| 54 |
+
mkdir -p $TRANSFORMERS_CACHE $HF_HOME $WANDB_DIR
|
| 55 |
+
|
| 56 |
+
# Symlink SFT checkpoint if not already present
|
| 57 |
+
if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then
|
| 58 |
+
ln -s "${SCRATCH}/training/biorlhf/kmp_sft_model_final" "${WORKDIR}/kmp_sft_model_final"
|
| 59 |
+
echo "Symlinked kmp_sft_model_final"
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
echo "Starting BioGRPO Full training..."
|
| 63 |
+
biorlhf-grpo --config configs/grpo_full.json
|
| 64 |
+
|
| 65 |
+
if [ $? -eq 0 ]; then
|
| 66 |
+
echo ""
|
| 67 |
+
echo "============================================================"
|
| 68 |
+
echo "BioGRPO Full training completed!"
|
| 69 |
+
echo "Model saved to: ./biogrpo_full_model"
|
| 70 |
+
echo "End time: $(date)"
|
| 71 |
+
echo "============================================================"
|
| 72 |
+
else
|
| 73 |
+
echo ""
|
| 74 |
+
echo "============================================================"
|
| 75 |
+
echo "BioGRPO Full training failed with exit code $?"
|
| 76 |
+
echo "Check logs/grpo_full_${SLURM_JOB_ID}.err for details"
|
| 77 |
+
echo "============================================================"
|
| 78 |
+
exit 1
|
| 79 |
+
fi
|
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=biogrpo_mve
|
| 3 |
+
#SBATCH --partition=scu-gpu
|
| 4 |
+
#SBATCH --account=cayuga_0003
|
| 5 |
+
#SBATCH --gres=gpu:1
|
| 6 |
+
#SBATCH --mem=64G
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --time=48:00:00
|
| 9 |
+
#SBATCH --output=logs/grpo_mve_%j.log
|
| 10 |
+
#SBATCH --error=logs/grpo_mve_%j.err
|
| 11 |
+
|
| 12 |
+
# ============================================================
|
| 13 |
+
# BioGRPO Minimum Viable Experiment (MVE)
|
| 14 |
+
# V1+V4 verifiers, G=4, from SFT checkpoint
|
| 15 |
+
# ============================================================
|
| 16 |
+
|
| 17 |
+
SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
|
| 18 |
+
WORKDIR="${SCRATCH}/training/BioRLHF"
|
| 19 |
+
|
| 20 |
+
echo "============================================================"
|
| 21 |
+
echo "BioGRPO MVE Training"
|
| 22 |
+
echo "Job ID: $SLURM_JOB_ID"
|
| 23 |
+
echo "Node: $SLURMD_NODENAME"
|
| 24 |
+
echo "Working dir: $WORKDIR"
|
| 25 |
+
echo "Start time: $(date)"
|
| 26 |
+
echo "============================================================"
|
| 27 |
+
|
| 28 |
+
cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; }
|
| 29 |
+
mkdir -p logs
|
| 30 |
+
|
| 31 |
+
# Load modules
|
| 32 |
+
module purge
|
| 33 |
+
module load cuda/12.1
|
| 34 |
+
|
| 35 |
+
# Activate conda environment
|
| 36 |
+
. /home/fs01/jak4013/miniconda3/miniconda3/etc/profile.d/conda.sh
|
| 37 |
+
conda activate biorlhf
|
| 38 |
+
|
| 39 |
+
# Verify GPU
|
| 40 |
+
echo ""
|
| 41 |
+
echo "GPU Information:"
|
| 42 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
|
| 43 |
+
echo ""
|
| 44 |
+
|
| 45 |
+
# Set environment variables
|
| 46 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 47 |
+
export TRANSFORMERS_CACHE="${WORKDIR}/cache/transformers"
|
| 48 |
+
export HF_HOME="${WORKDIR}/cache/huggingface"
|
| 49 |
+
export WANDB_DIR="${WORKDIR}/wandb"
|
| 50 |
+
export TOKENIZERS_PARALLELISM=false
|
| 51 |
+
|
| 52 |
+
# Data paths
|
| 53 |
+
export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark"
|
| 54 |
+
export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data"
|
| 55 |
+
export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm"
|
| 56 |
+
export BIOEVAL_ROOT="${SCRATCH}/data/BioEval"
|
| 57 |
+
|
| 58 |
+
mkdir -p $TRANSFORMERS_CACHE $HF_HOME $WANDB_DIR
|
| 59 |
+
|
| 60 |
+
# Symlink SFT checkpoint if not already present
|
| 61 |
+
if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then
|
| 62 |
+
ln -s "${SCRATCH}/training/biorlhf/kmp_sft_model_final" "${WORKDIR}/kmp_sft_model_final"
|
| 63 |
+
echo "Symlinked kmp_sft_model_final"
|
| 64 |
+
fi
|
| 65 |
+
|
| 66 |
+
# Run GRPO MVE training
|
| 67 |
+
echo "Starting BioGRPO MVE training..."
|
| 68 |
+
biorlhf-grpo --config configs/grpo_mve.json
|
| 69 |
+
|
| 70 |
+
if [ $? -eq 0 ]; then
|
| 71 |
+
echo ""
|
| 72 |
+
echo "============================================================"
|
| 73 |
+
echo "BioGRPO MVE training completed!"
|
| 74 |
+
echo "Model saved to: ./biogrpo_mve_model"
|
| 75 |
+
echo "End time: $(date)"
|
| 76 |
+
echo "============================================================"
|
| 77 |
+
else
|
| 78 |
+
echo ""
|
| 79 |
+
echo "============================================================"
|
| 80 |
+
echo "BioGRPO MVE training failed with exit code $?"
|
| 81 |
+
echo "Check logs/grpo_mve_${SLURM_JOB_ID}.err for details"
|
| 82 |
+
echo "============================================================"
|
| 83 |
+
exit 1
|
| 84 |
+
fi
|
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# ============================================================
|
| 3 |
+
# BioGRPO Environment Setup for Cayuga HPC
|
| 4 |
+
# Run once to verify/upgrade GRPO dependencies
|
| 5 |
+
# ============================================================
|
| 6 |
+
|
| 7 |
+
SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
|
| 8 |
+
WORKDIR="${SCRATCH}/training/BioRLHF"
|
| 9 |
+
|
| 10 |
+
echo "============================================================"
|
| 11 |
+
echo "BioGRPO Environment Setup"
|
| 12 |
+
echo "Working dir: $WORKDIR"
|
| 13 |
+
echo "============================================================"
|
| 14 |
+
|
| 15 |
+
cd "$WORKDIR" || { echo "WORKDIR not found: $WORKDIR"; exit 1; }
|
| 16 |
+
|
| 17 |
+
# Activate environment
|
| 18 |
+
source ~/.bashrc
|
| 19 |
+
conda activate biorlhf
|
| 20 |
+
|
| 21 |
+
# Step 1: Check current versions
|
| 22 |
+
echo ""
|
| 23 |
+
echo "[1/6] Current package versions..."
|
| 24 |
+
python -c "import trl; print(f' TRL: {trl.__version__}')"
|
| 25 |
+
python -c "import peft; print(f' PEFT: {peft.__version__}')"
|
| 26 |
+
python -c "import transformers; print(f' Transformers: {transformers.__version__}')"
|
| 27 |
+
python -c "import torch; print(f' PyTorch: {torch.__version__}'); print(f' CUDA: {torch.cuda.is_available()}')"
|
| 28 |
+
|
| 29 |
+
# Step 2: Upgrade TRL if needed
|
| 30 |
+
echo ""
|
| 31 |
+
echo "[2/6] Ensuring TRL >= 0.26.0..."
|
| 32 |
+
pip install "trl>=0.26.0" --upgrade --quiet
|
| 33 |
+
|
| 34 |
+
# Step 3: Verify GRPO imports
|
| 35 |
+
echo ""
|
| 36 |
+
echo "[3/6] Verifying GRPO imports..."
|
| 37 |
+
python -c "
|
| 38 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 39 |
+
print(' GRPOTrainer: OK')
|
| 40 |
+
print(' GRPOConfig: OK')
|
| 41 |
+
config = GRPOConfig(output_dir='/tmp/test', scale_rewards='group', loss_type='grpo')
|
| 42 |
+
print(f' scale_rewards={config.scale_rewards}, loss_type={config.loss_type}: OK')
|
| 43 |
+
"
|
| 44 |
+
|
| 45 |
+
# Step 4: Install biorlhf package
|
| 46 |
+
echo ""
|
| 47 |
+
echo "[4/6] Installing biorlhf package..."
|
| 48 |
+
pip install -e . --quiet 2>/dev/null || pip install -e . 2>&1 | tail -3
|
| 49 |
+
|
| 50 |
+
# Step 5: Verify biorlhf imports
|
| 51 |
+
echo ""
|
| 52 |
+
echo "[5/6] Verifying biorlhf imports..."
|
| 53 |
+
python -c "
|
| 54 |
+
from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training
|
| 55 |
+
print(' BioGRPOConfig: OK')
|
| 56 |
+
from biorlhf.verifiers.composer import make_grpo_reward_function
|
| 57 |
+
print(' make_grpo_reward_function: OK')
|
| 58 |
+
from biorlhf.data.grpo_dataset import build_grpo_dataset
|
| 59 |
+
print(' build_grpo_dataset: OK')
|
| 60 |
+
from biorlhf.evaluation.calibration import compute_calibration_metrics
|
| 61 |
+
print(' compute_calibration_metrics: OK')
|
| 62 |
+
"
|
| 63 |
+
|
| 64 |
+
# Step 6: Smoke test
|
| 65 |
+
echo ""
|
| 66 |
+
echo "[6/6] Running smoke test..."
|
| 67 |
+
python -c "
|
| 68 |
+
from biorlhf.verifiers.composer import make_grpo_reward_function
|
| 69 |
+
import json
|
| 70 |
+
reward_fn = make_grpo_reward_function(active_verifiers=['V1', 'V4'])
|
| 71 |
+
rewards = reward_fn(
|
| 72 |
+
completions=['Oxidative phosphorylation is upregulated. Confidence: high.'],
|
| 73 |
+
ground_truth=[json.dumps({
|
| 74 |
+
'pathway': 'HALLMARK_OXIDATIVE_PHOSPHORYLATION',
|
| 75 |
+
'direction': 'UP',
|
| 76 |
+
'expected_confidence': 'high',
|
| 77 |
+
})],
|
| 78 |
+
question_type=['direction'],
|
| 79 |
+
applicable_verifiers=[json.dumps(['V1', 'V4'])],
|
| 80 |
+
)
|
| 81 |
+
print(f' Reward: {rewards[0]:.3f} (expected > 0.5)')
|
| 82 |
+
assert rewards[0] > 0.3, 'Reward too low'
|
| 83 |
+
print(' Smoke test: PASSED')
|
| 84 |
+
"
|
| 85 |
+
|
| 86 |
+
# Create directories
|
| 87 |
+
mkdir -p logs configs results cache/transformers cache/huggingface wandb
|
| 88 |
+
|
| 89 |
+
# Step 6b: Symlink SFT checkpoint
|
| 90 |
+
echo ""
|
| 91 |
+
echo "[6b/7] Setting up SFT checkpoint symlink..."
|
| 92 |
+
if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then
|
| 93 |
+
if [ -d "${SCRATCH}/training/biorlhf/kmp_sft_model_final" ]; then
|
| 94 |
+
ln -s "${SCRATCH}/training/biorlhf/kmp_sft_model_final" "${WORKDIR}/kmp_sft_model_final"
|
| 95 |
+
echo " Symlinked kmp_sft_model_final: OK"
|
| 96 |
+
else
|
| 97 |
+
echo " WARNING: kmp_sft_model_final not found at ${SCRATCH}/training/biorlhf/"
|
| 98 |
+
echo " You will need to provide the SFT checkpoint manually"
|
| 99 |
+
fi
|
| 100 |
+
else
|
| 101 |
+
echo " kmp_sft_model_final already exists: OK"
|
| 102 |
+
fi
|
| 103 |
+
|
| 104 |
+
# Step 7: Verify data paths
|
| 105 |
+
echo ""
|
| 106 |
+
echo "[7/7] Verifying data availability..."
|
| 107 |
+
export GENELAB_BASE="${SCRATCH}/data/GeneLab_benchmark"
|
| 108 |
+
export BIOEVAL_DATA="${SCRATCH}/data/BioEval/data"
|
| 109 |
+
export SPACEOMICS_DATA="${SCRATCH}/data/SpaceOmicsBench/v3/evaluation/llm"
|
| 110 |
+
export BIOEVAL_ROOT="${SCRATCH}/data/BioEval"
|
| 111 |
+
|
| 112 |
+
for d in "$GENELAB_BASE" "$BIOEVAL_DATA" "$SPACEOMICS_DATA" "$BIOEVAL_ROOT"; do
|
| 113 |
+
if [ -d "$d" ]; then
|
| 114 |
+
echo " $d: OK"
|
| 115 |
+
else
|
| 116 |
+
echo " $d: MISSING"
|
| 117 |
+
fi
|
| 118 |
+
done
|
| 119 |
+
|
| 120 |
+
echo ""
|
| 121 |
+
echo "============================================================"
|
| 122 |
+
echo "BioGRPO setup complete!"
|
| 123 |
+
echo ""
|
| 124 |
+
echo "Next steps:"
|
| 125 |
+
echo " sbatch scripts/run_grpo_mve.sh"
|
| 126 |
+
echo " tail -f logs/grpo_mve_*.log"
|
| 127 |
+
echo "============================================================"
|
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#
|
| 3 |
+
# BioRLHF Training Script - Ecosystem Improved Model
|
| 4 |
+
# ====================================================
|
| 5 |
+
#
|
| 6 |
+
# This script trains a model on the combined dataset including:
|
| 7 |
+
# - Original KMP study data (363 examples)
|
| 8 |
+
# - Ecosystem failure-based examples (15 examples)
|
| 9 |
+
# - Calibration training
|
| 10 |
+
# - Adversarial resistance
|
| 11 |
+
# - Protocol completeness
|
| 12 |
+
# - Fact drilling
|
| 13 |
+
#
|
| 14 |
+
# Requirements:
|
| 15 |
+
# - CUDA-capable GPU (recommended: A100, V100, or 4090)
|
| 16 |
+
# - 24GB+ VRAM for Mistral-7B with 4-bit quantization
|
| 17 |
+
# - Python environment with: torch, transformers, peft, trl, bitsandbytes
|
| 18 |
+
#
|
| 19 |
+
# Usage:
|
| 20 |
+
# ./scripts/train_ecosystem_improved.sh
|
| 21 |
+
#
|
| 22 |
+
# Or on HPC with SLURM:
|
| 23 |
+
# sbatch scripts/train_ecosystem_improved.sh
|
| 24 |
+
#
|
| 25 |
+
|
| 26 |
+
# ==============================================================================
|
| 27 |
+
# SLURM Configuration (for HPC clusters - uncomment if using SLURM)
|
| 28 |
+
# ==============================================================================
|
| 29 |
+
#SBATCH --job-name=biorlhf_ecosystem
|
| 30 |
+
#SBATCH --output=logs/biorlhf_ecosystem_%j.out
|
| 31 |
+
#SBATCH --error=logs/biorlhf_ecosystem_%j.err
|
| 32 |
+
#SBATCH --time=4:00:00
|
| 33 |
+
#SBATCH --gres=gpu:1
|
| 34 |
+
#SBATCH --mem=48G
|
| 35 |
+
#SBATCH --cpus-per-task=8
|
| 36 |
+
|
| 37 |
+
# ==============================================================================
|
| 38 |
+
# Environment Setup
|
| 39 |
+
# ==============================================================================
|
| 40 |
+
echo "============================================================"
|
| 41 |
+
echo "BioRLHF Ecosystem Training"
|
| 42 |
+
echo "============================================================"
|
| 43 |
+
echo "Start time: $(date)"
|
| 44 |
+
echo "Host: $(hostname)"
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
# Activate conda environment (adjust path as needed)
|
| 48 |
+
# source /path/to/conda/etc/profile.d/conda.sh
|
| 49 |
+
# conda activate biorlhf
|
| 50 |
+
|
| 51 |
+
# Set working directory
|
| 52 |
+
cd "$(dirname "$0")/.." || exit 1
|
| 53 |
+
echo "Working directory: $(pwd)"
|
| 54 |
+
|
| 55 |
+
# Check GPU
|
| 56 |
+
echo ""
|
| 57 |
+
echo "GPU Information:"
|
| 58 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv 2>/dev/null || echo "No GPU detected"
|
| 59 |
+
echo ""
|
| 60 |
+
|
| 61 |
+
# ==============================================================================
|
| 62 |
+
# Training Configuration
|
| 63 |
+
# ==============================================================================
|
| 64 |
+
|
| 65 |
+
# Model settings
|
| 66 |
+
MODEL="mistralai/Mistral-7B-v0.3"
|
| 67 |
+
DATASET="data/combined_training.json"
|
| 68 |
+
OUTPUT_DIR="./ecosystem_improved_model"
|
| 69 |
+
|
| 70 |
+
# Training hyperparameters (optimized based on prior BioRLHF experiments)
|
| 71 |
+
EPOCHS=10 # More epochs for better fact memorization
|
| 72 |
+
BATCH_SIZE=4 # Adjust based on GPU memory
|
| 73 |
+
GRAD_ACCUM=4 # Effective batch size = 16
|
| 74 |
+
LEARNING_RATE=2e-4 # Standard for LoRA fine-tuning
|
| 75 |
+
MAX_LENGTH=1024 # Sufficient for most examples
|
| 76 |
+
|
| 77 |
+
# LoRA configuration (higher rank for domain knowledge)
|
| 78 |
+
LORA_R=64 # Higher rank for better capacity
|
| 79 |
+
LORA_ALPHA=128 # Alpha = 2 * r
|
| 80 |
+
|
| 81 |
+
# Logging
|
| 82 |
+
WANDB_PROJECT="biorlhf"
|
| 83 |
+
WANDB_RUN="ecosystem_improved_$(date +%Y%m%d_%H%M%S)"
|
| 84 |
+
|
| 85 |
+
# ==============================================================================
|
| 86 |
+
# Pre-training Checks
|
| 87 |
+
# ==============================================================================
|
| 88 |
+
echo "============================================================"
|
| 89 |
+
echo "Configuration:"
|
| 90 |
+
echo "============================================================"
|
| 91 |
+
echo "Model: $MODEL"
|
| 92 |
+
echo "Dataset: $DATASET"
|
| 93 |
+
echo "Output: $OUTPUT_DIR"
|
| 94 |
+
echo "Epochs: $EPOCHS"
|
| 95 |
+
echo "Batch size: $BATCH_SIZE (effective: $((BATCH_SIZE * GRAD_ACCUM)))"
|
| 96 |
+
echo "LoRA r/α: $LORA_R / $LORA_ALPHA"
|
| 97 |
+
echo "Max length: $MAX_LENGTH"
|
| 98 |
+
echo ""
|
| 99 |
+
|
| 100 |
+
# Check if dataset exists
|
| 101 |
+
if [ ! -f "$DATASET" ]; then
|
| 102 |
+
echo "ERROR: Dataset not found at $DATASET"
|
| 103 |
+
echo "Run: python scripts/merge_training_data.py"
|
| 104 |
+
exit 1
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
# Count examples
|
| 108 |
+
EXAMPLE_COUNT=$(python3 -c "import json; print(len(json.load(open('$DATASET'))))")
|
| 109 |
+
echo "Dataset contains $EXAMPLE_COUNT examples"
|
| 110 |
+
echo ""
|
| 111 |
+
|
| 112 |
+
# ==============================================================================
|
| 113 |
+
# Run Training
|
| 114 |
+
# ==============================================================================
|
| 115 |
+
echo "============================================================"
|
| 116 |
+
echo "Starting Training..."
|
| 117 |
+
echo "============================================================"
|
| 118 |
+
|
| 119 |
+
python3 sft_train_v2.py \
|
| 120 |
+
--model "$MODEL" \
|
| 121 |
+
--dataset "$DATASET" \
|
| 122 |
+
--output_dir "$OUTPUT_DIR" \
|
| 123 |
+
--epochs $EPOCHS \
|
| 124 |
+
--batch_size $BATCH_SIZE \
|
| 125 |
+
--grad_accum $GRAD_ACCUM \
|
| 126 |
+
--lr $LEARNING_RATE \
|
| 127 |
+
--max_length $MAX_LENGTH \
|
| 128 |
+
--lora_r $LORA_R \
|
| 129 |
+
--lora_alpha $LORA_ALPHA \
|
| 130 |
+
--use_4bit \
|
| 131 |
+
--wandb_project "$WANDB_PROJECT" \
|
| 132 |
+
--wandb_run "$WANDB_RUN"
|
| 133 |
+
|
| 134 |
+
# Check exit status
|
| 135 |
+
if [ $? -eq 0 ]; then
|
| 136 |
+
echo ""
|
| 137 |
+
echo "============================================================"
|
| 138 |
+
echo "✅ Training Complete!"
|
| 139 |
+
echo "============================================================"
|
| 140 |
+
echo "Model saved to: $OUTPUT_DIR"
|
| 141 |
+
echo "End time: $(date)"
|
| 142 |
+
echo ""
|
| 143 |
+
echo "Next steps:"
|
| 144 |
+
echo "1. Evaluate on SpaceOmicsBench: python evaluate_model.py --model $OUTPUT_DIR"
|
| 145 |
+
echo "2. Evaluate on CAMELOT: python evaluate_model.py --model $OUTPUT_DIR --benchmark camelot"
|
| 146 |
+
echo "3. Compare with baseline: python compare_models.py"
|
| 147 |
+
else
|
| 148 |
+
echo ""
|
| 149 |
+
echo "============================================================"
|
| 150 |
+
echo "❌ Training Failed!"
|
| 151 |
+
echo "============================================================"
|
| 152 |
+
echo "Check the error messages above."
|
| 153 |
+
exit 1
|
| 154 |
+
fi
|
|
@@ -9,10 +9,30 @@ __version__ = "0.1.0"
|
|
| 9 |
__author__ = "JangKeun Kim"
|
| 10 |
__email__ = "jangkeun.kim@med.cornell.edu"
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
from biorlhf.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
"__version__",
|
|
|
|
| 9 |
__author__ = "JangKeun Kim"
|
| 10 |
__email__ = "jangkeun.kim@med.cornell.edu"
|
| 11 |
|
| 12 |
+
def __getattr__(name):
|
| 13 |
+
"""Lazy imports for torch-dependent modules."""
|
| 14 |
+
if name == "SFTTrainingConfig":
|
| 15 |
+
from biorlhf.training.sft import SFTTrainingConfig
|
| 16 |
+
return SFTTrainingConfig
|
| 17 |
+
elif name == "run_sft_training":
|
| 18 |
+
from biorlhf.training.sft import run_sft_training
|
| 19 |
+
return run_sft_training
|
| 20 |
+
elif name == "DPOTrainingConfig":
|
| 21 |
+
from biorlhf.training.dpo import DPOTrainingConfig
|
| 22 |
+
return DPOTrainingConfig
|
| 23 |
+
elif name == "run_dpo_training":
|
| 24 |
+
from biorlhf.training.dpo import run_dpo_training
|
| 25 |
+
return run_dpo_training
|
| 26 |
+
elif name == "create_sft_dataset":
|
| 27 |
+
from biorlhf.data.dataset import create_sft_dataset
|
| 28 |
+
return create_sft_dataset
|
| 29 |
+
elif name == "load_dataset":
|
| 30 |
+
from biorlhf.data.dataset import load_dataset
|
| 31 |
+
return load_dataset
|
| 32 |
+
elif name == "evaluate_model":
|
| 33 |
+
from biorlhf.evaluation.evaluate import evaluate_model
|
| 34 |
+
return evaluate_model
|
| 35 |
+
raise AttributeError(f"module 'biorlhf' has no attribute {name!r}")
|
| 36 |
|
| 37 |
__all__ = [
|
| 38 |
"__version__",
|
|
@@ -11,6 +11,7 @@ from pathlib import Path
|
|
| 11 |
|
| 12 |
from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
|
| 13 |
from biorlhf.evaluation.evaluate import evaluate_model as _evaluate_model
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def train():
|
|
@@ -264,5 +265,131 @@ def evaluate():
|
|
| 264 |
sys.exit(1)
|
| 265 |
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
if __name__ == "__main__":
|
| 268 |
-
print("Use 'biorlhf-train' or 'biorlhf-
|
|
|
|
| 11 |
|
| 12 |
from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
|
| 13 |
from biorlhf.evaluation.evaluate import evaluate_model as _evaluate_model
|
| 14 |
+
from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training
|
| 15 |
|
| 16 |
|
| 17 |
def train():
|
|
|
|
| 265 |
sys.exit(1)
|
| 266 |
|
| 267 |
|
| 268 |
+
def grpo_train():
|
| 269 |
+
"""CLI entry point for GRPO training with biological verifiers."""
|
| 270 |
+
parser = argparse.ArgumentParser(
|
| 271 |
+
description="Train a BioGRPO model with composable biological verifiers",
|
| 272 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--model",
|
| 277 |
+
type=str,
|
| 278 |
+
default="mistralai/Mistral-7B-v0.3",
|
| 279 |
+
help="Base model to fine-tune",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--sft-model",
|
| 283 |
+
type=str,
|
| 284 |
+
default=None,
|
| 285 |
+
help="Path to SFT checkpoint (recommended)",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--output",
|
| 289 |
+
type=str,
|
| 290 |
+
default="./biogrpo_model",
|
| 291 |
+
help="Output directory",
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--num-generations",
|
| 295 |
+
type=int,
|
| 296 |
+
default=8,
|
| 297 |
+
help="G value: number of completions per prompt",
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--beta",
|
| 301 |
+
type=float,
|
| 302 |
+
default=0.04,
|
| 303 |
+
help="KL penalty coefficient",
|
| 304 |
+
)
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
"--learning-rate",
|
| 307 |
+
type=float,
|
| 308 |
+
default=1e-6,
|
| 309 |
+
help="Learning rate",
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--lora-r",
|
| 313 |
+
type=int,
|
| 314 |
+
default=32,
|
| 315 |
+
help="LoRA rank",
|
| 316 |
+
)
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
"--lora-alpha",
|
| 319 |
+
type=int,
|
| 320 |
+
default=64,
|
| 321 |
+
help="LoRA alpha",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--verifiers",
|
| 325 |
+
type=str,
|
| 326 |
+
nargs="+",
|
| 327 |
+
default=None,
|
| 328 |
+
help="Active verifiers (e.g., V1 V2 V3 V4). Default: all",
|
| 329 |
+
)
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--pathway-db",
|
| 332 |
+
type=str,
|
| 333 |
+
default="hallmark",
|
| 334 |
+
choices=["hallmark", "kegg", "reactome", "mitocarta"],
|
| 335 |
+
help="Pathway database for GeneLab questions",
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--no-wandb",
|
| 339 |
+
action="store_true",
|
| 340 |
+
help="Disable W&B logging",
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--wandb-project",
|
| 344 |
+
type=str,
|
| 345 |
+
default="biogrpo",
|
| 346 |
+
help="W&B project name",
|
| 347 |
+
)
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--wandb-run-name",
|
| 350 |
+
type=str,
|
| 351 |
+
default="grpo_v1",
|
| 352 |
+
help="W&B run name",
|
| 353 |
+
)
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--config",
|
| 356 |
+
type=str,
|
| 357 |
+
default=None,
|
| 358 |
+
help="Path to JSON config file (overrides other args)",
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
args = parser.parse_args()
|
| 362 |
+
|
| 363 |
+
if args.config:
|
| 364 |
+
with open(args.config) as f:
|
| 365 |
+
config_dict = json.load(f)
|
| 366 |
+
config = BioGRPOConfig(**config_dict)
|
| 367 |
+
else:
|
| 368 |
+
config = BioGRPOConfig(
|
| 369 |
+
model_name=args.model,
|
| 370 |
+
sft_model_path=args.sft_model,
|
| 371 |
+
output_dir=args.output,
|
| 372 |
+
num_generations=args.num_generations,
|
| 373 |
+
beta=args.beta,
|
| 374 |
+
learning_rate=args.learning_rate,
|
| 375 |
+
lora_r=args.lora_r,
|
| 376 |
+
lora_alpha=args.lora_alpha,
|
| 377 |
+
active_verifiers=args.verifiers,
|
| 378 |
+
pathway_db=args.pathway_db,
|
| 379 |
+
use_wandb=not args.no_wandb,
|
| 380 |
+
wandb_project=args.wandb_project,
|
| 381 |
+
wandb_run_name=args.wandb_run_name,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
try:
|
| 385 |
+
output_path = run_grpo_training(config)
|
| 386 |
+
print(f"\nModel saved to: {output_path}")
|
| 387 |
+
except Exception as e:
|
| 388 |
+
import traceback
|
| 389 |
+
traceback.print_exc()
|
| 390 |
+
print(f"Error during GRPO training: {e}", file=sys.stderr)
|
| 391 |
+
sys.exit(1)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
if __name__ == "__main__":
|
| 395 |
+
print("Use 'biorlhf-train', 'biorlhf-evaluate', or 'biorlhf-grpo' commands after installation.")
|
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""Data processing and dataset creation modules for BioRLHF."""
|
| 2 |
|
| 3 |
-
|
| 4 |
from biorlhf.data.ground_truth import (
|
| 5 |
STRESSOR_EFFECTS,
|
| 6 |
KMP_EFFECTS,
|
|
@@ -18,3 +18,11 @@ __all__ = [
|
|
| 18 |
"TISSUE_TYPES",
|
| 19 |
"OXPHOS_PATTERNS",
|
| 20 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Data processing and dataset creation modules for BioRLHF."""
|
| 2 |
|
| 3 |
+
# ground_truth has no heavy dependencies, safe to import eagerly
|
| 4 |
from biorlhf.data.ground_truth import (
|
| 5 |
STRESSOR_EFFECTS,
|
| 6 |
KMP_EFFECTS,
|
|
|
|
| 18 |
"TISSUE_TYPES",
|
| 19 |
"OXPHOS_PATTERNS",
|
| 20 |
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def __getattr__(name):
|
| 24 |
+
"""Lazy imports for modules with heavy dependencies."""
|
| 25 |
+
if name in ("create_sft_dataset", "load_dataset"):
|
| 26 |
+
from biorlhf.data.dataset import create_sft_dataset, load_dataset
|
| 27 |
+
return {"create_sft_dataset": create_sft_dataset, "load_dataset": load_dataset}[name]
|
| 28 |
+
raise AttributeError(f"module 'biorlhf.data' has no attribute {name!r}")
|
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeneLab fGSEA/GSVA data loading for BioGRPO.
|
| 3 |
+
|
| 4 |
+
Loads pathway enrichment results from the GeneLab_benchmark project's
|
| 5 |
+
processed fGSEA and GSVA files. Provides consensus pathway directions
|
| 6 |
+
across missions for use as verifiable ground truth.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, List
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
import json
|
| 14 |
+
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
# ── Paths (configurable via env vars for HPC) ─────────────────────────────
|
| 18 |
+
GENELAB_BASE = Path(os.environ.get(
|
| 19 |
+
"GENELAB_BASE",
|
| 20 |
+
"/Users/jak4013/Dropbox/Bioinformatics/Claude/GeneLab_benchmark",
|
| 21 |
+
))
|
| 22 |
+
FGSEA_DIR = GENELAB_BASE / "processed" / "fgsea"
|
| 23 |
+
GSVA_DIR = GENELAB_BASE / "processed" / "pathway_scores"
|
| 24 |
+
TASKS_DIR = GENELAB_BASE / "tasks"
|
| 25 |
+
EVAL_DIR = GENELAB_BASE / "evaluation"
|
| 26 |
+
|
| 27 |
+
# ── Tissue → available missions (from actual files) ───────────────────────
|
| 28 |
+
TISSUE_MISSIONS: Dict[str, List[str]] = {
|
| 29 |
+
"liver": ["MHU-2", "RR-1", "RR-3", "RR-6", "RR-8", "RR-9"],
|
| 30 |
+
"gastrocnemius": ["RR-1", "RR-9"],
|
| 31 |
+
"kidney": ["RR-1", "RR-3", "RR-7"],
|
| 32 |
+
"thymus": ["MHU-2", "RR-6", "RR-9"],
|
| 33 |
+
"skin": ["MHU-2_dorsal", "MHU-2_femoral", "RR-6"],
|
| 34 |
+
"eye": ["RR-1", "RR-3"],
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Tissue → LOMO task ID
|
| 38 |
+
TISSUE_TASK_MAP: Dict[str, str] = {
|
| 39 |
+
"liver": "A1",
|
| 40 |
+
"gastrocnemius": "A2",
|
| 41 |
+
"kidney": "A3",
|
| 42 |
+
"thymus": "A4",
|
| 43 |
+
"skin": "A5",
|
| 44 |
+
"eye": "A6",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
DBS = ["hallmark", "kegg", "reactome", "mitocarta"]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class PathwayResult:
|
| 52 |
+
"""Single pathway enrichment result from fGSEA."""
|
| 53 |
+
pathway: str
|
| 54 |
+
nes: float
|
| 55 |
+
padj: float
|
| 56 |
+
direction: str # "UP", "DOWN", or "NS"
|
| 57 |
+
tissue: str
|
| 58 |
+
mission: str
|
| 59 |
+
db: str
|
| 60 |
+
leading_edge: List[str] = field(default_factory=list)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ── Loading functions ──────────────────────────────────────────────────────
|
| 64 |
+
|
| 65 |
+
def load_fgsea(tissue: str, mission: str, db: str = "hallmark") -> pd.DataFrame:
|
| 66 |
+
"""Load a single fGSEA result CSV.
|
| 67 |
+
|
| 68 |
+
Returns DataFrame with columns:
|
| 69 |
+
pathway, pval, padj, log2err, ES, NES, size, db,
|
| 70 |
+
leadingEdge_str, tissue, mission, glds
|
| 71 |
+
"""
|
| 72 |
+
path = FGSEA_DIR / tissue / f"{mission}_fgsea_{db}.csv"
|
| 73 |
+
if not path.exists():
|
| 74 |
+
raise FileNotFoundError(f"fGSEA file not found: {path}")
|
| 75 |
+
return pd.read_csv(path)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def load_all_fgsea(tissue: str, db: str = "hallmark") -> pd.DataFrame:
|
| 79 |
+
"""Load all fGSEA results for a tissue across all available missions."""
|
| 80 |
+
dfs = []
|
| 81 |
+
for mission in TISSUE_MISSIONS.get(tissue, []):
|
| 82 |
+
path = FGSEA_DIR / tissue / f"{mission}_fgsea_{db}.csv"
|
| 83 |
+
if path.exists():
|
| 84 |
+
dfs.append(pd.read_csv(path))
|
| 85 |
+
if not dfs:
|
| 86 |
+
return pd.DataFrame()
|
| 87 |
+
return pd.concat(dfs, ignore_index=True)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_pathway_directions(
|
| 91 |
+
tissue: str,
|
| 92 |
+
db: str = "hallmark",
|
| 93 |
+
padj_threshold: float = 0.05,
|
| 94 |
+
) -> Dict[str, Dict[str, str]]:
|
| 95 |
+
"""Return pathway directions per mission.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
{mission: {pathway: "UP"/"DOWN"/"NS"}}
|
| 99 |
+
Only pathways with padj < threshold get UP/DOWN; rest are NS.
|
| 100 |
+
"""
|
| 101 |
+
df = load_all_fgsea(tissue, db)
|
| 102 |
+
if df.empty:
|
| 103 |
+
return {}
|
| 104 |
+
|
| 105 |
+
result: Dict[str, Dict[str, str]] = {}
|
| 106 |
+
for mission, mdf in df.groupby("mission"):
|
| 107 |
+
directions: Dict[str, str] = {}
|
| 108 |
+
for _, row in mdf.iterrows():
|
| 109 |
+
if pd.notna(row["padj"]) and row["padj"] < padj_threshold:
|
| 110 |
+
directions[row["pathway"]] = "UP" if row["NES"] > 0 else "DOWN"
|
| 111 |
+
else:
|
| 112 |
+
directions[row["pathway"]] = "NS"
|
| 113 |
+
result[str(mission)] = directions
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_consensus_directions(
|
| 118 |
+
tissue: str,
|
| 119 |
+
db: str = "hallmark",
|
| 120 |
+
min_missions: int = 2,
|
| 121 |
+
padj_threshold: float = 0.05,
|
| 122 |
+
) -> Dict[str, Dict]:
|
| 123 |
+
"""Return pathways with consensus direction across missions.
|
| 124 |
+
|
| 125 |
+
Only includes pathways where >= min_missions agree on direction
|
| 126 |
+
and the majority direction has more votes than the opposite.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
{pathway: {
|
| 130 |
+
direction: "UP"/"DOWN",
|
| 131 |
+
n_agree: int,
|
| 132 |
+
n_disagree: int,
|
| 133 |
+
n_ns: int,
|
| 134 |
+
missions_agree: List[str],
|
| 135 |
+
missions_disagree: List[str],
|
| 136 |
+
}}
|
| 137 |
+
"""
|
| 138 |
+
all_dirs = get_pathway_directions(tissue, db, padj_threshold)
|
| 139 |
+
if not all_dirs:
|
| 140 |
+
return {}
|
| 141 |
+
|
| 142 |
+
# Collect per-pathway votes
|
| 143 |
+
pathway_votes: Dict[str, Dict[str, List[str]]] = {}
|
| 144 |
+
for mission, pmap in all_dirs.items():
|
| 145 |
+
for pathway, direction in pmap.items():
|
| 146 |
+
if pathway not in pathway_votes:
|
| 147 |
+
pathway_votes[pathway] = {"UP": [], "DOWN": [], "NS": []}
|
| 148 |
+
pathway_votes[pathway][direction].append(mission)
|
| 149 |
+
|
| 150 |
+
consensus: Dict[str, Dict] = {}
|
| 151 |
+
for pathway, votes in pathway_votes.items():
|
| 152 |
+
n_up = len(votes["UP"])
|
| 153 |
+
n_down = len(votes["DOWN"])
|
| 154 |
+
n_ns = len(votes["NS"])
|
| 155 |
+
|
| 156 |
+
if n_up >= min_missions and n_up > n_down:
|
| 157 |
+
consensus[pathway] = {
|
| 158 |
+
"direction": "UP",
|
| 159 |
+
"n_agree": n_up,
|
| 160 |
+
"n_disagree": n_down,
|
| 161 |
+
"n_ns": n_ns,
|
| 162 |
+
"missions_agree": votes["UP"],
|
| 163 |
+
"missions_disagree": votes["DOWN"],
|
| 164 |
+
}
|
| 165 |
+
elif n_down >= min_missions and n_down > n_up:
|
| 166 |
+
consensus[pathway] = {
|
| 167 |
+
"direction": "DOWN",
|
| 168 |
+
"n_agree": n_down,
|
| 169 |
+
"n_disagree": n_up,
|
| 170 |
+
"n_ns": n_ns,
|
| 171 |
+
"missions_agree": votes["DOWN"],
|
| 172 |
+
"missions_disagree": votes["UP"],
|
| 173 |
+
}
|
| 174 |
+
return consensus
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_disagreeing_pathways(
|
| 178 |
+
tissue: str,
|
| 179 |
+
db: str = "hallmark",
|
| 180 |
+
padj_threshold: float = 0.05,
|
| 181 |
+
) -> Dict[str, Dict]:
|
| 182 |
+
"""Return pathways where missions disagree on direction.
|
| 183 |
+
|
| 184 |
+
These are ideal for uncertainty questions — the model should
|
| 185 |
+
express uncertainty about direction.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
{pathway: {
|
| 189 |
+
missions_up: List[str],
|
| 190 |
+
missions_down: List[str],
|
| 191 |
+
missions_ns: List[str],
|
| 192 |
+
}}
|
| 193 |
+
"""
|
| 194 |
+
all_dirs = get_pathway_directions(tissue, db, padj_threshold)
|
| 195 |
+
if not all_dirs:
|
| 196 |
+
return {}
|
| 197 |
+
|
| 198 |
+
pathway_votes: Dict[str, Dict[str, List[str]]] = {}
|
| 199 |
+
for mission, pmap in all_dirs.items():
|
| 200 |
+
for pathway, direction in pmap.items():
|
| 201 |
+
if pathway not in pathway_votes:
|
| 202 |
+
pathway_votes[pathway] = {"UP": [], "DOWN": [], "NS": []}
|
| 203 |
+
pathway_votes[pathway][direction].append(mission)
|
| 204 |
+
|
| 205 |
+
disagreeing: Dict[str, Dict] = {}
|
| 206 |
+
for pathway, votes in pathway_votes.items():
|
| 207 |
+
if votes["UP"] and votes["DOWN"]:
|
| 208 |
+
disagreeing[pathway] = {
|
| 209 |
+
"missions_up": votes["UP"],
|
| 210 |
+
"missions_down": votes["DOWN"],
|
| 211 |
+
"missions_ns": votes["NS"],
|
| 212 |
+
}
|
| 213 |
+
return disagreeing
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def load_gsva_scores(
|
| 217 |
+
tissue: str,
|
| 218 |
+
mission: str,
|
| 219 |
+
db: str = "hallmark",
|
| 220 |
+
) -> pd.DataFrame:
|
| 221 |
+
"""Load GSVA pathway scores (samples × pathways)."""
|
| 222 |
+
path = GSVA_DIR / tissue / f"{mission}_gsva_{db}.csv"
|
| 223 |
+
if not path.exists():
|
| 224 |
+
raise FileNotFoundError(f"GSVA file not found: {path}")
|
| 225 |
+
return pd.read_csv(path, index_col=0)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def load_lomo_splits(tissue: str) -> List[Dict]:
|
| 229 |
+
"""Load LOMO fold definitions from task_info.json."""
|
| 230 |
+
task_id = TISSUE_TASK_MAP.get(tissue)
|
| 231 |
+
if not task_id:
|
| 232 |
+
return []
|
| 233 |
+
task_dir = TASKS_DIR / f"{task_id}_{tissue}_lomo"
|
| 234 |
+
info_path = task_dir / "task_info.json"
|
| 235 |
+
if not info_path.exists():
|
| 236 |
+
return []
|
| 237 |
+
with open(info_path) as f:
|
| 238 |
+
info = json.load(f)
|
| 239 |
+
return info.get("folds", [])
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def load_nes_conservation(db: str = "hallmark") -> Dict:
|
| 243 |
+
"""Load NES conservation analysis (cross-mission correlation data)."""
|
| 244 |
+
path = EVAL_DIR / f"NES_conservation_{db}.json"
|
| 245 |
+
if not path.exists():
|
| 246 |
+
return {}
|
| 247 |
+
with open(path) as f:
|
| 248 |
+
return json.load(f)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def get_all_pathways(tissue: str, db: str = "hallmark") -> List[str]:
|
| 252 |
+
"""Get sorted list of all pathway names for a tissue/db combo."""
|
| 253 |
+
df = load_all_fgsea(tissue, db)
|
| 254 |
+
if df.empty:
|
| 255 |
+
return []
|
| 256 |
+
return sorted(df["pathway"].unique().tolist())
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_pathway_nes_matrix(
|
| 260 |
+
tissue: str,
|
| 261 |
+
db: str = "hallmark",
|
| 262 |
+
) -> pd.DataFrame:
|
| 263 |
+
"""Return a mission × pathway NES matrix for a tissue.
|
| 264 |
+
|
| 265 |
+
Useful for visualizing pathway behavior across missions.
|
| 266 |
+
"""
|
| 267 |
+
df = load_all_fgsea(tissue, db)
|
| 268 |
+
if df.empty:
|
| 269 |
+
return pd.DataFrame()
|
| 270 |
+
return df.pivot_table(
|
| 271 |
+
index="mission", columns="pathway", values="NES", aggfunc="first",
|
| 272 |
+
)
|
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified GRPO dataset builder for BioGRPO.
|
| 3 |
+
|
| 4 |
+
Merges pathway questions from GeneLab, calibration tasks from BioEval,
|
| 5 |
+
and domain questions from SpaceOmicsBench into a single TRL-compatible
|
| 6 |
+
dataset with multi-dimensional ground truth.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Dict, Optional, Tuple
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
from datasets import Dataset as HFDataset
|
| 15 |
+
|
| 16 |
+
from biorlhf.data.question_generator import generate_all_questions
|
| 17 |
+
|
| 18 |
+
# ── External data paths (configurable via env vars for HPC) ───────────────
|
| 19 |
+
BIOEVAL_DATA = Path(os.environ.get(
|
| 20 |
+
"BIOEVAL_DATA",
|
| 21 |
+
"/Users/jak4013/Dropbox/Bioinformatics/Claude/Evaluation_model/BioEval/data",
|
| 22 |
+
))
|
| 23 |
+
SPACEOMICS_DATA = Path(os.environ.get(
|
| 24 |
+
"SPACEOMICS_DATA",
|
| 25 |
+
"/Users/jak4013/Dropbox/Bioinformatics/Claude/SpaceOmicsBench/v3/evaluation/llm",
|
| 26 |
+
))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_bioeval_for_grpo() -> List[Dict]:
|
| 30 |
+
"""Load BioEval tasks that have verifiable ground truth.
|
| 31 |
+
|
| 32 |
+
Selects:
|
| 33 |
+
- calibration tasks (30) → V4 training
|
| 34 |
+
- bioambiguity tasks (45) → V3 training
|
| 35 |
+
- Other verifiable tasks → V2 training
|
| 36 |
+
"""
|
| 37 |
+
samples: List[Dict] = []
|
| 38 |
+
base_path = BIOEVAL_DATA / "bioeval_v060_base.jsonl"
|
| 39 |
+
if not base_path.exists():
|
| 40 |
+
return samples
|
| 41 |
+
|
| 42 |
+
with open(base_path) as f:
|
| 43 |
+
for line in f:
|
| 44 |
+
task = json.loads(line)
|
| 45 |
+
component = task.get("component", "")
|
| 46 |
+
prompt = task.get("prompt", "")
|
| 47 |
+
gt = task.get("ground_truth", "{}")
|
| 48 |
+
|
| 49 |
+
# Ensure ground_truth is a JSON string
|
| 50 |
+
gt_str = json.dumps(gt) if isinstance(gt, dict) else gt
|
| 51 |
+
|
| 52 |
+
if component == "calibration":
|
| 53 |
+
samples.append({
|
| 54 |
+
"prompt": prompt,
|
| 55 |
+
"ground_truth": gt_str,
|
| 56 |
+
"question_type": "calibration",
|
| 57 |
+
"applicable_verifiers": json.dumps(["V4"]),
|
| 58 |
+
"source": "bioeval",
|
| 59 |
+
"tissue": "general",
|
| 60 |
+
"difficulty": "medium",
|
| 61 |
+
})
|
| 62 |
+
elif component == "bioambiguity":
|
| 63 |
+
samples.append({
|
| 64 |
+
"prompt": prompt,
|
| 65 |
+
"ground_truth": gt_str,
|
| 66 |
+
"question_type": "context_dependent",
|
| 67 |
+
"applicable_verifiers": json.dumps(["V3", "V4"]),
|
| 68 |
+
"source": "bioeval",
|
| 69 |
+
"tissue": "general",
|
| 70 |
+
"difficulty": "hard",
|
| 71 |
+
})
|
| 72 |
+
elif component in ("causalbio", "designcheck", "adversarial"):
|
| 73 |
+
samples.append({
|
| 74 |
+
"prompt": prompt,
|
| 75 |
+
"ground_truth": gt_str,
|
| 76 |
+
"question_type": component,
|
| 77 |
+
"applicable_verifiers": json.dumps(["V2"]),
|
| 78 |
+
"source": "bioeval",
|
| 79 |
+
"tissue": "general",
|
| 80 |
+
"difficulty": "hard" if component == "adversarial" else "medium",
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
return samples
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def load_spaceomics_for_grpo() -> List[Dict]:
|
| 87 |
+
"""Load SpaceOmicsBench v3 questions with ground truth."""
|
| 88 |
+
samples: List[Dict] = []
|
| 89 |
+
qbank_path = SPACEOMICS_DATA / "question_bank_v3.json"
|
| 90 |
+
if not qbank_path.exists():
|
| 91 |
+
return samples
|
| 92 |
+
|
| 93 |
+
with open(qbank_path) as f:
|
| 94 |
+
qbank = json.load(f)
|
| 95 |
+
|
| 96 |
+
questions = qbank.get("questions", [])
|
| 97 |
+
for q in questions:
|
| 98 |
+
gt = {
|
| 99 |
+
"key_facts": q.get("ground_truth_key_facts", []),
|
| 100 |
+
"expected_reasoning": q.get("expected_reasoning", []),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
verifiers = ["V2"]
|
| 104 |
+
if q.get("requires_uncertainty_calibration", False):
|
| 105 |
+
verifiers.append("V4")
|
| 106 |
+
gt["expected_confidence"] = "medium"
|
| 107 |
+
|
| 108 |
+
samples.append({
|
| 109 |
+
"prompt": q["question"],
|
| 110 |
+
"ground_truth": json.dumps(gt),
|
| 111 |
+
"question_type": q.get("category", "factual"),
|
| 112 |
+
"applicable_verifiers": json.dumps(verifiers),
|
| 113 |
+
"source": "spaceomics",
|
| 114 |
+
"tissue": "general",
|
| 115 |
+
"difficulty": q.get("difficulty", "medium"),
|
| 116 |
+
})
|
| 117 |
+
|
| 118 |
+
return samples
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def build_grpo_dataset(
|
| 122 |
+
db: str = "hallmark",
|
| 123 |
+
seed: int = 42,
|
| 124 |
+
hold_out_tissues: Optional[List[str]] = None,
|
| 125 |
+
) -> Tuple[HFDataset, HFDataset]:
|
| 126 |
+
"""Build the full GRPO training dataset with train/eval split.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
db: Pathway database to use for GeneLab questions.
|
| 130 |
+
seed: Random seed for splitting.
|
| 131 |
+
hold_out_tissues: If set, questions from these tissues go to eval.
|
| 132 |
+
Otherwise uses random 10% split.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
(train_dataset, eval_dataset) as HuggingFace Datasets.
|
| 136 |
+
|
| 137 |
+
Dataset columns (TRL-compatible):
|
| 138 |
+
- prompt: str (required by GRPOTrainer)
|
| 139 |
+
- ground_truth: str (JSON, forwarded to reward function)
|
| 140 |
+
- question_type: str (forwarded to reward function)
|
| 141 |
+
- applicable_verifiers: str (JSON list, forwarded to reward function)
|
| 142 |
+
- source: str ("genelab", "bioeval", "spaceomics")
|
| 143 |
+
- tissue: str (for LOMO splitting)
|
| 144 |
+
- difficulty: str ("easy", "medium", "hard")
|
| 145 |
+
"""
|
| 146 |
+
all_samples: List[Dict] = []
|
| 147 |
+
|
| 148 |
+
# 1. GeneLab pathway questions
|
| 149 |
+
genelab_qs = generate_all_questions(db)
|
| 150 |
+
for q in genelab_qs:
|
| 151 |
+
all_samples.append({
|
| 152 |
+
"prompt": q.prompt,
|
| 153 |
+
"ground_truth": json.dumps(q.ground_truth),
|
| 154 |
+
"question_type": q.question_type,
|
| 155 |
+
"applicable_verifiers": json.dumps(q.applicable_verifiers),
|
| 156 |
+
"source": "genelab",
|
| 157 |
+
"tissue": q.tissue,
|
| 158 |
+
"difficulty": q.difficulty,
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
# 2. BioEval tasks
|
| 162 |
+
all_samples.extend(load_bioeval_for_grpo())
|
| 163 |
+
|
| 164 |
+
# 3. SpaceOmicsBench questions
|
| 165 |
+
all_samples.extend(load_spaceomics_for_grpo())
|
| 166 |
+
|
| 167 |
+
if not all_samples:
|
| 168 |
+
raise ValueError("No training samples generated. Check data paths.")
|
| 169 |
+
|
| 170 |
+
# Convert to HF Dataset
|
| 171 |
+
full_dataset = HFDataset.from_list(all_samples)
|
| 172 |
+
|
| 173 |
+
# Split strategy
|
| 174 |
+
if hold_out_tissues:
|
| 175 |
+
train_indices = []
|
| 176 |
+
eval_indices = []
|
| 177 |
+
for i, sample in enumerate(all_samples):
|
| 178 |
+
if sample["tissue"] in hold_out_tissues:
|
| 179 |
+
eval_indices.append(i)
|
| 180 |
+
else:
|
| 181 |
+
train_indices.append(i)
|
| 182 |
+
if not eval_indices:
|
| 183 |
+
# Fallback: random split if no matching tissues
|
| 184 |
+
split = full_dataset.train_test_split(test_size=0.1, seed=seed)
|
| 185 |
+
return split["train"], split["test"]
|
| 186 |
+
train_dataset = full_dataset.select(train_indices)
|
| 187 |
+
eval_dataset = full_dataset.select(eval_indices)
|
| 188 |
+
else:
|
| 189 |
+
split = full_dataset.train_test_split(test_size=0.1, seed=seed)
|
| 190 |
+
train_dataset = split["train"]
|
| 191 |
+
eval_dataset = split["test"]
|
| 192 |
+
|
| 193 |
+
return train_dataset, eval_dataset
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_dataset_stats(dataset: HFDataset) -> Dict:
|
| 197 |
+
"""Return summary statistics for a GRPO dataset."""
|
| 198 |
+
sources = {}
|
| 199 |
+
types = {}
|
| 200 |
+
tissues = {}
|
| 201 |
+
difficulties = {}
|
| 202 |
+
|
| 203 |
+
for sample in dataset:
|
| 204 |
+
src = sample["source"]
|
| 205 |
+
sources[src] = sources.get(src, 0) + 1
|
| 206 |
+
qt = sample["question_type"]
|
| 207 |
+
types[qt] = types.get(qt, 0) + 1
|
| 208 |
+
t = sample["tissue"]
|
| 209 |
+
tissues[t] = tissues.get(t, 0) + 1
|
| 210 |
+
d = sample["difficulty"]
|
| 211 |
+
difficulties[d] = difficulties.get(d, 0) + 1
|
| 212 |
+
|
| 213 |
+
return {
|
| 214 |
+
"total": len(dataset),
|
| 215 |
+
"by_source": sources,
|
| 216 |
+
"by_question_type": types,
|
| 217 |
+
"by_tissue": tissues,
|
| 218 |
+
"by_difficulty": difficulties,
|
| 219 |
+
}
|
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pathway reasoning question generator for BioGRPO.
|
| 3 |
+
|
| 4 |
+
Generates verifiable QA pairs from GeneLab fGSEA pathway data.
|
| 5 |
+
Each question has structured ground truth for scoring by the verifier stack.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Dict, Set
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
|
| 11 |
+
from biorlhf.data.genelabloader import (
|
| 12 |
+
get_consensus_directions,
|
| 13 |
+
get_disagreeing_pathways,
|
| 14 |
+
get_pathway_directions,
|
| 15 |
+
load_nes_conservation,
|
| 16 |
+
TISSUE_MISSIONS,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class GRPOQuestion:
|
| 22 |
+
"""A question with verifiable ground truth for GRPO training."""
|
| 23 |
+
prompt: str
|
| 24 |
+
ground_truth: Dict
|
| 25 |
+
tissue: str
|
| 26 |
+
db: str
|
| 27 |
+
question_type: str # "direction", "comparison", "consistency", "uncertainty"
|
| 28 |
+
applicable_verifiers: List[str]
|
| 29 |
+
difficulty: str # "easy", "medium", "hard"
|
| 30 |
+
metadata: Dict = field(default_factory=dict)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _clean_pathway_name(pathway: str) -> str:
|
| 34 |
+
"""HALLMARK_OXIDATIVE_PHOSPHORYLATION → Oxidative Phosphorylation"""
|
| 35 |
+
for prefix in ("HALLMARK_", "KEGG_", "REACTOME_", "MITOCARTA_"):
|
| 36 |
+
pathway = pathway.replace(prefix, "")
|
| 37 |
+
return pathway.replace("_", " ").title()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ── Question generators ────────────────────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
def generate_direction_questions(
|
| 43 |
+
tissue: str,
|
| 44 |
+
db: str = "hallmark",
|
| 45 |
+
padj_threshold: float = 0.05,
|
| 46 |
+
) -> List[GRPOQuestion]:
|
| 47 |
+
"""Generate V1-targetable questions about pathway direction."""
|
| 48 |
+
consensus = get_consensus_directions(tissue, db, min_missions=2, padj_threshold=padj_threshold)
|
| 49 |
+
questions: List[GRPOQuestion] = []
|
| 50 |
+
|
| 51 |
+
for pathway, info in consensus.items():
|
| 52 |
+
pw = _clean_pathway_name(pathway)
|
| 53 |
+
n_agree = info["n_agree"]
|
| 54 |
+
|
| 55 |
+
# Type 1: Direct direction question (easy/medium)
|
| 56 |
+
questions.append(GRPOQuestion(
|
| 57 |
+
prompt=(
|
| 58 |
+
f"In mouse {tissue} tissue during spaceflight, is the "
|
| 59 |
+
f"{pw} pathway upregulated or downregulated based on "
|
| 60 |
+
f"gene set enrichment analysis? "
|
| 61 |
+
f"Provide your confidence level."
|
| 62 |
+
),
|
| 63 |
+
ground_truth={
|
| 64 |
+
"pathway": pathway,
|
| 65 |
+
"direction": info["direction"],
|
| 66 |
+
"n_supporting_missions": n_agree,
|
| 67 |
+
"expected_confidence": "high" if n_agree >= 3 else "medium",
|
| 68 |
+
},
|
| 69 |
+
tissue=tissue,
|
| 70 |
+
db=db,
|
| 71 |
+
question_type="direction",
|
| 72 |
+
applicable_verifiers=["V1", "V4"],
|
| 73 |
+
difficulty="easy" if n_agree >= 3 else "medium",
|
| 74 |
+
))
|
| 75 |
+
|
| 76 |
+
# Type 2: Mechanistic reasoning question (medium/hard)
|
| 77 |
+
direction_word = "activation" if info["direction"] == "UP" else "suppression"
|
| 78 |
+
questions.append(GRPOQuestion(
|
| 79 |
+
prompt=(
|
| 80 |
+
f"Explain the biological significance of {pw} pathway "
|
| 81 |
+
f"{direction_word} in mouse {tissue} under spaceflight conditions. "
|
| 82 |
+
f"What mechanisms might drive this change? "
|
| 83 |
+
f"State your confidence in the direction and magnitude."
|
| 84 |
+
),
|
| 85 |
+
ground_truth={
|
| 86 |
+
"pathway": pathway,
|
| 87 |
+
"direction": info["direction"],
|
| 88 |
+
"n_supporting_missions": n_agree,
|
| 89 |
+
"requires_mechanism": True,
|
| 90 |
+
"expected_confidence": "medium",
|
| 91 |
+
},
|
| 92 |
+
tissue=tissue,
|
| 93 |
+
db=db,
|
| 94 |
+
question_type="direction",
|
| 95 |
+
applicable_verifiers=["V1", "V2", "V4"],
|
| 96 |
+
difficulty="medium" if n_agree >= 3 else "hard",
|
| 97 |
+
))
|
| 98 |
+
|
| 99 |
+
return questions
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def generate_comparison_questions(
|
| 103 |
+
db: str = "hallmark",
|
| 104 |
+
padj_threshold: float = 0.05,
|
| 105 |
+
) -> List[GRPOQuestion]:
|
| 106 |
+
"""Generate cross-tissue comparison questions (V1 + V3 targetable)."""
|
| 107 |
+
questions: List[GRPOQuestion] = []
|
| 108 |
+
|
| 109 |
+
# Collect consensus directions across tissues (exclude skin subsites for cleaner Qs)
|
| 110 |
+
tissue_dirs: Dict[str, Dict[str, Dict]] = {}
|
| 111 |
+
comparison_tissues = ["liver", "gastrocnemius", "kidney", "thymus", "eye"]
|
| 112 |
+
for tissue in comparison_tissues:
|
| 113 |
+
consensus = get_consensus_directions(tissue, db, min_missions=2, padj_threshold=padj_threshold)
|
| 114 |
+
if consensus:
|
| 115 |
+
tissue_dirs[tissue] = consensus
|
| 116 |
+
|
| 117 |
+
if len(tissue_dirs) < 2:
|
| 118 |
+
return questions
|
| 119 |
+
|
| 120 |
+
# Find pathways in 2+ tissues
|
| 121 |
+
all_pathways: Set[str] = set()
|
| 122 |
+
for dirs in tissue_dirs.values():
|
| 123 |
+
all_pathways.update(dirs.keys())
|
| 124 |
+
|
| 125 |
+
for pathway in sorted(all_pathways):
|
| 126 |
+
tissues_with = {
|
| 127 |
+
t: d[pathway]
|
| 128 |
+
for t, d in tissue_dirs.items()
|
| 129 |
+
if pathway in d
|
| 130 |
+
}
|
| 131 |
+
if len(tissues_with) < 2:
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
pw = _clean_pathway_name(pathway)
|
| 135 |
+
tissue_list = sorted(tissues_with.keys())
|
| 136 |
+
directions_set = {info["direction"] for info in tissues_with.values()}
|
| 137 |
+
is_consistent = len(directions_set) == 1
|
| 138 |
+
|
| 139 |
+
questions.append(GRPOQuestion(
|
| 140 |
+
prompt=(
|
| 141 |
+
f"Compare the response of the {pw} pathway to spaceflight "
|
| 142 |
+
f"across {', '.join(tissue_list)} tissues in mice. "
|
| 143 |
+
f"Is the direction of change consistent or tissue-specific? "
|
| 144 |
+
f"Explain the biological basis for any differences."
|
| 145 |
+
),
|
| 146 |
+
ground_truth={
|
| 147 |
+
"pathway": pathway,
|
| 148 |
+
"tissue_directions": {
|
| 149 |
+
t: info["direction"] for t, info in tissues_with.items()
|
| 150 |
+
},
|
| 151 |
+
"is_consistent": is_consistent,
|
| 152 |
+
"n_tissues": len(tissues_with),
|
| 153 |
+
},
|
| 154 |
+
tissue="multi",
|
| 155 |
+
db=db,
|
| 156 |
+
question_type="comparison",
|
| 157 |
+
applicable_verifiers=["V1", "V3", "V4"],
|
| 158 |
+
difficulty="hard",
|
| 159 |
+
))
|
| 160 |
+
|
| 161 |
+
return questions
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def generate_uncertainty_questions(
|
| 165 |
+
tissue: str,
|
| 166 |
+
db: str = "hallmark",
|
| 167 |
+
padj_threshold: float = 0.05,
|
| 168 |
+
) -> List[GRPOQuestion]:
|
| 169 |
+
"""Generate questions where missions disagree → model should express uncertainty."""
|
| 170 |
+
disagreeing = get_disagreeing_pathways(tissue, db, padj_threshold)
|
| 171 |
+
questions: List[GRPOQuestion] = []
|
| 172 |
+
|
| 173 |
+
for pathway, info in disagreeing.items():
|
| 174 |
+
pw = _clean_pathway_name(pathway)
|
| 175 |
+
questions.append(GRPOQuestion(
|
| 176 |
+
prompt=(
|
| 177 |
+
f"Is the {pw} pathway consistently activated or suppressed "
|
| 178 |
+
f"in mouse {tissue} across different spaceflight missions? "
|
| 179 |
+
f"How confident are you in the direction of change?"
|
| 180 |
+
),
|
| 181 |
+
ground_truth={
|
| 182 |
+
"pathway": pathway,
|
| 183 |
+
"missions_up": info["missions_up"],
|
| 184 |
+
"missions_down": info["missions_down"],
|
| 185 |
+
"missions_ns": info["missions_ns"],
|
| 186 |
+
"correct_behavior": "context_dependent",
|
| 187 |
+
"expected_confidence": "low",
|
| 188 |
+
},
|
| 189 |
+
tissue=tissue,
|
| 190 |
+
db=db,
|
| 191 |
+
question_type="uncertainty",
|
| 192 |
+
applicable_verifiers=["V1", "V4"],
|
| 193 |
+
difficulty="hard",
|
| 194 |
+
))
|
| 195 |
+
|
| 196 |
+
return questions
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def generate_conservation_questions(
|
| 200 |
+
db: str = "hallmark",
|
| 201 |
+
) -> List[GRPOQuestion]:
|
| 202 |
+
"""Generate questions about NES conservation across missions."""
|
| 203 |
+
conservation = load_nes_conservation(db)
|
| 204 |
+
if not conservation:
|
| 205 |
+
return []
|
| 206 |
+
|
| 207 |
+
questions: List[GRPOQuestion] = []
|
| 208 |
+
data = conservation.get("data", conservation)
|
| 209 |
+
|
| 210 |
+
for tissue, info in data.items():
|
| 211 |
+
if not isinstance(info, dict):
|
| 212 |
+
continue
|
| 213 |
+
mean_r = info.get("nes_mean_r")
|
| 214 |
+
if mean_r is None:
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
if mean_r > 0.5:
|
| 218 |
+
conservation_level = "highly conserved"
|
| 219 |
+
expected_conf = "high"
|
| 220 |
+
elif mean_r > 0.2:
|
| 221 |
+
conservation_level = "moderately conserved"
|
| 222 |
+
expected_conf = "medium"
|
| 223 |
+
else:
|
| 224 |
+
conservation_level = "poorly conserved"
|
| 225 |
+
expected_conf = "medium"
|
| 226 |
+
|
| 227 |
+
questions.append(GRPOQuestion(
|
| 228 |
+
prompt=(
|
| 229 |
+
f"How conserved are pathway-level responses to spaceflight "
|
| 230 |
+
f"across different missions in mouse {tissue}? "
|
| 231 |
+
f"Are the enrichment patterns reproducible?"
|
| 232 |
+
),
|
| 233 |
+
ground_truth={
|
| 234 |
+
"tissue": tissue,
|
| 235 |
+
"nes_mean_r": mean_r,
|
| 236 |
+
"conservation_level": conservation_level,
|
| 237 |
+
"expected_confidence": expected_conf,
|
| 238 |
+
"key_facts": [
|
| 239 |
+
f"Mean pairwise NES correlation across missions is {mean_r:.3f}",
|
| 240 |
+
f"Pathway responses in {tissue} are {conservation_level}",
|
| 241 |
+
],
|
| 242 |
+
},
|
| 243 |
+
tissue=tissue,
|
| 244 |
+
db=db,
|
| 245 |
+
question_type="direction",
|
| 246 |
+
applicable_verifiers=["V2", "V4"],
|
| 247 |
+
difficulty="medium",
|
| 248 |
+
))
|
| 249 |
+
|
| 250 |
+
return questions
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def generate_all_questions(db: str = "hallmark") -> List[GRPOQuestion]:
|
| 254 |
+
"""Generate the full question set from GeneLab data."""
|
| 255 |
+
all_q: List[GRPOQuestion] = []
|
| 256 |
+
|
| 257 |
+
for tissue in TISSUE_MISSIONS:
|
| 258 |
+
all_q.extend(generate_direction_questions(tissue, db))
|
| 259 |
+
all_q.extend(generate_uncertainty_questions(tissue, db))
|
| 260 |
+
|
| 261 |
+
all_q.extend(generate_comparison_questions(db))
|
| 262 |
+
all_q.extend(generate_conservation_questions(db))
|
| 263 |
+
|
| 264 |
+
return all_q
|
|
@@ -1,8 +1,14 @@
|
|
| 1 |
"""Evaluation modules for BioRLHF."""
|
| 2 |
|
| 3 |
-
from biorlhf.evaluation.evaluate import evaluate_model, compute_metrics
|
| 4 |
-
|
| 5 |
__all__ = [
|
| 6 |
"evaluate_model",
|
| 7 |
"compute_metrics",
|
| 8 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Evaluation modules for BioRLHF."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
__all__ = [
|
| 4 |
"evaluate_model",
|
| 5 |
"compute_metrics",
|
| 6 |
]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def __getattr__(name):
|
| 10 |
+
"""Lazy imports for torch-dependent modules."""
|
| 11 |
+
if name in ("evaluate_model", "compute_metrics"):
|
| 12 |
+
from biorlhf.evaluation.evaluate import evaluate_model, compute_metrics
|
| 13 |
+
return {"evaluate_model": evaluate_model, "compute_metrics": compute_metrics}[name]
|
| 14 |
+
raise AttributeError(f"module 'biorlhf.evaluation' has no attribute {name!r}")
|
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Calibration evaluation metrics for BioGRPO.
|
| 3 |
+
|
| 4 |
+
Implements Expected Calibration Error (ECE), Brier score, overconfidence
|
| 5 |
+
rate, and reliability diagram data generation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, List, Tuple
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class CalibrationMetrics:
|
| 14 |
+
"""Aggregated calibration metrics."""
|
| 15 |
+
ece: float # Expected Calibration Error
|
| 16 |
+
mce: float # Maximum Calibration Error
|
| 17 |
+
brier_score: float
|
| 18 |
+
overconfidence_rate: float # P(wrong | confidence > threshold)
|
| 19 |
+
underconfidence_rate: float # P(correct | confidence < threshold)
|
| 20 |
+
mean_confidence: float
|
| 21 |
+
mean_accuracy: float
|
| 22 |
+
n_samples: int
|
| 23 |
+
reliability_bins: List[Dict] # For plotting reliability diagrams
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def compute_ece(
|
| 27 |
+
confidences: List[float],
|
| 28 |
+
correctnesses: List[bool],
|
| 29 |
+
n_bins: int = 10,
|
| 30 |
+
) -> Tuple[float, float, List[Dict]]:
|
| 31 |
+
"""Compute Expected and Maximum Calibration Error.
|
| 32 |
+
|
| 33 |
+
Uses equal-width binning.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
confidences: Model's stated confidence for each prediction (0-1).
|
| 37 |
+
correctnesses: Whether each prediction was correct.
|
| 38 |
+
n_bins: Number of calibration bins.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
(ECE, MCE, bin_data) where bin_data is list of dicts for plotting.
|
| 42 |
+
"""
|
| 43 |
+
if not confidences:
|
| 44 |
+
return 0.0, 0.0, []
|
| 45 |
+
|
| 46 |
+
bin_width = 1.0 / n_bins
|
| 47 |
+
bins: List[Dict] = []
|
| 48 |
+
|
| 49 |
+
for i in range(n_bins):
|
| 50 |
+
bin_lower = i * bin_width
|
| 51 |
+
bin_upper = (i + 1) * bin_width
|
| 52 |
+
|
| 53 |
+
# Find samples in this bin
|
| 54 |
+
indices = [
|
| 55 |
+
j for j, c in enumerate(confidences)
|
| 56 |
+
if bin_lower <= c < bin_upper or (i == n_bins - 1 and c == 1.0)
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
if not indices:
|
| 60 |
+
bins.append({
|
| 61 |
+
"bin_lower": bin_lower,
|
| 62 |
+
"bin_upper": bin_upper,
|
| 63 |
+
"mean_confidence": (bin_lower + bin_upper) / 2,
|
| 64 |
+
"mean_accuracy": 0.0,
|
| 65 |
+
"count": 0,
|
| 66 |
+
"calibration_error": 0.0,
|
| 67 |
+
})
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
bin_confs = [confidences[j] for j in indices]
|
| 71 |
+
bin_accs = [float(correctnesses[j]) for j in indices]
|
| 72 |
+
mean_conf = sum(bin_confs) / len(bin_confs)
|
| 73 |
+
mean_acc = sum(bin_accs) / len(bin_accs)
|
| 74 |
+
|
| 75 |
+
bins.append({
|
| 76 |
+
"bin_lower": bin_lower,
|
| 77 |
+
"bin_upper": bin_upper,
|
| 78 |
+
"mean_confidence": mean_conf,
|
| 79 |
+
"mean_accuracy": mean_acc,
|
| 80 |
+
"count": len(indices),
|
| 81 |
+
"calibration_error": abs(mean_acc - mean_conf),
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
# ECE: weighted average of calibration errors
|
| 85 |
+
total_samples = len(confidences)
|
| 86 |
+
ece = sum(
|
| 87 |
+
b["count"] / total_samples * b["calibration_error"]
|
| 88 |
+
for b in bins if b["count"] > 0
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# MCE: maximum calibration error across non-empty bins
|
| 92 |
+
non_empty_errors = [b["calibration_error"] for b in bins if b["count"] > 0]
|
| 93 |
+
mce = max(non_empty_errors) if non_empty_errors else 0.0
|
| 94 |
+
|
| 95 |
+
return ece, mce, bins
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compute_brier_score(
|
| 99 |
+
confidences: List[float],
|
| 100 |
+
correctnesses: List[bool],
|
| 101 |
+
) -> float:
|
| 102 |
+
"""Compute Brier score: mean squared error between confidence and outcome.
|
| 103 |
+
|
| 104 |
+
Lower is better. Range [0, 1].
|
| 105 |
+
"""
|
| 106 |
+
if not confidences:
|
| 107 |
+
return 0.0
|
| 108 |
+
n = len(confidences)
|
| 109 |
+
return sum(
|
| 110 |
+
(c - float(o)) ** 2 for c, o in zip(confidences, correctnesses)
|
| 111 |
+
) / n
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def compute_overconfidence_rate(
|
| 115 |
+
confidences: List[float],
|
| 116 |
+
correctnesses: List[bool],
|
| 117 |
+
threshold: float = 0.8,
|
| 118 |
+
) -> float:
|
| 119 |
+
"""P(wrong | confidence > threshold).
|
| 120 |
+
|
| 121 |
+
High overconfidence rate indicates the model is unreliably confident.
|
| 122 |
+
"""
|
| 123 |
+
high_conf = [
|
| 124 |
+
(c, o) for c, o in zip(confidences, correctnesses) if c > threshold
|
| 125 |
+
]
|
| 126 |
+
if not high_conf:
|
| 127 |
+
return 0.0
|
| 128 |
+
wrong = sum(1 for _, o in high_conf if not o)
|
| 129 |
+
return wrong / len(high_conf)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def compute_underconfidence_rate(
|
| 133 |
+
confidences: List[float],
|
| 134 |
+
correctnesses: List[bool],
|
| 135 |
+
threshold: float = 0.3,
|
| 136 |
+
) -> float:
|
| 137 |
+
"""P(correct | confidence < threshold).
|
| 138 |
+
|
| 139 |
+
High underconfidence rate means the model knows more than it admits.
|
| 140 |
+
"""
|
| 141 |
+
low_conf = [
|
| 142 |
+
(c, o) for c, o in zip(confidences, correctnesses) if c < threshold
|
| 143 |
+
]
|
| 144 |
+
if not low_conf:
|
| 145 |
+
return 0.0
|
| 146 |
+
correct = sum(1 for _, o in low_conf if o)
|
| 147 |
+
return correct / len(low_conf)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def compute_calibration_metrics(
|
| 151 |
+
confidences: List[float],
|
| 152 |
+
correctnesses: List[bool],
|
| 153 |
+
n_bins: int = 10,
|
| 154 |
+
overconf_threshold: float = 0.8,
|
| 155 |
+
underconf_threshold: float = 0.3,
|
| 156 |
+
) -> CalibrationMetrics:
|
| 157 |
+
"""Compute full calibration metrics suite."""
|
| 158 |
+
if not confidences:
|
| 159 |
+
return CalibrationMetrics(
|
| 160 |
+
ece=0.0, mce=0.0, brier_score=0.0,
|
| 161 |
+
overconfidence_rate=0.0, underconfidence_rate=0.0,
|
| 162 |
+
mean_confidence=0.0, mean_accuracy=0.0,
|
| 163 |
+
n_samples=0, reliability_bins=[],
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
ece, mce, bins = compute_ece(confidences, correctnesses, n_bins)
|
| 167 |
+
brier = compute_brier_score(confidences, correctnesses)
|
| 168 |
+
overconf = compute_overconfidence_rate(confidences, correctnesses, overconf_threshold)
|
| 169 |
+
underconf = compute_underconfidence_rate(confidences, correctnesses, underconf_threshold)
|
| 170 |
+
|
| 171 |
+
mean_conf = sum(confidences) / len(confidences)
|
| 172 |
+
mean_acc = sum(float(c) for c in correctnesses) / len(correctnesses)
|
| 173 |
+
|
| 174 |
+
return CalibrationMetrics(
|
| 175 |
+
ece=ece,
|
| 176 |
+
mce=mce,
|
| 177 |
+
brier_score=brier,
|
| 178 |
+
overconfidence_rate=overconf,
|
| 179 |
+
underconfidence_rate=underconf,
|
| 180 |
+
mean_confidence=mean_conf,
|
| 181 |
+
mean_accuracy=mean_acc,
|
| 182 |
+
n_samples=len(confidences),
|
| 183 |
+
reliability_bins=bins,
|
| 184 |
+
)
|
|
@@ -1,11 +1,24 @@
|
|
| 1 |
"""Training modules for BioRLHF."""
|
| 2 |
|
| 3 |
-
from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
|
| 4 |
-
from biorlhf.training.dpo import DPOTrainingConfig, run_dpo_training
|
| 5 |
-
|
| 6 |
__all__ = [
|
| 7 |
"SFTTrainingConfig",
|
| 8 |
"run_sft_training",
|
| 9 |
"DPOTrainingConfig",
|
| 10 |
"run_dpo_training",
|
|
|
|
|
|
|
| 11 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Training modules for BioRLHF."""
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
__all__ = [
|
| 4 |
"SFTTrainingConfig",
|
| 5 |
"run_sft_training",
|
| 6 |
"DPOTrainingConfig",
|
| 7 |
"run_dpo_training",
|
| 8 |
+
"BioGRPOConfig",
|
| 9 |
+
"run_grpo_training",
|
| 10 |
]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def __getattr__(name):
|
| 14 |
+
"""Lazy imports for torch-dependent modules."""
|
| 15 |
+
if name in ("SFTTrainingConfig", "run_sft_training"):
|
| 16 |
+
from biorlhf.training.sft import SFTTrainingConfig, run_sft_training
|
| 17 |
+
return {"SFTTrainingConfig": SFTTrainingConfig, "run_sft_training": run_sft_training}[name]
|
| 18 |
+
elif name in ("DPOTrainingConfig", "run_dpo_training"):
|
| 19 |
+
from biorlhf.training.dpo import DPOTrainingConfig, run_dpo_training
|
| 20 |
+
return {"DPOTrainingConfig": DPOTrainingConfig, "run_dpo_training": run_dpo_training}[name]
|
| 21 |
+
elif name in ("BioGRPOConfig", "run_grpo_training"):
|
| 22 |
+
from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training
|
| 23 |
+
return {"BioGRPOConfig": BioGRPOConfig, "run_grpo_training": run_grpo_training}[name]
|
| 24 |
+
raise AttributeError(f"module 'biorlhf.training' has no attribute {name!r}")
|
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Group Relative Policy Optimization (GRPO) training for BioGRPO.
|
| 3 |
+
|
| 4 |
+
Uses TRL's GRPOTrainer with composable biological verifiers as reward functions.
|
| 5 |
+
Supports configurable G values, verifier weights, and LoRA parameters.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Optional, List, Dict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 14 |
+
from peft import LoraConfig, PeftModel
|
| 15 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 16 |
+
|
| 17 |
+
from biorlhf.verifiers.composer import make_grpo_reward_function
|
| 18 |
+
from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class BioGRPOConfig:
|
| 23 |
+
"""Configuration for BioGRPO training."""
|
| 24 |
+
|
| 25 |
+
# Model settings
|
| 26 |
+
model_name: str = "mistralai/Mistral-7B-v0.3"
|
| 27 |
+
sft_model_path: Optional[str] = None
|
| 28 |
+
output_dir: str = "./biogrpo_model"
|
| 29 |
+
|
| 30 |
+
# GRPO hyperparameters
|
| 31 |
+
num_generations: int = 8
|
| 32 |
+
beta: float = 0.04
|
| 33 |
+
num_iterations: int = 1
|
| 34 |
+
scale_rewards: str = "group"
|
| 35 |
+
loss_type: str = "grpo"
|
| 36 |
+
|
| 37 |
+
# Training hyperparameters
|
| 38 |
+
num_epochs: int = 1
|
| 39 |
+
batch_size: int = 2
|
| 40 |
+
gradient_accumulation_steps: int = 8
|
| 41 |
+
learning_rate: float = 1e-6
|
| 42 |
+
max_completion_length: int = 1024
|
| 43 |
+
max_prompt_length: int = 512
|
| 44 |
+
warmup_ratio: float = 0.1
|
| 45 |
+
|
| 46 |
+
# LoRA settings
|
| 47 |
+
lora_r: int = 32
|
| 48 |
+
lora_alpha: int = 64
|
| 49 |
+
lora_dropout: float = 0.05
|
| 50 |
+
|
| 51 |
+
# Verifier configuration
|
| 52 |
+
verifier_weights: Optional[Dict[str, float]] = None
|
| 53 |
+
active_verifiers: Optional[List[str]] = None
|
| 54 |
+
|
| 55 |
+
# Data
|
| 56 |
+
pathway_db: str = "hallmark"
|
| 57 |
+
hold_out_tissues: Optional[List[str]] = None
|
| 58 |
+
seed: int = 42
|
| 59 |
+
|
| 60 |
+
# Quantization
|
| 61 |
+
use_4bit: bool = True
|
| 62 |
+
|
| 63 |
+
# Logging
|
| 64 |
+
wandb_project: str = "biogrpo"
|
| 65 |
+
wandb_run_name: str = "grpo_v1"
|
| 66 |
+
use_wandb: bool = True
|
| 67 |
+
logging_steps: int = 5
|
| 68 |
+
save_steps: int = 50
|
| 69 |
+
eval_steps: int = 50
|
| 70 |
+
save_total_limit: int = 3
|
| 71 |
+
log_completions: bool = True
|
| 72 |
+
|
| 73 |
+
# Memory optimization
|
| 74 |
+
use_vllm: bool = False
|
| 75 |
+
gradient_checkpointing: bool = True
|
| 76 |
+
bf16: bool = True
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def run_grpo_training(config: Optional[BioGRPOConfig] = None) -> str:
|
| 80 |
+
"""Run BioGRPO training.
|
| 81 |
+
|
| 82 |
+
Pipeline:
|
| 83 |
+
1. Build dataset from GeneLab + BioEval + SpaceOmicsBench
|
| 84 |
+
2. Create composed reward function from verifier stack
|
| 85 |
+
3. Load tokenizer and configure GRPOTrainer with LoRA
|
| 86 |
+
4. Train and save model
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
config: Training configuration. Uses defaults if None.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Path to the saved model directory.
|
| 93 |
+
"""
|
| 94 |
+
if config is None:
|
| 95 |
+
config = BioGRPOConfig()
|
| 96 |
+
|
| 97 |
+
print("=" * 60)
|
| 98 |
+
print("BioGRPO Training")
|
| 99 |
+
print("=" * 60)
|
| 100 |
+
print(f" Model: {config.model_name}")
|
| 101 |
+
print(f" SFT checkpoint: {config.sft_model_path or 'None (from base)'}")
|
| 102 |
+
print(f" G (generations): {config.num_generations}")
|
| 103 |
+
print(f" Beta (KL): {config.beta}")
|
| 104 |
+
print(f" Loss type: {config.loss_type}")
|
| 105 |
+
print(f" Active verifiers:{config.active_verifiers or 'all (V1-V4)'}")
|
| 106 |
+
print(f" Verifier weights:{config.verifier_weights or 'default'}")
|
| 107 |
+
print(f" LoRA r/alpha: {config.lora_r}/{config.lora_alpha}")
|
| 108 |
+
print(f" Learning rate: {config.learning_rate}")
|
| 109 |
+
print(f" QLoRA 4-bit: {config.use_4bit}")
|
| 110 |
+
print(f" Output: {config.output_dir}")
|
| 111 |
+
print("=" * 60)
|
| 112 |
+
|
| 113 |
+
# Initialize wandb
|
| 114 |
+
if config.use_wandb:
|
| 115 |
+
try:
|
| 116 |
+
import wandb
|
| 117 |
+
wandb.init(
|
| 118 |
+
project=config.wandb_project,
|
| 119 |
+
name=config.wandb_run_name,
|
| 120 |
+
config={k: v for k, v in vars(config).items() if not k.startswith("_")},
|
| 121 |
+
)
|
| 122 |
+
except ImportError:
|
| 123 |
+
print("Warning: wandb not installed, disabling logging")
|
| 124 |
+
config.use_wandb = False
|
| 125 |
+
|
| 126 |
+
# 1. Build dataset
|
| 127 |
+
print("\n[1/5] Building GRPO dataset...")
|
| 128 |
+
train_dataset, eval_dataset = build_grpo_dataset(
|
| 129 |
+
db=config.pathway_db,
|
| 130 |
+
seed=config.seed,
|
| 131 |
+
hold_out_tissues=config.hold_out_tissues,
|
| 132 |
+
)
|
| 133 |
+
train_stats = get_dataset_stats(train_dataset)
|
| 134 |
+
eval_stats = get_dataset_stats(eval_dataset)
|
| 135 |
+
print(f" Train: {train_stats['total']} samples")
|
| 136 |
+
print(f" By source: {train_stats['by_source']}")
|
| 137 |
+
print(f" By type: {train_stats['by_question_type']}")
|
| 138 |
+
print(f" Eval: {eval_stats['total']} samples")
|
| 139 |
+
|
| 140 |
+
# 2. Create reward function
|
| 141 |
+
print("\n[2/5] Initializing verifier stack...")
|
| 142 |
+
reward_func = make_grpo_reward_function(
|
| 143 |
+
weights=config.verifier_weights,
|
| 144 |
+
active_verifiers=config.active_verifiers,
|
| 145 |
+
)
|
| 146 |
+
print(f" Active: {config.active_verifiers or ['V1', 'V2', 'V3', 'V4']}")
|
| 147 |
+
|
| 148 |
+
# 3. Load tokenizer (always from base model; adapter dirs lack config.json)
|
| 149 |
+
print("\n[3/5] Loading tokenizer...")
|
| 150 |
+
tokenizer_source = config.model_name
|
| 151 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, trust_remote_code=True)
|
| 152 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 153 |
+
tokenizer.padding_side = "left"
|
| 154 |
+
print(f" Tokenizer: {tokenizer.__class__.__name__}, vocab={tokenizer.vocab_size}")
|
| 155 |
+
|
| 156 |
+
# 4. Configure LoRA
|
| 157 |
+
peft_config = LoraConfig(
|
| 158 |
+
r=config.lora_r,
|
| 159 |
+
lora_alpha=config.lora_alpha,
|
| 160 |
+
target_modules=[
|
| 161 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 162 |
+
"gate_proj", "up_proj", "down_proj",
|
| 163 |
+
],
|
| 164 |
+
lora_dropout=config.lora_dropout,
|
| 165 |
+
bias="none",
|
| 166 |
+
task_type="CAUSAL_LM",
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# 5. Load model (merge SFT adapter if present)
|
| 170 |
+
print("\n[4/5] Loading model...")
|
| 171 |
+
|
| 172 |
+
# QLoRA quantization config
|
| 173 |
+
bnb_config = None
|
| 174 |
+
if config.use_4bit:
|
| 175 |
+
bnb_config = BitsAndBytesConfig(
|
| 176 |
+
load_in_4bit=True,
|
| 177 |
+
bnb_4bit_quant_type="nf4",
|
| 178 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 179 |
+
bnb_4bit_use_double_quant=True,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Check if sft_model_path is a LoRA adapter or a full model
|
| 183 |
+
sft_is_adapter = (
|
| 184 |
+
config.sft_model_path
|
| 185 |
+
and os.path.isdir(config.sft_model_path)
|
| 186 |
+
and os.path.exists(os.path.join(config.sft_model_path, "adapter_config.json"))
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if sft_is_adapter:
|
| 190 |
+
# Load base model, merge SFT adapter, then apply fresh LoRA for GRPO
|
| 191 |
+
print(f" Loading base model: {config.model_name}")
|
| 192 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 193 |
+
config.model_name,
|
| 194 |
+
quantization_config=bnb_config,
|
| 195 |
+
torch_dtype=torch.bfloat16,
|
| 196 |
+
trust_remote_code=True,
|
| 197 |
+
)
|
| 198 |
+
print(f" Loading SFT LoRA adapter: {config.sft_model_path}")
|
| 199 |
+
model = PeftModel.from_pretrained(base_model, config.sft_model_path)
|
| 200 |
+
print(" Merging SFT adapter into base model...")
|
| 201 |
+
model = model.merge_and_unload()
|
| 202 |
+
print(" SFT adapter merged successfully")
|
| 203 |
+
else:
|
| 204 |
+
# sft_model_path is a full model or use base model
|
| 205 |
+
model_path = config.sft_model_path or config.model_name
|
| 206 |
+
print(f" Loading model: {model_path}")
|
| 207 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 208 |
+
model_path,
|
| 209 |
+
quantization_config=bnb_config,
|
| 210 |
+
torch_dtype=torch.bfloat16,
|
| 211 |
+
trust_remote_code=True,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# 6. Configure GRPOTrainer
|
| 215 |
+
print("\n[5/6] Configuring GRPOTrainer...")
|
| 216 |
+
|
| 217 |
+
grpo_config = GRPOConfig(
|
| 218 |
+
output_dir=config.output_dir,
|
| 219 |
+
num_train_epochs=config.num_epochs,
|
| 220 |
+
per_device_train_batch_size=config.batch_size,
|
| 221 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| 222 |
+
learning_rate=config.learning_rate,
|
| 223 |
+
warmup_ratio=config.warmup_ratio,
|
| 224 |
+
lr_scheduler_type="cosine",
|
| 225 |
+
|
| 226 |
+
# GRPO-specific
|
| 227 |
+
num_generations=config.num_generations,
|
| 228 |
+
beta=config.beta,
|
| 229 |
+
loss_type=config.loss_type,
|
| 230 |
+
max_completion_length=config.max_completion_length,
|
| 231 |
+
max_prompt_length=config.max_prompt_length,
|
| 232 |
+
num_iterations=config.num_iterations,
|
| 233 |
+
scale_rewards=config.scale_rewards,
|
| 234 |
+
|
| 235 |
+
# Memory/compute
|
| 236 |
+
gradient_checkpointing=config.gradient_checkpointing,
|
| 237 |
+
bf16=config.bf16,
|
| 238 |
+
use_vllm=config.use_vllm,
|
| 239 |
+
|
| 240 |
+
# Logging
|
| 241 |
+
logging_steps=config.logging_steps,
|
| 242 |
+
save_steps=config.save_steps,
|
| 243 |
+
save_total_limit=config.save_total_limit,
|
| 244 |
+
report_to="wandb" if config.use_wandb else "none",
|
| 245 |
+
run_name=config.wandb_run_name,
|
| 246 |
+
|
| 247 |
+
# Evaluation
|
| 248 |
+
eval_strategy="steps",
|
| 249 |
+
eval_steps=config.eval_steps,
|
| 250 |
+
log_completions=config.log_completions,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
trainer = GRPOTrainer(
|
| 254 |
+
model=model,
|
| 255 |
+
args=grpo_config,
|
| 256 |
+
reward_funcs=reward_func,
|
| 257 |
+
train_dataset=train_dataset,
|
| 258 |
+
eval_dataset=eval_dataset,
|
| 259 |
+
peft_config=peft_config,
|
| 260 |
+
processing_class=tokenizer,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Train
|
| 264 |
+
print("\n[6/6] Starting GRPO training...")
|
| 265 |
+
print("=" * 60)
|
| 266 |
+
|
| 267 |
+
trainer.train()
|
| 268 |
+
|
| 269 |
+
# Save
|
| 270 |
+
print(f"\nSaving model to {config.output_dir}")
|
| 271 |
+
trainer.save_model(config.output_dir)
|
| 272 |
+
|
| 273 |
+
if config.use_wandb:
|
| 274 |
+
try:
|
| 275 |
+
import wandb
|
| 276 |
+
wandb.finish()
|
| 277 |
+
except ImportError:
|
| 278 |
+
pass
|
| 279 |
+
|
| 280 |
+
print("\n" + "=" * 60)
|
| 281 |
+
print("BioGRPO Training complete!")
|
| 282 |
+
print("=" * 60)
|
| 283 |
+
|
| 284 |
+
return config.output_dir
|
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Composable biological verifiers for BioGRPO."""
|
| 2 |
+
|
| 3 |
+
from biorlhf.verifiers.base import BaseVerifier, VerifierResult
|
| 4 |
+
from biorlhf.verifiers.pathway import PathwayDirectionVerifier
|
| 5 |
+
from biorlhf.verifiers.factual import BiologicalFactVerifier
|
| 6 |
+
from biorlhf.verifiers.consistency import CrossContextConsistencyVerifier
|
| 7 |
+
from biorlhf.verifiers.uncertainty import UncertaintyVerifier
|
| 8 |
+
from biorlhf.verifiers.composer import (
|
| 9 |
+
VerifierComposer,
|
| 10 |
+
make_grpo_reward_function,
|
| 11 |
+
make_single_verifier_reward,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"BaseVerifier",
|
| 16 |
+
"VerifierResult",
|
| 17 |
+
"PathwayDirectionVerifier",
|
| 18 |
+
"BiologicalFactVerifier",
|
| 19 |
+
"CrossContextConsistencyVerifier",
|
| 20 |
+
"UncertaintyVerifier",
|
| 21 |
+
"VerifierComposer",
|
| 22 |
+
"make_grpo_reward_function",
|
| 23 |
+
"make_single_verifier_reward",
|
| 24 |
+
]
|
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base class for biological verifiers."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Dict, List
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class VerifierResult:
|
| 10 |
+
"""Result from a single verifier."""
|
| 11 |
+
score: float # 0.0 to 1.0
|
| 12 |
+
verifier_name: str
|
| 13 |
+
details: Dict = field(default_factory=dict)
|
| 14 |
+
applicable: bool = True # False if verifier doesn't apply
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseVerifier(ABC):
|
| 18 |
+
"""Abstract base class for biological verifiers.
|
| 19 |
+
|
| 20 |
+
Each verifier scores a model completion against structured ground truth
|
| 21 |
+
on a specific dimension (pathway direction, factual accuracy, etc.).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
name: str = "base"
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def score(
|
| 28 |
+
self,
|
| 29 |
+
prompt: str,
|
| 30 |
+
completion: str,
|
| 31 |
+
ground_truth: Dict,
|
| 32 |
+
question_type: str,
|
| 33 |
+
) -> VerifierResult:
|
| 34 |
+
"""Score a single completion against ground truth.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
prompt: The original question.
|
| 38 |
+
completion: The model's generated response.
|
| 39 |
+
ground_truth: Parsed ground truth dictionary.
|
| 40 |
+
question_type: Type of question for routing logic.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
VerifierResult with score in [0, 1].
|
| 44 |
+
"""
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
|
| 47 |
+
def is_applicable(self, applicable_verifiers: List[str]) -> bool:
|
| 48 |
+
"""Check if this verifier should score this question."""
|
| 49 |
+
return self.name in applicable_verifiers
|
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Verifier Composer: Weighted composition of V1-V4 into a TRL-compatible reward function.
|
| 3 |
+
|
| 4 |
+
This is THE critical integration point between the verifier stack and
|
| 5 |
+
TRL's GRPOTrainer. The reward function signature must match TRL's expected
|
| 6 |
+
interface exactly.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from typing import Callable, Dict, List, Optional
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
|
| 13 |
+
from biorlhf.verifiers.base import BaseVerifier, VerifierResult
|
| 14 |
+
from biorlhf.verifiers.pathway import PathwayDirectionVerifier
|
| 15 |
+
from biorlhf.verifiers.factual import BiologicalFactVerifier
|
| 16 |
+
from biorlhf.verifiers.consistency import CrossContextConsistencyVerifier
|
| 17 |
+
from biorlhf.verifiers.uncertainty import UncertaintyVerifier
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class ComposedReward:
|
| 22 |
+
"""Result of composed reward computation."""
|
| 23 |
+
total_reward: float
|
| 24 |
+
verifier_scores: Dict[str, float]
|
| 25 |
+
verifier_details: Dict[str, Dict]
|
| 26 |
+
weights_used: Dict[str, float]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Default weights — factual signals dominate
|
| 30 |
+
DEFAULT_WEIGHTS = {
|
| 31 |
+
"V1": 0.35, # Pathway direction (hard signal)
|
| 32 |
+
"V2": 0.30, # Biological facts (soft signal)
|
| 33 |
+
"V3": 0.15, # Cross-context consistency
|
| 34 |
+
"V4": 0.20, # Uncertainty appropriateness
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class VerifierComposer:
|
| 39 |
+
"""Composes V1-V4 verifiers into a unified reward signal."""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
weights: Optional[Dict[str, float]] = None,
|
| 44 |
+
active_verifiers: Optional[List[str]] = None,
|
| 45 |
+
):
|
| 46 |
+
all_verifiers: Dict[str, BaseVerifier] = {
|
| 47 |
+
"V1": PathwayDirectionVerifier(),
|
| 48 |
+
"V2": BiologicalFactVerifier(),
|
| 49 |
+
"V3": CrossContextConsistencyVerifier(),
|
| 50 |
+
"V4": UncertaintyVerifier(),
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
self.weights = dict(weights or DEFAULT_WEIGHTS)
|
| 54 |
+
|
| 55 |
+
# Filter to active verifiers if specified
|
| 56 |
+
if active_verifiers:
|
| 57 |
+
self.verifiers = {
|
| 58 |
+
k: v for k, v in all_verifiers.items() if k in active_verifiers
|
| 59 |
+
}
|
| 60 |
+
# Renormalize weights
|
| 61 |
+
total_w = sum(self.weights.get(k, 0) for k in self.verifiers)
|
| 62 |
+
if total_w > 0:
|
| 63 |
+
self.weights = {
|
| 64 |
+
k: self.weights.get(k, 0) / total_w for k in self.verifiers
|
| 65 |
+
}
|
| 66 |
+
else:
|
| 67 |
+
self.verifiers = all_verifiers
|
| 68 |
+
|
| 69 |
+
def compute_reward(
|
| 70 |
+
self,
|
| 71 |
+
prompt: str,
|
| 72 |
+
completion: str,
|
| 73 |
+
ground_truth: str,
|
| 74 |
+
question_type: str,
|
| 75 |
+
applicable_verifiers: str,
|
| 76 |
+
) -> ComposedReward:
|
| 77 |
+
"""Compute composed reward from all applicable verifiers.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
prompt: The question text.
|
| 81 |
+
completion: Model's generated response.
|
| 82 |
+
ground_truth: JSON string of ground truth.
|
| 83 |
+
question_type: Question type for routing.
|
| 84 |
+
applicable_verifiers: JSON list of verifier names.
|
| 85 |
+
"""
|
| 86 |
+
gt = json.loads(ground_truth) if isinstance(ground_truth, str) else ground_truth
|
| 87 |
+
applicable = (
|
| 88 |
+
json.loads(applicable_verifiers)
|
| 89 |
+
if isinstance(applicable_verifiers, str)
|
| 90 |
+
else applicable_verifiers
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
scores: Dict[str, float] = {}
|
| 94 |
+
details: Dict[str, Dict] = {}
|
| 95 |
+
weights_used: Dict[str, float] = {}
|
| 96 |
+
|
| 97 |
+
for vname, verifier in self.verifiers.items():
|
| 98 |
+
if not verifier.is_applicable(applicable):
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
result = verifier.score(prompt, completion, gt, question_type)
|
| 102 |
+
|
| 103 |
+
if not result.applicable:
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
scores[vname] = result.score
|
| 107 |
+
details[vname] = result.details
|
| 108 |
+
weights_used[vname] = self.weights.get(vname, 0)
|
| 109 |
+
|
| 110 |
+
# Compute weighted sum with renormalization
|
| 111 |
+
if not weights_used:
|
| 112 |
+
return ComposedReward(
|
| 113 |
+
total_reward=0.0,
|
| 114 |
+
verifier_scores=scores,
|
| 115 |
+
verifier_details=details,
|
| 116 |
+
weights_used=weights_used,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
w_total = sum(weights_used.values())
|
| 120 |
+
if w_total > 0:
|
| 121 |
+
normalized = {k: v / w_total for k, v in weights_used.items()}
|
| 122 |
+
else:
|
| 123 |
+
normalized = weights_used
|
| 124 |
+
|
| 125 |
+
total = sum(scores[k] * normalized.get(k, 0) for k in scores)
|
| 126 |
+
|
| 127 |
+
return ComposedReward(
|
| 128 |
+
total_reward=total,
|
| 129 |
+
verifier_scores=scores,
|
| 130 |
+
verifier_details=details,
|
| 131 |
+
weights_used=normalized,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def make_grpo_reward_function(
|
| 136 |
+
weights: Optional[Dict[str, float]] = None,
|
| 137 |
+
active_verifiers: Optional[List[str]] = None,
|
| 138 |
+
) -> Callable:
|
| 139 |
+
"""Create a TRL-compatible reward function from the verifier composer.
|
| 140 |
+
|
| 141 |
+
TRL's GRPOTrainer calls reward functions with signature:
|
| 142 |
+
reward_func(completions, **kwargs) -> list[float]
|
| 143 |
+
|
| 144 |
+
where kwargs include all dataset columns except "prompt".
|
| 145 |
+
The completions are list of list of dicts in chat format, or list of strings.
|
| 146 |
+
|
| 147 |
+
Note: TRL passes prompts separately. Dataset columns (ground_truth,
|
| 148 |
+
question_type, applicable_verifiers, etc.) are forwarded as kwargs.
|
| 149 |
+
"""
|
| 150 |
+
composer = VerifierComposer(weights=weights, active_verifiers=active_verifiers)
|
| 151 |
+
|
| 152 |
+
def reward_func(
|
| 153 |
+
completions: List,
|
| 154 |
+
ground_truth: Optional[List[str]] = None,
|
| 155 |
+
question_type: Optional[List[str]] = None,
|
| 156 |
+
applicable_verifiers: Optional[List[str]] = None,
|
| 157 |
+
**kwargs,
|
| 158 |
+
) -> List[float]:
|
| 159 |
+
"""TRL-compatible reward function using composed biological verifiers.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
completions: List of model completions (strings or chat messages).
|
| 163 |
+
ground_truth: List of JSON ground truth strings (from dataset).
|
| 164 |
+
question_type: List of question type strings (from dataset).
|
| 165 |
+
applicable_verifiers: List of JSON lists of verifier names.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
List of float rewards, one per completion.
|
| 169 |
+
"""
|
| 170 |
+
rewards: List[float] = []
|
| 171 |
+
n = len(completions)
|
| 172 |
+
|
| 173 |
+
# Handle missing kwargs gracefully
|
| 174 |
+
if ground_truth is None:
|
| 175 |
+
ground_truth = ["{}"] * n
|
| 176 |
+
if question_type is None:
|
| 177 |
+
question_type = ["unknown"] * n
|
| 178 |
+
if applicable_verifiers is None:
|
| 179 |
+
applicable_verifiers = [json.dumps(["V1", "V2", "V3", "V4"])] * n
|
| 180 |
+
|
| 181 |
+
# Extract prompts if available in kwargs
|
| 182 |
+
prompts = kwargs.get("prompts", kwargs.get("prompt", [""] * n))
|
| 183 |
+
if isinstance(prompts, str):
|
| 184 |
+
prompts = [prompts] * n
|
| 185 |
+
|
| 186 |
+
for i in range(n):
|
| 187 |
+
# Extract completion text
|
| 188 |
+
completion_text = _extract_text(completions[i])
|
| 189 |
+
prompt_text = _extract_text(prompts[i]) if i < len(prompts) else ""
|
| 190 |
+
|
| 191 |
+
result = composer.compute_reward(
|
| 192 |
+
prompt=prompt_text,
|
| 193 |
+
completion=completion_text,
|
| 194 |
+
ground_truth=ground_truth[i],
|
| 195 |
+
question_type=question_type[i],
|
| 196 |
+
applicable_verifiers=applicable_verifiers[i],
|
| 197 |
+
)
|
| 198 |
+
rewards.append(result.total_reward)
|
| 199 |
+
|
| 200 |
+
return rewards
|
| 201 |
+
|
| 202 |
+
return reward_func
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def make_single_verifier_reward(verifier_name: str) -> Callable:
|
| 206 |
+
"""Create a reward function using only one verifier (for ablation)."""
|
| 207 |
+
return make_grpo_reward_function(active_verifiers=[verifier_name])
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _extract_text(item) -> str:
|
| 211 |
+
"""Extract plain text from various completion formats.
|
| 212 |
+
|
| 213 |
+
TRL may pass completions as:
|
| 214 |
+
- str: plain text
|
| 215 |
+
- list[dict]: chat messages [{"role": "assistant", "content": "..."}]
|
| 216 |
+
"""
|
| 217 |
+
if isinstance(item, str):
|
| 218 |
+
return item
|
| 219 |
+
elif isinstance(item, list):
|
| 220 |
+
# Chat format
|
| 221 |
+
texts = []
|
| 222 |
+
for msg in item:
|
| 223 |
+
if isinstance(msg, dict) and "content" in msg:
|
| 224 |
+
texts.append(msg["content"])
|
| 225 |
+
return " ".join(texts)
|
| 226 |
+
else:
|
| 227 |
+
return str(item)
|
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
V3: Cross-Context Consistency Verifier.
|
| 3 |
+
|
| 4 |
+
Scores whether the model appropriately distinguishes or generalizes
|
| 5 |
+
across biological contexts (tissues, species, doses, timepoints).
|
| 6 |
+
|
| 7 |
+
For comparison questions: checks tissue coverage + consistency assessment.
|
| 8 |
+
For context-dependent questions: checks nuance and hedging.
|
| 9 |
+
For BioAmbiguity tasks: checks context awareness using BioEval scoring logic.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import re
|
| 14 |
+
from typing import Dict, List
|
| 15 |
+
|
| 16 |
+
from biorlhf.verifiers.base import BaseVerifier, VerifierResult
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ── Indicator patterns ─────────────────────────────────────────────────────
|
| 20 |
+
CONSISTENCY_TERMS = [
|
| 21 |
+
"consistent", "conserved", "similar across", "same direction",
|
| 22 |
+
"reproducible", "concordant", "shared", "common response",
|
| 23 |
+
"universal", "uniform",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
SPECIFICITY_TERMS = [
|
| 27 |
+
"tissue-specific", "differs", "different", "opposite", "varies",
|
| 28 |
+
"divergent", "heterogeneous", "discordant", "unique to",
|
| 29 |
+
"distinct", "context-dependent", "tissue-dependent",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
NUANCE_INDICATORS = [
|
| 33 |
+
"depends", "context", "varies", "mission-specific",
|
| 34 |
+
"not consistent", "differs", "some missions", "mixed",
|
| 35 |
+
"heterogeneous", "variable", "inconsistent",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
HEDGING_INDICATORS = [
|
| 39 |
+
"uncertain", "unclear", "difficult to generalize",
|
| 40 |
+
"not enough evidence", "conflicting", "limited data",
|
| 41 |
+
"preliminary", "tentative", "cannot be determined",
|
| 42 |
+
"more research", "caution",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class CrossContextConsistencyVerifier(BaseVerifier):
|
| 47 |
+
"""V3: Scores context-appropriate reasoning."""
|
| 48 |
+
|
| 49 |
+
name = "V3"
|
| 50 |
+
|
| 51 |
+
def score(
|
| 52 |
+
self,
|
| 53 |
+
prompt: str,
|
| 54 |
+
completion: str,
|
| 55 |
+
ground_truth: Dict,
|
| 56 |
+
question_type: str,
|
| 57 |
+
) -> VerifierResult:
|
| 58 |
+
gt = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth)
|
| 59 |
+
|
| 60 |
+
if question_type == "comparison":
|
| 61 |
+
return self._score_comparison(completion, gt)
|
| 62 |
+
elif question_type in ("context_dependent", "uncertainty"):
|
| 63 |
+
return self._score_context_dependent(completion, gt)
|
| 64 |
+
elif "contexts" in gt:
|
| 65 |
+
# BioEval BioAmbiguity format
|
| 66 |
+
return self._score_bioambiguity(completion, gt)
|
| 67 |
+
elif "tissue_directions" in gt:
|
| 68 |
+
return self._score_comparison(completion, gt)
|
| 69 |
+
else:
|
| 70 |
+
return VerifierResult(
|
| 71 |
+
score=0.5,
|
| 72 |
+
verifier_name=self.name,
|
| 73 |
+
details={"reason": "not_applicable"},
|
| 74 |
+
applicable=False,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def _score_comparison(self, completion: str, gt: Dict) -> VerifierResult:
|
| 78 |
+
"""Score cross-tissue comparison questions."""
|
| 79 |
+
tissue_directions = gt.get("tissue_directions", {})
|
| 80 |
+
is_consistent = gt.get("is_consistent", False)
|
| 81 |
+
comp_lower = completion.lower()
|
| 82 |
+
|
| 83 |
+
# Check tissue coverage
|
| 84 |
+
tissues_mentioned = sum(
|
| 85 |
+
1 for tissue in tissue_directions if tissue.lower() in comp_lower
|
| 86 |
+
)
|
| 87 |
+
n_tissues = len(tissue_directions) if tissue_directions else 1
|
| 88 |
+
tissue_coverage = tissues_mentioned / n_tissues
|
| 89 |
+
|
| 90 |
+
# Check consistency/specificity assessment
|
| 91 |
+
claims_consistent = any(t in comp_lower for t in CONSISTENCY_TERMS)
|
| 92 |
+
claims_specific = any(t in comp_lower for t in SPECIFICITY_TERMS)
|
| 93 |
+
|
| 94 |
+
consistency_correct = False
|
| 95 |
+
if is_consistent:
|
| 96 |
+
consistency_correct = claims_consistent
|
| 97 |
+
else:
|
| 98 |
+
consistency_correct = claims_specific
|
| 99 |
+
|
| 100 |
+
score = 0.5 * tissue_coverage + 0.5 * (1.0 if consistency_correct else 0.0)
|
| 101 |
+
|
| 102 |
+
return VerifierResult(
|
| 103 |
+
score=score,
|
| 104 |
+
verifier_name=self.name,
|
| 105 |
+
details={
|
| 106 |
+
"tissues_mentioned": tissues_mentioned,
|
| 107 |
+
"total_tissues": n_tissues,
|
| 108 |
+
"is_consistent_gt": is_consistent,
|
| 109 |
+
"claims_consistent": claims_consistent,
|
| 110 |
+
"claims_specific": claims_specific,
|
| 111 |
+
"consistency_correct": consistency_correct,
|
| 112 |
+
},
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def _score_context_dependent(self, completion: str, gt: Dict) -> VerifierResult:
|
| 116 |
+
"""Score questions where answer should acknowledge context-dependence."""
|
| 117 |
+
comp_lower = completion.lower()
|
| 118 |
+
|
| 119 |
+
nuance_hits = sum(1 for t in NUANCE_INDICATORS if t in comp_lower)
|
| 120 |
+
hedging_hits = sum(1 for t in HEDGING_INDICATORS if t in comp_lower)
|
| 121 |
+
|
| 122 |
+
# Scale: having 2-3 indicators is ideal
|
| 123 |
+
nuance_score = min(nuance_hits / 2.0, 1.0)
|
| 124 |
+
hedging_score = min(hedging_hits / 2.0, 1.0)
|
| 125 |
+
|
| 126 |
+
score = 0.6 * nuance_score + 0.4 * hedging_score
|
| 127 |
+
|
| 128 |
+
return VerifierResult(
|
| 129 |
+
score=min(score, 1.0),
|
| 130 |
+
verifier_name=self.name,
|
| 131 |
+
details={
|
| 132 |
+
"nuance_hits": nuance_hits,
|
| 133 |
+
"hedging_hits": hedging_hits,
|
| 134 |
+
"nuance_score": nuance_score,
|
| 135 |
+
"hedging_score": hedging_score,
|
| 136 |
+
},
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def _score_bioambiguity(self, completion: str, gt: Dict) -> VerifierResult:
|
| 140 |
+
"""Score BioEval BioAmbiguity tasks.
|
| 141 |
+
|
| 142 |
+
GT format:
|
| 143 |
+
{"contexts": {context_name: {"key_terms": [...], "role": "..."}},
|
| 144 |
+
"distinction_key": "..."}
|
| 145 |
+
"""
|
| 146 |
+
contexts = gt.get("contexts", {})
|
| 147 |
+
distinction_key = gt.get("distinction_key", "")
|
| 148 |
+
comp_lower = completion.lower()
|
| 149 |
+
|
| 150 |
+
if not contexts:
|
| 151 |
+
return VerifierResult(
|
| 152 |
+
score=0.5, verifier_name=self.name,
|
| 153 |
+
details={"reason": "no_contexts"}, applicable=False,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Context awareness: % of key terms found across all contexts
|
| 157 |
+
total_terms = 0
|
| 158 |
+
found_terms = 0
|
| 159 |
+
context_scores = {}
|
| 160 |
+
|
| 161 |
+
for ctx_name, ctx_info in contexts.items():
|
| 162 |
+
key_terms = ctx_info.get("key_terms", [])
|
| 163 |
+
if not key_terms:
|
| 164 |
+
continue
|
| 165 |
+
hits = sum(1 for t in key_terms if t.lower() in comp_lower)
|
| 166 |
+
total_terms += len(key_terms)
|
| 167 |
+
found_terms += hits
|
| 168 |
+
context_scores[ctx_name] = hits / len(key_terms) if key_terms else 0
|
| 169 |
+
|
| 170 |
+
context_awareness = found_terms / total_terms if total_terms > 0 else 0
|
| 171 |
+
|
| 172 |
+
# Distinction quality: does response contain distinction key words?
|
| 173 |
+
if distinction_key:
|
| 174 |
+
dist_terms = _extract_key_terms(distinction_key)
|
| 175 |
+
dist_hits = sum(1 for t in dist_terms if t.lower() in comp_lower)
|
| 176 |
+
distinction_quality = dist_hits / len(dist_terms) if dist_terms else 0
|
| 177 |
+
else:
|
| 178 |
+
distinction_quality = 0
|
| 179 |
+
|
| 180 |
+
# Evidence support: does response mention roles?
|
| 181 |
+
role_hits = 0
|
| 182 |
+
role_total = 0
|
| 183 |
+
for ctx_info in contexts.values():
|
| 184 |
+
role = ctx_info.get("role", "")
|
| 185 |
+
if role:
|
| 186 |
+
role_total += 1
|
| 187 |
+
role_terms = _extract_key_terms(role)
|
| 188 |
+
if any(t.lower() in comp_lower for t in role_terms):
|
| 189 |
+
role_hits += 1
|
| 190 |
+
evidence_support = role_hits / role_total if role_total > 0 else 0
|
| 191 |
+
|
| 192 |
+
# Composite: 40% context + 35% distinction + 25% evidence
|
| 193 |
+
score = (
|
| 194 |
+
0.40 * context_awareness
|
| 195 |
+
+ 0.35 * distinction_quality
|
| 196 |
+
+ 0.25 * evidence_support
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
return VerifierResult(
|
| 200 |
+
score=score,
|
| 201 |
+
verifier_name=self.name,
|
| 202 |
+
details={
|
| 203 |
+
"context_awareness": context_awareness,
|
| 204 |
+
"distinction_quality": distinction_quality,
|
| 205 |
+
"evidence_support": evidence_support,
|
| 206 |
+
"context_scores": context_scores,
|
| 207 |
+
"terms_found": found_terms,
|
| 208 |
+
"terms_total": total_terms,
|
| 209 |
+
},
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _extract_key_terms(text: str, min_length: int = 4) -> List[str]:
|
| 214 |
+
"""Extract key terms from text for matching."""
|
| 215 |
+
stopwords = {
|
| 216 |
+
"the", "and", "for", "that", "this", "with", "from", "are",
|
| 217 |
+
"was", "were", "been", "have", "has", "had", "will", "would",
|
| 218 |
+
"could", "should", "may", "might", "can", "does", "between",
|
| 219 |
+
}
|
| 220 |
+
words = re.findall(r"\b[a-zA-Z0-9-]+\b", text)
|
| 221 |
+
return [w for w in words if len(w) >= min_length and w.lower() not in stopwords]
|
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
V2: Biological Fact Verifier.
|
| 3 |
+
|
| 4 |
+
Scores model responses based on overlap with known correct facts
|
| 5 |
+
from curated knowledge bases (SpaceOmicsBench, BioEval, GeneTuring).
|
| 6 |
+
|
| 7 |
+
Scoring: proportion of ground truth key facts found in the response.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
import json
|
| 12 |
+
from typing import Dict, List
|
| 13 |
+
|
| 14 |
+
from biorlhf.verifiers.base import BaseVerifier, VerifierResult
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _extract_key_terms(text: str, min_length: int = 4, max_terms: int = 10) -> List[str]:
|
| 18 |
+
"""Extract important terms from a text string."""
|
| 19 |
+
# Remove common stopwords and short words
|
| 20 |
+
stopwords = {
|
| 21 |
+
"the", "and", "for", "that", "this", "with", "from", "are", "was",
|
| 22 |
+
"were", "been", "have", "has", "had", "will", "would", "could",
|
| 23 |
+
"should", "may", "might", "can", "does", "did", "but", "not",
|
| 24 |
+
"its", "also", "into", "than", "then", "when", "which", "what",
|
| 25 |
+
"where", "who", "how", "all", "each", "every", "both", "more",
|
| 26 |
+
"most", "other", "some", "such", "only", "same", "very", "just",
|
| 27 |
+
}
|
| 28 |
+
words = re.findall(r"\b[a-zA-Z0-9-]+\b", text)
|
| 29 |
+
terms = [
|
| 30 |
+
w for w in words
|
| 31 |
+
if len(w) >= min_length and w.lower() not in stopwords
|
| 32 |
+
]
|
| 33 |
+
return terms[:max_terms]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _phrase_match(phrase: str, text: str) -> bool:
|
| 37 |
+
"""Check if a phrase (or its key terms) appears in text."""
|
| 38 |
+
text_lower = text.lower()
|
| 39 |
+
phrase_lower = phrase.lower()
|
| 40 |
+
|
| 41 |
+
# Direct substring match
|
| 42 |
+
if phrase_lower in text_lower:
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
# For multi-word phrases, check if key terms co-occur
|
| 46 |
+
terms = _extract_key_terms(phrase, min_length=4, max_terms=5)
|
| 47 |
+
if not terms:
|
| 48 |
+
return phrase_lower in text_lower
|
| 49 |
+
|
| 50 |
+
matches = sum(1 for t in terms if t.lower() in text_lower)
|
| 51 |
+
# Require majority of key terms to match
|
| 52 |
+
return matches >= max(1, len(terms) // 2)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BiologicalFactVerifier(BaseVerifier):
|
| 56 |
+
"""V2: Verifies biological factual claims against curated knowledge."""
|
| 57 |
+
|
| 58 |
+
name = "V2"
|
| 59 |
+
|
| 60 |
+
def score(
|
| 61 |
+
self,
|
| 62 |
+
prompt: str,
|
| 63 |
+
completion: str,
|
| 64 |
+
ground_truth: Dict,
|
| 65 |
+
question_type: str,
|
| 66 |
+
) -> VerifierResult:
|
| 67 |
+
"""Score based on overlap with ground truth key facts.
|
| 68 |
+
|
| 69 |
+
Handles multiple GT formats:
|
| 70 |
+
- {"key_facts": ["fact1", "fact2", ...]}
|
| 71 |
+
- {"ground_truth_key_facts": [...]}
|
| 72 |
+
- {"expected_answer": "text"}
|
| 73 |
+
- {"expected_reasoning": [...]}
|
| 74 |
+
- {"correct_steps": [...]}
|
| 75 |
+
"""
|
| 76 |
+
gt = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth)
|
| 77 |
+
|
| 78 |
+
# Extract key facts from various GT formats
|
| 79 |
+
key_facts = self._extract_facts(gt)
|
| 80 |
+
|
| 81 |
+
if not key_facts:
|
| 82 |
+
return VerifierResult(
|
| 83 |
+
score=0.5,
|
| 84 |
+
verifier_name=self.name,
|
| 85 |
+
details={"reason": "no_key_facts_in_gt"},
|
| 86 |
+
applicable=False,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Score: proportion of key facts found in completion
|
| 90 |
+
matched_facts: List[str] = []
|
| 91 |
+
for fact in key_facts:
|
| 92 |
+
if isinstance(fact, str) and _phrase_match(fact, completion):
|
| 93 |
+
matched_facts.append(fact)
|
| 94 |
+
|
| 95 |
+
total = len(key_facts)
|
| 96 |
+
matched = len(matched_facts)
|
| 97 |
+
score = matched / total if total > 0 else 0.0
|
| 98 |
+
|
| 99 |
+
return VerifierResult(
|
| 100 |
+
score=score,
|
| 101 |
+
verifier_name=self.name,
|
| 102 |
+
details={
|
| 103 |
+
"matched_facts": matched_facts,
|
| 104 |
+
"total_facts": total,
|
| 105 |
+
"matched_count": matched,
|
| 106 |
+
"unmatched": [f for f in key_facts if f not in matched_facts],
|
| 107 |
+
},
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def _extract_facts(self, gt: Dict) -> List[str]:
|
| 111 |
+
"""Extract verifiable facts from ground truth dictionary."""
|
| 112 |
+
facts: List[str] = []
|
| 113 |
+
|
| 114 |
+
# Direct key facts lists
|
| 115 |
+
for key in ("key_facts", "ground_truth_key_facts"):
|
| 116 |
+
if key in gt and isinstance(gt[key], list):
|
| 117 |
+
facts.extend(str(f) for f in gt[key] if f)
|
| 118 |
+
|
| 119 |
+
# Expected reasoning points
|
| 120 |
+
if "expected_reasoning" in gt and isinstance(gt["expected_reasoning"], list):
|
| 121 |
+
facts.extend(str(f) for f in gt["expected_reasoning"] if f)
|
| 122 |
+
|
| 123 |
+
# Single expected answer
|
| 124 |
+
if "expected_answer" in gt and isinstance(gt["expected_answer"], str):
|
| 125 |
+
facts.append(gt["expected_answer"])
|
| 126 |
+
|
| 127 |
+
# Protocol steps (BioEval protoreason)
|
| 128 |
+
if "correct_steps" in gt and isinstance(gt["correct_steps"], list):
|
| 129 |
+
facts.extend(str(s) for s in gt["correct_steps"] if s)
|
| 130 |
+
|
| 131 |
+
# NES conservation facts
|
| 132 |
+
if "conservation_level" in gt:
|
| 133 |
+
facts.append(gt["conservation_level"])
|
| 134 |
+
|
| 135 |
+
# Deduplicate while preserving order
|
| 136 |
+
seen = set()
|
| 137 |
+
unique_facts = []
|
| 138 |
+
for f in facts:
|
| 139 |
+
if f not in seen:
|
| 140 |
+
seen.add(f)
|
| 141 |
+
unique_facts.append(f)
|
| 142 |
+
|
| 143 |
+
return unique_facts
|
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
V1: Pathway Direction Verifier.
|
| 3 |
+
|
| 4 |
+
Extracts directional claims about biological pathways from model responses
|
| 5 |
+
and compares them against fGSEA NES direction ground truth.
|
| 6 |
+
|
| 7 |
+
Scoring:
|
| 8 |
+
1.0 — correct direction claimed
|
| 9 |
+
0.5 — mixed/contradictory claims
|
| 10 |
+
0.3 — no directional claim extracted
|
| 11 |
+
0.0 — wrong direction claimed
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
from typing import Dict, List, Tuple
|
| 16 |
+
|
| 17 |
+
from biorlhf.verifiers.base import BaseVerifier, VerifierResult
|
| 18 |
+
|
| 19 |
+
# ── Direction indicator patterns ──────────────────────────────────────────
|
| 20 |
+
UP_INDICATORS = [
|
| 21 |
+
r"\bupregulat\w*\b",
|
| 22 |
+
r"\bactivat\w*\b",
|
| 23 |
+
r"\bincreas\w*\b",
|
| 24 |
+
r"\belevat\w*\b",
|
| 25 |
+
r"\benhanc\w*\b",
|
| 26 |
+
r"\binduced?\b",
|
| 27 |
+
r"\bhigher\b",
|
| 28 |
+
r"\boverexpress\w*\b",
|
| 29 |
+
r"\benrich\w*\b",
|
| 30 |
+
r"\bpositive\s+NES\b",
|
| 31 |
+
r"\bNES\s*[>=]\s*0\b",
|
| 32 |
+
r"\bupstream\s+activat\w*\b",
|
| 33 |
+
r"\bstimulat\w*\b",
|
| 34 |
+
r"\bpromot\w*\b",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
DOWN_INDICATORS = [
|
| 38 |
+
r"\bdownregulat\w*\b",
|
| 39 |
+
r"\bsuppress\w*\b",
|
| 40 |
+
r"\bdecreas\w*\b",
|
| 41 |
+
r"\breduced?\b",
|
| 42 |
+
r"\binhibit\w*\b",
|
| 43 |
+
r"\brepress\w*\b",
|
| 44 |
+
r"\blower\w*\b",
|
| 45 |
+
r"\bunderexpress\w*\b",
|
| 46 |
+
r"\bdepress\w*\b",
|
| 47 |
+
r"\bnegative\s+NES\b",
|
| 48 |
+
r"\bNES\s*<\s*0\b",
|
| 49 |
+
r"\bdiminish\w*\b",
|
| 50 |
+
r"\battenuati\w*\b",
|
| 51 |
+
r"\bimpair\w*\b",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
# Negation patterns that flip direction
|
| 55 |
+
NEGATION_PATTERNS = [
|
| 56 |
+
r"\bnot\s+",
|
| 57 |
+
r"\bno\s+",
|
| 58 |
+
r"\bneither\b",
|
| 59 |
+
r"\bwithout\s+",
|
| 60 |
+
r"\bfail\w*\s+to\b",
|
| 61 |
+
r"\bdoes\s+not\b",
|
| 62 |
+
r"\bdid\s+not\b",
|
| 63 |
+
r"\bisn'?t\b",
|
| 64 |
+
r"\bwasn'?t\b",
|
| 65 |
+
r"\baren'?t\b",
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
# ── Pathway name abbreviations ────────────────────────────────────────────
|
| 69 |
+
PATHWAY_ABBREVIATIONS: Dict[str, List[str]] = {
|
| 70 |
+
"oxidative phosphorylation": ["oxphos", "oxidative phosphorylation", "ox phos"],
|
| 71 |
+
"tnfa signaling via nfkb": ["tnf-alpha", "nfkb", "nf-kb", "nf-κb", "tnfα"],
|
| 72 |
+
"mtorc1 signaling": ["mtor", "mtorc1"],
|
| 73 |
+
"pi3k akt mtor signaling": ["pi3k", "akt", "mtor", "pi3k/akt"],
|
| 74 |
+
"interferon gamma response": ["ifn-gamma", "ifn-γ", "interferon gamma", "ifnγ"],
|
| 75 |
+
"interferon alpha response": ["ifn-alpha", "ifn-α", "interferon alpha", "ifnα"],
|
| 76 |
+
"adipogenesis": ["adipogenesis", "adipogenic"],
|
| 77 |
+
"myogenesis": ["myogenesis", "myogenic"],
|
| 78 |
+
"epithelial mesenchymal transition": ["emt", "epithelial-mesenchymal"],
|
| 79 |
+
"unfolded protein response": ["upr", "unfolded protein"],
|
| 80 |
+
"reactive oxygen species pathway": ["ros", "reactive oxygen"],
|
| 81 |
+
"fatty acid metabolism": ["fatty acid", "fat metabolism", "lipid metabolism"],
|
| 82 |
+
"glycolysis": ["glycolysis", "glycolytic"],
|
| 83 |
+
"dna repair": ["dna repair", "dna damage response"],
|
| 84 |
+
"apoptosis": ["apoptosis", "apoptotic", "programmed cell death"],
|
| 85 |
+
"inflammatory response": ["inflammatory", "inflammation"],
|
| 86 |
+
"hypoxia": ["hypoxia", "hypoxic"],
|
| 87 |
+
"angiogenesis": ["angiogenesis", "angiogenic"],
|
| 88 |
+
"p53 pathway": ["p53", "tp53"],
|
| 89 |
+
"wnt beta catenin signaling": ["wnt", "beta-catenin", "β-catenin"],
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _generate_pathway_variants(pathway_name: str) -> List[str]:
|
| 94 |
+
"""Generate matching variants for a pathway name.
|
| 95 |
+
|
| 96 |
+
E.g. HALLMARK_OXIDATIVE_PHOSPHORYLATION →
|
| 97 |
+
["HALLMARK_OXIDATIVE_PHOSPHORYLATION",
|
| 98 |
+
"oxidative phosphorylation",
|
| 99 |
+
"oxidative phosphorylation pathway",
|
| 100 |
+
"oxphos"]
|
| 101 |
+
"""
|
| 102 |
+
variants = [pathway_name]
|
| 103 |
+
|
| 104 |
+
clean = pathway_name
|
| 105 |
+
for prefix in ("HALLMARK_", "KEGG_", "REACTOME_", "MITOCARTA_"):
|
| 106 |
+
clean = clean.replace(prefix, "")
|
| 107 |
+
human = clean.replace("_", " ").lower()
|
| 108 |
+
|
| 109 |
+
variants.append(human)
|
| 110 |
+
variants.append(human + " pathway")
|
| 111 |
+
|
| 112 |
+
# Add known abbreviations
|
| 113 |
+
for key, abbrevs in PATHWAY_ABBREVIATIONS.items():
|
| 114 |
+
if key in human:
|
| 115 |
+
variants.extend(abbrevs)
|
| 116 |
+
|
| 117 |
+
return variants
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _extract_sentences_with_term(text: str, term: str) -> List[str]:
|
| 121 |
+
"""Extract sentences containing a term."""
|
| 122 |
+
sentences = re.split(r"[.!?\n]+", text)
|
| 123 |
+
return [
|
| 124 |
+
s.strip()
|
| 125 |
+
for s in sentences
|
| 126 |
+
if term.lower() in s.lower() and len(s.strip()) > 10
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _has_negation_before(text: str, match_start: int, window: int = 12) -> bool:
|
| 131 |
+
"""Check if a negation word appears shortly before a match position.
|
| 132 |
+
|
| 133 |
+
Window of ~12 chars catches "not " + up to ~8 chars of whitespace/adverbs,
|
| 134 |
+
without reaching across clause boundaries like "not X but rather Y".
|
| 135 |
+
"""
|
| 136 |
+
start = max(0, match_start - window)
|
| 137 |
+
preceding = text[start:match_start].lower()
|
| 138 |
+
return any(re.search(p, preceding) for p in NEGATION_PATTERNS)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def extract_direction_claims(
|
| 142 |
+
text: str,
|
| 143 |
+
pathway_name: str,
|
| 144 |
+
) -> List[Tuple[str, str]]:
|
| 145 |
+
"""Extract directional claims about a specific pathway from text.
|
| 146 |
+
|
| 147 |
+
Returns list of (pathway_variant, direction) tuples.
|
| 148 |
+
Direction is "UP", "DOWN", or "AMBIGUOUS".
|
| 149 |
+
"""
|
| 150 |
+
text_lower = text.lower()
|
| 151 |
+
pathway_variants = _generate_pathway_variants(pathway_name)
|
| 152 |
+
|
| 153 |
+
claims: List[Tuple[str, str]] = []
|
| 154 |
+
for variant in pathway_variants:
|
| 155 |
+
if variant.lower() not in text_lower:
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
sentences = _extract_sentences_with_term(text, variant)
|
| 159 |
+
for sentence in sentences:
|
| 160 |
+
sent_lower = sentence.lower()
|
| 161 |
+
up_count = 0
|
| 162 |
+
down_count = 0
|
| 163 |
+
|
| 164 |
+
for pattern in UP_INDICATORS:
|
| 165 |
+
for match in re.finditer(pattern, sent_lower):
|
| 166 |
+
if _has_negation_before(sent_lower, match.start()):
|
| 167 |
+
down_count += 1 # Negated up = down
|
| 168 |
+
else:
|
| 169 |
+
up_count += 1
|
| 170 |
+
|
| 171 |
+
for pattern in DOWN_INDICATORS:
|
| 172 |
+
for match in re.finditer(pattern, sent_lower):
|
| 173 |
+
if _has_negation_before(sent_lower, match.start()):
|
| 174 |
+
up_count += 1 # Negated down = up
|
| 175 |
+
else:
|
| 176 |
+
down_count += 1
|
| 177 |
+
|
| 178 |
+
if up_count > down_count:
|
| 179 |
+
claims.append((variant, "UP"))
|
| 180 |
+
elif down_count > up_count:
|
| 181 |
+
claims.append((variant, "DOWN"))
|
| 182 |
+
elif up_count > 0:
|
| 183 |
+
claims.append((variant, "AMBIGUOUS"))
|
| 184 |
+
|
| 185 |
+
return claims
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class PathwayDirectionVerifier(BaseVerifier):
|
| 189 |
+
"""V1: Verifies pathway direction claims against fGSEA NES data."""
|
| 190 |
+
|
| 191 |
+
name = "V1"
|
| 192 |
+
|
| 193 |
+
def score(
|
| 194 |
+
self,
|
| 195 |
+
prompt: str,
|
| 196 |
+
completion: str,
|
| 197 |
+
ground_truth: Dict,
|
| 198 |
+
question_type: str,
|
| 199 |
+
) -> VerifierResult:
|
| 200 |
+
if "pathway" not in ground_truth or "direction" not in ground_truth:
|
| 201 |
+
return VerifierResult(
|
| 202 |
+
score=0.5,
|
| 203 |
+
verifier_name=self.name,
|
| 204 |
+
details={"reason": "no_pathway_in_gt"},
|
| 205 |
+
applicable=False,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
expected_dir = ground_truth["direction"]
|
| 209 |
+
pathway = ground_truth["pathway"]
|
| 210 |
+
|
| 211 |
+
# For comparison questions, check all tissue directions
|
| 212 |
+
if "tissue_directions" in ground_truth and question_type == "comparison":
|
| 213 |
+
return self._score_comparison(completion, ground_truth)
|
| 214 |
+
|
| 215 |
+
claims = extract_direction_claims(completion, pathway)
|
| 216 |
+
|
| 217 |
+
if not claims:
|
| 218 |
+
return VerifierResult(
|
| 219 |
+
score=0.3,
|
| 220 |
+
verifier_name=self.name,
|
| 221 |
+
details={
|
| 222 |
+
"reason": "no_claim_extracted",
|
| 223 |
+
"pathway": pathway,
|
| 224 |
+
"expected": expected_dir,
|
| 225 |
+
},
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
matching = [c for c in claims if c[1] == expected_dir]
|
| 229 |
+
contradicting = [
|
| 230 |
+
c for c in claims if c[1] != expected_dir and c[1] != "AMBIGUOUS"
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
if matching and not contradicting:
|
| 234 |
+
score = 1.0
|
| 235 |
+
elif matching and contradicting:
|
| 236 |
+
score = 0.5
|
| 237 |
+
elif contradicting:
|
| 238 |
+
score = 0.0
|
| 239 |
+
else:
|
| 240 |
+
score = 0.3 # Only ambiguous claims
|
| 241 |
+
|
| 242 |
+
return VerifierResult(
|
| 243 |
+
score=score,
|
| 244 |
+
verifier_name=self.name,
|
| 245 |
+
details={
|
| 246 |
+
"pathway": pathway,
|
| 247 |
+
"expected": expected_dir,
|
| 248 |
+
"claims": [(v, d) for v, d in claims],
|
| 249 |
+
"n_matching": len(matching),
|
| 250 |
+
"n_contradicting": len(contradicting),
|
| 251 |
+
},
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def _score_comparison(
|
| 255 |
+
self, completion: str, ground_truth: Dict
|
| 256 |
+
) -> VerifierResult:
|
| 257 |
+
"""Score cross-tissue comparison: check direction per tissue."""
|
| 258 |
+
tissue_dirs = ground_truth.get("tissue_directions", {})
|
| 259 |
+
pathway = ground_truth.get("pathway", "")
|
| 260 |
+
|
| 261 |
+
if not tissue_dirs:
|
| 262 |
+
return VerifierResult(
|
| 263 |
+
score=0.5, verifier_name=self.name,
|
| 264 |
+
details={"reason": "no_tissue_directions"}, applicable=False,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
correct = 0
|
| 268 |
+
checked = 0
|
| 269 |
+
details_per_tissue = {}
|
| 270 |
+
|
| 271 |
+
for tissue, expected_dir in tissue_dirs.items():
|
| 272 |
+
# Look for tissue-specific claims in the response
|
| 273 |
+
tissue_sentences = _extract_sentences_with_term(completion, tissue)
|
| 274 |
+
if not tissue_sentences:
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
tissue_text = " ".join(tissue_sentences)
|
| 278 |
+
claims = extract_direction_claims(tissue_text, pathway)
|
| 279 |
+
checked += 1
|
| 280 |
+
|
| 281 |
+
if any(c[1] == expected_dir for c in claims):
|
| 282 |
+
correct += 1
|
| 283 |
+
details_per_tissue[tissue] = "correct"
|
| 284 |
+
elif claims:
|
| 285 |
+
details_per_tissue[tissue] = "wrong"
|
| 286 |
+
else:
|
| 287 |
+
details_per_tissue[tissue] = "no_claim"
|
| 288 |
+
|
| 289 |
+
score = correct / checked if checked > 0 else 0.3
|
| 290 |
+
|
| 291 |
+
return VerifierResult(
|
| 292 |
+
score=score,
|
| 293 |
+
verifier_name=self.name,
|
| 294 |
+
details={
|
| 295 |
+
"pathway": pathway,
|
| 296 |
+
"tissues_checked": checked,
|
| 297 |
+
"tissues_correct": correct,
|
| 298 |
+
"per_tissue": details_per_tissue,
|
| 299 |
+
},
|
| 300 |
+
)
|
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
V4: Uncertainty Appropriateness Verifier.
|
| 3 |
+
|
| 4 |
+
Scores whether a model's stated confidence aligns with the ground-truth
|
| 5 |
+
expected confidence level. Integrates with BioEval's calibration scoring
|
| 6 |
+
when available, with a built-in fallback.
|
| 7 |
+
|
| 8 |
+
Scoring dimensions:
|
| 9 |
+
- Confidence level alignment (stated vs. expected)
|
| 10 |
+
- Calibration task behavior (acknowledge_unknown, overconfidence_trap, etc.)
|
| 11 |
+
- Default: penalizes extreme overconfidence
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
import json
|
| 16 |
+
from typing import Dict, List
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
from biorlhf.verifiers.base import BaseVerifier, VerifierResult
|
| 20 |
+
|
| 21 |
+
# ── Try importing BioEval calibration infrastructure ──────────────────────
|
| 22 |
+
try:
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
_bioeval_root = os.environ.get(
|
| 26 |
+
"BIOEVAL_ROOT",
|
| 27 |
+
"/Users/jak4013/Dropbox/Bioinformatics/Claude/Evaluation_model/BioEval",
|
| 28 |
+
)
|
| 29 |
+
sys.path.insert(0, _bioeval_root)
|
| 30 |
+
from bioeval.scoring.calibration import extract_confidence, ConfidenceExtraction
|
| 31 |
+
HAS_BIOEVAL = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
HAS_BIOEVAL = False
|
| 34 |
+
|
| 35 |
+
# ── Built-in confidence extraction (fallback) ─────────────────────────────
|
| 36 |
+
HIGH_CONFIDENCE_PATTERNS = [
|
| 37 |
+
r"\bhigh\s*confidence\b", r"\bvery\s+confident\b", r"\bconfident\s+that\b",
|
| 38 |
+
r"\bcertainly\b", r"\bclearly\b", r"\bdefinitely\b", r"\bwithout\s+doubt\b",
|
| 39 |
+
r"\bconfidence:\s*high\b", r"\bstrongly\s+suggest\b",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
MEDIUM_CONFIDENCE_PATTERNS = [
|
| 43 |
+
r"\bmoderate\s*confidence\b", r"\breasonably\s+confident\b",
|
| 44 |
+
r"\blikely\b", r"\bprobably\b", r"\bsuggest\w*\b",
|
| 45 |
+
r"\bconfidence:\s*medium\b", r"\bconfidence:\s*moderate\b",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
LOW_CONFIDENCE_PATTERNS = [
|
| 49 |
+
r"\blow\s*confidence\b", r"\bnot\s+confident\b", r"\buncertain\b",
|
| 50 |
+
r"\bunclear\b", r"\bnot\s+sure\b", r"\bdon'?t\s+know\b",
|
| 51 |
+
r"\bcannot\s+determine\b", r"\binsufficient\s+\w*\s*(?:data|evidence)\b",
|
| 52 |
+
r"\blimited\s+evidence\b", r"\bspeculat\w*\b",
|
| 53 |
+
r"\bconfidence:\s*low\b",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
# Explicit numeric confidence
|
| 57 |
+
NUMERIC_CONFIDENCE_RE = re.compile(
|
| 58 |
+
r"(?:confidence|certainty|probability)[:\s]*(\d{1,3})%",
|
| 59 |
+
re.IGNORECASE,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Expected confidence ranges
|
| 63 |
+
CONFIDENCE_RANGES = {
|
| 64 |
+
"high": (0.70, 1.00),
|
| 65 |
+
"medium": (0.35, 0.75),
|
| 66 |
+
"low": (0.00, 0.40),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# Expected confidence for calibration task behaviors
|
| 70 |
+
BEHAVIOR_EXPECTED_CONFIDENCE = {
|
| 71 |
+
"acknowledge_unknown": 0.15,
|
| 72 |
+
"high_confidence_correct": 0.90,
|
| 73 |
+
"partial_knowledge": 0.50,
|
| 74 |
+
"context_dependent": 0.50,
|
| 75 |
+
"moderate_confidence": 0.50,
|
| 76 |
+
"overconfidence_trap": 0.30,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class SimpleConfidence:
|
| 82 |
+
"""Fallback confidence extraction result."""
|
| 83 |
+
stated: str # "high", "medium", "low"
|
| 84 |
+
numeric: float # 0.0 to 1.0
|
| 85 |
+
source: str # "explicit", "pattern", "language"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _extract_confidence_simple(text: str) -> SimpleConfidence:
|
| 89 |
+
"""Simple confidence extraction without BioEval."""
|
| 90 |
+
text_lower = text.lower()
|
| 91 |
+
|
| 92 |
+
# Check for explicit numeric confidence
|
| 93 |
+
num_match = NUMERIC_CONFIDENCE_RE.search(text)
|
| 94 |
+
if num_match:
|
| 95 |
+
pct = int(num_match.group(1))
|
| 96 |
+
numeric = pct / 100.0
|
| 97 |
+
if numeric >= 0.70:
|
| 98 |
+
stated = "high"
|
| 99 |
+
elif numeric >= 0.40:
|
| 100 |
+
stated = "medium"
|
| 101 |
+
else:
|
| 102 |
+
stated = "low"
|
| 103 |
+
return SimpleConfidence(stated=stated, numeric=numeric, source="explicit")
|
| 104 |
+
|
| 105 |
+
# Count pattern matches
|
| 106 |
+
high_count = sum(1 for p in HIGH_CONFIDENCE_PATTERNS if re.search(p, text_lower))
|
| 107 |
+
med_count = sum(1 for p in MEDIUM_CONFIDENCE_PATTERNS if re.search(p, text_lower))
|
| 108 |
+
low_count = sum(1 for p in LOW_CONFIDENCE_PATTERNS if re.search(p, text_lower))
|
| 109 |
+
|
| 110 |
+
if low_count > high_count and low_count > med_count:
|
| 111 |
+
return SimpleConfidence(stated="low", numeric=0.25, source="pattern")
|
| 112 |
+
elif high_count > low_count and high_count > med_count:
|
| 113 |
+
return SimpleConfidence(stated="high", numeric=0.85, source="pattern")
|
| 114 |
+
elif med_count > 0:
|
| 115 |
+
return SimpleConfidence(stated="medium", numeric=0.55, source="pattern")
|
| 116 |
+
else:
|
| 117 |
+
# Default: assume moderate confidence
|
| 118 |
+
return SimpleConfidence(stated="medium", numeric=0.50, source="language")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class UncertaintyVerifier(BaseVerifier):
|
| 122 |
+
"""V4: Verifies that model's confidence is appropriate for the question."""
|
| 123 |
+
|
| 124 |
+
name = "V4"
|
| 125 |
+
|
| 126 |
+
def score(
|
| 127 |
+
self,
|
| 128 |
+
prompt: str,
|
| 129 |
+
completion: str,
|
| 130 |
+
ground_truth: Dict,
|
| 131 |
+
question_type: str,
|
| 132 |
+
) -> VerifierResult:
|
| 133 |
+
gt = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth)
|
| 134 |
+
|
| 135 |
+
expected_confidence = gt.get("expected_confidence")
|
| 136 |
+
correct_behavior = gt.get("correct_behavior")
|
| 137 |
+
|
| 138 |
+
# Extract confidence from completion
|
| 139 |
+
if HAS_BIOEVAL:
|
| 140 |
+
conf_extraction = extract_confidence(completion)
|
| 141 |
+
conf_score = conf_extraction.confidence_score
|
| 142 |
+
stated = conf_extraction.stated_confidence or "medium"
|
| 143 |
+
else:
|
| 144 |
+
simple = _extract_confidence_simple(completion)
|
| 145 |
+
conf_score = simple.numeric
|
| 146 |
+
stated = simple.stated
|
| 147 |
+
|
| 148 |
+
# Route to appropriate scoring
|
| 149 |
+
if correct_behavior:
|
| 150 |
+
return self._score_calibration_behavior(
|
| 151 |
+
completion, gt, conf_score, stated, correct_behavior,
|
| 152 |
+
)
|
| 153 |
+
elif expected_confidence:
|
| 154 |
+
return self._score_confidence_alignment(
|
| 155 |
+
conf_score, stated, expected_confidence,
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
return self._score_default(conf_score, stated)
|
| 159 |
+
|
| 160 |
+
def _score_calibration_behavior(
|
| 161 |
+
self,
|
| 162 |
+
completion: str,
|
| 163 |
+
gt: Dict,
|
| 164 |
+
conf_score: float,
|
| 165 |
+
stated: str,
|
| 166 |
+
correct_behavior: str,
|
| 167 |
+
) -> VerifierResult:
|
| 168 |
+
"""Score calibration tasks with specific expected behaviors."""
|
| 169 |
+
expected_conf = BEHAVIOR_EXPECTED_CONFIDENCE.get(correct_behavior, 0.5)
|
| 170 |
+
conf_error = abs(conf_score - expected_conf)
|
| 171 |
+
|
| 172 |
+
# Behavior-specific checks
|
| 173 |
+
behavior_correct = False
|
| 174 |
+
|
| 175 |
+
if correct_behavior == "acknowledge_unknown":
|
| 176 |
+
# Model should say it doesn't know
|
| 177 |
+
unknown_markers = [
|
| 178 |
+
"don't know", "cannot determine", "not enough",
|
| 179 |
+
"unknown", "hypothetical", "no information",
|
| 180 |
+
"unable to", "cannot assess", "fictional",
|
| 181 |
+
]
|
| 182 |
+
behavior_correct = any(m in completion.lower() for m in unknown_markers)
|
| 183 |
+
|
| 184 |
+
elif correct_behavior == "high_confidence_correct":
|
| 185 |
+
behavior_correct = conf_score >= 0.65
|
| 186 |
+
|
| 187 |
+
elif correct_behavior == "overconfidence_trap":
|
| 188 |
+
# Should NOT be overconfident
|
| 189 |
+
behavior_correct = conf_score < 0.70
|
| 190 |
+
|
| 191 |
+
elif correct_behavior in ("partial_knowledge", "context_dependent", "moderate_confidence"):
|
| 192 |
+
behavior_correct = 0.30 <= conf_score <= 0.75
|
| 193 |
+
|
| 194 |
+
# Combined score: behavior compliance + confidence alignment
|
| 195 |
+
behavior_score = 1.0 if behavior_correct else 0.0
|
| 196 |
+
alignment_score = max(0, 1.0 - 2.0 * conf_error)
|
| 197 |
+
score = 0.6 * behavior_score + 0.4 * alignment_score
|
| 198 |
+
|
| 199 |
+
return VerifierResult(
|
| 200 |
+
score=score,
|
| 201 |
+
verifier_name=self.name,
|
| 202 |
+
details={
|
| 203 |
+
"correct_behavior": correct_behavior,
|
| 204 |
+
"expected_confidence": expected_conf,
|
| 205 |
+
"actual_confidence": conf_score,
|
| 206 |
+
"stated_confidence": stated,
|
| 207 |
+
"confidence_error": conf_error,
|
| 208 |
+
"behavior_correct": behavior_correct,
|
| 209 |
+
"using_bioeval": HAS_BIOEVAL,
|
| 210 |
+
},
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def _score_confidence_alignment(
|
| 214 |
+
self,
|
| 215 |
+
conf_score: float,
|
| 216 |
+
stated: str,
|
| 217 |
+
expected: str,
|
| 218 |
+
) -> VerifierResult:
|
| 219 |
+
"""Score how well stated confidence aligns with expected level."""
|
| 220 |
+
if expected not in CONFIDENCE_RANGES:
|
| 221 |
+
return VerifierResult(
|
| 222 |
+
score=0.5, verifier_name=self.name,
|
| 223 |
+
details={"reason": "unknown_expected_level"},
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
low, high = CONFIDENCE_RANGES[expected]
|
| 227 |
+
in_range = low <= conf_score <= high
|
| 228 |
+
|
| 229 |
+
if in_range:
|
| 230 |
+
score = 1.0
|
| 231 |
+
else:
|
| 232 |
+
distance = min(abs(conf_score - low), abs(conf_score - high))
|
| 233 |
+
score = max(0, 1.0 - 2.5 * distance)
|
| 234 |
+
|
| 235 |
+
return VerifierResult(
|
| 236 |
+
score=score,
|
| 237 |
+
verifier_name=self.name,
|
| 238 |
+
details={
|
| 239 |
+
"expected_level": expected,
|
| 240 |
+
"expected_range": (low, high),
|
| 241 |
+
"actual_confidence": conf_score,
|
| 242 |
+
"stated_confidence": stated,
|
| 243 |
+
"in_range": in_range,
|
| 244 |
+
"using_bioeval": HAS_BIOEVAL,
|
| 245 |
+
},
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def _score_default(
|
| 249 |
+
self,
|
| 250 |
+
conf_score: float,
|
| 251 |
+
stated: str,
|
| 252 |
+
) -> VerifierResult:
|
| 253 |
+
"""Default scoring: penalize extreme overconfidence."""
|
| 254 |
+
if conf_score > 0.90:
|
| 255 |
+
score = 0.4 # Overconfidence penalty
|
| 256 |
+
elif conf_score < 0.10:
|
| 257 |
+
score = 0.3 # Extreme underconfidence penalty
|
| 258 |
+
else:
|
| 259 |
+
score = 0.7 # Moderate confidence is good default
|
| 260 |
+
|
| 261 |
+
return VerifierResult(
|
| 262 |
+
score=score,
|
| 263 |
+
verifier_name=self.name,
|
| 264 |
+
details={
|
| 265 |
+
"actual_confidence": conf_score,
|
| 266 |
+
"stated_confidence": stated,
|
| 267 |
+
"mode": "default",
|
| 268 |
+
"using_bioeval": HAS_BIOEVAL,
|
| 269 |
+
},
|
| 270 |
+
)
|