Spaces:
Sleeping
Sleeping
File size: 3,862 Bytes
d63a1ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | """Generate resumable SFT data from deterministic heuristic rollouts."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from training_utils import (
PROMPT_VARIANTS,
append_jsonl,
build_text_example,
generate_heuristic_transitions,
split_for_key,
write_json,
)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--output-root", default="artifacts/lora_qwen3_4b")
parser.add_argument("--augmentations", type=int, default=12)
parser.add_argument("--eval-ratio", type=float, default=0.2)
parser.add_argument("--force", action="store_true")
args = parser.parse_args()
output_root = (ROOT / args.output_root).resolve()
data_dir = output_root / "data"
transitions_path = data_dir / "transitions.jsonl"
train_path = data_dir / "train.jsonl"
eval_path = data_dir / "eval.jsonl"
manifest_path = output_root / "run_manifest.json"
if args.force:
for path in (transitions_path, train_path, eval_path):
if path.exists():
path.unlink()
if transitions_path.exists() and train_path.exists() and eval_path.exists():
print(json.dumps({"status": "already_exists", "output_root": str(output_root)}, indent=2))
return
transition_count = 0
train_examples = 0
eval_examples = 0
for transition in generate_heuristic_transitions():
record = {
"task_id": transition.task_id,
"difficulty": transition.difficulty,
"step_index": transition.step_index,
"observation": transition.observation,
"action": transition.action,
"reward_after_action": transition.reward_after_action,
"score_after_action": transition.score_after_action,
"done": transition.done,
}
append_jsonl(transitions_path, record)
transition_count += 1
for augmentation_index in range(args.augmentations):
prompt_variant = PROMPT_VARIANTS[augmentation_index % len(PROMPT_VARIANTS)]
example = build_text_example(
observation=transition.observation,
action=transition.action,
prompt_variant=prompt_variant,
)
example_record = {
"id": f"{transition.task_id}-step{transition.step_index}-aug{augmentation_index}",
"task_id": transition.task_id,
"difficulty": transition.difficulty,
"step_index": transition.step_index,
"prompt_variant": prompt_variant,
**example,
}
split = split_for_key(example_record["id"], args.eval_ratio)
append_jsonl(train_path if split == "train" else eval_path, example_record)
if split == "train":
train_examples += 1
else:
eval_examples += 1
write_json(
manifest_path,
{
"status": "data_ready",
"output_root": str(output_root),
"transition_count": transition_count,
"train_examples": train_examples,
"eval_examples": eval_examples,
"augmentations": args.augmentations,
"eval_ratio": args.eval_ratio,
},
)
print(
json.dumps(
{
"status": "ok",
"output_root": str(output_root),
"transition_count": transition_count,
"train_examples": train_examples,
"eval_examples": eval_examples,
},
indent=2,
)
)
if __name__ == "__main__":
main()
|