File size: 5,964 Bytes
3807ea3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Ephemeral Modal Labs launcher for CyberSecurity_OWASP training smoke runs.

Run from the repo root:

    modal run scripts/modal_ephemeral_train.py --mode smoke --episodes 4

This intentionally stays separate from ``training/train_grpo.py``. It packages
the local repo into a temporary Modal app and returns compact JSON artifacts to
the local process, so the run disappears when ``modal run`` exits.
"""

from __future__ import annotations

import json
from datetime import datetime
from pathlib import Path
from typing import Any

import modal


APP_NAME = "CyberSecurity_OWASP-ephemeral-training"
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
PROJECT_ROOT = Path(__file__).resolve().parents[1]

app = modal.App(APP_NAME)

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("git")
    .add_local_dir(
        PROJECT_ROOT,
        remote_path=REMOTE_PROJECT,
        copy=True,
        ignore=[
            ".git",
            ".venv",
            "__pycache__",
            ".pytest_cache",
            "outputs",
            "*.pyc",
        ],
    )
    .run_commands(f"pip install -e {REMOTE_PROJECT}")
    .workdir(REMOTE_PROJECT)
)


class NoopTrainer:
    """Deterministic placeholder policy for cheap Modal smoke runs."""

    def generate_rollout_completions(self, prompts: list[str]) -> list[dict[str, Any]]:
        return [
            {
                "text": '{"tool_name":"noop","arguments":{}}',
                "prompt_ids": [],
                "completion_ids": [],
                "logprobs": [],
            }
            for _ in prompts
        ]


@app.function(image=image, timeout=60 * 30)
def run_ephemeral_smoke(episodes: int = 4, seed_start: int = 0) -> dict[str, Any]:
    from CyberSecurity_OWASP.models import CyberSecurityOWASPAction
    from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
        CybersecurityOwaspEnvironment,
    )
    from training.rollout import rollout_once

    baseline = []
    oracle = []

    for offset in range(episodes):
        seed = seed_start + offset

        baseline_env = CybersecurityOwaspEnvironment()
        baseline_env.reset(seed=seed, split="validation")
        baseline.append(rollout_once(NoopTrainer(), baseline_env, max_steps=5))

        oracle_env = CybersecurityOwaspEnvironment()
        oracle_env.reset(seed=seed, split="validation")
        hidden = oracle_env.state.hidden_facts
        oracle_env.step(
            CyberSecurityOWASPAction(
                tool_name="submit_finding",
                arguments={
                    "summary": "BOLA/IDOR authorization bug in invoice read route.",
                    "evidence": (
                        f"user {hidden['owner_user_id']} can request invoice "
                        f"{hidden['other_invoice_id']} despite the owner/admin policy"
                    ),
                    "policy_rule": "Only owner or billing_admin in same tenant may read invoices.",
                },
            )
        )
        source = (
            Path(hidden["workspace"]) / "app/routes/invoices.py"
        ).read_text(encoding="utf-8")
        fixed = source.replace(
            "    # BUG: this only checks that the caller is authenticated. It forgets the\n"
            "    # owner/admin and tenant policy checks required by the policy graph.\n"
            "    return {\"status\": 200, \"body\": invoice}\n",
            "    if invoice[\"tenant_id\"] != actor[\"tenant_id\"]:\n"
            "        return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
            "    if invoice[\"owner_user_id\"] != actor[\"user_id\"] and not is_billing_admin(actor):\n"
            "        return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
            "    return {\"status\": 200, \"body\": invoice}\n",
        )
        oracle_env.step(
            CyberSecurityOWASPAction(
                tool_name="patch_file",
                arguments={"path": "app/routes/invoices.py", "content": fixed},
            )
        )
        oracle_env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
        final = oracle_env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
        oracle.append(
            {
                "seed": seed,
                "success": oracle_env.state.success,
                "reward_total": final.reward_breakdown.get("total", 0.0),
                "reward_breakdown": final.reward_breakdown,
            }
        )

    def mean(items: list[dict[str, Any]], key: str) -> float:
        return sum(float(item.get(key, 0.0)) for item in items) / max(1, len(items))

    return {
        "run_name": f"{APP_NAME}-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}",
        "mode": "smoke",
        "episodes": episodes,
        "seed_start": seed_start,
        "baseline_mean_reward": mean(baseline, "reward_total"),
        "oracle_mean_reward": mean(oracle, "reward_total"),
        "oracle_success_rate": mean(oracle, "success"),
        "baseline": baseline,
        "oracle": oracle,
    }


@app.function(image=image, timeout=60 * 10)
def run_grpo_config_check() -> str:
    from training.train_grpo import build_grpo_config

    return str(build_grpo_config())


@app.local_entrypoint()
def main(mode: str = "smoke", episodes: int = 4, seed_start: int = 0) -> None:
    if mode == "smoke":
        result = run_ephemeral_smoke.remote(episodes=episodes, seed_start=seed_start)
        output_dir = PROJECT_ROOT / "outputs" / "rollouts"
        output_dir.mkdir(parents=True, exist_ok=True)
        output_path = output_dir / f"{result['run_name']}.json"
        output_path.write_text(json.dumps(result, indent=2, sort_keys=True), encoding="utf-8")
        print(json.dumps({"saved": str(output_path), **result}, indent=2, sort_keys=True))
    elif mode == "grpo-config":
        print(run_grpo_config_check.remote())
    else:
        raise ValueError("mode must be 'smoke' or 'grpo-config'")