from __future__ import annotations import json from pathlib import Path from app.common.constants import REQUIRED_REWARD_KEYS from app.training.sft_trl import effective_sft_max_steps, effective_sft_save_steps from scripts.generate_hf_training_report import generate_report def test_sft_max_steps_zero_means_full_epoch() -> None: assert effective_sft_max_steps(0) == -1 assert effective_sft_max_steps(-5) == -1 assert effective_sft_max_steps(12) == 12 assert effective_sft_save_steps(0) == 500 assert effective_sft_save_steps(12) == 12 def test_generate_hf_training_report_writes_charts_and_checks(tmp_path: Path) -> None: run_dir = tmp_path / "sweeps" / "qwen-qwen2-5-0-5b-instruct" run_dir.mkdir(parents=True) components = {key: 0.500 for key in REQUIRED_REWARD_KEYS} (run_dir / "sft_trl_run.json").write_text( json.dumps( { "status": "ok", "backend": "trl_transformers", "model_id": "Qwen/Qwen2.5-0.5B-Instruct", "examples_used": 24, "train_loss": 0.123, "artifact_path": "checkpoints/sweeps/qwen/sft_adapter", } ), encoding="utf-8", ) (run_dir / "grpo_trl_run.json").write_text( json.dumps( { "status": "ok", "backend": "trl_transformers", "model_id": "Qwen/Qwen2.5-0.5B-Instruct", "records": 24, "artifact_path": "checkpoints/sweeps/qwen/grpo_adapter", "reward_summary": { "avg_reward": 0.720, "avg_reward_components": components, "avg_primary_reward_channels": { "safety_legality": 0.700, "clinical_improvement": 0.710, "dosing_quality": 0.720, "process_integrity": 0.730, }, }, } ), encoding="utf-8", ) (run_dir / "postsave_inference_sft.json").write_text( json.dumps({"status": "ok", "model_source": "adapter", "avg_env_reward": 0.650, "valid_rate": 1.0}), encoding="utf-8", ) (run_dir / "postsave_inference_grpo.json").write_text( json.dumps({"status": "ok", "model_source": "adapter", "avg_env_reward": 0.710, "valid_rate": 1.0}), encoding="utf-8", ) (run_dir / "sft_history.json").write_text(json.dumps([{"loss": 0.5}, {"loss": 0.25}]), encoding="utf-8") (run_dir / "grpo_history.json").write_text(json.dumps([{"reward": 0.5}, {"reward": 0.72}]), encoding="utf-8") (run_dir / "grpo_reward_components.jsonl").write_text( json.dumps( { "legal": True, "reward": 0.720, "selected_candidate_id": "cand_01", "reward_breakdown": components, "primary_reward_channels": { "safety_legality": 0.700, "clinical_improvement": 0.710, "dosing_quality": 0.720, "process_integrity": 0.730, }, } ) + "\n", encoding="utf-8", ) summary, anti_hacking = generate_report( sweep_dir=tmp_path / "sweeps", plot_dir=tmp_path / "plots", output_path=tmp_path / "hf_sweep_summary.json", anti_hacking_output=tmp_path / "anti_hacking_overfit_report.json", ) assert summary["completed_models"] == 1 assert anti_hacking["passed"] is True for chart_path in summary["charts"].values(): assert Path(chart_path).exists() def test_generate_hf_training_report_accepts_sft_baseline_sweep(tmp_path: Path) -> None: run_dir = tmp_path / "sweeps" / "qwen-qwen2-5-1-5b-instruct" run_dir.mkdir(parents=True) (run_dir / "sft_trl_run.json").write_text( json.dumps( { "status": "ok", "backend": "trl_transformers", "model_id": "Qwen/Qwen2.5-1.5B-Instruct", "examples_used": 2000, "train_loss": 0.321, "artifact_path": "checkpoints/sweeps/qwen/sft_adapter", } ), encoding="utf-8", ) (run_dir / "postsave_inference_sft.json").write_text( json.dumps( { "status": "ok", "model_source": "adapter", "avg_env_reward": 0.690, "valid_rate": 1.0, "avg_latency_seconds": 0.42, } ), encoding="utf-8", ) (run_dir / "sft_history.json").write_text(json.dumps([{"loss": 0.8}, {"loss": 0.32}]), encoding="utf-8") summary, anti_hacking = generate_report( sweep_dir=tmp_path / "sweeps", plot_dir=tmp_path / "plots", output_path=tmp_path / "hf_sft_sweep_summary.json", anti_hacking_output=tmp_path / "anti_hacking_sft_report.json", mode="sft-baseline", ) assert summary["training_mode"] == "sft-baseline" assert summary["completed_models"] == 1 assert summary["models"][0]["artifact_paths"]["sft"] assert summary["models"][0]["artifact_paths"]["grpo"] == "" assert anti_hacking["passed"] is True assert "qwen_model_sft_reward" in summary["charts"]