Don Rishabh Claude Opus 4.7 (1M context) commited on
Commit
1da121e
·
1 Parent(s): e424cfe

Revert agent loading to TRL + PEFT (Unsloth collides with frozen target)

Browse files

Unsloth patches Qwen2Attention at import time and installs apply_qkv
only on instances created by FastLanguageModel.from_pretrained. Our
env loads the target model (also Qwen2) via vanilla AutoModelForCausalLM,
so its attention instances hit the patched forward without apply_qkv
and crash with AttributeError.

Co-hosting agent (Unsloth) + target (vanilla) in one process isn't
viable without reloading the target through Unsloth too, which adds
complexity we don't need. The plain TRL + PEFT path is proven from
the earlier smoke run (reward climbed 0.293 -> 0.372 in 20 steps).

Kept: padding_side='left', artifact upload, all other fixes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. training/train_grpo.py +27 -35
training/train_grpo.py CHANGED
@@ -276,49 +276,31 @@ def main() -> None:
276
  os.environ.setdefault("PROMPT_GOLF_TARGET_MODEL", args.target_model)
277
  os.environ.setdefault("PROMPT_GOLF_TARGET_BACKEND", "hf")
278
 
279
- # ----- Unsloth MUST be imported before transformers/trl so its
280
- # monkey-patches (fused kernels, gradient checkpointing, generation
281
- # optimizations) take effect on the model classes. -----
282
- import unsloth # noqa: F401
283
- from unsloth import FastLanguageModel
284
-
285
- # Heavy imports — after unsloth patches.
286
  import torch
287
- from transformers import set_seed
 
288
  from trl import GRPOConfig, GRPOTrainer
289
 
290
  from prompt_golf_env.server.prompt_golf_environment import PromptGolfEnvironment
291
  from prompt_golf_env.server.tasks import list_task_ids
292
 
 
 
 
 
 
 
 
 
293
  set_seed(args.seed)
294
 
295
- # ----- agent (trainable) via Unsloth -----
296
- max_seq = args.max_prompt_length + args.max_completion_length
297
- model, tokenizer = FastLanguageModel.from_pretrained(
298
- model_name=args.agent_model,
299
- max_seq_length=max_seq,
300
- load_in_4bit=False,
301
- dtype=None, # auto (bf16 on Ampere+, fp16 otherwise)
302
- )
303
- # Left-pad for decoder-only generation (fixes the TRL warning and
304
- # ensures correct token alignment during rollout).
305
- tokenizer.padding_side = "left"
306
  if tokenizer.pad_token is None:
307
  tokenizer.pad_token = tokenizer.eos_token
308
 
309
- # Wrap with LoRA via Unsloth's helper (fused kernels).
310
- model = FastLanguageModel.get_peft_model(
311
- model,
312
- r=args.lora_r,
313
- lora_alpha=args.lora_alpha,
314
- lora_dropout=args.lora_dropout,
315
- bias="none",
316
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
317
- "gate_proj", "up_proj", "down_proj"],
318
- use_gradient_checkpointing="unsloth",
319
- random_state=args.seed,
320
- )
321
-
322
  # ----- env (target loaded lazily on first forward pass) -----
323
  env = PromptGolfEnvironment()
324
  all_tasks = list_task_ids()
@@ -356,16 +338,26 @@ def main() -> None:
356
  remove_unused_columns=False, # keep task_id / seed in batch
357
  )
358
 
 
 
 
 
 
 
 
 
 
 
 
359
  # ----- Train -----
360
- # NOTE: we pass the Unsloth-wrapped model directly; NO peft_config
361
- # (LoRA already applied above via FastLanguageModel.get_peft_model).
362
  trainer = GRPOTrainer(
363
- model=model,
364
  processing_class=tokenizer,
365
  args=grpo_cfg,
366
  reward_funcs=[reward_fn],
367
  train_dataset=train_ds,
368
  eval_dataset=eval_ds,
 
369
  callbacks=[MetricsCallback()],
370
  )
371
 
 
276
  os.environ.setdefault("PROMPT_GOLF_TARGET_MODEL", args.target_model)
277
  os.environ.setdefault("PROMPT_GOLF_TARGET_BACKEND", "hf")
278
 
279
+ # Heavy imports.
 
 
 
 
 
 
280
  import torch
281
+ from peft import LoraConfig
282
+ from transformers import AutoTokenizer, set_seed
283
  from trl import GRPOConfig, GRPOTrainer
284
 
285
  from prompt_golf_env.server.prompt_golf_environment import PromptGolfEnvironment
286
  from prompt_golf_env.server.tasks import list_task_ids
287
 
288
+ # NOTE: we deliberately do NOT import Unsloth here. Unsloth patches
289
+ # Qwen2Attention at import time, which breaks the target model
290
+ # (also Qwen2) that we load via vanilla transformers inside the env.
291
+ # See target_model.py — target uses AutoModelForCausalLM. Co-hosting
292
+ # agent (would-be Unsloth) + target (vanilla) in one process requires
293
+ # Unsloth to be absent. The proven-working path is plain TRL + PEFT
294
+ # LoRA; on L40S a 500-step 1.5B run finishes in ~140 min.
295
+
296
  set_seed(args.seed)
297
 
298
+ # ----- tokenizer (agent) -----
299
+ tokenizer = AutoTokenizer.from_pretrained(args.agent_model)
300
+ tokenizer.padding_side = "left" # decoder-only generation
 
 
 
 
 
 
 
 
301
  if tokenizer.pad_token is None:
302
  tokenizer.pad_token = tokenizer.eos_token
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  # ----- env (target loaded lazily on first forward pass) -----
305
  env = PromptGolfEnvironment()
306
  all_tasks = list_task_ids()
 
338
  remove_unused_columns=False, # keep task_id / seed in batch
339
  )
340
 
341
+ # ----- LoRA -----
342
+ peft_cfg = LoraConfig(
343
+ r=args.lora_r,
344
+ lora_alpha=args.lora_alpha,
345
+ lora_dropout=args.lora_dropout,
346
+ bias="none",
347
+ task_type="CAUSAL_LM",
348
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
349
+ "gate_proj", "up_proj", "down_proj"],
350
+ )
351
+
352
  # ----- Train -----
 
 
353
  trainer = GRPOTrainer(
354
+ model=args.agent_model,
355
  processing_class=tokenizer,
356
  args=grpo_cfg,
357
  reward_funcs=[reward_fn],
358
  train_dataset=train_ds,
359
  eval_dataset=eval_ds,
360
+ peft_config=peft_cfg,
361
  callbacks=[MetricsCallback()],
362
  )
363