Spaces:
Sleeping
Sleeping
Commit ·
1544ce8
1
Parent(s): 60f97ab
fix: update README with SFT training configuration details, modify modal training scripts to disable assistant-only loss and packing for compatibility, and adjust test assertions to reflect these changes
Browse files- README.md +8 -5
- scripts/modal_train_grpo.py +5 -1
- scripts/modal_train_sft.py +8 -4
- tests/test_modal_scenario_cache_static.py +4 -3
README.md
CHANGED
|
@@ -331,11 +331,14 @@ uv run --extra modal modal run --detach scripts/modal_train_sft.py \
|
|
| 331 |
`scripts/modal_train_sft.py` re-checks the JSONL reward metadata locally before
|
| 332 |
upload and again inside Modal before loading the model. It refuses to start SFT
|
| 333 |
unless all required curriculum difficulties are represented and the verifier
|
| 334 |
-
reward metadata passes. The default SFT config trains
|
| 335 |
-
(`--max-steps -1`) with
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
Continue GRPO from the SFT LoRA:
|
| 341 |
|
|
|
|
| 331 |
`scripts/modal_train_sft.py` re-checks the JSONL reward metadata locally before
|
| 332 |
upload and again inside Modal before loading the model. It refuses to start SFT
|
| 333 |
unless all required curriculum difficulties are represented and the verifier
|
| 334 |
+
reward metadata passes. The default SFT config trains the full dataset
|
| 335 |
+
(`--max-steps -1`) with bf16/tf32, LoRA rank 32, and Modal GPU fallback
|
| 336 |
+
`H200 -> H100 -> A100-80GB -> L40S`. TRL does not support packing or
|
| 337 |
+
assistant-only loss for the Gemma 4 vision-language loader, so both remain
|
| 338 |
+
disabled for this model. Dataset preprocessing disables multiprocessing because
|
| 339 |
+
the Gemma/Unsloth config is not pickle-safe under TRL dataset workers. A warm run
|
| 340 |
+
for the 300-400 episode dataset should usually finish in about 20-60 minutes;
|
| 341 |
+
first image or model-cache builds can push that closer to 45-90 minutes.
|
| 342 |
|
| 343 |
Continue GRPO from the SFT LoRA:
|
| 344 |
|
scripts/modal_train_grpo.py
CHANGED
|
@@ -1088,7 +1088,11 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1088 |
from peft import PeftModel
|
| 1089 |
from transformers import TrainerCallback
|
| 1090 |
from trl import GRPOConfig, GRPOTrainer, clone_chat_template
|
| 1091 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1092 |
|
| 1093 |
import trackio
|
| 1094 |
|
|
|
|
| 1088 |
from peft import PeftModel
|
| 1089 |
from transformers import TrainerCallback
|
| 1090 |
from trl import GRPOConfig, GRPOTrainer, clone_chat_template
|
| 1091 |
+
try:
|
| 1092 |
+
from trl.chat_template_utils import add_response_schema
|
| 1093 |
+
except ImportError:
|
| 1094 |
+
def add_response_schema(tokenizer):
|
| 1095 |
+
return tokenizer
|
| 1096 |
|
| 1097 |
import trackio
|
| 1098 |
|
scripts/modal_train_sft.py
CHANGED
|
@@ -376,7 +376,11 @@ def train_cybersecurity_owasp_sft(
|
|
| 376 |
from datasets import load_dataset
|
| 377 |
from huggingface_hub import snapshot_download
|
| 378 |
from trl import SFTConfig, SFTTrainer
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
from unsloth import FastVisionModel
|
| 381 |
|
| 382 |
model_name = _ensure_gemma4_model(model_name)
|
|
@@ -478,6 +482,7 @@ def train_cybersecurity_owasp_sft(
|
|
| 478 |
"gradient_accumulation_steps": gradient_accumulation_steps,
|
| 479 |
"learning_rate": learning_rate,
|
| 480 |
"optim": "adamw_8bit",
|
|
|
|
| 481 |
"logging_steps": 1,
|
| 482 |
"logging_first_step": True,
|
| 483 |
"save_steps": max(10, max_steps) if max_steps > 0 else 100,
|
|
@@ -485,9 +490,8 @@ def train_cybersecurity_owasp_sft(
|
|
| 485 |
"project": trackio_project,
|
| 486 |
"trackio_space_id": trackio_space_id,
|
| 487 |
"run_name": run_name,
|
| 488 |
-
"assistant_only_loss":
|
| 489 |
-
"packing":
|
| 490 |
-
"packing_strategy": "bfd",
|
| 491 |
"bf16": True,
|
| 492 |
"tf32": True,
|
| 493 |
"gradient_checkpointing": True,
|
|
|
|
| 376 |
from datasets import load_dataset
|
| 377 |
from huggingface_hub import snapshot_download
|
| 378 |
from trl import SFTConfig, SFTTrainer
|
| 379 |
+
try:
|
| 380 |
+
from trl.chat_template_utils import add_response_schema
|
| 381 |
+
except ImportError:
|
| 382 |
+
def add_response_schema(tokenizer):
|
| 383 |
+
return tokenizer
|
| 384 |
from unsloth import FastVisionModel
|
| 385 |
|
| 386 |
model_name = _ensure_gemma4_model(model_name)
|
|
|
|
| 482 |
"gradient_accumulation_steps": gradient_accumulation_steps,
|
| 483 |
"learning_rate": learning_rate,
|
| 484 |
"optim": "adamw_8bit",
|
| 485 |
+
"dataset_num_proc": None,
|
| 486 |
"logging_steps": 1,
|
| 487 |
"logging_first_step": True,
|
| 488 |
"save_steps": max(10, max_steps) if max_steps > 0 else 100,
|
|
|
|
| 490 |
"project": trackio_project,
|
| 491 |
"trackio_space_id": trackio_space_id,
|
| 492 |
"run_name": run_name,
|
| 493 |
+
"assistant_only_loss": False,
|
| 494 |
+
"packing": False,
|
|
|
|
| 495 |
"bf16": True,
|
| 496 |
"tf32": True,
|
| 497 |
"gradient_checkpointing": True,
|
tests/test_modal_scenario_cache_static.py
CHANGED
|
@@ -55,9 +55,10 @@ def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
|
|
| 55 |
assert source.count("max_steps: int = -1") >= 2
|
| 56 |
assert source.count("per_device_train_batch_size: int = 4") >= 2
|
| 57 |
assert source.count("gradient_accumulation_steps: int = 4") >= 2
|
| 58 |
-
assert '"assistant_only_loss":
|
| 59 |
-
assert '"packing":
|
| 60 |
-
assert '"packing_strategy": "bfd"' in source
|
|
|
|
| 61 |
assert '"bf16": True' in source
|
| 62 |
assert '"tf32": True' in source
|
| 63 |
assert '"hub_strategy": "every_save"' in source
|
|
|
|
| 55 |
assert source.count("max_steps: int = -1") >= 2
|
| 56 |
assert source.count("per_device_train_batch_size: int = 4") >= 2
|
| 57 |
assert source.count("gradient_accumulation_steps: int = 4") >= 2
|
| 58 |
+
assert '"assistant_only_loss": False' in source
|
| 59 |
+
assert '"packing": False' in source
|
| 60 |
+
assert '"packing_strategy": "bfd"' not in source
|
| 61 |
+
assert '"dataset_num_proc": None' in source
|
| 62 |
assert '"bf16": True' in source
|
| 63 |
assert '"tf32": True' in source
|
| 64 |
assert '"hub_strategy": "every_save"' in source
|