Cyber_analyst-round1 / scripts /modal_ephemeral_train.py
Humanlearning's picture
feat: update training configuration and documentation for Modal execution, including new model integration and enhanced tracking utilities
b3ee507
raw
history blame
13.3 kB
"""Ephemeral Modal Labs launcher for CyberSecurity_OWASP training smoke runs.
Run from the repo root:
modal run scripts/modal_ephemeral_train.py --mode smoke --episodes 4
This intentionally stays separate from ``training/train_grpo.py``. It packages
the local repo into a temporary Modal app and returns compact JSON artifacts to
the local process, so the run disappears when ``modal run`` exits.
"""
from __future__ import annotations
import json
import subprocess
import time
from datetime import datetime
from pathlib import Path
from typing import Any
import modal
APP_NAME = "CyberSecurity_OWASP-ephemeral-training"
SECRET_NAME = "CyberSecurity_OWASP-secrets"
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
PROJECT_ROOT = Path(__file__).resolve().parents[1]
app = modal.App(APP_NAME)
image = (
modal.Image.debian_slim(python_version="3.11")
.apt_install("git")
.add_local_dir(
PROJECT_ROOT,
remote_path=REMOTE_PROJECT,
copy=True,
ignore=[
".git",
".venv",
".env",
".env.*",
"__pycache__",
".pytest_cache",
"outputs",
"*.pyc",
],
)
.run_commands(f"pip install -e {REMOTE_PROJECT}")
.workdir(REMOTE_PROJECT)
)
class NoopTrainer:
"""Deterministic placeholder policy for cheap Modal smoke runs."""
def generate_rollout_completions(self, prompts: list[str]) -> list[dict[str, Any]]:
return [
{
"text": '{"tool_name":"noop","arguments":{}}',
"prompt_ids": [],
"completion_ids": [],
"logprobs": [],
}
for _ in prompts
]
@app.function(
image=image,
timeout=60 * 30,
secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
)
def run_ephemeral_smoke(
episodes: int = 4,
seed_start: int = 0,
trackio_space_id: str = "",
trackio_project: str = "CyberSecurity_OWASP-smoke",
) -> dict[str, Any]:
from CyberSecurity_OWASP.models import CyberSecurityOWASPAction
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
CybersecurityOwaspEnvironment,
)
from training.rollout import rollout_once
from training.trackio_utils import (
aggregate_episode_metrics,
episode_record_from_state,
log_episode_batch,
log_trackio_metrics,
trace_table_rows,
trackio_run,
)
baseline = []
oracle = []
run_context = {
"algo": "modal_ephemeral_smoke",
"reward_version": "reward_v1",
"env_version": "0.1.0",
}
for offset in range(episodes):
seed = seed_start + offset
baseline_env = CybersecurityOwaspEnvironment()
baseline_rollout = rollout_once(
NoopTrainer(),
baseline_env,
max_steps=5,
reset_kwargs={"seed": seed, "split": "validation", "difficulty": 0},
)
baseline_record = episode_record_from_state(
baseline_env.state,
run_context={**run_context, "base_model": "noop"},
)
baseline_record.update(
{
"reward_total": baseline_rollout.get("reward_total", 0.0),
"success": baseline_rollout.get("success", False),
"episode_length": baseline_rollout.get("episode_length", 0),
}
)
baseline.append(baseline_record)
oracle_env = CybersecurityOwaspEnvironment()
oracle_env.reset(seed=seed, split="validation")
hidden = oracle_env.state.hidden_facts
oracle_env.step(
CyberSecurityOWASPAction(
tool_name="submit_finding",
arguments={
"summary": "BOLA/IDOR authorization bug in invoice read route.",
"evidence": (
f"user {hidden['owner_user_id']} can request invoice "
f"{hidden['other_invoice_id']} despite the owner/admin policy"
),
"policy_rule": "Only owner or billing_admin in same tenant may read invoices.",
},
)
)
source = (
Path(hidden["workspace"]) / "app/routes/invoices.py"
).read_text(encoding="utf-8")
fixed = source.replace(
" # BUG: this only checks that the caller is authenticated. It forgets the\n"
" # owner/admin and tenant policy checks required by the policy graph.\n"
" return {\"status\": 200, \"body\": invoice}\n",
" if invoice[\"tenant_id\"] != actor[\"tenant_id\"]:\n"
" return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
" if invoice[\"owner_user_id\"] != actor[\"user_id\"] and not is_billing_admin(actor):\n"
" return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
" return {\"status\": 200, \"body\": invoice}\n",
)
oracle_env.step(
CyberSecurityOWASPAction(
tool_name="patch_file",
arguments={"path": "app/routes/invoices.py", "content": fixed},
)
)
oracle_env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
final = oracle_env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
oracle_record = episode_record_from_state(
oracle_env.state,
run_context={**run_context, "base_model": "oracle"},
final_observation=final.model_dump(),
)
oracle_record.update(
{
"reward_total": final.reward_breakdown.get("total", 0.0),
"success": oracle_env.state.success,
}
)
oracle.append(oracle_record)
def mean(items: list[dict[str, Any]], key: str) -> float:
return sum(float(item.get(key, 0.0)) for item in items) / max(1, len(items))
run_name = f"{APP_NAME}-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}"
episode_records = [*baseline, *oracle]
tracking_metrics = aggregate_episode_metrics(episode_records)
result = {
"run_name": run_name,
"mode": "smoke",
"episodes": episodes,
"seed_start": seed_start,
"baseline_mean_reward": mean(baseline, "reward_total"),
"oracle_mean_reward": mean(oracle, "reward_total"),
"oracle_success_rate": mean(oracle, "success"),
"tracking_metrics": tracking_metrics,
"tracking_trace_rows": trace_table_rows(episode_records),
"baseline": baseline,
"oracle": oracle,
}
with trackio_run(
run_name=run_name,
run_type="modal_ephemeral_smoke",
project=trackio_project,
space_id=trackio_space_id,
config={
"episodes": episodes,
"seed_start": seed_start,
"mode": "smoke",
},
group="smoke",
):
logged_metrics = log_episode_batch(episode_records, step=0)
log_trackio_metrics(
{
**logged_metrics,
"smoke/baseline_mean_reward": result["baseline_mean_reward"],
"smoke/oracle_mean_reward": result["oracle_mean_reward"],
"smoke/oracle_success_rate": result["oracle_success_rate"],
"smoke/episodes": episodes,
},
step=0,
)
return result
@app.function(image=image, timeout=60 * 10)
def run_grpo_config_check() -> str:
from training.train_grpo import build_grpo_config
return str(build_grpo_config())
@app.function(
image=image,
timeout=60 * 10,
secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
)
def verify_trackio_run(
run_name: str,
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
trackio_project: str = "CyberSecurity_OWASP-smoke",
) -> dict[str, Any]:
import os
from training.trackio_utils import (
REQUIRED_SMOKE_TRACKIO_ITEMS,
missing_required_trackio_items,
)
hf_token = os.environ["HF_TOKEN"]
cmd = [
"trackio",
"get",
"run",
"--project",
trackio_project,
"--run",
run_name,
"--space",
trackio_space_id,
"--hf-token",
hf_token,
"--json",
]
metrics_cmd = [
"trackio",
"list",
"metrics",
"--project",
trackio_project,
"--run",
run_name,
"--space",
trackio_space_id,
"--hf-token",
hf_token,
"--json",
]
last_result: dict[str, Any] = {}
for attempt in range(1, 4):
completed = subprocess.run(cmd, capture_output=True, text=True)
metrics_completed = subprocess.run(metrics_cmd, capture_output=True, text=True)
last_result = {
"attempt": attempt,
"returncode": completed.returncode,
"stdout": completed.stdout[-4000:],
"stderr": completed.stderr[-4000:],
"metrics_returncode": metrics_completed.returncode,
"metrics_stdout": metrics_completed.stdout[-4000:],
"metrics_stderr": metrics_completed.stderr[-4000:],
}
if completed.returncode == 0:
data = json.loads(completed.stdout)
if metrics_completed.returncode == 0:
metrics_data = json.loads(metrics_completed.stdout)
if isinstance(metrics_data.get("metrics"), list):
data["metrics"] = metrics_data["metrics"]
missing = missing_required_trackio_items(data, REQUIRED_SMOKE_TRACKIO_ITEMS)
return {
"ok": not missing,
"trackio_space_id": trackio_space_id,
"trackio_project": trackio_project,
"run_name": run_name,
"required_items": list(REQUIRED_SMOKE_TRACKIO_ITEMS),
"missing_required_items": missing,
"run": data,
}
time.sleep(10)
return {
"ok": False,
"trackio_space_id": trackio_space_id,
"trackio_project": trackio_project,
"run_name": run_name,
"last_result": last_result,
}
@app.function(
image=image,
timeout=60 * 10,
secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
)
def inspect_trackio_space(
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
) -> dict[str, Any]:
import os
hf_token = os.environ["HF_TOKEN"]
def run_trackio(args: list[str]) -> dict[str, Any]:
completed = subprocess.run(
["trackio", *args, "--space", trackio_space_id, "--hf-token", hf_token, "--json"],
capture_output=True,
text=True,
)
result = {
"returncode": completed.returncode,
"stdout": completed.stdout[-8000:],
"stderr": completed.stderr[-4000:],
}
if completed.returncode == 0:
result["json"] = json.loads(completed.stdout)
return result
projects_result = run_trackio(["list", "projects"])
projects = (projects_result.get("json") or {}).get("projects", [])
runs_by_project = {
project: run_trackio(["list", "runs", "--project", project])
for project in projects
}
return {
"trackio_space_id": trackio_space_id,
"projects": projects_result,
"runs_by_project": runs_by_project,
}
@app.local_entrypoint()
def main(
mode: str = "smoke",
episodes: int = 4,
seed_start: int = 0,
trackio_space_id: str = "",
trackio_project: str = "CyberSecurity_OWASP-smoke",
run_name: str = "",
) -> None:
if mode == "smoke":
result = run_ephemeral_smoke.remote(
episodes=episodes,
seed_start=seed_start,
trackio_space_id=trackio_space_id,
trackio_project=trackio_project,
)
output_dir = PROJECT_ROOT / "outputs" / "rollouts"
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / f"{result['run_name']}.json"
output_path.write_text(json.dumps(result, indent=2, sort_keys=True), encoding="utf-8")
print(json.dumps({"saved": str(output_path), **result}, indent=2, sort_keys=True))
elif mode == "grpo-config":
print(run_grpo_config_check.remote())
elif mode == "verify-trackio":
if not run_name:
raise ValueError("--run-name is required for verify-trackio mode")
result = verify_trackio_run.remote(
run_name=run_name,
trackio_space_id=trackio_space_id
or "Humanlearning/CyberSecurity_OWASP-trackio",
trackio_project=trackio_project,
)
print(json.dumps(result, indent=2, sort_keys=True))
elif mode == "inspect-trackio":
result = inspect_trackio_space.remote(
trackio_space_id=trackio_space_id
or "Humanlearning/CyberSecurity_OWASP-trackio",
)
print(json.dumps(result, indent=2, sort_keys=True))
else:
raise ValueError(
"mode must be 'smoke', 'grpo-config', 'verify-trackio', or 'inspect-trackio'"
)