| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| import torch |
| from datasets import Dataset |
| from transformers import AutoTokenizer |
| from trl import GRPOConfig, GRPOTrainer |
| from trl.experimental.openenv import generate_rollout_completions |
|
|
| |
| from kernrl import kernrl_env, KernelAction, KernelObservation |
|
|
| |
| |
| MODEL_ID = "Qwen/Qwen2.5-Coder-1.5B-Instruct" |
| ENV_URL = "http://localhost:8000" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| env = kernrl_env(base_url=ENV_URL) |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| obs = env.reset(problem_id="L1_23_Softmax") |
| print(f"Problem: {obs.problem_id}") |
| print(f"GPU: {obs.gpu_info}") |
| print(f"Max turns: {obs.max_turns}") |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| import math |
|
|
| def reward_compilation(completions: list[str], **kwargs) -> list[float]: |
| """Reward for successful compilation.""" |
| compilation_success = kwargs.get("compilation_success", []) |
| return [0.1 if success else 0.0 for success in compilation_success] |
|
|
| def reward_correctness(completions: list[str], **kwargs) -> list[float]: |
| """Reward for correct output.""" |
| correctness_pass = kwargs.get("correctness_pass", []) |
| return [0.3 if correct else 0.0 for correct in correctness_pass] |
|
|
| def reward_speedup(completions: list[str], **kwargs) -> list[float]: |
| """Reward scaled by speedup achieved.""" |
| speedups = kwargs.get("speedup", []) |
| rewards = [] |
| for speedup in speedups: |
| if speedup is None or speedup <= 0: |
| rewards.append(0.0) |
| elif speedup <= 1.0: |
| |
| rewards.append(-0.1) |
| else: |
| |
| |
| bonus = min(0.3 * math.log2(speedup), 0.6) |
| rewards.append(0.3 + bonus) |
| return rewards |
|
|
| def reward_combined(completions: list[str], **kwargs) -> list[float]: |
| """Combined reward from all signals.""" |
| comp_rewards = reward_compilation(completions, **kwargs) |
| corr_rewards = reward_correctness(completions, **kwargs) |
| speed_rewards = reward_speedup(completions, **kwargs) |
| return [c + r + s for c, r, s in zip(comp_rewards, corr_rewards, speed_rewards)] |
|
|
| |
| |
| |
| |
|
|
| |
| SYSTEM_PROMPT = """You are an expert GPU kernel engineer specializing in CUDA and Triton. |
| |
| Your task is to optimize PyTorch operations by writing custom GPU kernels. |
| |
| Guidelines: |
| 1. Analyze the reference PyTorch implementation carefully |
| 2. Identify optimization opportunities (memory access patterns, parallelism, fusion) |
| 3. Write a Triton or CUDA kernel that computes the same result |
| 4. Ensure numerical correctness (outputs must match within tolerance) |
| |
| Output format: |
| - Provide a complete Python file |
| - Include a Model class with the same interface as the reference |
| - The Model.forward() method should use your optimized kernel |
| - Include all necessary imports (torch, triton, triton.language) |
| |
| Focus on: |
| - Coalesced memory access |
| - Efficient use of shared memory |
| - Minimizing thread divergence |
| - Optimal block/grid dimensions""" |
|
|
| |
| |
| |
| |
|
|
| |
| def make_prompt(problem_description: str, feedback: str = "") -> str: |
| """Create the user prompt for the model.""" |
| prompt = f"{problem_description}\n" |
| if feedback: |
| prompt += f"\n## Previous Attempt Feedback\n{feedback}\n" |
| prompt += "\nProvide your optimized kernel implementation:" |
| return prompt |
|
|
| def extract_code(completion: str) -> str: |
| """Extract code from model completion.""" |
| |
| if "```python" in completion: |
| start = completion.find("```python") + 9 |
| end = completion.find("```", start) |
| if end > start: |
| return completion[start:end].strip() |
| if "```" in completion: |
| start = completion.find("```") + 3 |
| end = completion.find("```", start) |
| if end > start: |
| return completion[start:end].strip() |
| |
| return completion.strip() |
|
|
| def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: |
| """ |
| Custom rollout function for kernrl environment. |
| |
| Generates kernel code and evaluates it to get rewards. |
| """ |
| |
| outputs = generate_rollout_completions(trainer, prompts) |
|
|
| completions_text = [ |
| tokenizer.decode(out["completion_ids"], skip_special_tokens=True) |
| for out in outputs |
| ] |
|
|
| |
| compilation_success = [] |
| correctness_pass = [] |
| speedups = [] |
|
|
| for completion in completions_text: |
| |
| obs = env.reset() |
|
|
| |
| code = extract_code(completion) |
| action = KernelAction(code=code) |
|
|
| try: |
| result = env.step(action) |
| obs = result.observation |
|
|
| compilation_success.append(obs.compilation_success) |
| correctness_pass.append(obs.correctness_pass or False) |
| speedups.append(obs.speedup) |
| except Exception as e: |
| print(f"Evaluation error: {e}") |
| compilation_success.append(False) |
| correctness_pass.append(False) |
| speedups.append(None) |
|
|
| return { |
| "prompt_ids": [out["prompt_ids"] for out in outputs], |
| "completion_ids": [out["completion_ids"] for out in outputs], |
| "logprobs": [out["logprobs"] for out in outputs], |
| |
| "compilation_success": compilation_success, |
| "correctness_pass": correctness_pass, |
| "speedup": speedups, |
| } |
|
|
| |
| |
| |
| |
|
|
| |
| def create_dataset(env: kernrl_env, levels: list[int] = [1, 2]) -> Dataset: |
| """Create training dataset from kernrl problems.""" |
| prompts = [] |
| problem_ids = [] |
|
|
| |
| all_problems = env.list_problems() |
|
|
| for problem_id in all_problems: |
| |
| level = int(problem_id.split("_")[0][1:]) |
| if level not in levels: |
| continue |
|
|
| |
| obs = env.reset(problem_id=problem_id) |
|
|
| |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": make_prompt(obs.problem_description)}, |
| ] |
| prompt = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=False, |
| ) |
|
|
| prompts.append(prompt) |
| problem_ids.append(problem_id) |
|
|
| return Dataset.from_dict({ |
| "prompt": prompts, |
| "problem_id": problem_ids, |
| }) |
|
|
| |
| dataset = create_dataset(env, levels=[1, 2]) |
| print(f"Created dataset with {len(dataset)} problems") |
|
|
| |
| |
| |
| |
|
|
| |
| |
| config = GRPOConfig( |
| output_dir="./kernrl_grpo_output", |
|
|
| |
| use_vllm=True, |
| vllm_mode="colocate", |
|
|
| |
| num_generations=4, |
| max_completion_length=2048, |
| temperature=0.7, |
|
|
| |
| num_train_epochs=3, |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=4, |
| learning_rate=1e-5, |
|
|
| |
| logging_steps=10, |
| save_steps=100, |
| report_to="wandb", |
| ) |
|
|
| |
| |
|
|
| |
| trainer = GRPOTrainer( |
| model=MODEL_ID, |
| processing_class=tokenizer, |
| reward_funcs=[ |
| reward_compilation, |
| reward_correctness, |
| reward_speedup, |
| ], |
| train_dataset=dataset, |
| rollout_func=rollout_func, |
| args=config, |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| trainer.train() |
|
|
| |
| trainer.save_model("./kernrl_trained_model") |
|
|
| |
| |
| |
| |
|
|
| |
| def evaluate_model(model_path: str, problem_ids: list[str]) -> dict: |
| """Evaluate a trained model on kernel optimization problems.""" |
| from transformers import AutoModelForCausalLM |
|
|
| model = AutoModelForCausalLM.from_pretrained(model_path) |
| model.eval() |
|
|
| results = [] |
|
|
| for problem_id in problem_ids: |
| obs = env.reset(problem_id=problem_id) |
|
|
| |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": make_prompt(obs.problem_description)}, |
| ] |
| prompt = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=False, |
| ) |
|
|
| inputs = tokenizer(prompt, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=2048, |
| temperature=0.3, |
| do_sample=True, |
| ) |
|
|
| completion = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| code = extract_code(completion) |
|
|
| |
| result = env.step(KernelAction(code=code)) |
| obs = result.observation |
|
|
| results.append({ |
| "problem_id": problem_id, |
| "compilation": obs.compilation_success, |
| "correctness": obs.correctness_pass, |
| "speedup": obs.speedup, |
| }) |
|
|
| print(f"{problem_id}: compile={obs.compilation_success}, " |
| f"correct={obs.correctness_pass}, speedup={obs.speedup:.2f}x" |
| if obs.speedup else f"{problem_id}: compile={obs.compilation_success}") |
|
|
| return results |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|