Humanlearning commited on
Commit
1544ce8
·
1 Parent(s): 60f97ab

fix: update README with SFT training configuration details, modify modal training scripts to disable assistant-only loss and packing for compatibility, and adjust test assertions to reflect these changes

Browse files
README.md CHANGED
@@ -331,11 +331,14 @@ uv run --extra modal modal run --detach scripts/modal_train_sft.py \
331
  `scripts/modal_train_sft.py` re-checks the JSONL reward metadata locally before
332
  upload and again inside Modal before loading the model. It refuses to start SFT
333
  unless all required curriculum difficulties are represented and the verifier
334
- reward metadata passes. The default SFT config trains one full epoch
335
- (`--max-steps -1`) with packed assistant-only loss, bf16/tf32, LoRA rank 32,
336
- and Modal GPU fallback `H200 -> H100 -> A100-80GB -> L40S`. A warm run for the
337
- 300-episode dataset should usually finish in about 15-45 minutes; first image
338
- or model-cache builds can push that closer to 35-75 minutes.
 
 
 
339
 
340
  Continue GRPO from the SFT LoRA:
341
 
 
331
  `scripts/modal_train_sft.py` re-checks the JSONL reward metadata locally before
332
  upload and again inside Modal before loading the model. It refuses to start SFT
333
  unless all required curriculum difficulties are represented and the verifier
334
+ reward metadata passes. The default SFT config trains the full dataset
335
+ (`--max-steps -1`) with bf16/tf32, LoRA rank 32, and Modal GPU fallback
336
+ `H200 -> H100 -> A100-80GB -> L40S`. TRL does not support packing or
337
+ assistant-only loss for the Gemma 4 vision-language loader, so both remain
338
+ disabled for this model. Dataset preprocessing disables multiprocessing because
339
+ the Gemma/Unsloth config is not pickle-safe under TRL dataset workers. A warm run
340
+ for the 300-400 episode dataset should usually finish in about 20-60 minutes;
341
+ first image or model-cache builds can push that closer to 45-90 minutes.
342
 
343
  Continue GRPO from the SFT LoRA:
344
 
scripts/modal_train_grpo.py CHANGED
@@ -1088,7 +1088,11 @@ def train_cybersecurity_owasp_grpo(
1088
  from peft import PeftModel
1089
  from transformers import TrainerCallback
1090
  from trl import GRPOConfig, GRPOTrainer, clone_chat_template
1091
- from trl.chat_template_utils import add_response_schema
 
 
 
 
1092
 
1093
  import trackio
1094
 
 
1088
  from peft import PeftModel
1089
  from transformers import TrainerCallback
1090
  from trl import GRPOConfig, GRPOTrainer, clone_chat_template
1091
+ try:
1092
+ from trl.chat_template_utils import add_response_schema
1093
+ except ImportError:
1094
+ def add_response_schema(tokenizer):
1095
+ return tokenizer
1096
 
1097
  import trackio
1098
 
scripts/modal_train_sft.py CHANGED
@@ -376,7 +376,11 @@ def train_cybersecurity_owasp_sft(
376
  from datasets import load_dataset
377
  from huggingface_hub import snapshot_download
378
  from trl import SFTConfig, SFTTrainer
379
- from trl.chat_template_utils import add_response_schema
 
 
 
 
380
  from unsloth import FastVisionModel
381
 
382
  model_name = _ensure_gemma4_model(model_name)
@@ -478,6 +482,7 @@ def train_cybersecurity_owasp_sft(
478
  "gradient_accumulation_steps": gradient_accumulation_steps,
479
  "learning_rate": learning_rate,
480
  "optim": "adamw_8bit",
 
481
  "logging_steps": 1,
482
  "logging_first_step": True,
483
  "save_steps": max(10, max_steps) if max_steps > 0 else 100,
@@ -485,9 +490,8 @@ def train_cybersecurity_owasp_sft(
485
  "project": trackio_project,
486
  "trackio_space_id": trackio_space_id,
487
  "run_name": run_name,
488
- "assistant_only_loss": True,
489
- "packing": True,
490
- "packing_strategy": "bfd",
491
  "bf16": True,
492
  "tf32": True,
493
  "gradient_checkpointing": True,
 
376
  from datasets import load_dataset
377
  from huggingface_hub import snapshot_download
378
  from trl import SFTConfig, SFTTrainer
379
+ try:
380
+ from trl.chat_template_utils import add_response_schema
381
+ except ImportError:
382
+ def add_response_schema(tokenizer):
383
+ return tokenizer
384
  from unsloth import FastVisionModel
385
 
386
  model_name = _ensure_gemma4_model(model_name)
 
482
  "gradient_accumulation_steps": gradient_accumulation_steps,
483
  "learning_rate": learning_rate,
484
  "optim": "adamw_8bit",
485
+ "dataset_num_proc": None,
486
  "logging_steps": 1,
487
  "logging_first_step": True,
488
  "save_steps": max(10, max_steps) if max_steps > 0 else 100,
 
490
  "project": trackio_project,
491
  "trackio_space_id": trackio_space_id,
492
  "run_name": run_name,
493
+ "assistant_only_loss": False,
494
+ "packing": False,
 
495
  "bf16": True,
496
  "tf32": True,
497
  "gradient_checkpointing": True,
tests/test_modal_scenario_cache_static.py CHANGED
@@ -55,9 +55,10 @@ def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
55
  assert source.count("max_steps: int = -1") >= 2
56
  assert source.count("per_device_train_batch_size: int = 4") >= 2
57
  assert source.count("gradient_accumulation_steps: int = 4") >= 2
58
- assert '"assistant_only_loss": True' in source
59
- assert '"packing": True' in source
60
- assert '"packing_strategy": "bfd"' in source
 
61
  assert '"bf16": True' in source
62
  assert '"tf32": True' in source
63
  assert '"hub_strategy": "every_save"' in source
 
55
  assert source.count("max_steps: int = -1") >= 2
56
  assert source.count("per_device_train_batch_size: int = 4") >= 2
57
  assert source.count("gradient_accumulation_steps: int = 4") >= 2
58
+ assert '"assistant_only_loss": False' in source
59
+ assert '"packing": False' in source
60
+ assert '"packing_strategy": "bfd"' not in source
61
+ assert '"dataset_num_proc": None' in source
62
  assert '"bf16": True' in source
63
  assert '"tf32": True' in source
64
  assert '"hub_strategy": "every_save"' in source