{ "cells": [ { "cell_type": "markdown", "id": "91b7681f", "metadata": {}, "source": [ "# Training LLMs to Write Fast GPU Kernels with GRPO\n", "\n", "This notebook demonstrates how to train a language model to write optimized CUDA/Triton\n", "kernels using TRL's GRPOTrainer and the kernrl OpenEnv environment.\n", "\n", "**What is kernrl?**\n", "- An RL environment for GPU kernel optimization\n", "- Agents receive PyTorch reference implementations\n", "- Must write faster CUDA/Triton kernels that produce correct outputs\n", "- Rewards based on compilation success, correctness, and speedup\n", "\n", "**What is GRPO?**\n", "- Group Relative Policy Optimization\n", "- Efficient RL algorithm for training LLMs\n", "- Uses multiple generations per prompt to estimate advantages\n", "- Works well with environment-based reward signals" ] }, { "cell_type": "markdown", "id": "1c818c9f", "metadata": {}, "source": [ "## Installation\n", "\n", "First, install the required packages:" ] }, { "cell_type": "code", "execution_count": null, "id": "03a24248", "metadata": {}, "outputs": [], "source": [ "!pip install torch triton trl transformers accelerate\n", "!pip install git+https://github.com/meta-pytorch/OpenEnv.git" ] }, { "cell_type": "markdown", "id": "a6bd7b19", "metadata": {}, "source": [ "## Setup\n", "\n", "Import necessary libraries and configure the environment." ] }, { "cell_type": "code", "execution_count": null, "id": "409d8ec7", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from datasets import Dataset\n", "from transformers import AutoTokenizer\n", "from trl import GRPOConfig, GRPOTrainer\n", "from trl.experimental.openenv import generate_rollout_completions\n", "\n", "# Import kernrl environment\n", "from kernrl import kernrl_env, KernelAction, KernelObservation" ] }, { "cell_type": "code", "execution_count": null, "id": "1195d838", "metadata": {}, "outputs": [], "source": [ "# Configuration\n", "MODEL_ID = \"Qwen/Qwen2.5-Coder-1.5B-Instruct\" # Good for code generation\n", "ENV_URL = \"http://localhost:8000\" # kernrl server URL\n", "\n", "# Initialize tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n", "if tokenizer.pad_token is None:\n", " tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "markdown", "id": "0ba43b24", "metadata": {}, "source": [ "## Connect to kernrl Environment\n", "\n", "The kernrl environment evaluates submitted kernels for:\n", "1. **Compilation**: Does the code compile?\n", "2. **Correctness**: Does output match reference (within tolerance)?\n", "3. **Performance**: Is it faster than PyTorch baseline?" ] }, { "cell_type": "code", "execution_count": null, "id": "d72ae756", "metadata": {}, "outputs": [], "source": [ "# Connect to the kernrl server\n", "# Option 1: Connect to running server\n", "env = kernrl_env(base_url=ENV_URL)\n", "\n", "# Option 2: Load from HuggingFace Hub (requires GPU)\n", "# env = kernrl_env.from_hub(\"Infatoshi/kernrl\")\n", "\n", "# Option 3: Local Docker\n", "# env = kernrl_env.from_docker_image(\"kernrl:latest\")\n", "\n", "# Test the connection\n", "obs = env.reset(problem_id=\"L1_23_Softmax\")\n", "print(f\"Problem: {obs.problem_id}\")\n", "print(f\"GPU: {obs.gpu_info}\")\n", "print(f\"Max turns: {obs.max_turns}\")" ] }, { "cell_type": "markdown", "id": "004905fc", "metadata": {}, "source": [ "## Reward Functions\n", "\n", "We define multiple reward signals to guide the model:\n", "- **Compilation reward**: +0.1 for successful compilation\n", "- **Correctness reward**: +0.3 for matching reference output\n", "- **Speedup reward**: Scaled reward for beating baseline performance" ] }, { "cell_type": "code", "execution_count": null, "id": "39237d0e", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "import math\n", "\n", "def reward_compilation(completions: list[str], **kwargs) -> list[float]:\n", " \"\"\"Reward for successful compilation.\"\"\"\n", " compilation_success = kwargs.get(\"compilation_success\", [])\n", " return [0.1 if success else 0.0 for success in compilation_success]\n", "\n", "def reward_correctness(completions: list[str], **kwargs) -> list[float]:\n", " \"\"\"Reward for correct output.\"\"\"\n", " correctness_pass = kwargs.get(\"correctness_pass\", [])\n", " return [0.3 if correct else 0.0 for correct in correctness_pass]\n", "\n", "def reward_speedup(completions: list[str], **kwargs) -> list[float]:\n", " \"\"\"Reward scaled by speedup achieved.\"\"\"\n", " speedups = kwargs.get(\"speedup\", [])\n", " rewards = []\n", " for speedup in speedups:\n", " if speedup is None or speedup <= 0:\n", " rewards.append(0.0)\n", " elif speedup <= 1.0:\n", " # Below baseline: small penalty\n", " rewards.append(-0.1)\n", " else:\n", " # Above baseline: reward scales with log2(speedup)\n", " # 2x speedup = 0.3, 4x = 0.6, 8x = 0.9\n", " bonus = min(0.3 * math.log2(speedup), 0.6)\n", " rewards.append(0.3 + bonus)\n", " return rewards\n", "\n", "def reward_combined(completions: list[str], **kwargs) -> list[float]:\n", " \"\"\"Combined reward from all signals.\"\"\"\n", " comp_rewards = reward_compilation(completions, **kwargs)\n", " corr_rewards = reward_correctness(completions, **kwargs)\n", " speed_rewards = reward_speedup(completions, **kwargs)\n", " return [c + r + s for c, r, s in zip(comp_rewards, corr_rewards, speed_rewards)]" ] }, { "cell_type": "markdown", "id": "53307241", "metadata": {}, "source": [ "## System Prompt\n", "\n", "The system prompt provides context about the task and expected output format." ] }, { "cell_type": "code", "execution_count": null, "id": "21d75bd3", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "SYSTEM_PROMPT = \"\"\"You are an expert GPU kernel engineer specializing in CUDA and Triton.\n", "\n", "Your task is to optimize PyTorch operations by writing custom GPU kernels.\n", "\n", "Guidelines:\n", "1. Analyze the reference PyTorch implementation carefully\n", "2. Identify optimization opportunities (memory access patterns, parallelism, fusion)\n", "3. Write a Triton or CUDA kernel that computes the same result\n", "4. Ensure numerical correctness (outputs must match within tolerance)\n", "\n", "Output format:\n", "- Provide a complete Python file\n", "- Include a Model class with the same interface as the reference\n", "- The Model.forward() method should use your optimized kernel\n", "- Include all necessary imports (torch, triton, triton.language)\n", "\n", "Focus on:\n", "- Coalesced memory access\n", "- Efficient use of shared memory\n", "- Minimizing thread divergence\n", "- Optimal block/grid dimensions\"\"\"" ] }, { "cell_type": "markdown", "id": "607299ce", "metadata": {}, "source": [ "## Rollout Function\n", "\n", "The rollout function generates kernel code and evaluates it in the environment." ] }, { "cell_type": "code", "execution_count": null, "id": "5da951b3", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "def make_prompt(problem_description: str, feedback: str = \"\") -> str:\n", " \"\"\"Create the user prompt for the model.\"\"\"\n", " prompt = f\"{problem_description}\\n\"\n", " if feedback:\n", " prompt += f\"\\n## Previous Attempt Feedback\\n{feedback}\\n\"\n", " prompt += \"\\nProvide your optimized kernel implementation:\"\n", " return prompt\n", "\n", "def extract_code(completion: str) -> str:\n", " \"\"\"Extract code from model completion.\"\"\"\n", " # Handle markdown code blocks\n", " if \"```python\" in completion:\n", " start = completion.find(\"```python\") + 9\n", " end = completion.find(\"```\", start)\n", " if end > start:\n", " return completion[start:end].strip()\n", " if \"```\" in completion:\n", " start = completion.find(\"```\") + 3\n", " end = completion.find(\"```\", start)\n", " if end > start:\n", " return completion[start:end].strip()\n", " # Return as-is if no code blocks\n", " return completion.strip()\n", "\n", "def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n", " \"\"\"\n", " Custom rollout function for kernrl environment.\n", "\n", " Generates kernel code and evaluates it to get rewards.\n", " \"\"\"\n", " # Generate completions\n", " outputs = generate_rollout_completions(trainer, prompts)\n", "\n", " completions_text = [\n", " tokenizer.decode(out[\"completion_ids\"], skip_special_tokens=True)\n", " for out in outputs\n", " ]\n", "\n", " # Evaluate each completion in the environment\n", " compilation_success = []\n", " correctness_pass = []\n", " speedups = []\n", "\n", " for completion in completions_text:\n", " # Reset environment for each evaluation\n", " obs = env.reset()\n", "\n", " # Extract code and submit\n", " code = extract_code(completion)\n", " action = KernelAction(code=code)\n", "\n", " try:\n", " result = env.step(action)\n", " obs = result.observation\n", "\n", " compilation_success.append(obs.compilation_success)\n", " correctness_pass.append(obs.correctness_pass or False)\n", " speedups.append(obs.speedup)\n", " except Exception as e:\n", " print(f\"Evaluation error: {e}\")\n", " compilation_success.append(False)\n", " correctness_pass.append(False)\n", " speedups.append(None)\n", "\n", " return {\n", " \"prompt_ids\": [out[\"prompt_ids\"] for out in outputs],\n", " \"completion_ids\": [out[\"completion_ids\"] for out in outputs],\n", " \"logprobs\": [out[\"logprobs\"] for out in outputs],\n", " # Pass reward signals to reward functions\n", " \"compilation_success\": compilation_success,\n", " \"correctness_pass\": correctness_pass,\n", " \"speedup\": speedups,\n", " }" ] }, { "cell_type": "markdown", "id": "dae933f9", "metadata": {}, "source": [ "## Create Training Dataset\n", "\n", "We create a dataset from kernrl problems. Each problem becomes a training prompt." ] }, { "cell_type": "code", "execution_count": null, "id": "36c6f196", "metadata": {}, "outputs": [], "source": [ "def create_dataset(env: kernrl_env, levels: list[int] = [1, 2]) -> Dataset:\n", " \"\"\"Create training dataset from kernrl problems.\"\"\"\n", " prompts = []\n", " problem_ids = []\n", "\n", " # Get all problem IDs\n", " all_problems = env.list_problems()\n", "\n", " for problem_id in all_problems:\n", " # Filter by level\n", " level = int(problem_id.split(\"_\")[0][1:]) # Extract level from \"L1_...\"\n", " if level not in levels:\n", " continue\n", "\n", " # Reset to get problem description\n", " obs = env.reset(problem_id=problem_id)\n", "\n", " # Create prompt\n", " messages = [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": make_prompt(obs.problem_description)},\n", " ]\n", " prompt = tokenizer.apply_chat_template(\n", " messages,\n", " add_generation_prompt=True,\n", " tokenize=False,\n", " )\n", "\n", " prompts.append(prompt)\n", " problem_ids.append(problem_id)\n", "\n", " return Dataset.from_dict({\n", " \"prompt\": prompts,\n", " \"problem_id\": problem_ids,\n", " })\n", "\n", "# Create dataset from Level 1 and 2 problems\n", "dataset = create_dataset(env, levels=[1, 2])\n", "print(f\"Created dataset with {len(dataset)} problems\")" ] }, { "cell_type": "markdown", "id": "61dcd8db", "metadata": {}, "source": [ "## Configure Training\n", "\n", "Set up GRPOTrainer with our custom rollout function and reward signals." ] }, { "cell_type": "code", "execution_count": null, "id": "6fd1d73d", "metadata": {}, "outputs": [], "source": [ "# Training configuration\n", "config = GRPOConfig(\n", " output_dir=\"./kernrl_grpo_output\",\n", "\n", " # vLLM settings\n", " use_vllm=True,\n", " vllm_mode=\"colocate\", # Use \"server\" mode for multi-GPU\n", "\n", " # Generation settings\n", " num_generations=4, # Generations per prompt\n", " max_completion_length=2048, # Kernel code can be long\n", " temperature=0.7,\n", "\n", " # Training settings\n", " num_train_epochs=3,\n", " per_device_train_batch_size=2,\n", " gradient_accumulation_steps=4,\n", " learning_rate=1e-5,\n", "\n", " # Logging\n", " logging_steps=10,\n", " save_steps=100,\n", " report_to=\"wandb\", # Optional: log to Weights & Biases\n", ")" ] }, { "cell_type": "markdown", "id": "36db4292", "metadata": {}, "source": [ "## Initialize Trainer" ] }, { "cell_type": "code", "execution_count": null, "id": "3058bd91", "metadata": {}, "outputs": [], "source": [ "trainer = GRPOTrainer(\n", " model=MODEL_ID,\n", " processing_class=tokenizer,\n", " reward_funcs=[\n", " reward_compilation,\n", " reward_correctness,\n", " reward_speedup,\n", " ],\n", " train_dataset=dataset,\n", " rollout_func=rollout_func,\n", " args=config,\n", ")" ] }, { "cell_type": "markdown", "id": "26d3cb0f", "metadata": {}, "source": [ "## Train!\n", "\n", "Start the training loop. The model will learn to write faster kernels through\n", "environment feedback." ] }, { "cell_type": "code", "execution_count": null, "id": "11157d97", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "# Start training\n", "trainer.train()\n", "\n", "# Save the final model\n", "trainer.save_model(\"./kernrl_trained_model\")" ] }, { "cell_type": "markdown", "id": "4ee87425", "metadata": {}, "source": [ "## Evaluate the Trained Model\n", "\n", "Test the trained model on some problems to see how well it learned." ] }, { "cell_type": "code", "execution_count": null, "id": "82ed4e39", "metadata": {}, "outputs": [], "source": [ "def evaluate_model(model_path: str, problem_ids: list[str]) -> dict:\n", " \"\"\"Evaluate a trained model on kernel optimization problems.\"\"\"\n", " from transformers import AutoModelForCausalLM\n", "\n", " model = AutoModelForCausalLM.from_pretrained(model_path)\n", " model.eval()\n", "\n", " results = []\n", "\n", " for problem_id in problem_ids:\n", " obs = env.reset(problem_id=problem_id)\n", "\n", " # Generate kernel code\n", " messages = [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": make_prompt(obs.problem_description)},\n", " ]\n", " prompt = tokenizer.apply_chat_template(\n", " messages,\n", " add_generation_prompt=True,\n", " tokenize=False,\n", " )\n", "\n", " inputs = tokenizer(prompt, return_tensors=\"pt\")\n", " with torch.no_grad():\n", " outputs = model.generate(\n", " **inputs,\n", " max_new_tokens=2048,\n", " temperature=0.3, # Lower temp for evaluation\n", " do_sample=True,\n", " )\n", "\n", " completion = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", " code = extract_code(completion)\n", "\n", " # Evaluate\n", " result = env.step(KernelAction(code=code))\n", " obs = result.observation\n", "\n", " results.append({\n", " \"problem_id\": problem_id,\n", " \"compilation\": obs.compilation_success,\n", " \"correctness\": obs.correctness_pass,\n", " \"speedup\": obs.speedup,\n", " })\n", "\n", " print(f\"{problem_id}: compile={obs.compilation_success}, \"\n", " f\"correct={obs.correctness_pass}, speedup={obs.speedup:.2f}x\"\n", " if obs.speedup else f\"{problem_id}: compile={obs.compilation_success}\")\n", "\n", " return results\n", "\n", "# Evaluate on a few problems\n", "# eval_results = evaluate_model(\"./kernrl_trained_model\", [\"L1_23_Softmax\", \"L1_26_GELU_\"])" ] }, { "cell_type": "markdown", "id": "45d94da1", "metadata": {}, "source": [ "## Running with Server Mode (Multi-GPU)\n", "\n", "For larger models or faster training, use vLLM in server mode:\n", "\n", "```bash\n", "# Terminal 1: Start vLLM server\n", "CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-Coder-7B-Instruct\n", "\n", "# Terminal 2: Start kernrl environment\n", "CUDA_VISIBLE_DEVICES=1 uvicorn kernrl.server.app:app --host 0.0.0.0 --port 8000\n", "\n", "# Terminal 3: Run training\n", "CUDA_VISIBLE_DEVICES=2 python train_kernrl.py --vllm-mode server --vllm-server-url http://localhost:8000\n", "```\n", "\n", "Update the config:\n", "```python\n", "config = GRPOConfig(\n", " use_vllm=True,\n", " vllm_mode=\"server\",\n", " vllm_server_base_url=\"http://localhost:8000\",\n", " ...\n", ")\n", "```" ] }, { "cell_type": "markdown", "id": "464e71b0", "metadata": {}, "source": [ "## Tips for Better Results\n", "\n", "1. **Start with simpler problems**: Level 1 problems (matmul, softmax) are easier\n", "2. **Use code-focused models**: Qwen2.5-Coder, DeepSeek-Coder work well\n", "3. **Increase generations**: More generations per prompt = better advantage estimates\n", "4. **Multi-turn training**: Let the model iterate based on feedback\n", "5. **Curriculum learning**: Start with L1, add harder problems gradually" ] }, { "cell_type": "markdown", "id": "2a03608e", "metadata": {}, "source": [ "## Resources\n", "\n", "- [kernrl HuggingFace Space](https://huggingface.co/spaces/Infatoshi/kernrl)\n", "- [OpenEnv Repository](https://github.com/meta-pytorch/OpenEnv)\n", "- [TRL Documentation](https://huggingface.co/docs/trl)\n", "- [Triton Tutorial](https://triton-lang.org/main/getting-started/tutorials/)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }