shank commited on
Commit
4668456
Β·
1 Parent(s): a2fa47a

fix: serialize bug_metadata as JSON to fix pyarrow mixed-type error

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +3 -2
training/train_grpo.py CHANGED
@@ -386,7 +386,8 @@ def reward_fn(completions: list[str], prompts: list[str], **kwargs) -> list[floa
386
  GRPO learns from RELATIVE differences within each group.
387
  """
388
  rewards = []
389
- bugs = kwargs.get("bug_metadata", [{}] * len(completions))
 
390
 
391
  for completion, bug in zip(completions, bugs):
392
  try:
@@ -452,7 +453,7 @@ model.train()
452
  # ── Build initial dataset ─────────────────────────────────────────────────────
453
  def make_dataset(step: int) -> Dataset:
454
  bugs = get_bugs_for_step(step)
455
- return Dataset.from_list([{"prompt": bug_to_prompt(b), "bug_metadata": b} for b in bugs])
456
 
457
  # ── Training config ───────────────────────────────────────────────────────────
458
  config = GRPOConfig(
 
386
  GRPO learns from RELATIVE differences within each group.
387
  """
388
  rewards = []
389
+ bugs_raw = kwargs.get("bug_metadata", [{}] * len(completions))
390
+ bugs = [json.loads(b) if isinstance(b, str) else b for b in bugs_raw]
391
 
392
  for completion, bug in zip(completions, bugs):
393
  try:
 
453
  # ── Build initial dataset ─────────────────────────────────────────────────────
454
  def make_dataset(step: int) -> Dataset:
455
  bugs = get_bugs_for_step(step)
456
+ return Dataset.from_list([{"prompt": bug_to_prompt(b), "bug_metadata": json.dumps(b)} for b in bugs])
457
 
458
  # ── Training config ───────────────────────────────────────────────────────────
459
  config = GRPOConfig(