kernel_v1
Browse files- README.md +21 -13
- app.py +61 -0
- env_server.py +202 -0
- inference.py +73 -0
- models.py +49 -0
- openenv.yaml +55 -0
- openenv_train.py +70 -0
- pyproject.toml +24 -0
- requirements.txt +11 -0
- runtime.txt +1 -0
- server/__init__.py +1 -0
- server/__pycache__/__init__.cpython-312.pyc +0 -0
- server/__pycache__/app.cpython-312.pyc +0 -0
- server/app.py +10 -0
README.md
CHANGED
|
@@ -1,13 +1,21 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Kernel Writer
|
| 2 |
+
|
| 3 |
+
CUDA kernel optimization
|
| 4 |
+
|
| 5 |
+
## Run locally
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
pip install -r requirements.txt
|
| 9 |
+
python app.py
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
## Hugging Face Space setup
|
| 13 |
+
|
| 14 |
+
Set the OpenAI key in Space **Settings → Variables and secrets** as:
|
| 15 |
+
|
| 16 |
+
- `OPENAI_API_KEY`
|
| 17 |
+
|
| 18 |
+
Optional:
|
| 19 |
+
|
| 20 |
+
- `MODEL_NAME` (default: `gpt-4`)
|
| 21 |
+
- `API_BASE_URL` (default: `https://api.openai.com/v1`)
|
app.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from typing import Iterator, Tuple
|
| 4 |
+
from env_server import KernelOptimization_env, TASKS
|
| 5 |
+
from openai import OpenAI
|
| 6 |
+
from models import Action
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import traceback
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
def ui(task_id:str, max_steps:int, openai_api_key:str)-> Iterator[Tuple[str,str]]:
|
| 13 |
+
log= []
|
| 14 |
+
env=KernelOptimization_env()
|
| 15 |
+
api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
|
| 16 |
+
if not api_key:
|
| 17 |
+
yield "ERROR: Missing OPENAI_API_KEY", ""
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
model = os.getenv("MODEL_NAME", "gpt-4")
|
| 21 |
+
client = OpenAI(api_key=api_key, base_url=os.getenv("API_BASE_URL", "https://api.openai.com/v1"))
|
| 22 |
+
obs = env.reset(task_id=task_id)["observation"]
|
| 23 |
+
best_code = obs["current_best_code"]
|
| 24 |
+
log.append(f"Task: {obs['task_name']}")
|
| 25 |
+
|
| 26 |
+
for _ in range(max_steps):
|
| 27 |
+
try:
|
| 28 |
+
prompt = f"Optimize CUDA code:\n{obs['current_best_code']}\nPending checks: {obs['pending_checks']}\nReturn code only."
|
| 29 |
+
res = client.chat.completions.create(
|
| 30 |
+
model=model,
|
| 31 |
+
temperature=0.0,
|
| 32 |
+
messages=[
|
| 33 |
+
{"role": "system", "content": "Return only optimized CUDA code."},
|
| 34 |
+
{"role": "user", "content": prompt},
|
| 35 |
+
],
|
| 36 |
+
)
|
| 37 |
+
code = (res.choices[0].message.content or "").strip() or obs["current_best_code"]
|
| 38 |
+
step = env.step(Action(optimized_code=code, strategy="ui_proposed"))
|
| 39 |
+
obs = step.observation.model_dump()
|
| 40 |
+
best_code = obs["current_best_code"]
|
| 41 |
+
log.append(f"step={obs['step_count']} reward={step.reward.value:.3f} speedup={obs['current_best_speedup']:.3f}x")
|
| 42 |
+
yield "\n".join(log), best_code
|
| 43 |
+
if step.done:
|
| 44 |
+
break
|
| 45 |
+
except Exception as e:
|
| 46 |
+
yield f"{chr(10).join(log)}\nERROR: {e}\n{traceback.format_exc()}", best_code
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
with gr.Blocks(title="CUDA Kernel Optimizer") as demo:
|
| 50 |
+
gr.Markdown("CUDA Kernel Optimizer - OpenEnv-aligned workflow")
|
| 51 |
+
task = gr.Dropdown(choices=list(TASKS.keys()), value="vector_add_easy", label="Task")
|
| 52 |
+
steps = gr.Slider(minimum=1, maximum=12, value=6, step=1, label="Max Steps")
|
| 53 |
+
key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
|
| 54 |
+
run = gr.Button("Run Optimization", variant="primary")
|
| 55 |
+
logs = gr.Textbox(label="Logs", lines=14)
|
| 56 |
+
code = gr.Code(label="Best Code", language="cpp", lines=16)
|
| 57 |
+
run.click(ui, inputs=[task, steps, key], outputs=[logs, code])
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
env_server.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Dict, Any
|
| 2 |
+
from models import Action, StepResult, ResetRequest, StepRequest, EnvState, Observation, Reward
|
| 3 |
+
from fastapi import FastAPI, HTTPException
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
TASKS: Dict[str, Dict[str, Any]] ={
|
| 7 |
+
"vector_add_easy": {
|
| 8 |
+
"name": "Vector Addition Kernel Optimization",
|
| 9 |
+
"difficulty": "easy",
|
| 10 |
+
"max_steps": 5,
|
| 11 |
+
"target_speedup": 1.8,
|
| 12 |
+
"baseline_code": """extern "C" __global__ void vector_add(const float* a, const float* b, float* c, int n)
|
| 13 |
+
{
|
| 14 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 15 |
+
if (idx < n) c[idx] = a[idx] + b[idx];
|
| 16 |
+
}""",
|
| 17 |
+
"checks": {
|
| 18 |
+
"coalesced_memory": "Use memory-coalesced indexing",
|
| 19 |
+
"vectorized_loads": "Use vectorized loads/stores (float2/float4)",
|
| 20 |
+
"bounds_safe": "Keep safe boundary checks",
|
| 21 |
+
},
|
| 22 |
+
|
| 23 |
+
},
|
| 24 |
+
"matmul_medium": {
|
| 25 |
+
"name": "Matrix Multiplication Kernel Optimization",
|
| 26 |
+
"difficulty": "medium",
|
| 27 |
+
"max_steps": 6,
|
| 28 |
+
"target_speedup": 3.0,
|
| 29 |
+
"baseline_code": """extern "C" __global__ void matmul(const float* A, const float* B, float* C, int N)
|
| 30 |
+
{
|
| 31 |
+
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
| 32 |
+
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
| 33 |
+
if (row < N && col < N) {
|
| 34 |
+
float sum = 0.0f;
|
| 35 |
+
for (int k = 0; k < N; k++) sum += A[row * N + k] * B[k * N + col];
|
| 36 |
+
C[row * N + col] = sum;
|
| 37 |
+
}
|
| 38 |
+
}""",
|
| 39 |
+
"checks": {
|
| 40 |
+
"shared_tiling": "Use shared-memory tiling",
|
| 41 |
+
"synchronization": "Synchronize tiles with __syncthreads",
|
| 42 |
+
"register_accumulation": "Accumulate partial sums in registers",
|
| 43 |
+
},
|
| 44 |
+
},
|
| 45 |
+
"reduction_hard": {
|
| 46 |
+
"name": "Reduction Kernel Optimization",
|
| 47 |
+
"difficulty": "hard",
|
| 48 |
+
"max_steps":7,
|
| 49 |
+
"target_speedup": 3.5,
|
| 50 |
+
"baseline_code": """extern "C" __global__ void reduce_sum(const float* input, float* output, int n)
|
| 51 |
+
{
|
| 52 |
+
extern __shared__ float sdata[];
|
| 53 |
+
int tid = threadIdx.x;
|
| 54 |
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
| 55 |
+
sdata[tid] = (i < n) ? input[i] : 0.0f;
|
| 56 |
+
__syncthreads();
|
| 57 |
+
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
| 58 |
+
if (tid < s) sdata[tid] += sdata[tid + s];
|
| 59 |
+
__syncthreads();
|
| 60 |
+
}
|
| 61 |
+
if (tid == 0) output[blockIdx.x] = sdata[0];
|
| 62 |
+
}""",
|
| 63 |
+
"checks": {
|
| 64 |
+
"warp_primitive": "Use warp-level primitive (e.g., __shfl_down_sync)",
|
| 65 |
+
"bank_conflict_reduction": "Reduce shared-memory bank conflicts",
|
| 66 |
+
"unrolled_reduction": "Use partial unrolling for final reduction",
|
| 67 |
+
},
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def check_passed(check_id:str, code_lower:str) ->bool:
|
| 72 |
+
if check_id =="coalesced_memory":
|
| 73 |
+
return "idx" in code_lower and ("blockidx.x" in code_lower or "threadidx.x" in code_lower)
|
| 74 |
+
if check_id == "vectorized_loads":
|
| 75 |
+
return "float4" in code_lower or "float2" in code_lower
|
| 76 |
+
if check_id == "bounds_safe":
|
| 77 |
+
return "if" in code_lower and "< n" in code_lower
|
| 78 |
+
if check_id == "shared_tiling":
|
| 79 |
+
return "__shared__" in code_lower
|
| 80 |
+
if check_id == "synchronization":
|
| 81 |
+
return "__syncthreads" in code_lower
|
| 82 |
+
if check_id == "register_accumulation":
|
| 83 |
+
return "sum" in code_lower or "acc" in code_lower
|
| 84 |
+
if check_id == "warp_primitive":
|
| 85 |
+
return "__shfl_down_sync" in code_lower or "__shfl_sync" in code_lower
|
| 86 |
+
if check_id =="bank_conflict_reduction":
|
| 87 |
+
return "pad" in code_lower or "bank" in code_lower or "+ 1" in code_lower
|
| 88 |
+
if check_id == "unrolled_reduction":
|
| 89 |
+
return "#pragma unroll" in code_lower or "unroll" in code_lower
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
def to_observation(task_id:str, state:EnvState)->Observation:
|
| 93 |
+
task = TASKS[task_id]
|
| 94 |
+
pending = [desc for cid, desc in task["checks"].items() if cid not in set(state.completed_checks)]
|
| 95 |
+
return Observation(task_id=task_id, task_name=task["name"], difficulty=task["difficulty"], baseline_code=task["baseline_code"], current_best_code=state.best_code or task["baseline_code"], current_best_speedup=state.best_speedup, step_count=state.step_count, max_steps=state.max_steps, pending_checks=pending, completed_checks=[task["checks"][cid] for cid in state.completed_checks if cid in task["checks"]], done=(len(pending) == 0 or state.step_count >= state.max_steps))
|
| 96 |
+
|
| 97 |
+
def grade_episode(task_id:str, completed_checks:List[str], best_speedup:float, step_count:int, max_steps:int)->float:
|
| 98 |
+
task=TASKS[task_id]
|
| 99 |
+
completion =len(completed_checks) / max(len(task["checks"]),1)
|
| 100 |
+
speedup_score = min(best_speedup /task["target_speedup"],1.0)
|
| 101 |
+
efficiency = max(0.0, 1.0 - ((step_count - 1) / max(max_steps, 1)))
|
| 102 |
+
return round(max(0.0, min(1.0, 0.5 * completion + 0.35 * speedup_score + 0.15 * efficiency)), 4)
|
| 103 |
+
|
| 104 |
+
class KernelOptimization_env:
|
| 105 |
+
def __init__(self):
|
| 106 |
+
self.state =EnvState(initialized=False)
|
| 107 |
+
self.current_task_id: Optional[str]=None
|
| 108 |
+
|
| 109 |
+
def reset(self, task_id:Optional[str]=None)->Dict[str, Any]:
|
| 110 |
+
if task_id and task_id not in TASKS:
|
| 111 |
+
raise HTTPException(status_code=400, detail=f"unknown task_id: {task_id}")
|
| 112 |
+
self.current_task_id =task_id or random.choice(list(TASKS.keys()))
|
| 113 |
+
task= TASKS[self.current_task_id]
|
| 114 |
+
self.state =EnvState(initialized=True, task_id=self.current_task_id, step_count=0, max_steps=task["max_steps"], total_reward=0.0, best_code=task["baseline_code"], best_speedup=1.0, completed_checks=[], action_history=[])
|
| 115 |
+
|
| 116 |
+
return {"observation": to_observation(self.current_task_id, self.state).model_dump()}
|
| 117 |
+
|
| 118 |
+
def step(self, action:Action) ->StepResult:
|
| 119 |
+
if not self.state.initialized or not self.current_task_id:
|
| 120 |
+
raise HTTPException(status_code=400, detail="Environment not initialized. Call /reset first.")
|
| 121 |
+
|
| 122 |
+
self.state.step_count += 1
|
| 123 |
+
code = action.optimized_code or ""
|
| 124 |
+
code_lower = code.lower()
|
| 125 |
+
compile_ok = "__global__" in code_lower and "{" in code and "}" in code
|
| 126 |
+
|
| 127 |
+
completed = set(self.state.completed_checks)
|
| 128 |
+
newly_completed = {cid for cid in TASKS[self.current_task_id]["checks"] if cid not in completed and check_passed(cid, code_lower)}
|
| 129 |
+
completed.update(newly_completed)
|
| 130 |
+
self.state.completed_checks = sorted(completed)
|
| 131 |
+
|
| 132 |
+
est_speedup = self.current_task_id, completed, compile_ok
|
| 133 |
+
if est_speedup > self.state.best_speedup:
|
| 134 |
+
self.state.best_speedup = est_speedup
|
| 135 |
+
self.state.best_code = code
|
| 136 |
+
|
| 137 |
+
progress = 0.22 * len(newly_completed)
|
| 138 |
+
quality = 0.18 * min(self.state.best_speedup / TASKS[self.current_task_id]["target_speedup"], 1.0)
|
| 139 |
+
penalty = 0.0
|
| 140 |
+
if not compile_ok:
|
| 141 |
+
penalty -= 0.25
|
| 142 |
+
if not newly_completed:
|
| 143 |
+
penalty -= 0.08
|
| 144 |
+
reward_value = max(0.0, min(1.0, progress + quality + penalty))
|
| 145 |
+
self.state.total_reward += reward_value
|
| 146 |
+
|
| 147 |
+
self.state.action_history.append(
|
| 148 |
+
{
|
| 149 |
+
"step": self.state.step_count,
|
| 150 |
+
"newly_completed": sorted(newly_completed),
|
| 151 |
+
"compile_ok": compile_ok,
|
| 152 |
+
"estimated_speedup": est_speedup,
|
| 153 |
+
"reward": reward_value,
|
| 154 |
+
}
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
obs =to_observation(self.current_task_id, self.state)
|
| 158 |
+
info: Dict[str, Any] = { "compile_ok": compile_ok, "estimated_speedup": est_speedup}
|
| 159 |
+
if obs.done:
|
| 160 |
+
info["final_score"] = grade_episode(
|
| 161 |
+
self.current_task_id, self.state.completed_checks, self.state.best_speedup, self.state.step_count, self.state.max_steps
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return StepResult(
|
| 165 |
+
observation=obs,
|
| 166 |
+
reward=Reward(
|
| 167 |
+
value=round(reward_value, 4),
|
| 168 |
+
components={"progress": round(progress, 4), "quality": round(quality, 4), "penalty": round(penalty, 4)},
|
| 169 |
+
),
|
| 170 |
+
done=obs.done,
|
| 171 |
+
info=info,
|
| 172 |
+
)
|
| 173 |
+
def state_dict(self)->Dict[str, Any]:
|
| 174 |
+
data = self.state.model_dump()
|
| 175 |
+
if self.current_task_id:
|
| 176 |
+
data["task_name"] = TASKS[self.current_task_id]["name"]
|
| 177 |
+
data["difficulty"] = TASKS[self.current_task_id]["difficulty"]
|
| 178 |
+
data["grader_score"] = grade_episode(
|
| 179 |
+
self.current_task_id, self.state.completed_checks, self.state.best_speedup, self.state.step_count, self.state.max_steps
|
| 180 |
+
)
|
| 181 |
+
return data
|
| 182 |
+
|
| 183 |
+
env=KernelOptimization_env()
|
| 184 |
+
app=FastAPI(title="Kernel Optimization", version="1.0.0")
|
| 185 |
+
|
| 186 |
+
@app.get("/")
|
| 187 |
+
def health_check():
|
| 188 |
+
return {"status": "healthy", "service": "kernel-optimization-openenv"}
|
| 189 |
+
|
| 190 |
+
@app.post("/reset")
|
| 191 |
+
def reset(request: ResetRequest = ResetRequest()):
|
| 192 |
+
return env.reset(task_id=request.task_id)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@app.post("/step")
|
| 196 |
+
def step(request: StepRequest):
|
| 197 |
+
return env.step(request.action).model_dump()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@app.get("/state")
|
| 201 |
+
def state():
|
| 202 |
+
return env.state_dict()
|
inference.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from openai import OpenAI, AuthenticationError
|
| 3 |
+
from typing import Dict
|
| 4 |
+
from env_server import TASKS, KernelOptimization_env, grade_episode
|
| 5 |
+
from models import Action
|
| 6 |
+
import json
|
| 7 |
+
import sys
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
def extract_code(text: str) -> str:
|
| 12 |
+
if "```" not in text:
|
| 13 |
+
return text
|
| 14 |
+
start = text.find("```")
|
| 15 |
+
end = text.rfind("```")
|
| 16 |
+
chunk = text[start + 3 : end]
|
| 17 |
+
if chunk.startswith("cuda") or chunk.startswith("cpp"):
|
| 18 |
+
return chunk.split("\n", 1)[1]
|
| 19 |
+
return chunk
|
| 20 |
+
|
| 21 |
+
def choose_action(client: OpenAI, model: str, observation: Dict) -> Action:
|
| 22 |
+
prompt = f"""Optimize this CUDA kernel.
|
| 23 |
+
Task: {observation['task_name']}
|
| 24 |
+
Pending checks: {observation['pending_checks']}
|
| 25 |
+
Baseline:
|
| 26 |
+
{observation['baseline_code']}
|
| 27 |
+
Current best speedup: {observation['current_best_speedup']}x
|
| 28 |
+
Return only optimized CUDA code.
|
| 29 |
+
"""
|
| 30 |
+
response = client.chat.completions.create(
|
| 31 |
+
model=model,
|
| 32 |
+
temperature=0.0,
|
| 33 |
+
messages=[
|
| 34 |
+
{"role": "system", "content": "You are a CUDA optimization expert. Return code only."},
|
| 35 |
+
{"role": "user", "content": prompt},
|
| 36 |
+
],
|
| 37 |
+
)
|
| 38 |
+
text = (response.choices[0].message.content or "").strip()
|
| 39 |
+
code = extract_code(text).strip() or observation["current_best_code"]
|
| 40 |
+
return Action(optimized_code=code, strategy="llm_proposed")
|
| 41 |
+
|
| 42 |
+
def run_task(client: OpenAI, model: str, task_id: str) -> float:
|
| 43 |
+
env = KernelOptimization_env()
|
| 44 |
+
obs = env.reset(task_id=task_id)["observation"]
|
| 45 |
+
done = False
|
| 46 |
+
while not done:
|
| 47 |
+
action = choose_action(client, model, obs)
|
| 48 |
+
step_result = env.step(action)
|
| 49 |
+
obs = step_result.observation.model_dump()
|
| 50 |
+
done = step_result.done
|
| 51 |
+
return grade_episode(task_id, env.state.completed_checks, env.state.best_speedup, env.state.step_count, env.state.max_steps)
|
| 52 |
+
def main()->int:
|
| 53 |
+
if not os.getenv("OPENAI_API_KEY"):
|
| 54 |
+
print("openai key not set")
|
| 55 |
+
|
| 56 |
+
model =os.getenv("MODEL_NAME", "gemma-3-4b")
|
| 57 |
+
client =OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url =os.getenv("API_BASE_URL", "https://api.oxlo.ai/v1"))
|
| 58 |
+
|
| 59 |
+
scores: Dict[str, float] = {}
|
| 60 |
+
try:
|
| 61 |
+
for task_id in TASKS:
|
| 62 |
+
scores[task_id] = run_task(client, model, task_id)
|
| 63 |
+
print(f"[TASK] {task_id} score={scores[task_id]:.4f}")
|
| 64 |
+
except AuthenticationError:
|
| 65 |
+
print("ERROR: OpenAI authentication failed. Check OPENAI_API_KEY.", file=sys.stderr)
|
| 66 |
+
return 1
|
| 67 |
+
|
| 68 |
+
avg = sum(scores.values()) / len(scores)
|
| 69 |
+
print(f"[BASELINE] model={model} average_score={avg:.4f}")
|
| 70 |
+
print(json.dumps({"scores": scores, "average": round(avg, 4)}))
|
| 71 |
+
return 0
|
| 72 |
+
if __name__=="__main__":
|
| 73 |
+
sys.exit(main())
|
models.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, Dict, Literal, List, Any
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Action(BaseModel):
|
| 6 |
+
optimized_code: str
|
| 7 |
+
strategy: Optional[str] = None
|
| 8 |
+
expected_speedup: Optional[float] = None
|
| 9 |
+
|
| 10 |
+
class Reward(BaseModel):
|
| 11 |
+
value: float = Field(ge=0.0, le=1.0)
|
| 12 |
+
components: Dict[str, float]
|
| 13 |
+
|
| 14 |
+
class Observation(BaseModel):
|
| 15 |
+
task_id: str
|
| 16 |
+
task_name: str
|
| 17 |
+
difficulty: Literal["easy", "medium", "hard"]
|
| 18 |
+
baseline_code: str
|
| 19 |
+
current_best_code: str
|
| 20 |
+
current_best_speedup: float
|
| 21 |
+
step_count: int
|
| 22 |
+
max_steps: int
|
| 23 |
+
pending_checks: List[str]
|
| 24 |
+
completed_checks: List[str]
|
| 25 |
+
done: bool
|
| 26 |
+
|
| 27 |
+
class EnvState(BaseModel):
|
| 28 |
+
initialized: bool
|
| 29 |
+
task_id: Optional[str] =None
|
| 30 |
+
step_count: int = 0
|
| 31 |
+
max_steps: int = 0
|
| 32 |
+
total_reward: float = 0.0
|
| 33 |
+
best_code: str = ""
|
| 34 |
+
best_speedup: float = 1.0
|
| 35 |
+
completed_checks: List[str] = Field(default_factory=list)
|
| 36 |
+
action_history: List[Dict[str, Any]] = Field(default_factory=list)
|
| 37 |
+
|
| 38 |
+
class ResetRequest(BaseModel):
|
| 39 |
+
task_id: Optional[str] = None
|
| 40 |
+
|
| 41 |
+
class StepRequest(BaseModel):
|
| 42 |
+
action: Action
|
| 43 |
+
|
| 44 |
+
class StepResult(BaseModel):
|
| 45 |
+
observation:Observation
|
| 46 |
+
reward: Reward
|
| 47 |
+
done: bool
|
| 48 |
+
info: Dict[str, Any]
|
| 49 |
+
|
openenv.yaml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: kernel_writer
|
| 2 |
+
version: 1.0.0
|
| 3 |
+
description: |
|
| 4 |
+
Real world CUDA kernel engineering environment for iterative optimization, code review checks and performance driven reward shaping.
|
| 5 |
+
|
| 6 |
+
environment:
|
| 7 |
+
type: code_optimization
|
| 8 |
+
runtime: python3.12.3
|
| 9 |
+
containerized: true
|
| 10 |
+
|
| 11 |
+
metadata:
|
| 12 |
+
tags:
|
| 13 |
+
- openenv
|
| 14 |
+
- CUDA
|
| 15 |
+
- kernel_optimization
|
| 16 |
+
- reinforcement_learning
|
| 17 |
+
author: aaloksan
|
| 18 |
+
|
| 19 |
+
tasks:
|
| 20 |
+
- id: vector_addition_easy
|
| 21 |
+
name: "Vector Addition Kernel Optimization"
|
| 22 |
+
difficulty: easy
|
| 23 |
+
objective: "Improve memory throughput while preserving correctness."
|
| 24 |
+
grader: deterministic_rule_based
|
| 25 |
+
|
| 26 |
+
- id: matmul_medium
|
| 27 |
+
name: "Matrix Multiplication Kernel Optimization"
|
| 28 |
+
difficulty: medium
|
| 29 |
+
objective: "Apply shared-memory tiling and synchronization safely."
|
| 30 |
+
grader: deterministic_rule_based
|
| 31 |
+
|
| 32 |
+
- id: reduction_hard
|
| 33 |
+
name: "Reduction Kernel Optimization"
|
| 34 |
+
difficulty: hard
|
| 35 |
+
objective: "Use warp-level optimization and reduce memory conflicts."
|
| 36 |
+
grader: deterministic_rule_based
|
| 37 |
+
|
| 38 |
+
interfaces:
|
| 39 |
+
reset:
|
| 40 |
+
method: POST
|
| 41 |
+
path: /reset
|
| 42 |
+
returns: initial observation and info
|
| 43 |
+
step:
|
| 44 |
+
method: POST
|
| 45 |
+
path: /step
|
| 46 |
+
returns: observation, reward, done, info
|
| 47 |
+
state:
|
| 48 |
+
method: GET
|
| 49 |
+
path: /state
|
| 50 |
+
returns: current environment state
|
| 51 |
+
|
| 52 |
+
baseline:
|
| 53 |
+
script: inference.py
|
| 54 |
+
model_env_var: MODEL_NAME
|
| 55 |
+
api_key_env_var: OPENAI_API_KEY
|
openenv_train.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from env_server import KernelOptimization_env, TASKS
|
| 2 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 3 |
+
from models import Action
|
| 4 |
+
from typing import List
|
| 5 |
+
from datasets import Dataset
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class KernelOptTool:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.env = KernelOptimization_env()
|
| 11 |
+
self.reward = 0.0
|
| 12 |
+
self.done = False
|
| 13 |
+
|
| 14 |
+
def reset(self, **kwargs) ->str|None:
|
| 15 |
+
task_id =kwargs.get("task_id")
|
| 16 |
+
result = self.env.reset(task_id=task_id)
|
| 17 |
+
obs = result["observation"]
|
| 18 |
+
self.reward = 0.0
|
| 19 |
+
self.done = False
|
| 20 |
+
return (
|
| 21 |
+
f"Task: {obs['task_name']}\n"
|
| 22 |
+
f"Baseline CUDA kernel:\n{obs['baseline_code']}\n"
|
| 23 |
+
f"Pending checks: {obs['pending_checks']}\n"
|
| 24 |
+
"Use tools to submit improved code."
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def submit_optiization(self, optimized_code:str, strategy:str ="")->str:
|
| 28 |
+
if self.done:
|
| 29 |
+
raise ValueError("Episode is already done.")
|
| 30 |
+
result = self.env.step(Action(optimized_code=optimized_code, strategy=strategy))
|
| 31 |
+
self.reward = result.reward.value
|
| 32 |
+
self.done = result.done
|
| 33 |
+
obs = result.observation
|
| 34 |
+
return (
|
| 35 |
+
f"reward={result.reward.value:.4f}, "
|
| 36 |
+
f"best_speedup={obs.current_best_speedup:.3f}x, "
|
| 37 |
+
f"pending_checks={obs.pending_checks}, done={result.done}"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def reward_func(environmnets, **kwargs)-> List[float]:
|
| 41 |
+
return [env.reward for env in environmnets]
|
| 42 |
+
|
| 43 |
+
def build_dataset(repeats_per_task:int=32)-> Dataset:
|
| 44 |
+
prompts, task_ids = [], []
|
| 45 |
+
for task_id, task in TASKS.items():
|
| 46 |
+
for _ in range(repeats_per_task):
|
| 47 |
+
prompts.append([{"role": "user", "content": f"Optimize CUDA kernel task: {task['name']}"}])
|
| 48 |
+
task_ids.append(task_id)
|
| 49 |
+
return Dataset.from_dict({"prompt": prompts, "task_id": task_ids})
|
| 50 |
+
|
| 51 |
+
def main():
|
| 52 |
+
model_name =os.getenv("TRAIN_MODEL", "Qwen/Qwen3-0.6B")
|
| 53 |
+
dataset = build_dataset()
|
| 54 |
+
trainer = GRPOTrainer(
|
| 55 |
+
model=model_name,
|
| 56 |
+
train_dataset=dataset,
|
| 57 |
+
reward_funcs=reward_func,
|
| 58 |
+
environment_factory=KernelOptTool,
|
| 59 |
+
args=GRPOConfig(
|
| 60 |
+
chat_template_kwargs={"enable_thinking": False},
|
| 61 |
+
max_completion_length=2048,
|
| 62 |
+
num_generations=4,
|
| 63 |
+
log_completions=True,
|
| 64 |
+
),
|
| 65 |
+
)
|
| 66 |
+
trainer.train()
|
| 67 |
+
# trainer = GRPOTrainer(model =model_name, train_dataset=dataset, reward_funcs =reward_func, env_factory=KernelOptTool)
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "kernel_writer"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "OpenEnv-compatible CUDA kernel optimization environment."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"fastapi>=0.110.0",
|
| 13 |
+
"uvicorn>=0.30.0",
|
| 14 |
+
"pydantic>=2.7.0",
|
| 15 |
+
"openai>=1.0.0",
|
| 16 |
+
"openenv-core>=0.2.0",
|
| 17 |
+
"python-dotenv>=1.0.0",
|
| 18 |
+
"gradio>=4.44.0",
|
| 19 |
+
"datasets>=2.20.0",
|
| 20 |
+
"trl>=0.12.0"
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project.scripts]
|
| 24 |
+
server = "server.app:main"
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
uvicorn>=0.30.0
|
| 3 |
+
pydantic>=2.7.0
|
| 4 |
+
openai>=1.0.0
|
| 5 |
+
openenv-core>=0.2.0
|
| 6 |
+
gradio>=4.44.0
|
| 7 |
+
datasets>=2.20.0
|
| 8 |
+
trl>=0.12.0
|
| 9 |
+
pytest>=7.4.0
|
| 10 |
+
pyyaml>=6.0.0
|
| 11 |
+
python-dotenv
|
runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.10
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .app import app
|
server/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
server/__pycache__/app.cpython-312.pyc
ADDED
|
Binary file (595 Bytes). View file
|
|
|
server/app.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uvicorn
|
| 3 |
+
from env_server import app
|
| 4 |
+
|
| 5 |
+
def main():
|
| 6 |
+
uvicorn.run("server.app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
main()
|