| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| from app.common.constants import REQUIRED_REWARD_KEYS |
| from app.training.sft_trl import effective_sft_max_steps, effective_sft_save_steps |
| from scripts.generate_hf_training_report import generate_report |
|
|
|
|
| def test_sft_max_steps_zero_means_full_epoch() -> None: |
| assert effective_sft_max_steps(0) == -1 |
| assert effective_sft_max_steps(-5) == -1 |
| assert effective_sft_max_steps(12) == 12 |
| assert effective_sft_save_steps(0) == 500 |
| assert effective_sft_save_steps(12) == 12 |
|
|
|
|
| def test_generate_hf_training_report_writes_charts_and_checks(tmp_path: Path) -> None: |
| run_dir = tmp_path / "sweeps" / "qwen-qwen2-5-0-5b-instruct" |
| run_dir.mkdir(parents=True) |
| components = {key: 0.500 for key in REQUIRED_REWARD_KEYS} |
|
|
| (run_dir / "sft_trl_run.json").write_text( |
| json.dumps( |
| { |
| "status": "ok", |
| "backend": "trl_transformers", |
| "model_id": "Qwen/Qwen2.5-0.5B-Instruct", |
| "examples_used": 24, |
| "train_loss": 0.123, |
| "artifact_path": "checkpoints/sweeps/qwen/sft_adapter", |
| } |
| ), |
| encoding="utf-8", |
| ) |
| (run_dir / "grpo_trl_run.json").write_text( |
| json.dumps( |
| { |
| "status": "ok", |
| "backend": "trl_transformers", |
| "model_id": "Qwen/Qwen2.5-0.5B-Instruct", |
| "records": 24, |
| "artifact_path": "checkpoints/sweeps/qwen/grpo_adapter", |
| "reward_summary": { |
| "avg_reward": 0.720, |
| "avg_reward_components": components, |
| "avg_primary_reward_channels": { |
| "safety_legality": 0.700, |
| "clinical_improvement": 0.710, |
| "dosing_quality": 0.720, |
| "process_integrity": 0.730, |
| }, |
| }, |
| } |
| ), |
| encoding="utf-8", |
| ) |
| (run_dir / "postsave_inference_sft.json").write_text( |
| json.dumps({"status": "ok", "model_source": "adapter", "avg_env_reward": 0.650, "valid_rate": 1.0}), |
| encoding="utf-8", |
| ) |
| (run_dir / "postsave_inference_grpo.json").write_text( |
| json.dumps({"status": "ok", "model_source": "adapter", "avg_env_reward": 0.710, "valid_rate": 1.0}), |
| encoding="utf-8", |
| ) |
| (run_dir / "sft_history.json").write_text(json.dumps([{"loss": 0.5}, {"loss": 0.25}]), encoding="utf-8") |
| (run_dir / "grpo_history.json").write_text(json.dumps([{"reward": 0.5}, {"reward": 0.72}]), encoding="utf-8") |
| (run_dir / "grpo_reward_components.jsonl").write_text( |
| json.dumps( |
| { |
| "legal": True, |
| "reward": 0.720, |
| "selected_candidate_id": "cand_01", |
| "reward_breakdown": components, |
| "primary_reward_channels": { |
| "safety_legality": 0.700, |
| "clinical_improvement": 0.710, |
| "dosing_quality": 0.720, |
| "process_integrity": 0.730, |
| }, |
| } |
| ) |
| + "\n", |
| encoding="utf-8", |
| ) |
|
|
| summary, anti_hacking = generate_report( |
| sweep_dir=tmp_path / "sweeps", |
| plot_dir=tmp_path / "plots", |
| output_path=tmp_path / "hf_sweep_summary.json", |
| anti_hacking_output=tmp_path / "anti_hacking_overfit_report.json", |
| ) |
|
|
| assert summary["completed_models"] == 1 |
| assert anti_hacking["passed"] is True |
| for chart_path in summary["charts"].values(): |
| assert Path(chart_path).exists() |
|
|
|
|
| def test_generate_hf_training_report_accepts_sft_baseline_sweep(tmp_path: Path) -> None: |
| run_dir = tmp_path / "sweeps" / "qwen-qwen2-5-1-5b-instruct" |
| run_dir.mkdir(parents=True) |
|
|
| (run_dir / "sft_trl_run.json").write_text( |
| json.dumps( |
| { |
| "status": "ok", |
| "backend": "trl_transformers", |
| "model_id": "Qwen/Qwen2.5-1.5B-Instruct", |
| "examples_used": 2000, |
| "train_loss": 0.321, |
| "artifact_path": "checkpoints/sweeps/qwen/sft_adapter", |
| } |
| ), |
| encoding="utf-8", |
| ) |
| (run_dir / "postsave_inference_sft.json").write_text( |
| json.dumps( |
| { |
| "status": "ok", |
| "model_source": "adapter", |
| "avg_env_reward": 0.690, |
| "valid_rate": 1.0, |
| "avg_latency_seconds": 0.42, |
| } |
| ), |
| encoding="utf-8", |
| ) |
| (run_dir / "sft_history.json").write_text(json.dumps([{"loss": 0.8}, {"loss": 0.32}]), encoding="utf-8") |
|
|
| summary, anti_hacking = generate_report( |
| sweep_dir=tmp_path / "sweeps", |
| plot_dir=tmp_path / "plots", |
| output_path=tmp_path / "hf_sft_sweep_summary.json", |
| anti_hacking_output=tmp_path / "anti_hacking_sft_report.json", |
| mode="sft-baseline", |
| ) |
|
|
| assert summary["training_mode"] == "sft-baseline" |
| assert summary["completed_models"] == 1 |
| assert summary["models"][0]["artifact_paths"]["sft"] |
| assert summary["models"][0]["artifact_paths"]["grpo"] == "" |
| assert anti_hacking["passed"] is True |
| assert "qwen_model_sft_reward" in summary["charts"] |
|
|