| from typing import List, Optional, Dict, Any |
| from models import Action, StepResult, ResetRequest, StepRequest, EnvState, Observation, Reward |
| from fastapi import FastAPI, HTTPException |
| import random |
|
|
| TASKS: Dict[str, Dict[str, Any]] ={ |
| "vector_add_easy": { |
| "name": "Vector Addition Kernel Optimization", |
| "difficulty": "easy", |
| "grader": "deterministic_rule_based", |
| "max_steps": 5, |
| "target_speedup": 1.8, |
| "baseline_code": """extern "C" __global__ void vector_add(const float* a, const float* b, float* c, int n) |
| { |
| int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| if (idx < n) c[idx] = a[idx] + b[idx]; |
| }""", |
| "checks": { |
| "coalesced_memory": "Use memory-coalesced indexing", |
| "vectorized_loads": "Use vectorized loads/stores (float2/float4)", |
| "bounds_safe": "Keep safe boundary checks", |
| }, |
|
|
| }, |
| "matmul_medium": { |
| "name": "Matrix Multiplication Kernel Optimization", |
| "difficulty": "medium", |
| "grader": "deterministic_rule_based", |
| "max_steps": 6, |
| "target_speedup": 3.0, |
| "baseline_code": """extern "C" __global__ void matmul(const float* A, const float* B, float* C, int N) |
| { |
| int row = blockIdx.y * blockDim.y + threadIdx.y; |
| int col = blockIdx.x * blockDim.x + threadIdx.x; |
| if (row < N && col < N) { |
| float sum = 0.0f; |
| for (int k = 0; k < N; k++) sum += A[row * N + k] * B[k * N + col]; |
| C[row * N + col] = sum; |
| } |
| }""", |
| "checks": { |
| "shared_tiling": "Use shared-memory tiling", |
| "synchronization": "Synchronize tiles with __syncthreads", |
| "register_accumulation": "Accumulate partial sums in registers", |
| }, |
| }, |
| "reduction_hard": { |
| "name": "Reduction Kernel Optimization", |
| "difficulty": "hard", |
| "grader": "deterministic_rule_based", |
| "max_steps":7, |
| "target_speedup": 3.5, |
| "baseline_code": """extern "C" __global__ void reduce_sum(const float* input, float* output, int n) |
| { |
| extern __shared__ float sdata[]; |
| int tid = threadIdx.x; |
| int i = blockIdx.x * blockDim.x + threadIdx.x; |
| sdata[tid] = (i < n) ? input[i] : 0.0f; |
| __syncthreads(); |
| for (int s = blockDim.x / 2; s > 0; s >>= 1) { |
| if (tid < s) sdata[tid] += sdata[tid + s]; |
| __syncthreads(); |
| } |
| if (tid == 0) output[blockIdx.x] = sdata[0]; |
| }""", |
| "checks": { |
| "warp_primitive": "Use warp-level primitive (e.g., __shfl_down_sync)", |
| "bank_conflict_reduction": "Reduce shared-memory bank conflicts", |
| "unrolled_reduction": "Use partial unrolling for final reduction", |
| }, |
| } |
| } |
|
|
| def check_passed(check_id:str, code_lower:str) ->bool: |
| if check_id =="coalesced_memory": |
| return "idx" in code_lower and ("blockidx.x" in code_lower or "threadidx.x" in code_lower) |
| if check_id == "vectorized_loads": |
| return "float4" in code_lower or "float2" in code_lower |
| if check_id == "bounds_safe": |
| return "if" in code_lower and "< n" in code_lower |
| if check_id == "shared_tiling": |
| return "__shared__" in code_lower |
| if check_id == "synchronization": |
| return "__syncthreads" in code_lower |
| if check_id == "register_accumulation": |
| return "sum" in code_lower or "acc" in code_lower |
| if check_id == "warp_primitive": |
| return "__shfl_down_sync" in code_lower or "__shfl_sync" in code_lower |
| if check_id =="bank_conflict_reduction": |
| return "pad" in code_lower or "bank" in code_lower or "+ 1" in code_lower |
| if check_id == "unrolled_reduction": |
| return "#pragma unroll" in code_lower or "unroll" in code_lower |
| return False |
|
|
| def to_observation(task_id:str, state:EnvState)->Observation: |
| task = TASKS[task_id] |
| pending = [desc for cid, desc in task["checks"].items() if cid not in set(state.completed_checks)] |
| return Observation(task_id=task_id, task_name=task["name"], difficulty=task["difficulty"], baseline_code=task["baseline_code"], current_best_code=state.best_code or task["baseline_code"], current_best_speedup=state.best_speedup, step_count=state.step_count, max_steps=state.max_steps, pending_checks=pending, completed_checks=[task["checks"][cid] for cid in state.completed_checks if cid in task["checks"]], done=(len(pending) == 0 or state.step_count >= state.max_steps)) |
|
|
| def grade_episode(task_id:str, completed_checks:List[str], best_speedup:float, step_count:int, max_steps:int)->float: |
| task=TASKS[task_id] |
| completion =len(completed_checks) / max(len(task["checks"]),1) |
| speedup_score = min(best_speedup /task["target_speedup"],1.0) |
| efficiency = max(0.0, 1.0 - ((step_count - 1) / max(max_steps, 1))) |
| return round(max(0.0, min(1.0, 0.5 * completion + 0.35 * speedup_score + 0.15 * efficiency)), 4) |
|
|
| class KernelOptimization_env: |
| def __init__(self): |
| self.state =EnvState(initialized=False) |
| self.current_task_id: Optional[str]=None |
|
|
| def reset(self, task_id:Optional[str]=None)->Dict[str, Any]: |
| if task_id and task_id not in TASKS: |
| raise HTTPException(status_code=400, detail=f"unknown task_id: {task_id}") |
| self.current_task_id =task_id or random.choice(list(TASKS.keys())) |
| task= TASKS[self.current_task_id] |
| self.state =EnvState(initialized=True, task_id=self.current_task_id, step_count=0, max_steps=task["max_steps"], total_reward=0.0, best_code=task["baseline_code"], best_speedup=1.0, completed_checks=[], action_history=[]) |
| return { |
| "observation": to_observation(self.current_task_id, self.state).model_dump(), |
| "info": { |
| "task_id": self.current_task_id, |
| "task_name": task["name"], |
| "difficulty": task["difficulty"], |
| "grader": task["grader"], |
| "max_steps": task["max_steps"], |
| "target_speedup": task["target_speedup"], |
| "checks": task["checks"], |
| }, |
| } |
| |
| def step(self, action:Action) ->StepResult: |
| if not self.state.initialized or not self.current_task_id: |
| raise HTTPException(status_code=400, detail="Environment not initialized. Call /reset first.") |
|
|
| self.state.step_count += 1 |
| code = action.optimized_code or "" |
| code_lower = code.lower() |
| compile_ok = "__global__" in code_lower and "{" in code and "}" in code |
|
|
| completed = set(self.state.completed_checks) |
| newly_completed = {cid for cid in TASKS[self.current_task_id]["checks"] if cid not in completed and check_passed(cid, code_lower)} |
| completed.update(newly_completed) |
| self.state.completed_checks = sorted(completed) |
|
|
| completion_ratio = len(completed) / max(len(TASKS[self.current_task_id]["checks"]), 1) |
| max_reasonable_speedup = 1.0 + completion_ratio * 3.0 |
| if action.expected_speedup is None: |
| est_speedup = round(max_reasonable_speedup, 3) |
| else: |
| est_speedup = round(max(1.0, min(action.expected_speedup, max_reasonable_speedup)), 3) |
| if est_speedup > self.state.best_speedup: |
| self.state.best_speedup = est_speedup |
| self.state.best_code = code |
|
|
| progress = 0.22 * len(newly_completed) |
| quality = 0.18 * min(self.state.best_speedup / TASKS[self.current_task_id]["target_speedup"], 1.0) |
| penalty = 0.0 |
| if not compile_ok: |
| penalty -= 0.25 |
| if not newly_completed: |
| penalty -= 0.08 |
| reward_value = max(0.0, min(1.0, progress + quality + penalty)) |
| self.state.total_reward += reward_value |
|
|
| self.state.action_history.append( |
| { |
| "step": self.state.step_count, |
| "newly_completed": sorted(newly_completed), |
| "compile_ok": compile_ok, |
| "estimated_speedup": est_speedup, |
| "reward": reward_value, |
| } |
| ) |
|
|
| obs =to_observation(self.current_task_id, self.state) |
| info: Dict[str, Any] = { "compile_ok": compile_ok, "estimated_speedup": est_speedup} |
| if obs.done: |
| info["final_score"] = grade_episode( |
| self.current_task_id, self.state.completed_checks, self.state.best_speedup, self.state.step_count, self.state.max_steps |
| ) |
|
|
| return StepResult( |
| observation=obs, |
| reward=Reward( |
| value=round(reward_value, 4), |
| components={"progress": round(progress, 4), "quality": round(quality, 4), "penalty": round(penalty, 4)}, |
| ), |
| done=obs.done, |
| info=info, |
| ) |
| def state_dict(self)->Dict[str, Any]: |
| data = self.state.model_dump() |
| if self.current_task_id: |
| data["task_name"] = TASKS[self.current_task_id]["name"] |
| data["difficulty"] = TASKS[self.current_task_id]["difficulty"] |
| data["grader_score"] = grade_episode( |
| self.current_task_id, self.state.completed_checks, self.state.best_speedup, self.state.step_count, self.state.max_steps |
| ) |
| return data |
|
|
| env=KernelOptimization_env() |
| app=FastAPI(title="Kernel Optimization", version="1.0.0") |
|
|
| @app.get("/health") |
| def health_check(): |
| return {"status": "healthy", "service": "kernel-optimization-openenv"} |
|
|
| @app.post("/reset") |
| def reset(request: ResetRequest | None = None): |
| return env.reset(task_id=request.task_id if request else None) |
|
|
|
|
| @app.post("/step") |
| def step(request: StepRequest): |
| return env.step(request.action).model_dump() |
|
|
|
|
| @app.get("/state") |
| def state(): |
| return env.state_dict() |
|
|