shank commited on
Commit Β·
4668456
1
Parent(s): a2fa47a
fix: serialize bug_metadata as JSON to fix pyarrow mixed-type error
Browse files- training/train_grpo.py +3 -2
training/train_grpo.py
CHANGED
|
@@ -386,7 +386,8 @@ def reward_fn(completions: list[str], prompts: list[str], **kwargs) -> list[floa
|
|
| 386 |
GRPO learns from RELATIVE differences within each group.
|
| 387 |
"""
|
| 388 |
rewards = []
|
| 389 |
-
|
|
|
|
| 390 |
|
| 391 |
for completion, bug in zip(completions, bugs):
|
| 392 |
try:
|
|
@@ -452,7 +453,7 @@ model.train()
|
|
| 452 |
# ββ Build initial dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 453 |
def make_dataset(step: int) -> Dataset:
|
| 454 |
bugs = get_bugs_for_step(step)
|
| 455 |
-
return Dataset.from_list([{"prompt": bug_to_prompt(b), "bug_metadata": b} for b in bugs])
|
| 456 |
|
| 457 |
# ββ Training config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 458 |
config = GRPOConfig(
|
|
|
|
| 386 |
GRPO learns from RELATIVE differences within each group.
|
| 387 |
"""
|
| 388 |
rewards = []
|
| 389 |
+
bugs_raw = kwargs.get("bug_metadata", [{}] * len(completions))
|
| 390 |
+
bugs = [json.loads(b) if isinstance(b, str) else b for b in bugs_raw]
|
| 391 |
|
| 392 |
for completion, bug in zip(completions, bugs):
|
| 393 |
try:
|
|
|
|
| 453 |
# ββ Build initial dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 454 |
def make_dataset(step: int) -> Dataset:
|
| 455 |
bugs = get_bugs_for_step(step)
|
| 456 |
+
return Dataset.from_list([{"prompt": bug_to_prompt(b), "bug_metadata": json.dumps(b)} for b in bugs])
|
| 457 |
|
| 458 |
# ββ Training config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 459 |
config = GRPOConfig(
|