aaloksan commited on
Commit
c780f59
·
1 Parent(s): 0108333

kernel_v1

Browse files
README.md CHANGED
@@ -1,13 +1,21 @@
1
- ---
2
- title: Kernel
3
- emoji: 🏆
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 6.12.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: optimizes kernel code
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
+ # Kernel Writer
2
+
3
+ CUDA kernel optimization
4
+
5
+ ## Run locally
6
+
7
+ ```bash
8
+ pip install -r requirements.txt
9
+ python app.py
10
+ ```
11
+
12
+ ## Hugging Face Space setup
13
+
14
+ Set the OpenAI key in Space **Settings → Variables and secrets** as:
15
+
16
+ - `OPENAI_API_KEY`
17
+
18
+ Optional:
19
+
20
+ - `MODEL_NAME` (default: `gpt-4`)
21
+ - `API_BASE_URL` (default: `https://api.openai.com/v1`)
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from typing import Iterator, Tuple
4
+ from env_server import KernelOptimization_env, TASKS
5
+ from openai import OpenAI
6
+ from models import Action
7
+ import gradio as gr
8
+ import traceback
9
+
10
+ load_dotenv()
11
+
12
+ def ui(task_id:str, max_steps:int, openai_api_key:str)-> Iterator[Tuple[str,str]]:
13
+ log= []
14
+ env=KernelOptimization_env()
15
+ api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
16
+ if not api_key:
17
+ yield "ERROR: Missing OPENAI_API_KEY", ""
18
+ return
19
+
20
+ model = os.getenv("MODEL_NAME", "gpt-4")
21
+ client = OpenAI(api_key=api_key, base_url=os.getenv("API_BASE_URL", "https://api.openai.com/v1"))
22
+ obs = env.reset(task_id=task_id)["observation"]
23
+ best_code = obs["current_best_code"]
24
+ log.append(f"Task: {obs['task_name']}")
25
+
26
+ for _ in range(max_steps):
27
+ try:
28
+ prompt = f"Optimize CUDA code:\n{obs['current_best_code']}\nPending checks: {obs['pending_checks']}\nReturn code only."
29
+ res = client.chat.completions.create(
30
+ model=model,
31
+ temperature=0.0,
32
+ messages=[
33
+ {"role": "system", "content": "Return only optimized CUDA code."},
34
+ {"role": "user", "content": prompt},
35
+ ],
36
+ )
37
+ code = (res.choices[0].message.content or "").strip() or obs["current_best_code"]
38
+ step = env.step(Action(optimized_code=code, strategy="ui_proposed"))
39
+ obs = step.observation.model_dump()
40
+ best_code = obs["current_best_code"]
41
+ log.append(f"step={obs['step_count']} reward={step.reward.value:.3f} speedup={obs['current_best_speedup']:.3f}x")
42
+ yield "\n".join(log), best_code
43
+ if step.done:
44
+ break
45
+ except Exception as e:
46
+ yield f"{chr(10).join(log)}\nERROR: {e}\n{traceback.format_exc()}", best_code
47
+ return
48
+
49
+ with gr.Blocks(title="CUDA Kernel Optimizer") as demo:
50
+ gr.Markdown("CUDA Kernel Optimizer - OpenEnv-aligned workflow")
51
+ task = gr.Dropdown(choices=list(TASKS.keys()), value="vector_add_easy", label="Task")
52
+ steps = gr.Slider(minimum=1, maximum=12, value=6, step=1, label="Max Steps")
53
+ key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
54
+ run = gr.Button("Run Optimization", variant="primary")
55
+ logs = gr.Textbox(label="Logs", lines=14)
56
+ code = gr.Code(label="Best Code", language="cpp", lines=16)
57
+ run.click(ui, inputs=[task, steps, key], outputs=[logs, code])
58
+
59
+
60
+ if __name__ == "__main__":
61
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
env_server.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict, Any
2
+ from models import Action, StepResult, ResetRequest, StepRequest, EnvState, Observation, Reward
3
+ from fastapi import FastAPI, HTTPException
4
+ import random
5
+
6
+ TASKS: Dict[str, Dict[str, Any]] ={
7
+ "vector_add_easy": {
8
+ "name": "Vector Addition Kernel Optimization",
9
+ "difficulty": "easy",
10
+ "max_steps": 5,
11
+ "target_speedup": 1.8,
12
+ "baseline_code": """extern "C" __global__ void vector_add(const float* a, const float* b, float* c, int n)
13
+ {
14
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
15
+ if (idx < n) c[idx] = a[idx] + b[idx];
16
+ }""",
17
+ "checks": {
18
+ "coalesced_memory": "Use memory-coalesced indexing",
19
+ "vectorized_loads": "Use vectorized loads/stores (float2/float4)",
20
+ "bounds_safe": "Keep safe boundary checks",
21
+ },
22
+
23
+ },
24
+ "matmul_medium": {
25
+ "name": "Matrix Multiplication Kernel Optimization",
26
+ "difficulty": "medium",
27
+ "max_steps": 6,
28
+ "target_speedup": 3.0,
29
+ "baseline_code": """extern "C" __global__ void matmul(const float* A, const float* B, float* C, int N)
30
+ {
31
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
32
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
33
+ if (row < N && col < N) {
34
+ float sum = 0.0f;
35
+ for (int k = 0; k < N; k++) sum += A[row * N + k] * B[k * N + col];
36
+ C[row * N + col] = sum;
37
+ }
38
+ }""",
39
+ "checks": {
40
+ "shared_tiling": "Use shared-memory tiling",
41
+ "synchronization": "Synchronize tiles with __syncthreads",
42
+ "register_accumulation": "Accumulate partial sums in registers",
43
+ },
44
+ },
45
+ "reduction_hard": {
46
+ "name": "Reduction Kernel Optimization",
47
+ "difficulty": "hard",
48
+ "max_steps":7,
49
+ "target_speedup": 3.5,
50
+ "baseline_code": """extern "C" __global__ void reduce_sum(const float* input, float* output, int n)
51
+ {
52
+ extern __shared__ float sdata[];
53
+ int tid = threadIdx.x;
54
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
55
+ sdata[tid] = (i < n) ? input[i] : 0.0f;
56
+ __syncthreads();
57
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
58
+ if (tid < s) sdata[tid] += sdata[tid + s];
59
+ __syncthreads();
60
+ }
61
+ if (tid == 0) output[blockIdx.x] = sdata[0];
62
+ }""",
63
+ "checks": {
64
+ "warp_primitive": "Use warp-level primitive (e.g., __shfl_down_sync)",
65
+ "bank_conflict_reduction": "Reduce shared-memory bank conflicts",
66
+ "unrolled_reduction": "Use partial unrolling for final reduction",
67
+ },
68
+ }
69
+ }
70
+
71
+ def check_passed(check_id:str, code_lower:str) ->bool:
72
+ if check_id =="coalesced_memory":
73
+ return "idx" in code_lower and ("blockidx.x" in code_lower or "threadidx.x" in code_lower)
74
+ if check_id == "vectorized_loads":
75
+ return "float4" in code_lower or "float2" in code_lower
76
+ if check_id == "bounds_safe":
77
+ return "if" in code_lower and "< n" in code_lower
78
+ if check_id == "shared_tiling":
79
+ return "__shared__" in code_lower
80
+ if check_id == "synchronization":
81
+ return "__syncthreads" in code_lower
82
+ if check_id == "register_accumulation":
83
+ return "sum" in code_lower or "acc" in code_lower
84
+ if check_id == "warp_primitive":
85
+ return "__shfl_down_sync" in code_lower or "__shfl_sync" in code_lower
86
+ if check_id =="bank_conflict_reduction":
87
+ return "pad" in code_lower or "bank" in code_lower or "+ 1" in code_lower
88
+ if check_id == "unrolled_reduction":
89
+ return "#pragma unroll" in code_lower or "unroll" in code_lower
90
+ return False
91
+
92
+ def to_observation(task_id:str, state:EnvState)->Observation:
93
+ task = TASKS[task_id]
94
+ pending = [desc for cid, desc in task["checks"].items() if cid not in set(state.completed_checks)]
95
+ return Observation(task_id=task_id, task_name=task["name"], difficulty=task["difficulty"], baseline_code=task["baseline_code"], current_best_code=state.best_code or task["baseline_code"], current_best_speedup=state.best_speedup, step_count=state.step_count, max_steps=state.max_steps, pending_checks=pending, completed_checks=[task["checks"][cid] for cid in state.completed_checks if cid in task["checks"]], done=(len(pending) == 0 or state.step_count >= state.max_steps))
96
+
97
+ def grade_episode(task_id:str, completed_checks:List[str], best_speedup:float, step_count:int, max_steps:int)->float:
98
+ task=TASKS[task_id]
99
+ completion =len(completed_checks) / max(len(task["checks"]),1)
100
+ speedup_score = min(best_speedup /task["target_speedup"],1.0)
101
+ efficiency = max(0.0, 1.0 - ((step_count - 1) / max(max_steps, 1)))
102
+ return round(max(0.0, min(1.0, 0.5 * completion + 0.35 * speedup_score + 0.15 * efficiency)), 4)
103
+
104
+ class KernelOptimization_env:
105
+ def __init__(self):
106
+ self.state =EnvState(initialized=False)
107
+ self.current_task_id: Optional[str]=None
108
+
109
+ def reset(self, task_id:Optional[str]=None)->Dict[str, Any]:
110
+ if task_id and task_id not in TASKS:
111
+ raise HTTPException(status_code=400, detail=f"unknown task_id: {task_id}")
112
+ self.current_task_id =task_id or random.choice(list(TASKS.keys()))
113
+ task= TASKS[self.current_task_id]
114
+ self.state =EnvState(initialized=True, task_id=self.current_task_id, step_count=0, max_steps=task["max_steps"], total_reward=0.0, best_code=task["baseline_code"], best_speedup=1.0, completed_checks=[], action_history=[])
115
+
116
+ return {"observation": to_observation(self.current_task_id, self.state).model_dump()}
117
+
118
+ def step(self, action:Action) ->StepResult:
119
+ if not self.state.initialized or not self.current_task_id:
120
+ raise HTTPException(status_code=400, detail="Environment not initialized. Call /reset first.")
121
+
122
+ self.state.step_count += 1
123
+ code = action.optimized_code or ""
124
+ code_lower = code.lower()
125
+ compile_ok = "__global__" in code_lower and "{" in code and "}" in code
126
+
127
+ completed = set(self.state.completed_checks)
128
+ newly_completed = {cid for cid in TASKS[self.current_task_id]["checks"] if cid not in completed and check_passed(cid, code_lower)}
129
+ completed.update(newly_completed)
130
+ self.state.completed_checks = sorted(completed)
131
+
132
+ est_speedup = self.current_task_id, completed, compile_ok
133
+ if est_speedup > self.state.best_speedup:
134
+ self.state.best_speedup = est_speedup
135
+ self.state.best_code = code
136
+
137
+ progress = 0.22 * len(newly_completed)
138
+ quality = 0.18 * min(self.state.best_speedup / TASKS[self.current_task_id]["target_speedup"], 1.0)
139
+ penalty = 0.0
140
+ if not compile_ok:
141
+ penalty -= 0.25
142
+ if not newly_completed:
143
+ penalty -= 0.08
144
+ reward_value = max(0.0, min(1.0, progress + quality + penalty))
145
+ self.state.total_reward += reward_value
146
+
147
+ self.state.action_history.append(
148
+ {
149
+ "step": self.state.step_count,
150
+ "newly_completed": sorted(newly_completed),
151
+ "compile_ok": compile_ok,
152
+ "estimated_speedup": est_speedup,
153
+ "reward": reward_value,
154
+ }
155
+ )
156
+
157
+ obs =to_observation(self.current_task_id, self.state)
158
+ info: Dict[str, Any] = { "compile_ok": compile_ok, "estimated_speedup": est_speedup}
159
+ if obs.done:
160
+ info["final_score"] = grade_episode(
161
+ self.current_task_id, self.state.completed_checks, self.state.best_speedup, self.state.step_count, self.state.max_steps
162
+ )
163
+
164
+ return StepResult(
165
+ observation=obs,
166
+ reward=Reward(
167
+ value=round(reward_value, 4),
168
+ components={"progress": round(progress, 4), "quality": round(quality, 4), "penalty": round(penalty, 4)},
169
+ ),
170
+ done=obs.done,
171
+ info=info,
172
+ )
173
+ def state_dict(self)->Dict[str, Any]:
174
+ data = self.state.model_dump()
175
+ if self.current_task_id:
176
+ data["task_name"] = TASKS[self.current_task_id]["name"]
177
+ data["difficulty"] = TASKS[self.current_task_id]["difficulty"]
178
+ data["grader_score"] = grade_episode(
179
+ self.current_task_id, self.state.completed_checks, self.state.best_speedup, self.state.step_count, self.state.max_steps
180
+ )
181
+ return data
182
+
183
+ env=KernelOptimization_env()
184
+ app=FastAPI(title="Kernel Optimization", version="1.0.0")
185
+
186
+ @app.get("/")
187
+ def health_check():
188
+ return {"status": "healthy", "service": "kernel-optimization-openenv"}
189
+
190
+ @app.post("/reset")
191
+ def reset(request: ResetRequest = ResetRequest()):
192
+ return env.reset(task_id=request.task_id)
193
+
194
+
195
+ @app.post("/step")
196
+ def step(request: StepRequest):
197
+ return env.step(request.action).model_dump()
198
+
199
+
200
+ @app.get("/state")
201
+ def state():
202
+ return env.state_dict()
inference.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI, AuthenticationError
3
+ from typing import Dict
4
+ from env_server import TASKS, KernelOptimization_env, grade_episode
5
+ from models import Action
6
+ import json
7
+ import sys
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+ def extract_code(text: str) -> str:
12
+ if "```" not in text:
13
+ return text
14
+ start = text.find("```")
15
+ end = text.rfind("```")
16
+ chunk = text[start + 3 : end]
17
+ if chunk.startswith("cuda") or chunk.startswith("cpp"):
18
+ return chunk.split("\n", 1)[1]
19
+ return chunk
20
+
21
+ def choose_action(client: OpenAI, model: str, observation: Dict) -> Action:
22
+ prompt = f"""Optimize this CUDA kernel.
23
+ Task: {observation['task_name']}
24
+ Pending checks: {observation['pending_checks']}
25
+ Baseline:
26
+ {observation['baseline_code']}
27
+ Current best speedup: {observation['current_best_speedup']}x
28
+ Return only optimized CUDA code.
29
+ """
30
+ response = client.chat.completions.create(
31
+ model=model,
32
+ temperature=0.0,
33
+ messages=[
34
+ {"role": "system", "content": "You are a CUDA optimization expert. Return code only."},
35
+ {"role": "user", "content": prompt},
36
+ ],
37
+ )
38
+ text = (response.choices[0].message.content or "").strip()
39
+ code = extract_code(text).strip() or observation["current_best_code"]
40
+ return Action(optimized_code=code, strategy="llm_proposed")
41
+
42
+ def run_task(client: OpenAI, model: str, task_id: str) -> float:
43
+ env = KernelOptimization_env()
44
+ obs = env.reset(task_id=task_id)["observation"]
45
+ done = False
46
+ while not done:
47
+ action = choose_action(client, model, obs)
48
+ step_result = env.step(action)
49
+ obs = step_result.observation.model_dump()
50
+ done = step_result.done
51
+ return grade_episode(task_id, env.state.completed_checks, env.state.best_speedup, env.state.step_count, env.state.max_steps)
52
+ def main()->int:
53
+ if not os.getenv("OPENAI_API_KEY"):
54
+ print("openai key not set")
55
+
56
+ model =os.getenv("MODEL_NAME", "gemma-3-4b")
57
+ client =OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url =os.getenv("API_BASE_URL", "https://api.oxlo.ai/v1"))
58
+
59
+ scores: Dict[str, float] = {}
60
+ try:
61
+ for task_id in TASKS:
62
+ scores[task_id] = run_task(client, model, task_id)
63
+ print(f"[TASK] {task_id} score={scores[task_id]:.4f}")
64
+ except AuthenticationError:
65
+ print("ERROR: OpenAI authentication failed. Check OPENAI_API_KEY.", file=sys.stderr)
66
+ return 1
67
+
68
+ avg = sum(scores.values()) / len(scores)
69
+ print(f"[BASELINE] model={model} average_score={avg:.4f}")
70
+ print(json.dumps({"scores": scores, "average": round(avg, 4)}))
71
+ return 0
72
+ if __name__=="__main__":
73
+ sys.exit(main())
models.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, Dict, Literal, List, Any
3
+
4
+
5
+ class Action(BaseModel):
6
+ optimized_code: str
7
+ strategy: Optional[str] = None
8
+ expected_speedup: Optional[float] = None
9
+
10
+ class Reward(BaseModel):
11
+ value: float = Field(ge=0.0, le=1.0)
12
+ components: Dict[str, float]
13
+
14
+ class Observation(BaseModel):
15
+ task_id: str
16
+ task_name: str
17
+ difficulty: Literal["easy", "medium", "hard"]
18
+ baseline_code: str
19
+ current_best_code: str
20
+ current_best_speedup: float
21
+ step_count: int
22
+ max_steps: int
23
+ pending_checks: List[str]
24
+ completed_checks: List[str]
25
+ done: bool
26
+
27
+ class EnvState(BaseModel):
28
+ initialized: bool
29
+ task_id: Optional[str] =None
30
+ step_count: int = 0
31
+ max_steps: int = 0
32
+ total_reward: float = 0.0
33
+ best_code: str = ""
34
+ best_speedup: float = 1.0
35
+ completed_checks: List[str] = Field(default_factory=list)
36
+ action_history: List[Dict[str, Any]] = Field(default_factory=list)
37
+
38
+ class ResetRequest(BaseModel):
39
+ task_id: Optional[str] = None
40
+
41
+ class StepRequest(BaseModel):
42
+ action: Action
43
+
44
+ class StepResult(BaseModel):
45
+ observation:Observation
46
+ reward: Reward
47
+ done: bool
48
+ info: Dict[str, Any]
49
+
openenv.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: kernel_writer
2
+ version: 1.0.0
3
+ description: |
4
+ Real world CUDA kernel engineering environment for iterative optimization, code review checks and performance driven reward shaping.
5
+
6
+ environment:
7
+ type: code_optimization
8
+ runtime: python3.12.3
9
+ containerized: true
10
+
11
+ metadata:
12
+ tags:
13
+ - openenv
14
+ - CUDA
15
+ - kernel_optimization
16
+ - reinforcement_learning
17
+ author: aaloksan
18
+
19
+ tasks:
20
+ - id: vector_addition_easy
21
+ name: "Vector Addition Kernel Optimization"
22
+ difficulty: easy
23
+ objective: "Improve memory throughput while preserving correctness."
24
+ grader: deterministic_rule_based
25
+
26
+ - id: matmul_medium
27
+ name: "Matrix Multiplication Kernel Optimization"
28
+ difficulty: medium
29
+ objective: "Apply shared-memory tiling and synchronization safely."
30
+ grader: deterministic_rule_based
31
+
32
+ - id: reduction_hard
33
+ name: "Reduction Kernel Optimization"
34
+ difficulty: hard
35
+ objective: "Use warp-level optimization and reduce memory conflicts."
36
+ grader: deterministic_rule_based
37
+
38
+ interfaces:
39
+ reset:
40
+ method: POST
41
+ path: /reset
42
+ returns: initial observation and info
43
+ step:
44
+ method: POST
45
+ path: /step
46
+ returns: observation, reward, done, info
47
+ state:
48
+ method: GET
49
+ path: /state
50
+ returns: current environment state
51
+
52
+ baseline:
53
+ script: inference.py
54
+ model_env_var: MODEL_NAME
55
+ api_key_env_var: OPENAI_API_KEY
openenv_train.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from env_server import KernelOptimization_env, TASKS
2
+ from trl import GRPOConfig, GRPOTrainer
3
+ from models import Action
4
+ from typing import List
5
+ from datasets import Dataset
6
+ import os
7
+
8
+ class KernelOptTool:
9
+ def __init__(self):
10
+ self.env = KernelOptimization_env()
11
+ self.reward = 0.0
12
+ self.done = False
13
+
14
+ def reset(self, **kwargs) ->str|None:
15
+ task_id =kwargs.get("task_id")
16
+ result = self.env.reset(task_id=task_id)
17
+ obs = result["observation"]
18
+ self.reward = 0.0
19
+ self.done = False
20
+ return (
21
+ f"Task: {obs['task_name']}\n"
22
+ f"Baseline CUDA kernel:\n{obs['baseline_code']}\n"
23
+ f"Pending checks: {obs['pending_checks']}\n"
24
+ "Use tools to submit improved code."
25
+ )
26
+
27
+ def submit_optiization(self, optimized_code:str, strategy:str ="")->str:
28
+ if self.done:
29
+ raise ValueError("Episode is already done.")
30
+ result = self.env.step(Action(optimized_code=optimized_code, strategy=strategy))
31
+ self.reward = result.reward.value
32
+ self.done = result.done
33
+ obs = result.observation
34
+ return (
35
+ f"reward={result.reward.value:.4f}, "
36
+ f"best_speedup={obs.current_best_speedup:.3f}x, "
37
+ f"pending_checks={obs.pending_checks}, done={result.done}"
38
+ )
39
+
40
+ def reward_func(environmnets, **kwargs)-> List[float]:
41
+ return [env.reward for env in environmnets]
42
+
43
+ def build_dataset(repeats_per_task:int=32)-> Dataset:
44
+ prompts, task_ids = [], []
45
+ for task_id, task in TASKS.items():
46
+ for _ in range(repeats_per_task):
47
+ prompts.append([{"role": "user", "content": f"Optimize CUDA kernel task: {task['name']}"}])
48
+ task_ids.append(task_id)
49
+ return Dataset.from_dict({"prompt": prompts, "task_id": task_ids})
50
+
51
+ def main():
52
+ model_name =os.getenv("TRAIN_MODEL", "Qwen/Qwen3-0.6B")
53
+ dataset = build_dataset()
54
+ trainer = GRPOTrainer(
55
+ model=model_name,
56
+ train_dataset=dataset,
57
+ reward_funcs=reward_func,
58
+ environment_factory=KernelOptTool,
59
+ args=GRPOConfig(
60
+ chat_template_kwargs={"enable_thinking": False},
61
+ max_completion_length=2048,
62
+ num_generations=4,
63
+ log_completions=True,
64
+ ),
65
+ )
66
+ trainer.train()
67
+ # trainer = GRPOTrainer(model =model_name, train_dataset=dataset, reward_funcs =reward_func, env_factory=KernelOptTool)
68
+
69
+ if __name__ == "__main__":
70
+ main()
pyproject.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "kernel_writer"
7
+ version = "1.0.0"
8
+ description = "OpenEnv-compatible CUDA kernel optimization environment."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "fastapi>=0.110.0",
13
+ "uvicorn>=0.30.0",
14
+ "pydantic>=2.7.0",
15
+ "openai>=1.0.0",
16
+ "openenv-core>=0.2.0",
17
+ "python-dotenv>=1.0.0",
18
+ "gradio>=4.44.0",
19
+ "datasets>=2.20.0",
20
+ "trl>=0.12.0"
21
+ ]
22
+
23
+ [project.scripts]
24
+ server = "server.app:main"
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn>=0.30.0
3
+ pydantic>=2.7.0
4
+ openai>=1.0.0
5
+ openenv-core>=0.2.0
6
+ gradio>=4.44.0
7
+ datasets>=2.20.0
8
+ trl>=0.12.0
9
+ pytest>=7.4.0
10
+ pyyaml>=6.0.0
11
+ python-dotenv
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.10
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .app import app
server/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (173 Bytes). View file
 
server/__pycache__/app.cpython-312.pyc ADDED
Binary file (595 Bytes). View file
 
server/app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uvicorn
3
+ from env_server import app
4
+
5
+ def main():
6
+ uvicorn.run("server.app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
7
+
8
+
9
+ if __name__ == "__main__":
10
+ main()