cernenv / training /training_unsloth.py
anugrah55's picture
Update CERNenv Space
2b0bffa verified
"""Unsloth + LoRA (Low-Rank Adaptation) GRPO training for CERNenv.
This is the recommended path for Colab / single-GPU runs because Unsloth's
fused kernels and 4-bit loading let us train 2B–8B models with limited VRAM.
Run on Colab:
!pip install -q unsloth unsloth_zoo trl peft datasets bitsandbytes
!python -m training.training_unsloth \
--model_name unsloth/Qwen2.5-3B-Instruct \
--total_episodes 400 --num_generations 4 --output_dir runs/unsloth-grpo
"""
from __future__ import annotations
import argparse
import logging
from typing import Any, List, Optional
from datasets import Dataset
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def main() -> None: # pragma: no cover - heavy GPU path
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="unsloth/Qwen2.5-3B-Instruct")
parser.add_argument("--scenario", default=None)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
parser.add_argument("--total_episodes", type=int, default=400)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max_steps", type=int, default=18)
parser.add_argument("--num_generations", type=int, default=4)
parser.add_argument("--max_prompt_length", type=int, default=2048)
parser.add_argument("--max_completion_length", type=int, default=384)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--load_in_4bit", action="store_true", default=True)
parser.add_argument("--lora_rank", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--output_dir", default="training/runs/unsloth-grpo")
args = parser.parse_args()
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer
from server.environment import CERNCollisionEnvironment
from training.llm_agent import (
LLMAgentConfig,
build_chat,
parse_action,
safe_default_action,
)
from training.training_script import EpisodeContext, _format_validity_bonus, _stepwise_reward
logger.info("Loading Unsloth model: %s", args.model_name)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_prompt_length + args.max_completion_length,
load_in_4bit=args.load_in_4bit,
fast_inference=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Build prompts
env = CERNCollisionEnvironment(max_steps=args.max_steps)
prompts: List[str] = []
for i in range(args.total_episodes):
obs = env.reset(seed=args.seed + i, scenario=args.scenario, difficulty=args.difficulty)
chat = build_chat(obs)
prompts.append(
tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
)
dataset = Dataset.from_dict({"prompt": prompts})
ctx = EpisodeContext(
env=env, seed=args.seed,
scenario=args.scenario, difficulty=args.difficulty,
)
def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]:
rewards: List[float] = []
for completion in completions:
r = _stepwise_reward(completion_text=completion, ctx=ctx)
r += _format_validity_bonus(completion)
rewards.append(float(r))
return rewards
cfg = GRPOConfig(
output_dir=args.output_dir,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=args.num_generations,
learning_rate=args.learning_rate,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
logging_steps=5,
save_steps=50,
seed=args.seed,
bf16=True,
report_to=[],
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
reward_funcs=[reward_fn],
args=cfg,
)
logger.info("Starting Unsloth + LoRA GRPO training")
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("Saved adapters to %s", args.output_dir)
if __name__ == "__main__": # pragma: no cover
main()