Humanlearning commited on
Commit
448eddd
·
1 Parent(s): b3ee507

feat: enhance training image setup and add startup notice for Modal execution, improve dependency installation process, and implement training heartbeat for monitoring

Browse files
Files changed (1) hide show
  1. scripts/modal_train_grpo.py +70 -33
scripts/modal_train_grpo.py CHANGED
@@ -43,6 +43,7 @@ PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
43
  PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
44
  PUBLIC_REPO_BRANCH = "master"
45
  DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
 
46
 
47
 
48
  def _model_repo_slug(model_name: str) -> str:
@@ -85,6 +86,22 @@ def _configure_modal_cache_env() -> dict[str, str]:
85
  return values
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def _load_local_env_file() -> None:
89
  env_path = PROJECT_ROOT / ".env.local"
90
  if not env_path.exists():
@@ -136,6 +153,7 @@ def _source_mode() -> str:
136
 
137
 
138
  def _training_image() -> modal.Image:
 
139
  image = (
140
  modal.Image.from_registry(
141
  "nvidia/cuda:12.8.0-devel-ubuntu22.04",
@@ -175,7 +193,7 @@ def _training_image() -> modal.Image:
175
  repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH)
176
  image = image.run_commands(
177
  f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}",
178
- f"python -m pip install -e {REMOTE_PROJECT}",
179
  )
180
  else:
181
  image = image.add_local_dir(
@@ -194,7 +212,7 @@ def _training_image() -> modal.Image:
194
  ],
195
  )
196
  image = image.run_commands(
197
- f"python -m pip install -e {REMOTE_PROJECT}",
198
  )
199
 
200
  return image.run_commands(
@@ -211,10 +229,11 @@ app = modal.App(APP_NAME)
211
  volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
212
  cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
213
  secrets = _modal_secrets()
 
214
 
215
 
216
  @app.function(
217
- image=_training_image(),
218
  gpu="L4",
219
  timeout=4 * 60 * 60,
220
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
@@ -251,7 +270,7 @@ def check_training_imports() -> dict[str, str]:
251
 
252
 
253
  @app.function(
254
- image=_training_image(),
255
  gpu="L4",
256
  timeout=4 * 60 * 60,
257
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
@@ -281,6 +300,8 @@ def train_cybersecurity_owasp_grpo(
281
  ) -> dict[str, str | int | float]:
282
  import inspect
283
  import statistics
 
 
284
 
285
  cache_env = _configure_modal_cache_env()
286
 
@@ -737,6 +758,13 @@ def train_cybersecurity_owasp_grpo(
737
  def on_train_begin(self, args, state, control, **kwargs):
738
  try:
739
  metrics = log_gpu_metrics(step=int(state.global_step or 0))
 
 
 
 
 
 
 
740
  except Exception as exc:
741
  print(f"Trackio GPU metrics initialization skipped: {exc!r}")
742
  return control
@@ -784,34 +812,6 @@ def train_cybersecurity_owasp_grpo(
784
  print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
785
  print(f"Hub push enabled: {push_to_hub}")
786
 
787
- trackio.init(
788
- project=trackio_project,
789
- name=run_name,
790
- group="grpo",
791
- space_id=trackio_space_id,
792
- auto_log_gpu=True,
793
- gpu_log_interval=10.0,
794
- config={
795
- "environment": "CyberSecurity_OWASP",
796
- "run_type": "modal_grpo",
797
- "model_name": model_name,
798
- "difficulty": difficulty,
799
- "split": split,
800
- "dataset_size": dataset_size,
801
- "max_steps": max_steps,
802
- "num_generations": num_generations,
803
- "max_seq_length": max_seq_length,
804
- "max_completion_length": max_completion_length,
805
- "lora_rank": lora_rank,
806
- "gpu_requested": "L4",
807
- "load_in_4bit": False,
808
- "fast_inference": False,
809
- "gradient_checkpointing": "unsloth",
810
- "optim": "adamw_8bit",
811
- },
812
- )
813
- log_gpu_metrics(step=0)
814
-
815
  expected_model_cache = _hf_model_cache_path(model_name)
816
  cache_hit = expected_model_cache.exists()
817
  print(f"Expected HF model cache path: {expected_model_cache}")
@@ -919,6 +919,7 @@ def train_cybersecurity_owasp_grpo(
919
  "max_steps": max_steps,
920
  "save_steps": max(10, max_steps),
921
  "report_to": "trackio",
 
922
  "trackio_space_id": trackio_space_id,
923
  "run_name": run_name,
924
  "output_dir": str(output_dir),
@@ -967,7 +968,30 @@ def train_cybersecurity_owasp_grpo(
967
  }
968
  )
969
  print("Starting GRPO trainer.train().")
970
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
  print("GRPO trainer.train() complete.")
972
  if push_to_hub:
973
  print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}")
@@ -1099,6 +1123,19 @@ def main(
1099
  )
1100
  print(f"Hub push enabled: {push_to_hub}")
1101
  print(f"Model cache volume: {CACHE_VOLUME_NAME}")
 
 
 
 
 
 
 
 
 
 
 
 
 
1102
 
1103
  kwargs = dict(
1104
  env_repo_id=env_repo_id,
 
43
  PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
44
  PUBLIC_REPO_BRANCH = "master"
45
  DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
46
+ _IMAGE_NOTICE_PRINTED = False
47
 
48
 
49
  def _model_repo_slug(model_name: str) -> str:
 
86
  return values
87
 
88
 
89
+ def _print_image_startup_notice() -> None:
90
+ global _IMAGE_NOTICE_PRINTED
91
+ if _IMAGE_NOTICE_PRINTED:
92
+ return
93
+ _IMAGE_NOTICE_PRINTED = True
94
+ print(
95
+ "Modal startup phase 1/5: building or validating the GPU training image. "
96
+ "If this takes minutes, it is Modal image packaging/dependency cache work, "
97
+ "not model-weight download."
98
+ )
99
+ print(
100
+ "Later remote phases will print: cache hit/miss, snapshot_download progress, "
101
+ "Unsloth weight loading, GRPO heartbeat, Trackio upload, and volume commits."
102
+ )
103
+
104
+
105
  def _load_local_env_file() -> None:
106
  env_path = PROJECT_ROOT / ".env.local"
107
  if not env_path.exists():
 
153
 
154
 
155
  def _training_image() -> modal.Image:
156
+ _print_image_startup_notice()
157
  image = (
158
  modal.Image.from_registry(
159
  "nvidia/cuda:12.8.0-devel-ubuntu22.04",
 
193
  repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH)
194
  image = image.run_commands(
195
  f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}",
196
+ f"python -m pip install --no-deps -e {REMOTE_PROJECT}",
197
  )
198
  else:
199
  image = image.add_local_dir(
 
212
  ],
213
  )
214
  image = image.run_commands(
215
+ f"python -m pip install --no-deps -e {REMOTE_PROJECT}",
216
  )
217
 
218
  return image.run_commands(
 
229
  volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
230
  cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
231
  secrets = _modal_secrets()
232
+ training_image = _training_image()
233
 
234
 
235
  @app.function(
236
+ image=training_image,
237
  gpu="L4",
238
  timeout=4 * 60 * 60,
239
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
 
270
 
271
 
272
  @app.function(
273
+ image=training_image,
274
  gpu="L4",
275
  timeout=4 * 60 * 60,
276
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
 
300
  ) -> dict[str, str | int | float]:
301
  import inspect
302
  import statistics
303
+ import threading
304
+ import time
305
 
306
  cache_env = _configure_modal_cache_env()
307
 
 
758
  def on_train_begin(self, args, state, control, **kwargs):
759
  try:
760
  metrics = log_gpu_metrics(step=int(state.global_step or 0))
761
+ log_trackio_metrics(
762
+ {
763
+ "system/model_cache_hit": float(cache_hit),
764
+ "system/hub_push_enabled": float(push_to_hub),
765
+ },
766
+ step=int(state.global_step or 0),
767
+ )
768
  except Exception as exc:
769
  print(f"Trackio GPU metrics initialization skipped: {exc!r}")
770
  return control
 
812
  print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
813
  print(f"Hub push enabled: {push_to_hub}")
814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815
  expected_model_cache = _hf_model_cache_path(model_name)
816
  cache_hit = expected_model_cache.exists()
817
  print(f"Expected HF model cache path: {expected_model_cache}")
 
919
  "max_steps": max_steps,
920
  "save_steps": max(10, max_steps),
921
  "report_to": "trackio",
922
+ "project": trackio_project,
923
  "trackio_space_id": trackio_space_id,
924
  "run_name": run_name,
925
  "output_dir": str(output_dir),
 
968
  }
969
  )
970
  print("Starting GRPO trainer.train().")
971
+ heartbeat_stop = threading.Event()
972
+
973
+ def _training_heartbeat() -> None:
974
+ start_time = time.monotonic()
975
+ while not heartbeat_stop.wait(30):
976
+ elapsed = int(time.monotonic() - start_time)
977
+ print(
978
+ "Training heartbeat: still inside trainer.train() "
979
+ f"after {elapsed}s. For this smoke, the slow part is usually "
980
+ f"Gemma generation/backprop on L4: {num_generations} completions "
981
+ f"up to {max_completion_length} tokens, plus Trackio upload."
982
+ )
983
+
984
+ heartbeat_thread = threading.Thread(
985
+ target=_training_heartbeat,
986
+ name="grpo-training-heartbeat",
987
+ daemon=True,
988
+ )
989
+ heartbeat_thread.start()
990
+ try:
991
+ trainer.train()
992
+ finally:
993
+ heartbeat_stop.set()
994
+ heartbeat_thread.join(timeout=2)
995
  print("GRPO trainer.train() complete.")
996
  if push_to_hub:
997
  print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}")
 
1123
  )
1124
  print(f"Hub push enabled: {push_to_hub}")
1125
  print(f"Model cache volume: {CACHE_VOLUME_NAME}")
1126
+ print("Launch phases:")
1127
+ print(
1128
+ "1. Modal image build/validation: happens before remote Python logs; "
1129
+ "slow when local source or dependency layers changed."
1130
+ )
1131
+ print("2. GPU container start on one L4 and persistent volume reload.")
1132
+ print("3. Model cache check in CyberSecurity_OWASP-model-cache.")
1133
+ print("4. Cached snapshot load into GPU RAM with Unsloth progress.")
1134
+ print("5. One GRPO step, Trackio sync, and volume commit.")
1135
+ print(
1136
+ "If there is a long pause after trainer.train() starts, watch for "
1137
+ "Training heartbeat lines every 30 seconds."
1138
+ )
1139
 
1140
  kwargs = dict(
1141
  env_repo_id=env_repo_id,