Humanlearning commited on
Commit
e5fe6f5
·
1 Parent(s): 1544ce8

feat: enhance SFT training process with new tokenization method, implement custom trainer class for loss computation, and update README with GRPO launcher details for Unsloth LoRA integration

Browse files
README.md CHANGED
@@ -335,13 +335,20 @@ reward metadata passes. The default SFT config trains the full dataset
335
  (`--max-steps -1`) with bf16/tf32, LoRA rank 32, and Modal GPU fallback
336
  `H200 -> H100 -> A100-80GB -> L40S`. TRL does not support packing or
337
  assistant-only loss for the Gemma 4 vision-language loader, so both remain
338
- disabled for this model. Dataset preprocessing disables multiprocessing because
339
- the Gemma/Unsloth config is not pickle-safe under TRL dataset workers. A warm run
340
- for the 300-400 episode dataset should usually finish in about 20-60 minutes;
341
- first image or model-cache builds can push that closer to 45-90 minutes.
 
 
 
342
 
343
  Continue GRPO from the SFT LoRA:
344
 
 
 
 
 
345
  ```bash
346
  uv run --extra modal modal run --detach scripts/modal_train_grpo.py \
347
  --initial-adapter-repo-id Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora \
 
335
  (`--max-steps -1`) with bf16/tf32, LoRA rank 32, and Modal GPU fallback
336
  `H200 -> H100 -> A100-80GB -> L40S`. TRL does not support packing or
337
  assistant-only loss for the Gemma 4 vision-language loader, so both remain
338
+ disabled for this model. The script pre-tokenizes the small JSONL dataset
339
+ serially before constructing `SFTTrainer`, which avoids TRL multiprocessing
340
+ around the Gemma/Unsloth config object. It also uses the base Transformers loss
341
+ path to avoid a TRL entropy-metric incompatibility with Gemma 4 lazy logits. A
342
+ warm run for the 300-400 episode dataset should usually finish in about 20-60
343
+ minutes; first image or model-cache builds can push that closer to 45-90
344
+ minutes.
345
 
346
  Continue GRPO from the SFT LoRA:
347
 
348
+ The GRPO launcher downloads the Hub adapter, attaches a matching trainable
349
+ Unsloth LoRA to Gemma 4, and then loads the adapter safetensors. This keeps the
350
+ SFT handoff compatible with Gemma 4's Unsloth linear wrappers.
351
+
352
  ```bash
353
  uv run --extra modal modal run --detach scripts/modal_train_grpo.py \
354
  --initial-adapter-repo-id Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora \
scripts/modal_train_grpo.py CHANGED
@@ -1081,11 +1081,12 @@ def train_cybersecurity_owasp_grpo(
1081
  trace_log_every = max(0, int(trace_log_every))
1082
 
1083
  import torch
 
1084
  from unsloth import FastVisionModel
1085
  import transformers.utils.hub as transformers_hub
1086
  from datasets import Dataset
1087
  from huggingface_hub import snapshot_download, whoami
1088
- from peft import PeftModel
1089
  from transformers import TrainerCallback
1090
  from trl import GRPOConfig, GRPOTrainer, clone_chat_template
1091
  try:
@@ -1869,7 +1870,61 @@ def train_cybersecurity_owasp_grpo(
1869
  cache_volume.commit()
1870
  if adapter_source:
1871
  print(f"Loading initial SFT adapter for trainable GRPO continuation: {adapter_source}")
1872
- model = PeftModel.from_pretrained(model, adapter_source, is_trainable=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1873
  if hasattr(model, "print_trainable_parameters"):
1874
  model.print_trainable_parameters()
1875
  else:
 
1081
  trace_log_every = max(0, int(trace_log_every))
1082
 
1083
  import torch
1084
+ from safetensors.torch import load_file as load_safetensors_file
1085
  from unsloth import FastVisionModel
1086
  import transformers.utils.hub as transformers_hub
1087
  from datasets import Dataset
1088
  from huggingface_hub import snapshot_download, whoami
1089
+ from peft import set_peft_model_state_dict
1090
  from transformers import TrainerCallback
1091
  from trl import GRPOConfig, GRPOTrainer, clone_chat_template
1092
  try:
 
1870
  cache_volume.commit()
1871
  if adapter_source:
1872
  print(f"Loading initial SFT adapter for trainable GRPO continuation: {adapter_source}")
1873
+ adapter_source_path = pathlib.Path(adapter_source)
1874
+ adapter_config_path = adapter_source_path / "adapter_config.json"
1875
+ if not adapter_config_path.exists():
1876
+ raise RuntimeError(f"Initial SFT adapter config not found: {adapter_config_path}")
1877
+ adapter_config = json.loads(adapter_config_path.read_text(encoding="utf-8"))
1878
+ adapter_rank = int(adapter_config.get("r") or lora_rank)
1879
+ adapter_alpha = int(adapter_config.get("lora_alpha") or adapter_rank * 2)
1880
+ adapter_target_modules = adapter_config.get("target_modules") or [
1881
+ "q_proj",
1882
+ "k_proj",
1883
+ "v_proj",
1884
+ "o_proj",
1885
+ "gate_proj",
1886
+ "up_proj",
1887
+ "down_proj",
1888
+ ]
1889
+ adapter_target_modules = list(adapter_target_modules)
1890
+ print(
1891
+ "Attaching Unsloth LoRA before loading SFT weights: "
1892
+ f"rank={adapter_rank}, alpha={adapter_alpha}, targets={adapter_target_modules}"
1893
+ )
1894
+ model = model_api.get_peft_model(
1895
+ model,
1896
+ r=adapter_rank,
1897
+ target_modules=adapter_target_modules,
1898
+ lora_alpha=adapter_alpha,
1899
+ use_gradient_checkpointing="unsloth",
1900
+ random_state=3407,
1901
+ )
1902
+ adapter_weights_path = adapter_source_path / "adapter_model.safetensors"
1903
+ if not adapter_weights_path.exists():
1904
+ raise RuntimeError(f"Initial SFT adapter weights not found: {adapter_weights_path}")
1905
+ adapter_state = load_safetensors_file(str(adapter_weights_path), device="cpu")
1906
+ adapter_load_result = set_peft_model_state_dict(
1907
+ model,
1908
+ adapter_state,
1909
+ adapter_name="default",
1910
+ )
1911
+ unexpected_adapter_keys = sorted(
1912
+ key
1913
+ for key in getattr(adapter_load_result, "unexpected_keys", [])
1914
+ if "lora_" in key or "modules_to_save" in key
1915
+ )
1916
+ if unexpected_adapter_keys:
1917
+ raise RuntimeError(
1918
+ "Initial SFT adapter keys do not match the trainable Unsloth LoRA. "
1919
+ f"Unexpected adapter keys: {unexpected_adapter_keys[:10]}"
1920
+ )
1921
+ missing_lora_keys = sorted(
1922
+ key
1923
+ for key in getattr(adapter_load_result, "missing_keys", [])
1924
+ if "lora_" in key or "modules_to_save" in key
1925
+ )
1926
+ if missing_lora_keys:
1927
+ print(f"Missing LoRA keys while loading SFT adapter: {missing_lora_keys[:10]}")
1928
  if hasattr(model, "print_trainable_parameters"):
1929
  model.print_trainable_parameters()
1930
  else:
scripts/modal_train_sft.py CHANGED
@@ -373,8 +373,9 @@ def train_cybersecurity_owasp_sft(
373
  ) -> dict[str, Any]:
374
  import inspect
375
 
376
- from datasets import load_dataset
377
  from huggingface_hub import snapshot_download
 
378
  from trl import SFTConfig, SFTTrainer
379
  try:
380
  from trl.chat_template_utils import add_response_schema
@@ -454,6 +455,47 @@ def train_cybersecurity_owasp_sft(
454
  except Exception as exc:
455
  print(f"Tokenizer response schema add skipped: {exc!r}")
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  model = model_api.get_peft_model(
458
  model,
459
  r=lora_rank,
@@ -522,7 +564,20 @@ def train_cybersecurity_owasp_sft(
522
  )
523
  if skipped_trainer:
524
  print(f"Skipping unsupported SFTTrainer keys: {skipped_trainer}")
525
- trainer = SFTTrainer(
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  **{
527
  key: value
528
  for key, value in trainer_values.items()
 
373
  ) -> dict[str, Any]:
374
  import inspect
375
 
376
+ from datasets import Dataset, load_dataset
377
  from huggingface_hub import snapshot_download
378
+ from transformers import Trainer
379
  from trl import SFTConfig, SFTTrainer
380
  try:
381
  from trl.chat_template_utils import add_response_schema
 
455
  except Exception as exc:
456
  print(f"Tokenizer response schema add skipped: {exc!r}")
457
 
458
+ def _tokenize_sft_split(split_name: str, split_dataset) -> Dataset:
459
+ tokenized_rows: list[dict[str, list[int]]] = []
460
+ total_rows = len(split_dataset)
461
+ for row_index, example in enumerate(split_dataset, start=1):
462
+ messages = example["messages"]
463
+ if isinstance(messages, str):
464
+ messages = json.loads(messages)
465
+ rendered = tokenizer.apply_chat_template(
466
+ messages,
467
+ tokenize=False,
468
+ add_generation_prompt=False,
469
+ )
470
+ try:
471
+ encoded = tokenizer(
472
+ rendered,
473
+ add_special_tokens=False,
474
+ truncation=True,
475
+ max_length=max_seq_length,
476
+ )
477
+ except TypeError:
478
+ encoded = tokenizer(
479
+ text=rendered,
480
+ add_special_tokens=False,
481
+ truncation=True,
482
+ max_length=max_seq_length,
483
+ )
484
+ input_ids = encoded["input_ids"]
485
+ if input_ids and isinstance(input_ids[0], list):
486
+ input_ids = input_ids[0]
487
+ input_ids = [int(token_id) for token_id in input_ids[:max_seq_length]]
488
+ if not input_ids:
489
+ raise RuntimeError(f"{split_name} row {row_index} produced no tokens.")
490
+ tokenized_rows.append({"input_ids": input_ids, "labels": list(input_ids)})
491
+ if row_index % 500 == 0 or row_index == total_rows:
492
+ print(f"Tokenized {split_name} rows: {row_index}/{total_rows}")
493
+ return Dataset.from_list(tokenized_rows)
494
+
495
+ dataset["train"] = _tokenize_sft_split("train", dataset["train"])
496
+ if has_validation:
497
+ dataset["validation"] = _tokenize_sft_split("validation", dataset["validation"])
498
+
499
  model = model_api.get_peft_model(
500
  model,
501
  r=lora_rank,
 
564
  )
565
  if skipped_trainer:
566
  print(f"Skipping unsupported SFTTrainer keys: {skipped_trainer}")
567
+ class CyberSecurityOWASPSFTTrainer(SFTTrainer):
568
+ def compute_loss(
569
+ self,
570
+ model,
571
+ inputs,
572
+ return_outputs: bool = False,
573
+ num_items_in_batch=None,
574
+ ):
575
+ compute_loss_kwargs = {"return_outputs": return_outputs}
576
+ if "num_items_in_batch" in inspect.signature(Trainer.compute_loss).parameters:
577
+ compute_loss_kwargs["num_items_in_batch"] = num_items_in_batch
578
+ return Trainer.compute_loss(self, model, inputs, **compute_loss_kwargs)
579
+
580
+ trainer = CyberSecurityOWASPSFTTrainer(
581
  **{
582
  key: value
583
  for key, value in trainer_values.items()
tests/test_modal_scenario_cache_static.py CHANGED
@@ -59,6 +59,10 @@ def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
59
  assert '"packing": False' in source
60
  assert '"packing_strategy": "bfd"' not in source
61
  assert '"dataset_num_proc": None' in source
 
 
 
 
62
  assert '"bf16": True' in source
63
  assert '"tf32": True' in source
64
  assert '"hub_strategy": "every_save"' in source
@@ -74,4 +78,7 @@ def test_modal_grpo_loads_sft_adapter_from_hub_as_trainable_lora():
74
  assert "initial_adapter_repo_id" in source
75
  assert "Downloading initial SFT adapter" in source
76
  assert "snapshot_download(" in source
77
- assert "PeftModel.from_pretrained(model, adapter_source, is_trainable=True)" in source
 
 
 
 
59
  assert '"packing": False' in source
60
  assert '"packing_strategy": "bfd"' not in source
61
  assert '"dataset_num_proc": None' in source
62
+ assert "Dataset.from_list(tokenized_rows)" in source
63
+ assert "tokenizer.apply_chat_template" in source
64
+ assert "class CyberSecurityOWASPSFTTrainer(SFTTrainer)" in source
65
+ assert "Trainer.compute_loss(self, model, inputs" in source
66
  assert '"bf16": True' in source
67
  assert '"tf32": True' in source
68
  assert '"hub_strategy": "every_save"' in source
 
78
  assert "initial_adapter_repo_id" in source
79
  assert "Downloading initial SFT adapter" in source
80
  assert "snapshot_download(" in source
81
+ assert "Attaching Unsloth LoRA before loading SFT weights" in source
82
+ assert "load_safetensors_file(str(adapter_weights_path), device=\"cpu\")" in source
83
+ assert "set_peft_model_state_dict(" in source
84
+ assert "unexpected_adapter_keys" in source