polyguard-openenv-workbench / polyguard-rl /tests /test_hf_training_sweep.py
TheJackBright's picture
Deploy GitHub root master to Space
c296d62
from __future__ import annotations
import json
from pathlib import Path
from app.common.constants import REQUIRED_REWARD_KEYS
from app.training.sft_trl import effective_sft_max_steps, effective_sft_save_steps
from scripts.generate_hf_training_report import generate_report
def test_sft_max_steps_zero_means_full_epoch() -> None:
assert effective_sft_max_steps(0) == -1
assert effective_sft_max_steps(-5) == -1
assert effective_sft_max_steps(12) == 12
assert effective_sft_save_steps(0) == 500
assert effective_sft_save_steps(12) == 12
def test_generate_hf_training_report_writes_charts_and_checks(tmp_path: Path) -> None:
run_dir = tmp_path / "sweeps" / "qwen-qwen2-5-0-5b-instruct"
run_dir.mkdir(parents=True)
components = {key: 0.500 for key in REQUIRED_REWARD_KEYS}
(run_dir / "sft_trl_run.json").write_text(
json.dumps(
{
"status": "ok",
"backend": "trl_transformers",
"model_id": "Qwen/Qwen2.5-0.5B-Instruct",
"examples_used": 24,
"train_loss": 0.123,
"artifact_path": "checkpoints/sweeps/qwen/sft_adapter",
}
),
encoding="utf-8",
)
(run_dir / "grpo_trl_run.json").write_text(
json.dumps(
{
"status": "ok",
"backend": "trl_transformers",
"model_id": "Qwen/Qwen2.5-0.5B-Instruct",
"records": 24,
"artifact_path": "checkpoints/sweeps/qwen/grpo_adapter",
"reward_summary": {
"avg_reward": 0.720,
"avg_reward_components": components,
"avg_primary_reward_channels": {
"safety_legality": 0.700,
"clinical_improvement": 0.710,
"dosing_quality": 0.720,
"process_integrity": 0.730,
},
},
}
),
encoding="utf-8",
)
(run_dir / "postsave_inference_sft.json").write_text(
json.dumps({"status": "ok", "model_source": "adapter", "avg_env_reward": 0.650, "valid_rate": 1.0}),
encoding="utf-8",
)
(run_dir / "postsave_inference_grpo.json").write_text(
json.dumps({"status": "ok", "model_source": "adapter", "avg_env_reward": 0.710, "valid_rate": 1.0}),
encoding="utf-8",
)
(run_dir / "sft_history.json").write_text(json.dumps([{"loss": 0.5}, {"loss": 0.25}]), encoding="utf-8")
(run_dir / "grpo_history.json").write_text(json.dumps([{"reward": 0.5}, {"reward": 0.72}]), encoding="utf-8")
(run_dir / "grpo_reward_components.jsonl").write_text(
json.dumps(
{
"legal": True,
"reward": 0.720,
"selected_candidate_id": "cand_01",
"reward_breakdown": components,
"primary_reward_channels": {
"safety_legality": 0.700,
"clinical_improvement": 0.710,
"dosing_quality": 0.720,
"process_integrity": 0.730,
},
}
)
+ "\n",
encoding="utf-8",
)
summary, anti_hacking = generate_report(
sweep_dir=tmp_path / "sweeps",
plot_dir=tmp_path / "plots",
output_path=tmp_path / "hf_sweep_summary.json",
anti_hacking_output=tmp_path / "anti_hacking_overfit_report.json",
)
assert summary["completed_models"] == 1
assert anti_hacking["passed"] is True
for chart_path in summary["charts"].values():
assert Path(chart_path).exists()
def test_generate_hf_training_report_accepts_sft_baseline_sweep(tmp_path: Path) -> None:
run_dir = tmp_path / "sweeps" / "qwen-qwen2-5-1-5b-instruct"
run_dir.mkdir(parents=True)
(run_dir / "sft_trl_run.json").write_text(
json.dumps(
{
"status": "ok",
"backend": "trl_transformers",
"model_id": "Qwen/Qwen2.5-1.5B-Instruct",
"examples_used": 2000,
"train_loss": 0.321,
"artifact_path": "checkpoints/sweeps/qwen/sft_adapter",
}
),
encoding="utf-8",
)
(run_dir / "postsave_inference_sft.json").write_text(
json.dumps(
{
"status": "ok",
"model_source": "adapter",
"avg_env_reward": 0.690,
"valid_rate": 1.0,
"avg_latency_seconds": 0.42,
}
),
encoding="utf-8",
)
(run_dir / "sft_history.json").write_text(json.dumps([{"loss": 0.8}, {"loss": 0.32}]), encoding="utf-8")
summary, anti_hacking = generate_report(
sweep_dir=tmp_path / "sweeps",
plot_dir=tmp_path / "plots",
output_path=tmp_path / "hf_sft_sweep_summary.json",
anti_hacking_output=tmp_path / "anti_hacking_sft_report.json",
mode="sft-baseline",
)
assert summary["training_mode"] == "sft-baseline"
assert summary["completed_models"] == 1
assert summary["models"][0]["artifact_paths"]["sft"]
assert summary["models"][0]["artifact_paths"]["grpo"] == ""
assert anti_hacking["passed"] is True
assert "qwen_model_sft_reward" in summary["charts"]