kernel / openenv_train.py
aaloksan's picture
fix: dervies speedup now
d5c6f39
from env_server import KernelOptimization_env, TASKS
from trl import GRPOConfig, GRPOTrainer
from models import Action
from typing import List
from datasets import Dataset
import os
class KernelOptTool:
def __init__(self):
self.env = KernelOptimization_env()
self.reward = 0.0
self.done = False
def reset(self, **kwargs) -> str:
task_id = kwargs.get("task_id")
if task_id is None and isinstance(kwargs.get("sample"), dict):
task_id = kwargs["sample"].get("task_id")
result = self.env.reset(task_id=task_id)
obs = result["observation"]
self.reward = 0.0
self.done = False
return (
f"Task: {obs['task_name']}\n"
f"Baseline CUDA kernel:\n{obs['baseline_code']}\n"
f"Pending checks: {obs['pending_checks']}\n"
"Use tools to submit improved code."
)
def submit_optimization(self, optimized_code: str, strategy: str = "", expected_speedup: float | None = None) -> str:
if self.done:
raise ValueError("Episode is already done.")
result = self.env.step(
Action(optimized_code=optimized_code, strategy=strategy, expected_speedup=expected_speedup)
)
self.reward = result.reward.value
self.done = result.done
obs = result.observation
return (
f"reward={result.reward.value:.4f}, "
f"best_speedup={obs.current_best_speedup:.3f}x, "
f"pending_checks={obs.pending_checks}, done={result.done}"
)
# Backward-compatible alias
def submit_optiization(self, optimized_code: str, strategy: str = "") -> str:
return self.submit_optimization(optimized_code=optimized_code, strategy=strategy)
def reward_func(environments, **kwargs) -> List[float]:
return [env.reward for env in environments]
def build_dataset(repeats_per_task:int=32)-> Dataset:
prompts, task_ids = [], []
for task_id, task in TASKS.items():
for _ in range(repeats_per_task):
prompts.append([{"role": "user", "content": f"Optimize CUDA kernel task: {task['name']}"}])
task_ids.append(task_id)
return Dataset.from_dict({"prompt": prompts, "task_id": task_ids})
def main():
model_name =os.getenv("TRAIN_MODEL", "Qwen/Qwen3-0.6B")
dataset = build_dataset()
trainer = GRPOTrainer(
model=model_name,
train_dataset=dataset,
reward_funcs=reward_func,
environment_factory=KernelOptTool,
args=GRPOConfig(
chat_template_kwargs={"enable_thinking": False},
max_completion_length=2048,
num_generations=4,
log_completions=True,
),
)
trainer.train()
# trainer = GRPOTrainer(model =model_name, train_dataset=dataset, reward_funcs =reward_func, env_factory=KernelOptTool)
if __name__ == "__main__":
main()