File size: 4,905 Bytes
2b0bffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Unsloth + LoRA (Low-Rank Adaptation) GRPO training for CERNenv.



This is the recommended path for Colab / single-GPU runs because Unsloth's

fused kernels and 4-bit loading let us train 2B–8B models with limited VRAM.



Run on Colab:

    !pip install -q unsloth unsloth_zoo trl peft datasets bitsandbytes

    !python -m training.training_unsloth \

        --model_name unsloth/Qwen2.5-3B-Instruct \

        --total_episodes 400 --num_generations 4 --output_dir runs/unsloth-grpo

"""

from __future__ import annotations

import argparse
import logging
from typing import Any, List, Optional

from datasets import Dataset


logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)


def main() -> None:  # pragma: no cover - heavy GPU path
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", default="unsloth/Qwen2.5-3B-Instruct")
    parser.add_argument("--scenario", default=None)
    parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
    parser.add_argument("--total_episodes", type=int, default=400)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max_steps", type=int, default=18)
    parser.add_argument("--num_generations", type=int, default=4)
    parser.add_argument("--max_prompt_length", type=int, default=2048)
    parser.add_argument("--max_completion_length", type=int, default=384)
    parser.add_argument("--learning_rate", type=float, default=5e-6)
    parser.add_argument("--load_in_4bit", action="store_true", default=True)
    parser.add_argument("--lora_rank", type=int, default=16)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--output_dir", default="training/runs/unsloth-grpo")
    args = parser.parse_args()

    from unsloth import FastLanguageModel
    from trl import GRPOConfig, GRPOTrainer

    from server.environment import CERNCollisionEnvironment
    from training.llm_agent import (
        LLMAgentConfig,
        build_chat,
        parse_action,
        safe_default_action,
    )
    from training.training_script import EpisodeContext, _format_validity_bonus, _stepwise_reward

    logger.info("Loading Unsloth model: %s", args.model_name)
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model_name,
        max_seq_length=args.max_prompt_length + args.max_completion_length,
        load_in_4bit=args.load_in_4bit,
        fast_inference=True,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=args.lora_rank,
        lora_alpha=args.lora_alpha,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        use_gradient_checkpointing="unsloth",
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Build prompts
    env = CERNCollisionEnvironment(max_steps=args.max_steps)
    prompts: List[str] = []
    for i in range(args.total_episodes):
        obs = env.reset(seed=args.seed + i, scenario=args.scenario, difficulty=args.difficulty)
        chat = build_chat(obs)
        prompts.append(
            tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
        )
    dataset = Dataset.from_dict({"prompt": prompts})

    ctx = EpisodeContext(
        env=env, seed=args.seed,
        scenario=args.scenario, difficulty=args.difficulty,
    )

    def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]:
        rewards: List[float] = []
        for completion in completions:
            r = _stepwise_reward(completion_text=completion, ctx=ctx)
            r += _format_validity_bonus(completion)
            rewards.append(float(r))
        return rewards

    cfg = GRPOConfig(
        output_dir=args.output_dir,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        num_generations=args.num_generations,
        learning_rate=args.learning_rate,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,
        logging_steps=5,
        save_steps=50,
        seed=args.seed,
        bf16=True,
        report_to=[],
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=dataset,
        reward_funcs=[reward_fn],
        args=cfg,
    )
    logger.info("Starting Unsloth + LoRA GRPO training")
    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    logger.info("Saved adapters to %s", args.output_dir)


if __name__ == "__main__":  # pragma: no cover
    main()