File size: 3,172 Bytes
b3ee507
3807ea3
b3ee507
 
 
3807ea3
 
 
 
 
 
4e663d8
 
3807ea3
b3ee507
 
 
be8eade
 
 
 
 
 
 
 
 
3807ea3
b3ee507
 
3807ea3
 
be8eade
4e663d8
b3ee507
 
 
 
 
4e663d8
 
 
 
 
3807ea3
 
 
 
4e663d8
3807ea3
 
 
 
 
 
be8eade
3807ea3
 
 
 
 
 
 
 
 
 
 
b3ee507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8eade
b3ee507
 
 
3807ea3
b3ee507
3807ea3
b3ee507
 
 
 
 
3807ea3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Modal-only GRPO config helper for CyberSecurity_OWASP.

This module intentionally does not run local training.
Use `scripts/modal_train_grpo.py` (persistent) or
`scripts/modal_ephemeral_train.py` (smoke) for execution.
"""

from __future__ import annotations

import os

from training.trackio_utils import build_run_name, get_git_sha


DEFAULT_GEMMA_MODEL = os.getenv("MODEL_NAME", "unsloth/gemma-4-E2B-it")


def ensure_gemma4_model(model_name: str) -> str:
    if model_name != "unsloth/gemma-4-E2B-it":
        raise ValueError(
            "CyberSecurity_OWASP GRPO is pinned to unsloth/gemma-4-E2B-it, "
            "matching the Unsloth Gemma 4 E2B RL notebook."
        )
    return model_name


def build_grpo_config():
    """Build the TRL GRPOConfig used by the Modal training pipeline."""

    from trl import GRPOConfig

    model_name = ensure_gemma4_model(os.getenv("MODEL_NAME", DEFAULT_GEMMA_MODEL))
    difficulty = int(os.getenv("DIFFICULTY", "0"))
    output_dir = os.getenv(
        "OUTPUT_DIR",
        f"CyberSecurity_OWASP-{model_name.replace('/', '-')}-grpo",
    )
    trackio_space_id = os.getenv("TRACKIO_SPACE_ID", "Humanlearning/CyberSecurity_OWASP-trackio")
    os.environ.setdefault("TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo")
    run_name = os.getenv(
        "RUN_NAME",
        build_run_name(model_name, "grpo", difficulty, git_sha=get_git_sha()),
    )
    return GRPOConfig(
        output_dir=output_dir,
        report_to="trackio",
        trackio_space_id=trackio_space_id,
        run_name=run_name,
        logging_steps=1,
        save_steps=25,
        learning_rate=5e-6,
        num_train_epochs=1,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=32,
        num_generations=6,
        max_prompt_length=4096,
        max_completion_length=768,
        use_vllm=True,
        vllm_mode="colocate",
        vllm_gpu_memory_utilization=0.2,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        push_to_hub=False,
    )


def main() -> None:
    import argparse

    parser = argparse.ArgumentParser(
        description=(
            "CyberSecurity_OWASP GRPO config helper."
            " Actual GRPO training is executed on Modal only."
        )
    )
    parser.add_argument(
        "--difficulty",
        type=int,
        default=0,
        help="Optional curriculum difficulty included in the generated run name.",
    )
    parser.add_argument("--model-name", default=DEFAULT_GEMMA_MODEL)
    parser.add_argument(
        "--output-dir",
        default=None,
        help="Optional GRPO output_dir override.",
    )
    args = parser.parse_args()

    os.environ["MODEL_NAME"] = ensure_gemma4_model(args.model_name)
    if args.output_dir:
        os.environ["OUTPUT_DIR"] = args.output_dir

    config = build_grpo_config()
    print("GRPO config (Modal execution):")
    print(config)
    print(
        "Run on Modal, for example:\n"
        "uv run --extra modal modal run scripts/modal_train_grpo.py "
        f"--model-name {args.model_name} --difficulty {args.difficulty}"
    )


if __name__ == "__main__":
    main()