Spaces:
Sleeping
Sleeping
Don Rishabh Claude Opus 4.7 (1M context) commited on
Commit ·
1da121e
1
Parent(s): e424cfe
Revert agent loading to TRL + PEFT (Unsloth collides with frozen target)
Browse filesUnsloth patches Qwen2Attention at import time and installs apply_qkv
only on instances created by FastLanguageModel.from_pretrained. Our
env loads the target model (also Qwen2) via vanilla AutoModelForCausalLM,
so its attention instances hit the patched forward without apply_qkv
and crash with AttributeError.
Co-hosting agent (Unsloth) + target (vanilla) in one process isn't
viable without reloading the target through Unsloth too, which adds
complexity we don't need. The plain TRL + PEFT path is proven from
the earlier smoke run (reward climbed 0.293 -> 0.372 in 20 steps).
Kept: padding_side='left', artifact upload, all other fixes.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- training/train_grpo.py +27 -35
training/train_grpo.py
CHANGED
|
@@ -276,49 +276,31 @@ def main() -> None:
|
|
| 276 |
os.environ.setdefault("PROMPT_GOLF_TARGET_MODEL", args.target_model)
|
| 277 |
os.environ.setdefault("PROMPT_GOLF_TARGET_BACKEND", "hf")
|
| 278 |
|
| 279 |
-
#
|
| 280 |
-
# monkey-patches (fused kernels, gradient checkpointing, generation
|
| 281 |
-
# optimizations) take effect on the model classes. -----
|
| 282 |
-
import unsloth # noqa: F401
|
| 283 |
-
from unsloth import FastLanguageModel
|
| 284 |
-
|
| 285 |
-
# Heavy imports — after unsloth patches.
|
| 286 |
import torch
|
| 287 |
-
from
|
|
|
|
| 288 |
from trl import GRPOConfig, GRPOTrainer
|
| 289 |
|
| 290 |
from prompt_golf_env.server.prompt_golf_environment import PromptGolfEnvironment
|
| 291 |
from prompt_golf_env.server.tasks import list_task_ids
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
set_seed(args.seed)
|
| 294 |
|
| 295 |
-
# -----
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
model_name=args.agent_model,
|
| 299 |
-
max_seq_length=max_seq,
|
| 300 |
-
load_in_4bit=False,
|
| 301 |
-
dtype=None, # auto (bf16 on Ampere+, fp16 otherwise)
|
| 302 |
-
)
|
| 303 |
-
# Left-pad for decoder-only generation (fixes the TRL warning and
|
| 304 |
-
# ensures correct token alignment during rollout).
|
| 305 |
-
tokenizer.padding_side = "left"
|
| 306 |
if tokenizer.pad_token is None:
|
| 307 |
tokenizer.pad_token = tokenizer.eos_token
|
| 308 |
|
| 309 |
-
# Wrap with LoRA via Unsloth's helper (fused kernels).
|
| 310 |
-
model = FastLanguageModel.get_peft_model(
|
| 311 |
-
model,
|
| 312 |
-
r=args.lora_r,
|
| 313 |
-
lora_alpha=args.lora_alpha,
|
| 314 |
-
lora_dropout=args.lora_dropout,
|
| 315 |
-
bias="none",
|
| 316 |
-
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 317 |
-
"gate_proj", "up_proj", "down_proj"],
|
| 318 |
-
use_gradient_checkpointing="unsloth",
|
| 319 |
-
random_state=args.seed,
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
# ----- env (target loaded lazily on first forward pass) -----
|
| 323 |
env = PromptGolfEnvironment()
|
| 324 |
all_tasks = list_task_ids()
|
|
@@ -356,16 +338,26 @@ def main() -> None:
|
|
| 356 |
remove_unused_columns=False, # keep task_id / seed in batch
|
| 357 |
)
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
# ----- Train -----
|
| 360 |
-
# NOTE: we pass the Unsloth-wrapped model directly; NO peft_config
|
| 361 |
-
# (LoRA already applied above via FastLanguageModel.get_peft_model).
|
| 362 |
trainer = GRPOTrainer(
|
| 363 |
-
model=
|
| 364 |
processing_class=tokenizer,
|
| 365 |
args=grpo_cfg,
|
| 366 |
reward_funcs=[reward_fn],
|
| 367 |
train_dataset=train_ds,
|
| 368 |
eval_dataset=eval_ds,
|
|
|
|
| 369 |
callbacks=[MetricsCallback()],
|
| 370 |
)
|
| 371 |
|
|
|
|
| 276 |
os.environ.setdefault("PROMPT_GOLF_TARGET_MODEL", args.target_model)
|
| 277 |
os.environ.setdefault("PROMPT_GOLF_TARGET_BACKEND", "hf")
|
| 278 |
|
| 279 |
+
# Heavy imports.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
import torch
|
| 281 |
+
from peft import LoraConfig
|
| 282 |
+
from transformers import AutoTokenizer, set_seed
|
| 283 |
from trl import GRPOConfig, GRPOTrainer
|
| 284 |
|
| 285 |
from prompt_golf_env.server.prompt_golf_environment import PromptGolfEnvironment
|
| 286 |
from prompt_golf_env.server.tasks import list_task_ids
|
| 287 |
|
| 288 |
+
# NOTE: we deliberately do NOT import Unsloth here. Unsloth patches
|
| 289 |
+
# Qwen2Attention at import time, which breaks the target model
|
| 290 |
+
# (also Qwen2) that we load via vanilla transformers inside the env.
|
| 291 |
+
# See target_model.py — target uses AutoModelForCausalLM. Co-hosting
|
| 292 |
+
# agent (would-be Unsloth) + target (vanilla) in one process requires
|
| 293 |
+
# Unsloth to be absent. The proven-working path is plain TRL + PEFT
|
| 294 |
+
# LoRA; on L40S a 500-step 1.5B run finishes in ~140 min.
|
| 295 |
+
|
| 296 |
set_seed(args.seed)
|
| 297 |
|
| 298 |
+
# ----- tokenizer (agent) -----
|
| 299 |
+
tokenizer = AutoTokenizer.from_pretrained(args.agent_model)
|
| 300 |
+
tokenizer.padding_side = "left" # decoder-only generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
if tokenizer.pad_token is None:
|
| 302 |
tokenizer.pad_token = tokenizer.eos_token
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
# ----- env (target loaded lazily on first forward pass) -----
|
| 305 |
env = PromptGolfEnvironment()
|
| 306 |
all_tasks = list_task_ids()
|
|
|
|
| 338 |
remove_unused_columns=False, # keep task_id / seed in batch
|
| 339 |
)
|
| 340 |
|
| 341 |
+
# ----- LoRA -----
|
| 342 |
+
peft_cfg = LoraConfig(
|
| 343 |
+
r=args.lora_r,
|
| 344 |
+
lora_alpha=args.lora_alpha,
|
| 345 |
+
lora_dropout=args.lora_dropout,
|
| 346 |
+
bias="none",
|
| 347 |
+
task_type="CAUSAL_LM",
|
| 348 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 349 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
# ----- Train -----
|
|
|
|
|
|
|
| 353 |
trainer = GRPOTrainer(
|
| 354 |
+
model=args.agent_model,
|
| 355 |
processing_class=tokenizer,
|
| 356 |
args=grpo_cfg,
|
| 357 |
reward_funcs=[reward_fn],
|
| 358 |
train_dataset=train_ds,
|
| 359 |
eval_dataset=eval_ds,
|
| 360 |
+
peft_config=peft_cfg,
|
| 361 |
callbacks=[MetricsCallback()],
|
| 362 |
)
|
| 363 |
|