"""Unsloth + LoRA (Low-Rank Adaptation) GRPO training for CERNenv. This is the recommended path for Colab / single-GPU runs because Unsloth's fused kernels and 4-bit loading let us train 2B–8B models with limited VRAM. Run on Colab: !pip install -q unsloth unsloth_zoo trl peft datasets bitsandbytes !python -m training.training_unsloth \ --model_name unsloth/Qwen2.5-3B-Instruct \ --total_episodes 400 --num_generations 4 --output_dir runs/unsloth-grpo """ from __future__ import annotations import argparse import logging from typing import Any, List, Optional from datasets import Dataset logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) def main() -> None: # pragma: no cover - heavy GPU path parser = argparse.ArgumentParser() parser.add_argument("--model_name", default="unsloth/Qwen2.5-3B-Instruct") parser.add_argument("--scenario", default=None) parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy") parser.add_argument("--total_episodes", type=int, default=400) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--max_steps", type=int, default=18) parser.add_argument("--num_generations", type=int, default=4) parser.add_argument("--max_prompt_length", type=int, default=2048) parser.add_argument("--max_completion_length", type=int, default=384) parser.add_argument("--learning_rate", type=float, default=5e-6) parser.add_argument("--load_in_4bit", action="store_true", default=True) parser.add_argument("--lora_rank", type=int, default=16) parser.add_argument("--lora_alpha", type=int, default=16) parser.add_argument("--output_dir", default="training/runs/unsloth-grpo") args = parser.parse_args() from unsloth import FastLanguageModel from trl import GRPOConfig, GRPOTrainer from server.environment import CERNCollisionEnvironment from training.llm_agent import ( LLMAgentConfig, build_chat, parse_action, safe_default_action, ) from training.training_script import EpisodeContext, _format_validity_bonus, _stepwise_reward logger.info("Loading Unsloth model: %s", args.model_name) model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model_name, max_seq_length=args.max_prompt_length + args.max_completion_length, load_in_4bit=args.load_in_4bit, fast_inference=True, ) model = FastLanguageModel.get_peft_model( model, r=args.lora_rank, lora_alpha=args.lora_alpha, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], use_gradient_checkpointing="unsloth", ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Build prompts env = CERNCollisionEnvironment(max_steps=args.max_steps) prompts: List[str] = [] for i in range(args.total_episodes): obs = env.reset(seed=args.seed + i, scenario=args.scenario, difficulty=args.difficulty) chat = build_chat(obs) prompts.append( tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) ) dataset = Dataset.from_dict({"prompt": prompts}) ctx = EpisodeContext( env=env, seed=args.seed, scenario=args.scenario, difficulty=args.difficulty, ) def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]: rewards: List[float] = [] for completion in completions: r = _stepwise_reward(completion_text=completion, ctx=ctx) r += _format_validity_bonus(completion) rewards.append(float(r)) return rewards cfg = GRPOConfig( output_dir=args.output_dir, per_device_train_batch_size=1, gradient_accumulation_steps=4, num_generations=args.num_generations, learning_rate=args.learning_rate, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, logging_steps=5, save_steps=50, seed=args.seed, bf16=True, report_to=[], ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, train_dataset=dataset, reward_funcs=[reward_fn], args=cfg, ) logger.info("Starting Unsloth + LoRA GRPO training") trainer.train() trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) logger.info("Saved adapters to %s", args.output_dir) if __name__ == "__main__": # pragma: no cover main()