jang1563 Claude Opus 4.6 commited on
Commit
7dbf475
Β·
1 Parent(s): bff2f94

Phase 3: Fix GRPO learning signal with continuous rewards and multi-reward

Browse files

- Smooth V1 scoring: continuous claim-strength within discrete bands [0.80-1.0]
- Smooth V4 scoring: continuous confidence-distance function peaking at 0.5
- Add make_individual_reward_functions() for TRL multi-reward support
- Add use_multi_reward config option with reward_weights in GRPOConfig
- Add eval_batch_size config to fix G=16 divisibility requirement
- Create grpo_full_v2.json: G=16, beta=0.02, num_iterations=2, 3 epochs
- Update run_grpo_full.sh to use v2 config
- Update .gitignore with BioGRPO model output dirs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

.gitignore CHANGED
@@ -186,3 +186,10 @@ logs/
186
  # Uncomment below if you want to exclude datasets from git:
187
  # *.json
188
  # !kmp_test_set.json
 
 
 
 
 
 
 
 
186
  # Uncomment below if you want to exclude datasets from git:
187
  # *.json
188
  # !kmp_test_set.json
189
+
190
+ # BioGRPO / ecosystem model outputs
191
+ ecosystem_improved_model/
192
+ biogrpo_mve_model/
193
+ biogrpo_full_model/
194
+ biogrpo_full_v2_model/
195
+ data/*.json
configs/grpo_full_v2.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "mistralai/Mistral-7B-v0.3",
3
+ "sft_model_path": "./kmp_sft_model_final",
4
+ "output_dir": "./biogrpo_full_v2_model",
5
+
6
+ "num_generations": 16,
7
+ "beta": 0.02,
8
+ "num_iterations": 2,
9
+ "scale_rewards": "group",
10
+ "loss_type": "grpo",
11
+
12
+ "num_epochs": 3,
13
+ "batch_size": 1,
14
+ "eval_batch_size": 16,
15
+ "gradient_accumulation_steps": 8,
16
+ "learning_rate": 5e-7,
17
+ "max_completion_length": 1024,
18
+ "max_prompt_length": 512,
19
+ "warmup_ratio": 0.1,
20
+
21
+ "lora_r": 32,
22
+ "lora_alpha": 64,
23
+ "lora_dropout": 0.05,
24
+
25
+ "use_multi_reward": true,
26
+ "active_verifiers": ["V1", "V2", "V3", "V4"],
27
+ "verifier_weights": {"V1": 0.35, "V2": 0.30, "V3": 0.15, "V4": 0.20},
28
+
29
+ "pathway_db": "hallmark",
30
+ "hold_out_tissues": ["eye", "thymus"],
31
+ "seed": 42,
32
+
33
+ "use_4bit": true,
34
+
35
+ "wandb_project": "biogrpo",
36
+ "wandb_run_name": "grpo_full_v2_G16_multireward",
37
+ "use_wandb": true,
38
+ "logging_steps": 10,
39
+ "save_steps": 50,
40
+ "eval_steps": 50,
41
+ "save_total_limit": 3,
42
+ "log_completions": true,
43
+
44
+ "use_vllm": false,
45
+ "gradient_checkpointing": true,
46
+ "bf16": true
47
+ }
scripts/run_grpo_full.sh CHANGED
@@ -5,13 +5,13 @@
5
  #SBATCH --gres=gpu:1
6
  #SBATCH --mem=96G
7
  #SBATCH --cpus-per-task=8
8
- #SBATCH --time=24:00:00
9
  #SBATCH --output=logs/grpo_full_%j.log
10
  #SBATCH --error=logs/grpo_full_%j.err
11
 
12
  # ============================================================
13
- # BioGRPO Full Experiment
14
- # All V1-V4 verifiers, G=8, from SFT checkpoint
15
  # ============================================================
16
 
17
  SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
@@ -59,14 +59,14 @@ if [ ! -e "${WORKDIR}/kmp_sft_model_final" ]; then
59
  echo "Symlinked kmp_sft_model_final"
60
  fi
61
 
62
- echo "Starting BioGRPO Full training..."
63
- biorlhf-grpo --config configs/grpo_full.json
64
 
65
  if [ $? -eq 0 ]; then
66
  echo ""
67
  echo "============================================================"
68
- echo "BioGRPO Full training completed!"
69
- echo "Model saved to: ./biogrpo_full_model"
70
  echo "End time: $(date)"
71
  echo "============================================================"
72
  else
 
5
  #SBATCH --gres=gpu:1
6
  #SBATCH --mem=96G
7
  #SBATCH --cpus-per-task=8
8
+ #SBATCH --time=48:00:00
9
  #SBATCH --output=logs/grpo_full_%j.log
10
  #SBATCH --error=logs/grpo_full_%j.err
11
 
12
  # ============================================================
13
+ # BioGRPO Full Experiment v2
14
+ # All V1-V4 verifiers, G=16, multi-reward, from SFT checkpoint
15
  # ============================================================
16
 
17
  SCRATCH="/athena/cayuga_0003/scratch/users/jak4013/otsuka"
 
59
  echo "Symlinked kmp_sft_model_final"
60
  fi
61
 
62
+ echo "Starting BioGRPO Full v2 training..."
63
+ biorlhf-grpo --config configs/grpo_full_v2.json
64
 
65
  if [ $? -eq 0 ]; then
66
  echo ""
67
  echo "============================================================"
68
+ echo "BioGRPO Full v2 training completed!"
69
+ echo "Model saved to: ./biogrpo_full_v2_model"
70
  echo "End time: $(date)"
71
  echo "============================================================"
72
  else
src/biorlhf/training/grpo.py CHANGED
@@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
14
  from peft import LoraConfig, PeftModel
15
  from trl import GRPOTrainer, GRPOConfig
16
 
17
- from biorlhf.verifiers.composer import make_grpo_reward_function
18
  from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
19
 
20
 
@@ -37,6 +37,7 @@ class BioGRPOConfig:
37
  # Training hyperparameters
38
  num_epochs: int = 1
39
  batch_size: int = 2
 
40
  gradient_accumulation_steps: int = 8
41
  learning_rate: float = 1e-6
42
  max_completion_length: int = 1024
@@ -57,6 +58,9 @@ class BioGRPOConfig:
57
  hold_out_tissues: Optional[List[str]] = None
58
  seed: int = 42
59
 
 
 
 
60
  # Quantization
61
  use_4bit: bool = True
62
 
@@ -137,12 +141,21 @@ def run_grpo_training(config: Optional[BioGRPOConfig] = None) -> str:
137
  print(f" By type: {train_stats['by_question_type']}")
138
  print(f" Eval: {eval_stats['total']} samples")
139
 
140
- # 2. Create reward function
141
  print("\n[2/5] Initializing verifier stack...")
142
- reward_func = make_grpo_reward_function(
143
- weights=config.verifier_weights,
144
- active_verifiers=config.active_verifiers,
145
- )
 
 
 
 
 
 
 
 
 
146
  print(f" Active: {config.active_verifiers or ['V1', 'V2', 'V3', 'V4']}")
147
 
148
  # 3. Load tokenizer (always from base model; adapter dirs lack config.json)
@@ -214,10 +227,11 @@ def run_grpo_training(config: Optional[BioGRPOConfig] = None) -> str:
214
  # 6. Configure GRPOTrainer
215
  print("\n[5/6] Configuring GRPOTrainer...")
216
 
217
- grpo_config = GRPOConfig(
218
  output_dir=config.output_dir,
219
  num_train_epochs=config.num_epochs,
220
  per_device_train_batch_size=config.batch_size,
 
221
  gradient_accumulation_steps=config.gradient_accumulation_steps,
222
  learning_rate=config.learning_rate,
223
  warmup_ratio=config.warmup_ratio,
@@ -250,10 +264,26 @@ def run_grpo_training(config: Optional[BioGRPOConfig] = None) -> str:
250
  log_completions=config.log_completions,
251
  )
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  trainer = GRPOTrainer(
254
  model=model,
255
  args=grpo_config,
256
- reward_funcs=reward_func,
257
  train_dataset=train_dataset,
258
  eval_dataset=eval_dataset,
259
  peft_config=peft_config,
 
14
  from peft import LoraConfig, PeftModel
15
  from trl import GRPOTrainer, GRPOConfig
16
 
17
+ from biorlhf.verifiers.composer import make_grpo_reward_function, make_individual_reward_functions
18
  from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
19
 
20
 
 
37
  # Training hyperparameters
38
  num_epochs: int = 1
39
  batch_size: int = 2
40
+ eval_batch_size: Optional[int] = None
41
  gradient_accumulation_steps: int = 8
42
  learning_rate: float = 1e-6
43
  max_completion_length: int = 1024
 
58
  hold_out_tissues: Optional[List[str]] = None
59
  seed: int = 42
60
 
61
+ # Multi-reward (per-verifier TRL reward functions)
62
+ use_multi_reward: bool = False
63
+
64
  # Quantization
65
  use_4bit: bool = True
66
 
 
141
  print(f" By type: {train_stats['by_question_type']}")
142
  print(f" Eval: {eval_stats['total']} samples")
143
 
144
+ # 2. Create reward function(s)
145
  print("\n[2/5] Initializing verifier stack...")
146
+ reward_weights = None
147
+ if config.use_multi_reward:
148
+ reward_funcs, reward_weights = make_individual_reward_functions(
149
+ active_verifiers=config.active_verifiers,
150
+ weights=config.verifier_weights,
151
+ )
152
+ print(f" Mode: multi-reward ({len(reward_funcs)} per-verifier functions)")
153
+ print(f" Weights: {reward_weights}")
154
+ else:
155
+ reward_funcs = make_grpo_reward_function(
156
+ weights=config.verifier_weights,
157
+ active_verifiers=config.active_verifiers,
158
+ )
159
  print(f" Active: {config.active_verifiers or ['V1', 'V2', 'V3', 'V4']}")
160
 
161
  # 3. Load tokenizer (always from base model; adapter dirs lack config.json)
 
227
  # 6. Configure GRPOTrainer
228
  print("\n[5/6] Configuring GRPOTrainer...")
229
 
230
+ grpo_kwargs = dict(
231
  output_dir=config.output_dir,
232
  num_train_epochs=config.num_epochs,
233
  per_device_train_batch_size=config.batch_size,
234
+ per_device_eval_batch_size=config.eval_batch_size or config.batch_size,
235
  gradient_accumulation_steps=config.gradient_accumulation_steps,
236
  learning_rate=config.learning_rate,
237
  warmup_ratio=config.warmup_ratio,
 
264
  log_completions=config.log_completions,
265
  )
266
 
267
+ # Add reward_weights for multi-reward mode if TRL supports it
268
+ if reward_weights is not None:
269
+ try:
270
+ # Check if GRPOConfig accepts reward_weights (TRL >= 0.27)
271
+ import inspect
272
+ if "reward_weights" in inspect.signature(GRPOConfig).parameters:
273
+ grpo_kwargs["reward_weights"] = reward_weights
274
+ print(f" reward_weights set in GRPOConfig: {reward_weights}")
275
+ else:
276
+ print(" Warning: TRL version does not support reward_weights in GRPOConfig")
277
+ print(" Per-verifier functions will still be used, but with equal weights")
278
+ except Exception:
279
+ pass
280
+
281
+ grpo_config = GRPOConfig(**grpo_kwargs)
282
+
283
  trainer = GRPOTrainer(
284
  model=model,
285
  args=grpo_config,
286
+ reward_funcs=reward_funcs,
287
  train_dataset=train_dataset,
288
  eval_dataset=eval_dataset,
289
  peft_config=peft_config,
src/biorlhf/verifiers/composer.py CHANGED
@@ -202,6 +202,93 @@ def make_grpo_reward_function(
202
  return reward_func
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def make_single_verifier_reward(verifier_name: str) -> Callable:
206
  """Create a reward function using only one verifier (for ablation)."""
207
  return make_grpo_reward_function(active_verifiers=[verifier_name])
 
202
  return reward_func
203
 
204
 
205
+ def make_individual_reward_functions(
206
+ active_verifiers: Optional[List[str]] = None,
207
+ weights: Optional[Dict[str, float]] = None,
208
+ ) -> tuple:
209
+ """Return (list_of_reward_funcs, list_of_weights) for TRL multi-reward.
210
+
211
+ Each reward function wraps a single verifier and returns List[float | None].
212
+ Non-applicable verifiers return None for samples where they don't apply;
213
+ TRL natively excludes None rewards from the GRPO calculation.
214
+
215
+ This enables per-verifier reward normalization in TRL, preventing a single
216
+ low-variance verifier from dominating the gradient signal.
217
+ """
218
+ all_verifiers = {
219
+ "V1": PathwayDirectionVerifier(),
220
+ "V2": BiologicalFactVerifier(),
221
+ "V3": CrossContextConsistencyVerifier(),
222
+ "V4": UncertaintyVerifier(),
223
+ }
224
+
225
+ if active_verifiers:
226
+ verifiers = {k: v for k, v in all_verifiers.items() if k in active_verifiers}
227
+ else:
228
+ verifiers = all_verifiers
229
+
230
+ w = dict(weights or DEFAULT_WEIGHTS)
231
+ weight_list = [w.get(k, 0) for k in verifiers]
232
+
233
+ def _make_single_reward_fn(verifier: BaseVerifier, vname: str) -> Callable:
234
+ """Create a closure-safe reward function for a single verifier."""
235
+
236
+ def reward_func(
237
+ completions: List,
238
+ ground_truth: Optional[List[str]] = None,
239
+ question_type: Optional[List[str]] = None,
240
+ applicable_verifiers: Optional[List[str]] = None,
241
+ **kwargs,
242
+ ) -> List:
243
+ n = len(completions)
244
+ if ground_truth is None:
245
+ ground_truth = ["{}"] * n
246
+ if question_type is None:
247
+ question_type = ["unknown"] * n
248
+ if applicable_verifiers is None:
249
+ applicable_verifiers = [json.dumps(list(all_verifiers.keys()))] * n
250
+
251
+ prompts = kwargs.get("prompts", kwargs.get("prompt", [""] * n))
252
+ if isinstance(prompts, str):
253
+ prompts = [prompts] * n
254
+
255
+ rewards = []
256
+ for i in range(n):
257
+ app = (
258
+ json.loads(applicable_verifiers[i])
259
+ if isinstance(applicable_verifiers[i], str)
260
+ else applicable_verifiers[i]
261
+ )
262
+ if not verifier.is_applicable(app):
263
+ rewards.append(None)
264
+ continue
265
+
266
+ completion_text = _extract_text(completions[i])
267
+ prompt_text = _extract_text(prompts[i]) if i < len(prompts) else ""
268
+ gt = (
269
+ json.loads(ground_truth[i])
270
+ if isinstance(ground_truth[i], str)
271
+ else ground_truth[i]
272
+ )
273
+
274
+ result = verifier.score(prompt_text, completion_text, gt, question_type[i])
275
+ if not result.applicable:
276
+ rewards.append(None)
277
+ else:
278
+ rewards.append(result.score)
279
+
280
+ return rewards
281
+
282
+ reward_func.__name__ = f"reward_{vname}"
283
+ return reward_func
284
+
285
+ reward_funcs = [
286
+ _make_single_reward_fn(v, name) for name, v in verifiers.items()
287
+ ]
288
+
289
+ return reward_funcs, weight_list
290
+
291
+
292
  def make_single_verifier_reward(verifier_name: str) -> Callable:
293
  """Create a reward function using only one verifier (for ablation)."""
294
  return make_grpo_reward_function(active_verifiers=[verifier_name])
src/biorlhf/verifiers/pathway.py CHANGED
@@ -4,11 +4,11 @@ V1: Pathway Direction Verifier.
4
  Extracts directional claims about biological pathways from model responses
5
  and compares them against fGSEA NES direction ground truth.
6
 
7
- Scoring:
8
- 1.0 β€” correct direction claimed
9
- 0.5 β€” mixed/contradictory claims
10
- 0.3 β€” no directional claim extracted
11
- 0.0 β€” wrong direction claimed
12
  """
13
 
14
  import re
@@ -229,15 +229,23 @@ class PathwayDirectionVerifier(BaseVerifier):
229
  contradicting = [
230
  c for c in claims if c[1] != expected_dir and c[1] != "AMBIGUOUS"
231
  ]
 
232
 
 
233
  if matching and not contradicting:
234
- score = 1.0
 
 
235
  elif matching and contradicting:
236
- score = 0.5
 
 
237
  elif contradicting:
238
- score = 0.0
 
 
239
  else:
240
- score = 0.3 # Only ambiguous claims
241
 
242
  return VerifierResult(
243
  score=score,
@@ -248,6 +256,7 @@ class PathwayDirectionVerifier(BaseVerifier):
248
  "claims": [(v, d) for v, d in claims],
249
  "n_matching": len(matching),
250
  "n_contradicting": len(contradicting),
 
251
  },
252
  )
253
 
 
4
  Extracts directional claims about biological pathways from model responses
5
  and compares them against fGSEA NES direction ground truth.
6
 
7
+ Scoring (continuous within bands to provide GRPO gradient signal):
8
+ [0.80, 1.00] β€” correct direction (strength = n_matching / (n_matching + 1))
9
+ [0.40, 0.60] β€” mixed/contradictory claims (ratio of matching to total)
10
+ 0.30 β€” no directional claim extracted (fixed)
11
+ [0.00, 0.10] β€” wrong direction (modulated by ambiguity ratio)
12
  """
13
 
14
  import re
 
229
  contradicting = [
230
  c for c in claims if c[1] != expected_dir and c[1] != "AMBIGUOUS"
231
  ]
232
+ ambiguous = [c for c in claims if c[1] == "AMBIGUOUS"]
233
 
234
+ # Continuous scoring within discrete bands for GRPO gradient signal
235
  if matching and not contradicting:
236
+ # Correct direction: [0.80, 1.00] based on claim strength
237
+ strength = len(matching) / (len(matching) + 1)
238
+ score = 0.8 + 0.2 * strength
239
  elif matching and contradicting:
240
+ # Mixed claims: [0.40, 0.60] based on matching ratio
241
+ total = len(matching) + len(contradicting)
242
+ score = 0.4 + 0.2 * (len(matching) / total)
243
  elif contradicting:
244
+ # Wrong direction: [0.00, 0.10] modulated by ambiguity
245
+ ambiguity_ratio = len(ambiguous) / len(claims) if claims else 0
246
+ score = 0.1 * ambiguity_ratio
247
  else:
248
+ score = 0.3 # Only ambiguous claims β€” no variance possible
249
 
250
  return VerifierResult(
251
  score=score,
 
256
  "claims": [(v, d) for v, d in claims],
257
  "n_matching": len(matching),
258
  "n_contradicting": len(contradicting),
259
+ "n_ambiguous": len(ambiguous),
260
  },
261
  )
262
 
src/biorlhf/verifiers/uncertainty.py CHANGED
@@ -250,13 +250,13 @@ class UncertaintyVerifier(BaseVerifier):
250
  conf_score: float,
251
  stated: str,
252
  ) -> VerifierResult:
253
- """Default scoring: penalize extreme overconfidence."""
254
- if conf_score > 0.90:
255
- score = 0.4 # Overconfidence penalty
256
- elif conf_score < 0.10:
257
- score = 0.3 # Extreme underconfidence penalty
258
- else:
259
- score = 0.7 # Moderate confidence is good default
260
 
261
  return VerifierResult(
262
  score=score,
 
250
  conf_score: float,
251
  stated: str,
252
  ) -> VerifierResult:
253
+ """Default scoring: continuous function rewarding moderate confidence.
254
+
255
+ Peaks at conf=0.5 (score=1.0), smoothly penalizes extremes.
256
+ Range: [0.25, 1.0] β€” provides GRPO gradient signal even when
257
+ generations extract slightly different confidence values.
258
+ """
259
+ score = max(0.2, 1.0 - abs(conf_score - 0.5) * 1.5)
260
 
261
  return VerifierResult(
262
  score=score,