File size: 4,905 Bytes
2b0bffa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """Unsloth + LoRA (Low-Rank Adaptation) GRPO training for CERNenv.
This is the recommended path for Colab / single-GPU runs because Unsloth's
fused kernels and 4-bit loading let us train 2B–8B models with limited VRAM.
Run on Colab:
!pip install -q unsloth unsloth_zoo trl peft datasets bitsandbytes
!python -m training.training_unsloth \
--model_name unsloth/Qwen2.5-3B-Instruct \
--total_episodes 400 --num_generations 4 --output_dir runs/unsloth-grpo
"""
from __future__ import annotations
import argparse
import logging
from typing import Any, List, Optional
from datasets import Dataset
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def main() -> None: # pragma: no cover - heavy GPU path
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="unsloth/Qwen2.5-3B-Instruct")
parser.add_argument("--scenario", default=None)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
parser.add_argument("--total_episodes", type=int, default=400)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max_steps", type=int, default=18)
parser.add_argument("--num_generations", type=int, default=4)
parser.add_argument("--max_prompt_length", type=int, default=2048)
parser.add_argument("--max_completion_length", type=int, default=384)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--load_in_4bit", action="store_true", default=True)
parser.add_argument("--lora_rank", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--output_dir", default="training/runs/unsloth-grpo")
args = parser.parse_args()
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer
from server.environment import CERNCollisionEnvironment
from training.llm_agent import (
LLMAgentConfig,
build_chat,
parse_action,
safe_default_action,
)
from training.training_script import EpisodeContext, _format_validity_bonus, _stepwise_reward
logger.info("Loading Unsloth model: %s", args.model_name)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_prompt_length + args.max_completion_length,
load_in_4bit=args.load_in_4bit,
fast_inference=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Build prompts
env = CERNCollisionEnvironment(max_steps=args.max_steps)
prompts: List[str] = []
for i in range(args.total_episodes):
obs = env.reset(seed=args.seed + i, scenario=args.scenario, difficulty=args.difficulty)
chat = build_chat(obs)
prompts.append(
tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
)
dataset = Dataset.from_dict({"prompt": prompts})
ctx = EpisodeContext(
env=env, seed=args.seed,
scenario=args.scenario, difficulty=args.difficulty,
)
def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]:
rewards: List[float] = []
for completion in completions:
r = _stepwise_reward(completion_text=completion, ctx=ctx)
r += _format_validity_bonus(completion)
rewards.append(float(r))
return rewards
cfg = GRPOConfig(
output_dir=args.output_dir,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=args.num_generations,
learning_rate=args.learning_rate,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
logging_steps=5,
save_steps=50,
seed=args.seed,
bf16=True,
report_to=[],
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
reward_funcs=[reward_fn],
args=cfg,
)
logger.info("Starting Unsloth + LoRA GRPO training")
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("Saved adapters to %s", args.output_dir)
if __name__ == "__main__": # pragma: no cover
main()
|