File size: 4,415 Bytes
be8eade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b6d30b
 
be8eade
 
 
 
 
 
1b6d30b
 
 
 
60f97ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1544ce8
 
 
 
e5fe6f5
 
 
 
60f97ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5fe6f5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from pathlib import Path


ROOT = Path(__file__).resolve().parents[1]


def test_modal_train_uses_persistent_required_scenario_cache():
    source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")

    assert "SCENARIO_CACHE_VOLUME_NAME = \"CyberSecurity_OWASP-scenario-cache\"" in source
    assert "SCENARIO_CACHE_DIR = pathlib.Path(\"/scenario-cache\")" in source
    assert "CYBERSECURITY_OWASP_SCENARIO_CACHE_MODE" in source
    assert "\"require\" if required else \"fallback\"" in source
    assert "mode == \"prepare-cache\"" in source
    assert "def verify_modal_scenario_cache_for_training" in source
    assert "CPU scenario cache preflight passed" in source
    assert "scenario_cache.assert_coverage" in source
    assert "volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume}" in source


def test_modal_ephemeral_smoke_uses_required_scenario_cache():
    source = (ROOT / "scripts" / "modal_ephemeral_train.py").read_text(encoding="utf-8")

    assert "SCENARIO_CACHE_VOLUME_NAME = \"CyberSecurity_OWASP-scenario-cache\"" in source
    assert "SCENARIO_CACHE_DIR = Path(\"/scenario-cache\")" in source
    assert "mode == \"prepare-cache\"" in source
    assert "_configure_scenario_cache_env(required=True)" in source
    assert "ScenarioCache(SCENARIO_CACHE_DIR" in source


def test_modal_training_is_pinned_to_gemma4_e2b():
    source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")

    assert 'GRPO_GPU_FALLBACK = ["L40S", "L4"]' in source
    assert "gpu=GRPO_GPU_FALLBACK" in source
    assert "DEFAULT_GEMMA_MODEL = \"unsloth/gemma-4-E2B-it\"" in source
    assert "def _ensure_gemma4_model(model_name: str) -> str:" in source
    assert "model_name = _ensure_gemma4_model(model_name)" in source
    assert "from unsloth import FastVisionModel" in source
    assert "Qwen" not in source
    assert "FastLanguageModel" not in source
    assert "sft-warmstart-grpo" in source
    assert "-sft-warmstart" in source
    assert "learning_rate: float = 5e-6" in source
    assert '"learning_rate": learning_rate' in source


def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
    source = (ROOT / "scripts" / "modal_train_sft.py").read_text(encoding="utf-8")

    assert 'SFT_GPU_FALLBACK = ["H200", "H100", "A100-80GB", "L40S"]' in source
    assert "gpu=SFT_GPU_FALLBACK" in source
    assert "DEFAULT_TOTAL_TRAIN_EPISODES = 300" in source
    assert "DEFAULT_EPISODES_PER_LEVEL = 75" in source
    assert 'DEFAULT_CURRICULUM_LEVELS = "0,1,2,3"' in source
    assert (
        'DEFAULT_SFT_OUTPUT_REPO_ID = (\n'
        '    "Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora"'
    ) in source
    assert "output_repo_id = output_repo_id or DEFAULT_SFT_OUTPUT_REPO_ID" in source
    assert source.count("max_steps: int = -1") >= 2
    assert source.count("per_device_train_batch_size: int = 4") >= 2
    assert source.count("gradient_accumulation_steps: int = 4") >= 2
    assert '"assistant_only_loss": False' in source
    assert '"packing": False' in source
    assert '"packing_strategy": "bfd"' not in source
    assert '"dataset_num_proc": None' in source
    assert "Dataset.from_list(tokenized_rows)" in source
    assert "tokenizer.apply_chat_template" in source
    assert "class CyberSecurityOWASPSFTTrainer(SFTTrainer)" in source
    assert "Trainer.compute_loss(self, model, inputs" in source
    assert '"bf16": True' in source
    assert '"tf32": True' in source
    assert '"hub_strategy": "every_save"' in source
    assert 'trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID' in source
    assert 'trackio_project: str = DEFAULT_TRACKIO_PROJECT' in source
    assert 'os.environ["TRACKIO_SPACE_ID"] = trackio_space_id' in source
    assert 'os.environ["TRACKIO_PROJECT"] = trackio_project' in source


def test_modal_grpo_loads_sft_adapter_from_hub_as_trainable_lora():
    source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")

    assert "initial_adapter_repo_id" in source
    assert "Downloading initial SFT adapter" in source
    assert "snapshot_download(" in source
    assert "Attaching Unsloth LoRA before loading SFT weights" in source
    assert "load_safetensors_file(str(adapter_weights_path), device=\"cpu\")" in source
    assert "set_peft_model_state_dict(" in source
    assert "unexpected_adapter_keys" in source