Spaces:
Sleeping
Sleeping
Add SKIP_EVAL flag to sft_on_hf.py for faster training-only runs
Browse files- scripts/sft_on_hf.py +23 -15
scripts/sft_on_hf.py
CHANGED
|
@@ -122,14 +122,18 @@ def main():
|
|
| 122 |
sft_args.extend(["--max_steps", str(MAX_STEPS)])
|
| 123 |
run(sft_args)
|
| 124 |
|
| 125 |
-
# 4. Eval (
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
# 5. Upload to HF Hub
|
| 135 |
token = os.environ.get("HF_TOKEN")
|
|
@@ -149,18 +153,22 @@ def main():
|
|
| 149 |
repo_type="model",
|
| 150 |
commit_message=f"SFT prime ({EPOCHS} epochs, lora r={LORA_RANK}) on teacher trajectories",
|
| 151 |
)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
| 158 |
|
| 159 |
print()
|
| 160 |
print("=" * 60)
|
| 161 |
print("DONE")
|
| 162 |
print(f" SFT'd model: https://huggingface.co/{MODEL_REPO}")
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
| 164 |
print("=" * 60)
|
| 165 |
|
| 166 |
|
|
|
|
| 122 |
sft_args.extend(["--max_steps", str(MAX_STEPS)])
|
| 123 |
run(sft_args)
|
| 124 |
|
| 125 |
+
# 4. Eval (optional — set SKIP_EVAL=1 to upload faster and run eval separately)
|
| 126 |
+
skip_eval = os.environ.get("SKIP_EVAL", "0") == "1"
|
| 127 |
+
if not skip_eval:
|
| 128 |
+
eval_args = [
|
| 129 |
+
"python", "training/inference_eval.py",
|
| 130 |
+
"--model_path", OUTPUT_DIR,
|
| 131 |
+
"--num_episodes", "5",
|
| 132 |
+
"--output_file", "eval_results.json",
|
| 133 |
+
]
|
| 134 |
+
run(eval_args)
|
| 135 |
+
else:
|
| 136 |
+
print("SKIP_EVAL=1: skipping embedded eval (run scripts/eval_on_hf.py separately)")
|
| 137 |
|
| 138 |
# 5. Upload to HF Hub
|
| 139 |
token = os.environ.get("HF_TOKEN")
|
|
|
|
| 153 |
repo_type="model",
|
| 154 |
commit_message=f"SFT prime ({EPOCHS} epochs, lora r={LORA_RANK}) on teacher trajectories",
|
| 155 |
)
|
| 156 |
+
if not skip_eval and Path("eval_results.json").exists():
|
| 157 |
+
api.upload_file(
|
| 158 |
+
path_or_fileobj="eval_results.json",
|
| 159 |
+
path_in_repo="eval_results.json",
|
| 160 |
+
repo_id=MODEL_REPO,
|
| 161 |
+
repo_type="model",
|
| 162 |
+
)
|
| 163 |
|
| 164 |
print()
|
| 165 |
print("=" * 60)
|
| 166 |
print("DONE")
|
| 167 |
print(f" SFT'd model: https://huggingface.co/{MODEL_REPO}")
|
| 168 |
+
if not skip_eval:
|
| 169 |
+
print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
|
| 170 |
+
else:
|
| 171 |
+
print(" Eval skipped — run scripts/eval_on_hf.py separately")
|
| 172 |
print("=" * 60)
|
| 173 |
|
| 174 |
|