Phase 3: Fix GRPO learning signal with continuous rewards and multi-reward
Browse files- Smooth V1 scoring: continuous claim-strength within discrete bands [0.80-1.0]
- Smooth V4 scoring: continuous confidence-distance function peaking at 0.5
- Add make_individual_reward_functions() for TRL multi-reward support
- Add use_multi_reward config option with reward_weights in GRPOConfig
- Add eval_batch_size config to fix G=16 divisibility requirement
- Create grpo_full_v2.json: G=16, beta=0.02, num_iterations=2, 3 epochs
- Update run_grpo_full.sh to use v2 config
- Update .gitignore with BioGRPO model output dirs
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- .gitignore +7 -0
- configs/grpo_full_v2.json +47 -0
- scripts/run_grpo_full.sh +7 -7
- src/biorlhf/training/grpo.py +38 -8
- src/biorlhf/verifiers/composer.py +87 -0
- src/biorlhf/verifiers/pathway.py +18 -9
- src/biorlhf/verifiers/uncertainty.py +7 -7
.gitignore
CHANGED
|
@@ -186,3 +186,10 @@ logs/
|
|
| 186 |
# Uncomment below if you want to exclude datasets from git:
|
| 187 |
# *.json
|
| 188 |
# !kmp_test_set.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
# Uncomment below if you want to exclude datasets from git:
|
| 187 |
# *.json
|
| 188 |
# !kmp_test_set.json
|
| 189 |
+
|
| 190 |
+
# BioGRPO / ecosystem model outputs
|
| 191 |
+
ecosystem_improved_model/
|
| 192 |
+
biogrpo_mve_model/
|
| 193 |
+
biogrpo_full_model/
|
| 194 |
+
biogrpo_full_v2_model/
|
| 195 |
+
data/*.json
|
configs/grpo_full_v2.json
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "mistralai/Mistral-7B-v0.3",
|
| 3 |
+
"sft_model_path": "./kmp_sft_model_final",
|
| 4 |
+
"output_dir": "./biogrpo_full_v2_model",
|
| 5 |
+
|
| 6 |
+
"num_generations": 16,
|
| 7 |
+
"beta": 0.02,
|
| 8 |
+
"num_iterations": 2,
|
| 9 |
+
"scale_rewards": "group",
|
| 10 |
+
"loss_type": "grpo",
|
| 11 |
+
|
| 12 |
+
"num_epochs": 3,
|
| 13 |
+
"batch_size": 1,
|
| 14 |
+
"eval_batch_size": 16,
|
| 15 |
+
"gradient_accumulation_steps": 8,
|
| 16 |
+
"learning_rate": 5e-7,
|
| 17 |
+
"max_completion_length": 1024,
|
| 18 |
+
"max_prompt_length": 512,
|
| 19 |
+
"warmup_ratio": 0.1,
|
| 20 |
+
|
| 21 |
+
"lora_r": 32,
|
| 22 |
+
"lora_alpha": 64,
|
| 23 |
+
"lora_dropout": 0.05,
|
| 24 |
+
|
| 25 |
+
"use_multi_reward": true,
|
| 26 |
+
"active_verifiers": ["V1", "V2", "V3", "V4"],
|
| 27 |
+
"verifier_weights": {"V1": 0.35, "V2": 0.30, "V3": 0.15, "V4": 0.20},
|
| 28 |
+
|
| 29 |
+
"pathway_db": "hallmark",
|
| 30 |
+
"hold_out_tissues": ["eye", "thymus"],
|
| 31 |
+
"seed": 42,
|
| 32 |
+
|
| 33 |
+
"use_4bit": true,
|
| 34 |
+
|
| 35 |
+
"wandb_project": "biogrpo",
|
| 36 |
+
"wandb_run_name": "grpo_full_v2_G16_multireward",
|
| 37 |
+
"use_wandb": true,
|
| 38 |
+
"logging_steps": 10,
|
| 39 |
+
"save_steps": 50,
|
| 40 |
+
"eval_steps": 50,
|
| 41 |
+
"save_total_limit": 3,
|
| 42 |
+
"log_completions": true,
|
| 43 |
+
|
| 44 |
+
"use_vllm": false,
|
| 45 |
+
"gradient_checkpointing": true,
|
| 46 |
+
"bf16": true
|
| 47 |
+
}
|
scripts/run_grpo_full.sh
CHANGED
|
@@ -5,13 +5,13 @@
|
|
| 5 |
#SBATCH --gres=gpu:1
|
| 6 |
#SBATCH --mem=96G
|
| 7 |
#SBATCH --cpus-per-task=8
|
| 8 |
-
#SBATCH --time=
|
| 9 |
#SBATCH --output=logs/grpo_full_%j.log
|
| 10 |
#SBATCH --error=logs/grpo_full_%j.err
|
| 11 |
|
| 12 |
# ============================================================
|
| 13 |
-
# BioGRPO Full Experiment
|
| 14 |
-
# All V1-V4 verifiers, G=
|
| 15 |
# ============================================================
|
| 16 |
|
| 17 |
SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
|
|
@@ -59,14 +59,14 @@ if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then
|
|
| 59 |
echo "Symlinked kmp_sft_model_final"
|
| 60 |
fi
|
| 61 |
|
| 62 |
-
echo "Starting BioGRPO Full training..."
|
| 63 |
-
biorlhf-grpo --config configs/
|
| 64 |
|
| 65 |
if [ $? -eq 0 ]; then
|
| 66 |
echo ""
|
| 67 |
echo "============================================================"
|
| 68 |
-
echo "BioGRPO Full training completed!"
|
| 69 |
-
echo "Model saved to: ./
|
| 70 |
echo "End time: $(date)"
|
| 71 |
echo "============================================================"
|
| 72 |
else
|
|
|
|
| 5 |
#SBATCH --gres=gpu:1
|
| 6 |
#SBATCH --mem=96G
|
| 7 |
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --time=48:00:00
|
| 9 |
#SBATCH --output=logs/grpo_full_%j.log
|
| 10 |
#SBATCH --error=logs/grpo_full_%j.err
|
| 11 |
|
| 12 |
# ============================================================
|
| 13 |
+
# BioGRPO Full Experiment v2
|
| 14 |
+
# All V1-V4 verifiers, G=16, multi-reward, from SFT checkpoint
|
| 15 |
# ============================================================
|
| 16 |
|
| 17 |
SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
|
|
|
|
| 59 |
echo "Symlinked kmp_sft_model_final"
|
| 60 |
fi
|
| 61 |
|
| 62 |
+
echo "Starting BioGRPO Full v2 training..."
|
| 63 |
+
biorlhf-grpo --config configs/grpo_full_v2.json
|
| 64 |
|
| 65 |
if [ $? -eq 0 ]; then
|
| 66 |
echo ""
|
| 67 |
echo "============================================================"
|
| 68 |
+
echo "BioGRPO Full v2 training completed!"
|
| 69 |
+
echo "Model saved to: ./biogrpo_full_v2_model"
|
| 70 |
echo "End time: $(date)"
|
| 71 |
echo "============================================================"
|
| 72 |
else
|
src/biorlhf/training/grpo.py
CHANGED
|
@@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
| 14 |
from peft import LoraConfig, PeftModel
|
| 15 |
from trl import GRPOTrainer, GRPOConfig
|
| 16 |
|
| 17 |
-
from biorlhf.verifiers.composer import make_grpo_reward_function
|
| 18 |
from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
|
| 19 |
|
| 20 |
|
|
@@ -37,6 +37,7 @@ class BioGRPOConfig:
|
|
| 37 |
# Training hyperparameters
|
| 38 |
num_epochs: int = 1
|
| 39 |
batch_size: int = 2
|
|
|
|
| 40 |
gradient_accumulation_steps: int = 8
|
| 41 |
learning_rate: float = 1e-6
|
| 42 |
max_completion_length: int = 1024
|
|
@@ -57,6 +58,9 @@ class BioGRPOConfig:
|
|
| 57 |
hold_out_tissues: Optional[List[str]] = None
|
| 58 |
seed: int = 42
|
| 59 |
|
|
|
|
|
|
|
|
|
|
| 60 |
# Quantization
|
| 61 |
use_4bit: bool = True
|
| 62 |
|
|
@@ -137,12 +141,21 @@ def run_grpo_training(config: Optional[BioGRPOConfig] = None) -> str:
|
|
| 137 |
print(f" By type: {train_stats['by_question_type']}")
|
| 138 |
print(f" Eval: {eval_stats['total']} samples")
|
| 139 |
|
| 140 |
-
# 2. Create reward function
|
| 141 |
print("\n[2/5] Initializing verifier stack...")
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
print(f" Active: {config.active_verifiers or ['V1', 'V2', 'V3', 'V4']}")
|
| 147 |
|
| 148 |
# 3. Load tokenizer (always from base model; adapter dirs lack config.json)
|
|
@@ -214,10 +227,11 @@ def run_grpo_training(config: Optional[BioGRPOConfig] = None) -> str:
|
|
| 214 |
# 6. Configure GRPOTrainer
|
| 215 |
print("\n[5/6] Configuring GRPOTrainer...")
|
| 216 |
|
| 217 |
-
|
| 218 |
output_dir=config.output_dir,
|
| 219 |
num_train_epochs=config.num_epochs,
|
| 220 |
per_device_train_batch_size=config.batch_size,
|
|
|
|
| 221 |
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| 222 |
learning_rate=config.learning_rate,
|
| 223 |
warmup_ratio=config.warmup_ratio,
|
|
@@ -250,10 +264,26 @@ def run_grpo_training(config: Optional[BioGRPOConfig] = None) -> str:
|
|
| 250 |
log_completions=config.log_completions,
|
| 251 |
)
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
trainer = GRPOTrainer(
|
| 254 |
model=model,
|
| 255 |
args=grpo_config,
|
| 256 |
-
reward_funcs=
|
| 257 |
train_dataset=train_dataset,
|
| 258 |
eval_dataset=eval_dataset,
|
| 259 |
peft_config=peft_config,
|
|
|
|
| 14 |
from peft import LoraConfig, PeftModel
|
| 15 |
from trl import GRPOTrainer, GRPOConfig
|
| 16 |
|
| 17 |
+
from biorlhf.verifiers.composer import make_grpo_reward_function, make_individual_reward_functions
|
| 18 |
from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
|
| 19 |
|
| 20 |
|
|
|
|
| 37 |
# Training hyperparameters
|
| 38 |
num_epochs: int = 1
|
| 39 |
batch_size: int = 2
|
| 40 |
+
eval_batch_size: Optional[int] = None
|
| 41 |
gradient_accumulation_steps: int = 8
|
| 42 |
learning_rate: float = 1e-6
|
| 43 |
max_completion_length: int = 1024
|
|
|
|
| 58 |
hold_out_tissues: Optional[List[str]] = None
|
| 59 |
seed: int = 42
|
| 60 |
|
| 61 |
+
# Multi-reward (per-verifier TRL reward functions)
|
| 62 |
+
use_multi_reward: bool = False
|
| 63 |
+
|
| 64 |
# Quantization
|
| 65 |
use_4bit: bool = True
|
| 66 |
|
|
|
|
| 141 |
print(f" By type: {train_stats['by_question_type']}")
|
| 142 |
print(f" Eval: {eval_stats['total']} samples")
|
| 143 |
|
| 144 |
+
# 2. Create reward function(s)
|
| 145 |
print("\n[2/5] Initializing verifier stack...")
|
| 146 |
+
reward_weights = None
|
| 147 |
+
if config.use_multi_reward:
|
| 148 |
+
reward_funcs, reward_weights = make_individual_reward_functions(
|
| 149 |
+
active_verifiers=config.active_verifiers,
|
| 150 |
+
weights=config.verifier_weights,
|
| 151 |
+
)
|
| 152 |
+
print(f" Mode: multi-reward ({len(reward_funcs)} per-verifier functions)")
|
| 153 |
+
print(f" Weights: {reward_weights}")
|
| 154 |
+
else:
|
| 155 |
+
reward_funcs = make_grpo_reward_function(
|
| 156 |
+
weights=config.verifier_weights,
|
| 157 |
+
active_verifiers=config.active_verifiers,
|
| 158 |
+
)
|
| 159 |
print(f" Active: {config.active_verifiers or ['V1', 'V2', 'V3', 'V4']}")
|
| 160 |
|
| 161 |
# 3. Load tokenizer (always from base model; adapter dirs lack config.json)
|
|
|
|
| 227 |
# 6. Configure GRPOTrainer
|
| 228 |
print("\n[5/6] Configuring GRPOTrainer...")
|
| 229 |
|
| 230 |
+
grpo_kwargs = dict(
|
| 231 |
output_dir=config.output_dir,
|
| 232 |
num_train_epochs=config.num_epochs,
|
| 233 |
per_device_train_batch_size=config.batch_size,
|
| 234 |
+
per_device_eval_batch_size=config.eval_batch_size or config.batch_size,
|
| 235 |
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| 236 |
learning_rate=config.learning_rate,
|
| 237 |
warmup_ratio=config.warmup_ratio,
|
|
|
|
| 264 |
log_completions=config.log_completions,
|
| 265 |
)
|
| 266 |
|
| 267 |
+
# Add reward_weights for multi-reward mode if TRL supports it
|
| 268 |
+
if reward_weights is not None:
|
| 269 |
+
try:
|
| 270 |
+
# Check if GRPOConfig accepts reward_weights (TRL >= 0.27)
|
| 271 |
+
import inspect
|
| 272 |
+
if "reward_weights" in inspect.signature(GRPOConfig).parameters:
|
| 273 |
+
grpo_kwargs["reward_weights"] = reward_weights
|
| 274 |
+
print(f" reward_weights set in GRPOConfig: {reward_weights}")
|
| 275 |
+
else:
|
| 276 |
+
print(" Warning: TRL version does not support reward_weights in GRPOConfig")
|
| 277 |
+
print(" Per-verifier functions will still be used, but with equal weights")
|
| 278 |
+
except Exception:
|
| 279 |
+
pass
|
| 280 |
+
|
| 281 |
+
grpo_config = GRPOConfig(**grpo_kwargs)
|
| 282 |
+
|
| 283 |
trainer = GRPOTrainer(
|
| 284 |
model=model,
|
| 285 |
args=grpo_config,
|
| 286 |
+
reward_funcs=reward_funcs,
|
| 287 |
train_dataset=train_dataset,
|
| 288 |
eval_dataset=eval_dataset,
|
| 289 |
peft_config=peft_config,
|
src/biorlhf/verifiers/composer.py
CHANGED
|
@@ -202,6 +202,93 @@ def make_grpo_reward_function(
|
|
| 202 |
return reward_func
|
| 203 |
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
def make_single_verifier_reward(verifier_name: str) -> Callable:
|
| 206 |
"""Create a reward function using only one verifier (for ablation)."""
|
| 207 |
return make_grpo_reward_function(active_verifiers=[verifier_name])
|
|
|
|
| 202 |
return reward_func
|
| 203 |
|
| 204 |
|
| 205 |
+
def make_individual_reward_functions(
|
| 206 |
+
active_verifiers: Optional[List[str]] = None,
|
| 207 |
+
weights: Optional[Dict[str, float]] = None,
|
| 208 |
+
) -> tuple:
|
| 209 |
+
"""Return (list_of_reward_funcs, list_of_weights) for TRL multi-reward.
|
| 210 |
+
|
| 211 |
+
Each reward function wraps a single verifier and returns List[float | None].
|
| 212 |
+
Non-applicable verifiers return None for samples where they don't apply;
|
| 213 |
+
TRL natively excludes None rewards from the GRPO calculation.
|
| 214 |
+
|
| 215 |
+
This enables per-verifier reward normalization in TRL, preventing a single
|
| 216 |
+
low-variance verifier from dominating the gradient signal.
|
| 217 |
+
"""
|
| 218 |
+
all_verifiers = {
|
| 219 |
+
"V1": PathwayDirectionVerifier(),
|
| 220 |
+
"V2": BiologicalFactVerifier(),
|
| 221 |
+
"V3": CrossContextConsistencyVerifier(),
|
| 222 |
+
"V4": UncertaintyVerifier(),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
if active_verifiers:
|
| 226 |
+
verifiers = {k: v for k, v in all_verifiers.items() if k in active_verifiers}
|
| 227 |
+
else:
|
| 228 |
+
verifiers = all_verifiers
|
| 229 |
+
|
| 230 |
+
w = dict(weights or DEFAULT_WEIGHTS)
|
| 231 |
+
weight_list = [w.get(k, 0) for k in verifiers]
|
| 232 |
+
|
| 233 |
+
def _make_single_reward_fn(verifier: BaseVerifier, vname: str) -> Callable:
|
| 234 |
+
"""Create a closure-safe reward function for a single verifier."""
|
| 235 |
+
|
| 236 |
+
def reward_func(
|
| 237 |
+
completions: List,
|
| 238 |
+
ground_truth: Optional[List[str]] = None,
|
| 239 |
+
question_type: Optional[List[str]] = None,
|
| 240 |
+
applicable_verifiers: Optional[List[str]] = None,
|
| 241 |
+
**kwargs,
|
| 242 |
+
) -> List:
|
| 243 |
+
n = len(completions)
|
| 244 |
+
if ground_truth is None:
|
| 245 |
+
ground_truth = ["{}"] * n
|
| 246 |
+
if question_type is None:
|
| 247 |
+
question_type = ["unknown"] * n
|
| 248 |
+
if applicable_verifiers is None:
|
| 249 |
+
applicable_verifiers = [json.dumps(list(all_verifiers.keys()))] * n
|
| 250 |
+
|
| 251 |
+
prompts = kwargs.get("prompts", kwargs.get("prompt", [""] * n))
|
| 252 |
+
if isinstance(prompts, str):
|
| 253 |
+
prompts = [prompts] * n
|
| 254 |
+
|
| 255 |
+
rewards = []
|
| 256 |
+
for i in range(n):
|
| 257 |
+
app = (
|
| 258 |
+
json.loads(applicable_verifiers[i])
|
| 259 |
+
if isinstance(applicable_verifiers[i], str)
|
| 260 |
+
else applicable_verifiers[i]
|
| 261 |
+
)
|
| 262 |
+
if not verifier.is_applicable(app):
|
| 263 |
+
rewards.append(None)
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
completion_text = _extract_text(completions[i])
|
| 267 |
+
prompt_text = _extract_text(prompts[i]) if i < len(prompts) else ""
|
| 268 |
+
gt = (
|
| 269 |
+
json.loads(ground_truth[i])
|
| 270 |
+
if isinstance(ground_truth[i], str)
|
| 271 |
+
else ground_truth[i]
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
result = verifier.score(prompt_text, completion_text, gt, question_type[i])
|
| 275 |
+
if not result.applicable:
|
| 276 |
+
rewards.append(None)
|
| 277 |
+
else:
|
| 278 |
+
rewards.append(result.score)
|
| 279 |
+
|
| 280 |
+
return rewards
|
| 281 |
+
|
| 282 |
+
reward_func.__name__ = f"reward_{vname}"
|
| 283 |
+
return reward_func
|
| 284 |
+
|
| 285 |
+
reward_funcs = [
|
| 286 |
+
_make_single_reward_fn(v, name) for name, v in verifiers.items()
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
return reward_funcs, weight_list
|
| 290 |
+
|
| 291 |
+
|
| 292 |
def make_single_verifier_reward(verifier_name: str) -> Callable:
|
| 293 |
"""Create a reward function using only one verifier (for ablation)."""
|
| 294 |
return make_grpo_reward_function(active_verifiers=[verifier_name])
|
src/biorlhf/verifiers/pathway.py
CHANGED
|
@@ -4,11 +4,11 @@ V1: Pathway Direction Verifier.
|
|
| 4 |
Extracts directional claims about biological pathways from model responses
|
| 5 |
and compares them against fGSEA NES direction ground truth.
|
| 6 |
|
| 7 |
-
Scoring:
|
| 8 |
-
1.
|
| 9 |
-
0.
|
| 10 |
-
0.
|
| 11 |
-
0.0 β wrong direction
|
| 12 |
"""
|
| 13 |
|
| 14 |
import re
|
|
@@ -229,15 +229,23 @@ class PathwayDirectionVerifier(BaseVerifier):
|
|
| 229 |
contradicting = [
|
| 230 |
c for c in claims if c[1] != expected_dir and c[1] != "AMBIGUOUS"
|
| 231 |
]
|
|
|
|
| 232 |
|
|
|
|
| 233 |
if matching and not contradicting:
|
| 234 |
-
|
|
|
|
|
|
|
| 235 |
elif matching and contradicting:
|
| 236 |
-
|
|
|
|
|
|
|
| 237 |
elif contradicting:
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
else:
|
| 240 |
-
score = 0.3 # Only ambiguous claims
|
| 241 |
|
| 242 |
return VerifierResult(
|
| 243 |
score=score,
|
|
@@ -248,6 +256,7 @@ class PathwayDirectionVerifier(BaseVerifier):
|
|
| 248 |
"claims": [(v, d) for v, d in claims],
|
| 249 |
"n_matching": len(matching),
|
| 250 |
"n_contradicting": len(contradicting),
|
|
|
|
| 251 |
},
|
| 252 |
)
|
| 253 |
|
|
|
|
| 4 |
Extracts directional claims about biological pathways from model responses
|
| 5 |
and compares them against fGSEA NES direction ground truth.
|
| 6 |
|
| 7 |
+
Scoring (continuous within bands to provide GRPO gradient signal):
|
| 8 |
+
[0.80, 1.00] β correct direction (strength = n_matching / (n_matching + 1))
|
| 9 |
+
[0.40, 0.60] β mixed/contradictory claims (ratio of matching to total)
|
| 10 |
+
0.30 β no directional claim extracted (fixed)
|
| 11 |
+
[0.00, 0.10] β wrong direction (modulated by ambiguity ratio)
|
| 12 |
"""
|
| 13 |
|
| 14 |
import re
|
|
|
|
| 229 |
contradicting = [
|
| 230 |
c for c in claims if c[1] != expected_dir and c[1] != "AMBIGUOUS"
|
| 231 |
]
|
| 232 |
+
ambiguous = [c for c in claims if c[1] == "AMBIGUOUS"]
|
| 233 |
|
| 234 |
+
# Continuous scoring within discrete bands for GRPO gradient signal
|
| 235 |
if matching and not contradicting:
|
| 236 |
+
# Correct direction: [0.80, 1.00] based on claim strength
|
| 237 |
+
strength = len(matching) / (len(matching) + 1)
|
| 238 |
+
score = 0.8 + 0.2 * strength
|
| 239 |
elif matching and contradicting:
|
| 240 |
+
# Mixed claims: [0.40, 0.60] based on matching ratio
|
| 241 |
+
total = len(matching) + len(contradicting)
|
| 242 |
+
score = 0.4 + 0.2 * (len(matching) / total)
|
| 243 |
elif contradicting:
|
| 244 |
+
# Wrong direction: [0.00, 0.10] modulated by ambiguity
|
| 245 |
+
ambiguity_ratio = len(ambiguous) / len(claims) if claims else 0
|
| 246 |
+
score = 0.1 * ambiguity_ratio
|
| 247 |
else:
|
| 248 |
+
score = 0.3 # Only ambiguous claims β no variance possible
|
| 249 |
|
| 250 |
return VerifierResult(
|
| 251 |
score=score,
|
|
|
|
| 256 |
"claims": [(v, d) for v, d in claims],
|
| 257 |
"n_matching": len(matching),
|
| 258 |
"n_contradicting": len(contradicting),
|
| 259 |
+
"n_ambiguous": len(ambiguous),
|
| 260 |
},
|
| 261 |
)
|
| 262 |
|
src/biorlhf/verifiers/uncertainty.py
CHANGED
|
@@ -250,13 +250,13 @@ class UncertaintyVerifier(BaseVerifier):
|
|
| 250 |
conf_score: float,
|
| 251 |
stated: str,
|
| 252 |
) -> VerifierResult:
|
| 253 |
-
"""Default scoring:
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
|
| 261 |
return VerifierResult(
|
| 262 |
score=score,
|
|
|
|
| 250 |
conf_score: float,
|
| 251 |
stated: str,
|
| 252 |
) -> VerifierResult:
|
| 253 |
+
"""Default scoring: continuous function rewarding moderate confidence.
|
| 254 |
+
|
| 255 |
+
Peaks at conf=0.5 (score=1.0), smoothly penalizes extremes.
|
| 256 |
+
Range: [0.25, 1.0] β provides GRPO gradient signal even when
|
| 257 |
+
generations extract slightly different confidence values.
|
| 258 |
+
"""
|
| 259 |
+
score = max(0.2, 1.0 - abs(conf_score - 0.5) * 1.5)
|
| 260 |
|
| 261 |
return VerifierResult(
|
| 262 |
score=score,
|