shank commited on
Commit
c325ad7
Β·
1 Parent(s): 5eea2dd

Replace unsloth with bitsandbytes+peft: fixes CUDA driver incompatibility on HF A100

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +34 -19
training/train_grpo.py CHANGED
@@ -37,18 +37,19 @@ parser.add_argument("--max_steps", type=int, default=500)
37
  args = parser.parse_args()
38
 
39
  # ── Install dependencies (for Colab/HF Spaces) ───────────────────────────────
40
- # If running locally with venv, comment these out
41
  if os.environ.get("COLAB_RELEASE_TAG") or os.environ.get("SPACE_ID"):
42
- os.system("pip install -q unsloth trl wandb datasets")
43
 
44
  # ── GPU/training imports (skipped in --test-local mode) ───────────────────────
45
  if not args.test_local:
46
  import torch
47
  import wandb
48
  from datasets import Dataset
49
- from unsloth import FastLanguageModel
 
 
 
50
  from trl import GRPOTrainer, GRPOConfig
51
- from transformers import TrainerCallback
52
 
53
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
54
  from server.reward_calculator import DebugRewardCalculator
@@ -234,23 +235,37 @@ if args.test_local:
234
 
235
  # ── Load model ────────────────────────────────────────────────────────────────
236
  print(f"Loading {MODEL_NAME}...")
237
- model, tokenizer = FastLanguageModel.from_pretrained(
238
- model_name=MODEL_NAME,
239
- max_seq_length=4096,
240
  load_in_4bit=True,
241
- dtype=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  )
243
- model = FastLanguageModel.get_peft_model(
244
- model,
 
245
  r=16,
 
246
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
247
  "gate_proj", "up_proj", "down_proj"],
248
- lora_alpha=16,
249
- lora_dropout=0,
250
  bias="none",
251
- use_gradient_checkpointing="unsloth",
252
- random_state=42,
253
  )
 
 
 
254
  print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
255
 
256
  # ── Runtime device selection ──────────────────────────────────────────────────
@@ -333,7 +348,7 @@ def reward_fn(completions: list[str], prompts: list[str], **kwargs) -> list[floa
333
  # ── Baseline evaluation (run BEFORE training) ─────────────────────────────────
334
  def run_baseline(n: int = 20) -> dict:
335
  print("\nRunning baseline evaluation on UNTRAINED model...")
336
- FastLanguageModel.for_inference(model)
337
  bugs = load_bugs(1)[:n]
338
  rewards = []
339
  solved = 0
@@ -357,7 +372,7 @@ def run_baseline(n: int = 20) -> dict:
357
  return result
358
 
359
  baseline = run_baseline()
360
- FastLanguageModel.for_training(model)
361
 
362
  # ── Build initial dataset ─────────────────────────────────────────────────────
363
  def make_dataset(step: int) -> Dataset:
@@ -407,7 +422,7 @@ print(f"Baseline solve rate: {baseline['solve_rate']:.1%} β€” target: >60% after
407
  trainer.train(resume_from_checkpoint=args.resume)
408
 
409
  # ── Post-training evaluation ──────────────────────────────────────────────────
410
- FastLanguageModel.for_inference(model)
411
  bugs = load_bugs(1)[:20]
412
  post_rewards = []
413
  post_solved = 0
@@ -439,6 +454,6 @@ model.save_pretrained("./final_model")
439
  tokenizer.save_pretrained("./final_model")
440
  HF_TOKEN = os.environ.get("HF_TOKEN")
441
  if HF_TOKEN and not args.test:
442
- model.push_to_hub(HF_REPO, token=HF_TOKEN)
443
- tokenizer.push_to_hub(HF_REPO, token=HF_TOKEN)
444
  print(f"Pushed to https://huggingface.co/{HF_REPO}")
 
37
  args = parser.parse_args()
38
 
39
  # ── Install dependencies (for Colab/HF Spaces) ───────────────────────────────
 
40
  if os.environ.get("COLAB_RELEASE_TAG") or os.environ.get("SPACE_ID"):
41
+ os.system("pip install -q trl wandb datasets bitsandbytes peft transformers accelerate")
42
 
43
  # ── GPU/training imports (skipped in --test-local mode) ───────────────────────
44
  if not args.test_local:
45
  import torch
46
  import wandb
47
  from datasets import Dataset
48
+ from transformers import (
49
+ AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainerCallback
50
+ )
51
+ from peft import get_peft_model, LoraConfig, TaskType
52
  from trl import GRPOTrainer, GRPOConfig
 
53
 
54
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
55
  from server.reward_calculator import DebugRewardCalculator
 
235
 
236
  # ── Load model ────────────────────────────────────────────────────────────────
237
  print(f"Loading {MODEL_NAME}...")
238
+ bnb_config = BitsAndBytesConfig(
 
 
239
  load_in_4bit=True,
240
+ bnb_4bit_quant_type="nf4",
241
+ bnb_4bit_compute_dtype=torch.bfloat16,
242
+ bnb_4bit_use_double_quant=True,
243
+ )
244
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
245
+ tokenizer.pad_token = tokenizer.eos_token
246
+ tokenizer.padding_side = "left"
247
+
248
+ model = AutoModelForCausalLM.from_pretrained(
249
+ MODEL_NAME,
250
+ quantization_config=bnb_config,
251
+ device_map="auto",
252
+ trust_remote_code=True,
253
+ torch_dtype=torch.bfloat16,
254
  )
255
+ model.config.use_cache = False
256
+
257
+ lora_config = LoraConfig(
258
  r=16,
259
+ lora_alpha=16,
260
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
261
  "gate_proj", "up_proj", "down_proj"],
262
+ lora_dropout=0.0,
 
263
  bias="none",
264
+ task_type=TaskType.CAUSAL_LM,
 
265
  )
266
+ model = get_peft_model(model, lora_config)
267
+ model.enable_input_require_grads()
268
+ model.gradient_checkpointing_enable()
269
  print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
270
 
271
  # ── Runtime device selection ──────────────────────────────────────────────────
 
348
  # ── Baseline evaluation (run BEFORE training) ─────────────────────────────────
349
  def run_baseline(n: int = 20) -> dict:
350
  print("\nRunning baseline evaluation on UNTRAINED model...")
351
+ model.eval()
352
  bugs = load_bugs(1)[:n]
353
  rewards = []
354
  solved = 0
 
372
  return result
373
 
374
  baseline = run_baseline()
375
+ model.train()
376
 
377
  # ── Build initial dataset ─────────────────────────────────────────────────────
378
  def make_dataset(step: int) -> Dataset:
 
422
  trainer.train(resume_from_checkpoint=args.resume)
423
 
424
  # ── Post-training evaluation ──────────────────────────────────────────────────
425
+ model.eval()
426
  bugs = load_bugs(1)[:20]
427
  post_rewards = []
428
  post_solved = 0
 
454
  tokenizer.save_pretrained("./final_model")
455
  HF_TOKEN = os.environ.get("HF_TOKEN")
456
  if HF_TOKEN and not args.test:
457
+ model.push_to_hub(HF_REPO, token=HF_TOKEN, private=True)
458
+ tokenizer.push_to_hub(HF_REPO, token=HF_TOKEN, private=True)
459
  print(f"Pushed to https://huggingface.co/{HF_REPO}")