Cyber_analyst-round1 / tests /test_modal_scenario_cache_static.py
Humanlearning's picture
feat: introduce GRPO GPU fallback support, enhance training script with warmstart tagging, and add learning rate parameter for improved training flexibility
1b6d30b
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
def test_modal_train_uses_persistent_required_scenario_cache():
source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")
assert "SCENARIO_CACHE_VOLUME_NAME = \"CyberSecurity_OWASP-scenario-cache\"" in source
assert "SCENARIO_CACHE_DIR = pathlib.Path(\"/scenario-cache\")" in source
assert "CYBERSECURITY_OWASP_SCENARIO_CACHE_MODE" in source
assert "\"require\" if required else \"fallback\"" in source
assert "mode == \"prepare-cache\"" in source
assert "def verify_modal_scenario_cache_for_training" in source
assert "CPU scenario cache preflight passed" in source
assert "scenario_cache.assert_coverage" in source
assert "volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume}" in source
def test_modal_ephemeral_smoke_uses_required_scenario_cache():
source = (ROOT / "scripts" / "modal_ephemeral_train.py").read_text(encoding="utf-8")
assert "SCENARIO_CACHE_VOLUME_NAME = \"CyberSecurity_OWASP-scenario-cache\"" in source
assert "SCENARIO_CACHE_DIR = Path(\"/scenario-cache\")" in source
assert "mode == \"prepare-cache\"" in source
assert "_configure_scenario_cache_env(required=True)" in source
assert "ScenarioCache(SCENARIO_CACHE_DIR" in source
def test_modal_training_is_pinned_to_gemma4_e2b():
source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")
assert 'GRPO_GPU_FALLBACK = ["L40S", "L4"]' in source
assert "gpu=GRPO_GPU_FALLBACK" in source
assert "DEFAULT_GEMMA_MODEL = \"unsloth/gemma-4-E2B-it\"" in source
assert "def _ensure_gemma4_model(model_name: str) -> str:" in source
assert "model_name = _ensure_gemma4_model(model_name)" in source
assert "from unsloth import FastVisionModel" in source
assert "Qwen" not in source
assert "FastLanguageModel" not in source
assert "sft-warmstart-grpo" in source
assert "-sft-warmstart" in source
assert "learning_rate: float = 5e-6" in source
assert '"learning_rate": learning_rate' in source
def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
source = (ROOT / "scripts" / "modal_train_sft.py").read_text(encoding="utf-8")
assert 'SFT_GPU_FALLBACK = ["H200", "H100", "A100-80GB", "L40S"]' in source
assert "gpu=SFT_GPU_FALLBACK" in source
assert "DEFAULT_TOTAL_TRAIN_EPISODES = 300" in source
assert "DEFAULT_EPISODES_PER_LEVEL = 75" in source
assert 'DEFAULT_CURRICULUM_LEVELS = "0,1,2,3"' in source
assert (
'DEFAULT_SFT_OUTPUT_REPO_ID = (\n'
' "Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora"'
) in source
assert "output_repo_id = output_repo_id or DEFAULT_SFT_OUTPUT_REPO_ID" in source
assert source.count("max_steps: int = -1") >= 2
assert source.count("per_device_train_batch_size: int = 4") >= 2
assert source.count("gradient_accumulation_steps: int = 4") >= 2
assert '"assistant_only_loss": False' in source
assert '"packing": False' in source
assert '"packing_strategy": "bfd"' not in source
assert '"dataset_num_proc": None' in source
assert "Dataset.from_list(tokenized_rows)" in source
assert "tokenizer.apply_chat_template" in source
assert "class CyberSecurityOWASPSFTTrainer(SFTTrainer)" in source
assert "Trainer.compute_loss(self, model, inputs" in source
assert '"bf16": True' in source
assert '"tf32": True' in source
assert '"hub_strategy": "every_save"' in source
assert 'trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID' in source
assert 'trackio_project: str = DEFAULT_TRACKIO_PROJECT' in source
assert 'os.environ["TRACKIO_SPACE_ID"] = trackio_space_id' in source
assert 'os.environ["TRACKIO_PROJECT"] = trackio_project' in source
def test_modal_grpo_loads_sft_adapter_from_hub_as_trainable_lora():
source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")
assert "initial_adapter_repo_id" in source
assert "Downloading initial SFT adapter" in source
assert "snapshot_download(" in source
assert "Attaching Unsloth LoRA before loading SFT weights" in source
assert "load_safetensors_file(str(adapter_weights_path), device=\"cpu\")" in source
assert "set_peft_model_state_dict(" in source
assert "unexpected_adapter_keys" in source