Spaces:
Sleeping
Sleeping
Commit ·
448eddd
1
Parent(s): b3ee507
feat: enhance training image setup and add startup notice for Modal execution, improve dependency installation process, and implement training heartbeat for monitoring
Browse files- scripts/modal_train_grpo.py +70 -33
scripts/modal_train_grpo.py
CHANGED
|
@@ -43,6 +43,7 @@ PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
|
| 43 |
PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
|
| 44 |
PUBLIC_REPO_BRANCH = "master"
|
| 45 |
DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
def _model_repo_slug(model_name: str) -> str:
|
|
@@ -85,6 +86,22 @@ def _configure_modal_cache_env() -> dict[str, str]:
|
|
| 85 |
return values
|
| 86 |
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
def _load_local_env_file() -> None:
|
| 89 |
env_path = PROJECT_ROOT / ".env.local"
|
| 90 |
if not env_path.exists():
|
|
@@ -136,6 +153,7 @@ def _source_mode() -> str:
|
|
| 136 |
|
| 137 |
|
| 138 |
def _training_image() -> modal.Image:
|
|
|
|
| 139 |
image = (
|
| 140 |
modal.Image.from_registry(
|
| 141 |
"nvidia/cuda:12.8.0-devel-ubuntu22.04",
|
|
@@ -175,7 +193,7 @@ def _training_image() -> modal.Image:
|
|
| 175 |
repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH)
|
| 176 |
image = image.run_commands(
|
| 177 |
f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}",
|
| 178 |
-
f"python -m pip install -e {REMOTE_PROJECT}",
|
| 179 |
)
|
| 180 |
else:
|
| 181 |
image = image.add_local_dir(
|
|
@@ -194,7 +212,7 @@ def _training_image() -> modal.Image:
|
|
| 194 |
],
|
| 195 |
)
|
| 196 |
image = image.run_commands(
|
| 197 |
-
f"python -m pip install -e {REMOTE_PROJECT}",
|
| 198 |
)
|
| 199 |
|
| 200 |
return image.run_commands(
|
|
@@ -211,10 +229,11 @@ app = modal.App(APP_NAME)
|
|
| 211 |
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
| 212 |
cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
|
| 213 |
secrets = _modal_secrets()
|
|
|
|
| 214 |
|
| 215 |
|
| 216 |
@app.function(
|
| 217 |
-
image=
|
| 218 |
gpu="L4",
|
| 219 |
timeout=4 * 60 * 60,
|
| 220 |
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
|
@@ -251,7 +270,7 @@ def check_training_imports() -> dict[str, str]:
|
|
| 251 |
|
| 252 |
|
| 253 |
@app.function(
|
| 254 |
-
image=
|
| 255 |
gpu="L4",
|
| 256 |
timeout=4 * 60 * 60,
|
| 257 |
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
|
@@ -281,6 +300,8 @@ def train_cybersecurity_owasp_grpo(
|
|
| 281 |
) -> dict[str, str | int | float]:
|
| 282 |
import inspect
|
| 283 |
import statistics
|
|
|
|
|
|
|
| 284 |
|
| 285 |
cache_env = _configure_modal_cache_env()
|
| 286 |
|
|
@@ -737,6 +758,13 @@ def train_cybersecurity_owasp_grpo(
|
|
| 737 |
def on_train_begin(self, args, state, control, **kwargs):
|
| 738 |
try:
|
| 739 |
metrics = log_gpu_metrics(step=int(state.global_step or 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
except Exception as exc:
|
| 741 |
print(f"Trackio GPU metrics initialization skipped: {exc!r}")
|
| 742 |
return control
|
|
@@ -784,34 +812,6 @@ def train_cybersecurity_owasp_grpo(
|
|
| 784 |
print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
|
| 785 |
print(f"Hub push enabled: {push_to_hub}")
|
| 786 |
|
| 787 |
-
trackio.init(
|
| 788 |
-
project=trackio_project,
|
| 789 |
-
name=run_name,
|
| 790 |
-
group="grpo",
|
| 791 |
-
space_id=trackio_space_id,
|
| 792 |
-
auto_log_gpu=True,
|
| 793 |
-
gpu_log_interval=10.0,
|
| 794 |
-
config={
|
| 795 |
-
"environment": "CyberSecurity_OWASP",
|
| 796 |
-
"run_type": "modal_grpo",
|
| 797 |
-
"model_name": model_name,
|
| 798 |
-
"difficulty": difficulty,
|
| 799 |
-
"split": split,
|
| 800 |
-
"dataset_size": dataset_size,
|
| 801 |
-
"max_steps": max_steps,
|
| 802 |
-
"num_generations": num_generations,
|
| 803 |
-
"max_seq_length": max_seq_length,
|
| 804 |
-
"max_completion_length": max_completion_length,
|
| 805 |
-
"lora_rank": lora_rank,
|
| 806 |
-
"gpu_requested": "L4",
|
| 807 |
-
"load_in_4bit": False,
|
| 808 |
-
"fast_inference": False,
|
| 809 |
-
"gradient_checkpointing": "unsloth",
|
| 810 |
-
"optim": "adamw_8bit",
|
| 811 |
-
},
|
| 812 |
-
)
|
| 813 |
-
log_gpu_metrics(step=0)
|
| 814 |
-
|
| 815 |
expected_model_cache = _hf_model_cache_path(model_name)
|
| 816 |
cache_hit = expected_model_cache.exists()
|
| 817 |
print(f"Expected HF model cache path: {expected_model_cache}")
|
|
@@ -919,6 +919,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 919 |
"max_steps": max_steps,
|
| 920 |
"save_steps": max(10, max_steps),
|
| 921 |
"report_to": "trackio",
|
|
|
|
| 922 |
"trackio_space_id": trackio_space_id,
|
| 923 |
"run_name": run_name,
|
| 924 |
"output_dir": str(output_dir),
|
|
@@ -967,7 +968,30 @@ def train_cybersecurity_owasp_grpo(
|
|
| 967 |
}
|
| 968 |
)
|
| 969 |
print("Starting GRPO trainer.train().")
|
| 970 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 971 |
print("GRPO trainer.train() complete.")
|
| 972 |
if push_to_hub:
|
| 973 |
print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}")
|
|
@@ -1099,6 +1123,19 @@ def main(
|
|
| 1099 |
)
|
| 1100 |
print(f"Hub push enabled: {push_to_hub}")
|
| 1101 |
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1102 |
|
| 1103 |
kwargs = dict(
|
| 1104 |
env_repo_id=env_repo_id,
|
|
|
|
| 43 |
PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
|
| 44 |
PUBLIC_REPO_BRANCH = "master"
|
| 45 |
DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
|
| 46 |
+
_IMAGE_NOTICE_PRINTED = False
|
| 47 |
|
| 48 |
|
| 49 |
def _model_repo_slug(model_name: str) -> str:
|
|
|
|
| 86 |
return values
|
| 87 |
|
| 88 |
|
| 89 |
+
def _print_image_startup_notice() -> None:
|
| 90 |
+
global _IMAGE_NOTICE_PRINTED
|
| 91 |
+
if _IMAGE_NOTICE_PRINTED:
|
| 92 |
+
return
|
| 93 |
+
_IMAGE_NOTICE_PRINTED = True
|
| 94 |
+
print(
|
| 95 |
+
"Modal startup phase 1/5: building or validating the GPU training image. "
|
| 96 |
+
"If this takes minutes, it is Modal image packaging/dependency cache work, "
|
| 97 |
+
"not model-weight download."
|
| 98 |
+
)
|
| 99 |
+
print(
|
| 100 |
+
"Later remote phases will print: cache hit/miss, snapshot_download progress, "
|
| 101 |
+
"Unsloth weight loading, GRPO heartbeat, Trackio upload, and volume commits."
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
def _load_local_env_file() -> None:
|
| 106 |
env_path = PROJECT_ROOT / ".env.local"
|
| 107 |
if not env_path.exists():
|
|
|
|
| 153 |
|
| 154 |
|
| 155 |
def _training_image() -> modal.Image:
|
| 156 |
+
_print_image_startup_notice()
|
| 157 |
image = (
|
| 158 |
modal.Image.from_registry(
|
| 159 |
"nvidia/cuda:12.8.0-devel-ubuntu22.04",
|
|
|
|
| 193 |
repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH)
|
| 194 |
image = image.run_commands(
|
| 195 |
f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}",
|
| 196 |
+
f"python -m pip install --no-deps -e {REMOTE_PROJECT}",
|
| 197 |
)
|
| 198 |
else:
|
| 199 |
image = image.add_local_dir(
|
|
|
|
| 212 |
],
|
| 213 |
)
|
| 214 |
image = image.run_commands(
|
| 215 |
+
f"python -m pip install --no-deps -e {REMOTE_PROJECT}",
|
| 216 |
)
|
| 217 |
|
| 218 |
return image.run_commands(
|
|
|
|
| 229 |
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
| 230 |
cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
|
| 231 |
secrets = _modal_secrets()
|
| 232 |
+
training_image = _training_image()
|
| 233 |
|
| 234 |
|
| 235 |
@app.function(
|
| 236 |
+
image=training_image,
|
| 237 |
gpu="L4",
|
| 238 |
timeout=4 * 60 * 60,
|
| 239 |
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
|
|
|
| 270 |
|
| 271 |
|
| 272 |
@app.function(
|
| 273 |
+
image=training_image,
|
| 274 |
gpu="L4",
|
| 275 |
timeout=4 * 60 * 60,
|
| 276 |
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
|
|
|
| 300 |
) -> dict[str, str | int | float]:
|
| 301 |
import inspect
|
| 302 |
import statistics
|
| 303 |
+
import threading
|
| 304 |
+
import time
|
| 305 |
|
| 306 |
cache_env = _configure_modal_cache_env()
|
| 307 |
|
|
|
|
| 758 |
def on_train_begin(self, args, state, control, **kwargs):
|
| 759 |
try:
|
| 760 |
metrics = log_gpu_metrics(step=int(state.global_step or 0))
|
| 761 |
+
log_trackio_metrics(
|
| 762 |
+
{
|
| 763 |
+
"system/model_cache_hit": float(cache_hit),
|
| 764 |
+
"system/hub_push_enabled": float(push_to_hub),
|
| 765 |
+
},
|
| 766 |
+
step=int(state.global_step or 0),
|
| 767 |
+
)
|
| 768 |
except Exception as exc:
|
| 769 |
print(f"Trackio GPU metrics initialization skipped: {exc!r}")
|
| 770 |
return control
|
|
|
|
| 812 |
print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
|
| 813 |
print(f"Hub push enabled: {push_to_hub}")
|
| 814 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
expected_model_cache = _hf_model_cache_path(model_name)
|
| 816 |
cache_hit = expected_model_cache.exists()
|
| 817 |
print(f"Expected HF model cache path: {expected_model_cache}")
|
|
|
|
| 919 |
"max_steps": max_steps,
|
| 920 |
"save_steps": max(10, max_steps),
|
| 921 |
"report_to": "trackio",
|
| 922 |
+
"project": trackio_project,
|
| 923 |
"trackio_space_id": trackio_space_id,
|
| 924 |
"run_name": run_name,
|
| 925 |
"output_dir": str(output_dir),
|
|
|
|
| 968 |
}
|
| 969 |
)
|
| 970 |
print("Starting GRPO trainer.train().")
|
| 971 |
+
heartbeat_stop = threading.Event()
|
| 972 |
+
|
| 973 |
+
def _training_heartbeat() -> None:
|
| 974 |
+
start_time = time.monotonic()
|
| 975 |
+
while not heartbeat_stop.wait(30):
|
| 976 |
+
elapsed = int(time.monotonic() - start_time)
|
| 977 |
+
print(
|
| 978 |
+
"Training heartbeat: still inside trainer.train() "
|
| 979 |
+
f"after {elapsed}s. For this smoke, the slow part is usually "
|
| 980 |
+
f"Gemma generation/backprop on L4: {num_generations} completions "
|
| 981 |
+
f"up to {max_completion_length} tokens, plus Trackio upload."
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
heartbeat_thread = threading.Thread(
|
| 985 |
+
target=_training_heartbeat,
|
| 986 |
+
name="grpo-training-heartbeat",
|
| 987 |
+
daemon=True,
|
| 988 |
+
)
|
| 989 |
+
heartbeat_thread.start()
|
| 990 |
+
try:
|
| 991 |
+
trainer.train()
|
| 992 |
+
finally:
|
| 993 |
+
heartbeat_stop.set()
|
| 994 |
+
heartbeat_thread.join(timeout=2)
|
| 995 |
print("GRPO trainer.train() complete.")
|
| 996 |
if push_to_hub:
|
| 997 |
print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}")
|
|
|
|
| 1123 |
)
|
| 1124 |
print(f"Hub push enabled: {push_to_hub}")
|
| 1125 |
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 1126 |
+
print("Launch phases:")
|
| 1127 |
+
print(
|
| 1128 |
+
"1. Modal image build/validation: happens before remote Python logs; "
|
| 1129 |
+
"slow when local source or dependency layers changed."
|
| 1130 |
+
)
|
| 1131 |
+
print("2. GPU container start on one L4 and persistent volume reload.")
|
| 1132 |
+
print("3. Model cache check in CyberSecurity_OWASP-model-cache.")
|
| 1133 |
+
print("4. Cached snapshot load into GPU RAM with Unsloth progress.")
|
| 1134 |
+
print("5. One GRPO step, Trackio sync, and volume commit.")
|
| 1135 |
+
print(
|
| 1136 |
+
"If there is a long pause after trainer.train() starts, watch for "
|
| 1137 |
+
"Training heartbeat lines every 30 seconds."
|
| 1138 |
+
)
|
| 1139 |
|
| 1140 |
kwargs = dict(
|
| 1141 |
env_repo_id=env_repo_id,
|