| from env_server import KernelOptimization_env, TASKS |
| from trl import GRPOConfig, GRPOTrainer |
| from models import Action |
| from typing import List |
| from datasets import Dataset |
| import os |
|
|
| class KernelOptTool: |
| def __init__(self): |
| self.env = KernelOptimization_env() |
| self.reward = 0.0 |
| self.done = False |
| |
| def reset(self, **kwargs) -> str: |
| task_id = kwargs.get("task_id") |
| if task_id is None and isinstance(kwargs.get("sample"), dict): |
| task_id = kwargs["sample"].get("task_id") |
| result = self.env.reset(task_id=task_id) |
| obs = result["observation"] |
| self.reward = 0.0 |
| self.done = False |
| return ( |
| f"Task: {obs['task_name']}\n" |
| f"Baseline CUDA kernel:\n{obs['baseline_code']}\n" |
| f"Pending checks: {obs['pending_checks']}\n" |
| "Use tools to submit improved code." |
| ) |
|
|
| def submit_optimization(self, optimized_code: str, strategy: str = "", expected_speedup: float | None = None) -> str: |
| if self.done: |
| raise ValueError("Episode is already done.") |
| result = self.env.step( |
| Action(optimized_code=optimized_code, strategy=strategy, expected_speedup=expected_speedup) |
| ) |
| self.reward = result.reward.value |
| self.done = result.done |
| obs = result.observation |
| return ( |
| f"reward={result.reward.value:.4f}, " |
| f"best_speedup={obs.current_best_speedup:.3f}x, " |
| f"pending_checks={obs.pending_checks}, done={result.done}" |
| ) |
|
|
| |
| def submit_optiization(self, optimized_code: str, strategy: str = "") -> str: |
| return self.submit_optimization(optimized_code=optimized_code, strategy=strategy) |
|
|
| def reward_func(environments, **kwargs) -> List[float]: |
| return [env.reward for env in environments] |
|
|
| def build_dataset(repeats_per_task:int=32)-> Dataset: |
| prompts, task_ids = [], [] |
| for task_id, task in TASKS.items(): |
| for _ in range(repeats_per_task): |
| prompts.append([{"role": "user", "content": f"Optimize CUDA kernel task: {task['name']}"}]) |
| task_ids.append(task_id) |
| return Dataset.from_dict({"prompt": prompts, "task_id": task_ids}) |
|
|
| def main(): |
| model_name =os.getenv("TRAIN_MODEL", "Qwen/Qwen3-0.6B") |
| dataset = build_dataset() |
| trainer = GRPOTrainer( |
| model=model_name, |
| train_dataset=dataset, |
| reward_funcs=reward_func, |
| environment_factory=KernelOptTool, |
| args=GRPOConfig( |
| chat_template_kwargs={"enable_thinking": False}, |
| max_completion_length=2048, |
| num_generations=4, |
| log_completions=True, |
| ), |
| ) |
| trainer.train() |
| |
|
|
| if __name__ == "__main__": |
| main() |
|
|