| """Unsloth + LoRA (Low-Rank Adaptation) GRPO training for CERNenv.
|
|
|
| This is the recommended path for Colab / single-GPU runs because Unsloth's
|
| fused kernels and 4-bit loading let us train 2B–8B models with limited VRAM.
|
|
|
| Run on Colab:
|
| !pip install -q unsloth unsloth_zoo trl peft datasets bitsandbytes
|
| !python -m training.training_unsloth \
|
| --model_name unsloth/Qwen2.5-3B-Instruct \
|
| --total_episodes 400 --num_generations 4 --output_dir runs/unsloth-grpo
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import logging
|
| from typing import Any, List, Optional
|
|
|
| from datasets import Dataset
|
|
|
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| def main() -> None:
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--model_name", default="unsloth/Qwen2.5-3B-Instruct")
|
| parser.add_argument("--scenario", default=None)
|
| parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
|
| parser.add_argument("--total_episodes", type=int, default=400)
|
| parser.add_argument("--seed", type=int, default=42)
|
| parser.add_argument("--max_steps", type=int, default=18)
|
| parser.add_argument("--num_generations", type=int, default=4)
|
| parser.add_argument("--max_prompt_length", type=int, default=2048)
|
| parser.add_argument("--max_completion_length", type=int, default=384)
|
| parser.add_argument("--learning_rate", type=float, default=5e-6)
|
| parser.add_argument("--load_in_4bit", action="store_true", default=True)
|
| parser.add_argument("--lora_rank", type=int, default=16)
|
| parser.add_argument("--lora_alpha", type=int, default=16)
|
| parser.add_argument("--output_dir", default="training/runs/unsloth-grpo")
|
| args = parser.parse_args()
|
|
|
| from unsloth import FastLanguageModel
|
| from trl import GRPOConfig, GRPOTrainer
|
|
|
| from server.environment import CERNCollisionEnvironment
|
| from training.llm_agent import (
|
| LLMAgentConfig,
|
| build_chat,
|
| parse_action,
|
| safe_default_action,
|
| )
|
| from training.training_script import EpisodeContext, _format_validity_bonus, _stepwise_reward
|
|
|
| logger.info("Loading Unsloth model: %s", args.model_name)
|
| model, tokenizer = FastLanguageModel.from_pretrained(
|
| model_name=args.model_name,
|
| max_seq_length=args.max_prompt_length + args.max_completion_length,
|
| load_in_4bit=args.load_in_4bit,
|
| fast_inference=True,
|
| )
|
| model = FastLanguageModel.get_peft_model(
|
| model,
|
| r=args.lora_rank,
|
| lora_alpha=args.lora_alpha,
|
| target_modules=[
|
| "q_proj", "k_proj", "v_proj", "o_proj",
|
| "gate_proj", "up_proj", "down_proj",
|
| ],
|
| use_gradient_checkpointing="unsloth",
|
| )
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| env = CERNCollisionEnvironment(max_steps=args.max_steps)
|
| prompts: List[str] = []
|
| for i in range(args.total_episodes):
|
| obs = env.reset(seed=args.seed + i, scenario=args.scenario, difficulty=args.difficulty)
|
| chat = build_chat(obs)
|
| prompts.append(
|
| tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
|
| )
|
| dataset = Dataset.from_dict({"prompt": prompts})
|
|
|
| ctx = EpisodeContext(
|
| env=env, seed=args.seed,
|
| scenario=args.scenario, difficulty=args.difficulty,
|
| )
|
|
|
| def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]:
|
| rewards: List[float] = []
|
| for completion in completions:
|
| r = _stepwise_reward(completion_text=completion, ctx=ctx)
|
| r += _format_validity_bonus(completion)
|
| rewards.append(float(r))
|
| return rewards
|
|
|
| cfg = GRPOConfig(
|
| output_dir=args.output_dir,
|
| per_device_train_batch_size=1,
|
| gradient_accumulation_steps=4,
|
| num_generations=args.num_generations,
|
| learning_rate=args.learning_rate,
|
| max_prompt_length=args.max_prompt_length,
|
| max_completion_length=args.max_completion_length,
|
| logging_steps=5,
|
| save_steps=50,
|
| seed=args.seed,
|
| bf16=True,
|
| report_to=[],
|
| )
|
|
|
| trainer = GRPOTrainer(
|
| model=model,
|
| processing_class=tokenizer,
|
| train_dataset=dataset,
|
| reward_funcs=[reward_fn],
|
| args=cfg,
|
| )
|
| logger.info("Starting Unsloth + LoRA GRPO training")
|
| trainer.train()
|
| trainer.save_model(args.output_dir)
|
| tokenizer.save_pretrained(args.output_dir)
|
| logger.info("Saved adapters to %s", args.output_dir)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|