InosLihka commited on
Commit
ff20f02
·
1 Parent(s): b9c9b8f

Add SKIP_EVAL flag to sft_on_hf.py for faster training-only runs

Browse files
Files changed (1) hide show
  1. scripts/sft_on_hf.py +23 -15
scripts/sft_on_hf.py CHANGED
@@ -122,14 +122,18 @@ def main():
122
  sft_args.extend(["--max_steps", str(MAX_STEPS)])
123
  run(sft_args)
124
 
125
- # 4. Eval (3 conditions: discrete-3 / in-dist / OOD)
126
- eval_args = [
127
- "python", "training/inference_eval.py",
128
- "--model_path", OUTPUT_DIR,
129
- "--num_episodes", "5",
130
- "--output_file", "eval_results.json",
131
- ]
132
- run(eval_args)
 
 
 
 
133
 
134
  # 5. Upload to HF Hub
135
  token = os.environ.get("HF_TOKEN")
@@ -149,18 +153,22 @@ def main():
149
  repo_type="model",
150
  commit_message=f"SFT prime ({EPOCHS} epochs, lora r={LORA_RANK}) on teacher trajectories",
151
  )
152
- api.upload_file(
153
- path_or_fileobj="eval_results.json",
154
- path_in_repo="eval_results.json",
155
- repo_id=MODEL_REPO,
156
- repo_type="model",
157
- )
 
158
 
159
  print()
160
  print("=" * 60)
161
  print("DONE")
162
  print(f" SFT'd model: https://huggingface.co/{MODEL_REPO}")
163
- print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
 
 
 
164
  print("=" * 60)
165
 
166
 
 
122
  sft_args.extend(["--max_steps", str(MAX_STEPS)])
123
  run(sft_args)
124
 
125
+ # 4. Eval (optional set SKIP_EVAL=1 to upload faster and run eval separately)
126
+ skip_eval = os.environ.get("SKIP_EVAL", "0") == "1"
127
+ if not skip_eval:
128
+ eval_args = [
129
+ "python", "training/inference_eval.py",
130
+ "--model_path", OUTPUT_DIR,
131
+ "--num_episodes", "5",
132
+ "--output_file", "eval_results.json",
133
+ ]
134
+ run(eval_args)
135
+ else:
136
+ print("SKIP_EVAL=1: skipping embedded eval (run scripts/eval_on_hf.py separately)")
137
 
138
  # 5. Upload to HF Hub
139
  token = os.environ.get("HF_TOKEN")
 
153
  repo_type="model",
154
  commit_message=f"SFT prime ({EPOCHS} epochs, lora r={LORA_RANK}) on teacher trajectories",
155
  )
156
+ if not skip_eval and Path("eval_results.json").exists():
157
+ api.upload_file(
158
+ path_or_fileobj="eval_results.json",
159
+ path_in_repo="eval_results.json",
160
+ repo_id=MODEL_REPO,
161
+ repo_type="model",
162
+ )
163
 
164
  print()
165
  print("=" * 60)
166
  print("DONE")
167
  print(f" SFT'd model: https://huggingface.co/{MODEL_REPO}")
168
+ if not skip_eval:
169
+ print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
170
+ else:
171
+ print(" Eval skipped — run scripts/eval_on_hf.py separately")
172
  print("=" * 60)
173
 
174