| --- |
| base_model: |
| - Qwen/QwQ-32B |
| tags: |
| - code |
| --- |
| |
| # Model Summary |
| KernelCoder is trained on a curated dataset of reasoning traces and CUDA kernel pairs. |
|
|
| See details in [paper](https://lkongam.github.io/ConCuR/). |
|
|
| # Usage |
|
|
| ```python |
| from vllm import LLM, SamplingParams |
| from transformers import AutoTokenizer |
| import torch |
| import re |
| from typing import List, Tuple |
| from string import Template |
| PROMPT_TEMPLATE = Template(''' |
| ''') |
| |
| class KernelCoder: |
| |
| def __init__(self, model_name="lkongam/KernelCoder", tensor_parallel_size=1, gpu_memory_utilization=0.9): |
| |
| self.model_name = model_name |
| |
| self.llm = LLM( |
| model=model_name, |
| tensor_parallel_size=tensor_parallel_size, |
| gpu_memory_utilization=gpu_memory_utilization, |
| trust_remote_code=True, |
| dtype="auto" |
| ) |
| |
| self.tokenizer = self.llm.get_tokenizer() |
| self.device = torch.device("cuda") |
| |
| def generate_raw(self, prompt, temperature=1.0): |
| messages = [ |
| {"role": "user", "content": prompt} |
| ] |
| text = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=True |
| ) |
| return text |
| |
| def extract_last_code_block(text): |
| code_blocks = re.findall(r"```(?:python)?\n(.*?)```", text, re.DOTALL) |
| if code_blocks: |
| return code_blocks[-1].strip() |
| match = re.search(r"</think>(.*)", text, re.S) |
| after_think = match.group(1).strip() if match else text |
| if not after_think: |
| return None |
| import_match = re.search(r"\bimport\b", after_think) |
| if import_match: |
| return after_think[import_match.start():].strip() |
| return after_think.strip() |
|
|
| origin_code = """ |
| """ |
| |
| model = KernelCoder(model_name="lkongam/KernelCoder") |
|
|
| prompt = PROMPT_TEMPLATE.substitute(code=origin_code) |
| code_output = model.generate_raw(prompt) |
| code = extract_last_code_block(code_output) |
| print(code) |
| ``` |
| |
| # Evaluation |
|  |
| |
| Left: Pass@1, Right: Pass@10. |