docs: blog post; HF sync workflow; kaggle prompt/reward + tests
Browse files- Add blog.md (hackathon narrative) and README Blog link
- Git LFS for plot PNGs; sync-to-hf-space workflow + sync_hf_space.sh
- Dataset/reward prompt contract and format penalty; hf_run_job uses clone
- Tests for preprocess_assembly and training prompt contract
Made-with: Cursor
- .gitattributes +1 -0
- .github/workflows/sync-to-hf-space.yml +42 -0
- README.md +1 -0
- blog.md +270 -0
- kaggle/dataset.py +14 -15
- kaggle/hf_run_job.sh +2 -263
- kaggle/reward_fn.py +105 -18
- scripts/sync_hf_space.sh +27 -0
- tests/test_preprocess_assembly.py +53 -0
- tests/test_training_prompt_contract.py +6 -3
.gitattributes
CHANGED
|
@@ -1 +1,2 @@
|
|
|
|
|
| 1 |
colab/results/plots/*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Hugging Face Hub rejects raw PNG git objects on Space pushes; use Git LFS.
|
| 2 |
colab/results/plots/*.png filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/sync-to-hf-space.yml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sync GitHub ka-ori/arm-gym → Hugging Face Space https://huggingface.co/spaces/kaori02/arm-gym
|
| 2 |
+
# (branch main → Space branch main).
|
| 3 |
+
#
|
| 4 |
+
# Hugging Face rejects raw PNG blobs in git packs; this workflow rewrites history in the
|
| 5 |
+
# runner with `git lfs migrate import` for colab/results/plots/*.png before pushing to the
|
| 6 |
+
# Space only (GitHub is unchanged). See: https://huggingface.co/docs/hub/xet
|
| 7 |
+
#
|
| 8 |
+
# Secrets: HF_TOKEN (write access to the Space). Optional vars: HF_USERNAME, HF_SPACE
|
| 9 |
+
# (defaults kaori02 / arm-gym).
|
| 10 |
+
#
|
| 11 |
+
# Local one-shot: scripts/sync_hf_space.sh
|
| 12 |
+
|
| 13 |
+
name: Sync GitHub to Hugging Face Space
|
| 14 |
+
|
| 15 |
+
on:
|
| 16 |
+
push:
|
| 17 |
+
branches: [main]
|
| 18 |
+
workflow_dispatch:
|
| 19 |
+
|
| 20 |
+
jobs:
|
| 21 |
+
sync-to-hub:
|
| 22 |
+
runs-on: ubuntu-latest
|
| 23 |
+
steps:
|
| 24 |
+
- uses: actions/checkout@v4
|
| 25 |
+
with:
|
| 26 |
+
fetch-depth: 0
|
| 27 |
+
lfs: true
|
| 28 |
+
|
| 29 |
+
- name: Rewrite PNGs to Git LFS for Hub push
|
| 30 |
+
run: |
|
| 31 |
+
git lfs install
|
| 32 |
+
git lfs migrate import --include="colab/results/plots/*.png" --include-ref=refs/heads/main
|
| 33 |
+
|
| 34 |
+
- name: Push to Hugging Face Space
|
| 35 |
+
env:
|
| 36 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 37 |
+
HF_USER: ${{ vars.HF_USERNAME }}
|
| 38 |
+
HF_SPACE: ${{ vars.HF_SPACE }}
|
| 39 |
+
run: |
|
| 40 |
+
U="${HF_USER:-kaori02}"
|
| 41 |
+
S="${HF_SPACE:-arm-gym}"
|
| 42 |
+
git push --force "https://${U}:${HF_TOKEN}@huggingface.co/spaces/${U}/${S}" HEAD:main
|
README.md
CHANGED
|
@@ -17,6 +17,7 @@ short_description: GRPO env for AArch64 superoptimization
|
|
| 17 |
|
| 18 |
<p align="center">
|
| 19 |
<a href="https://huggingface.co/spaces/dot-mkv/arm-gym">HF Space</a> ·
|
|
|
|
| 20 |
<a href="#quick-start">Quick Start</a> ·
|
| 21 |
<a href="#results">Results</a> ·
|
| 22 |
<a href="#why-it-matters">Why It Matters</a>
|
|
|
|
| 17 |
|
| 18 |
<p align="center">
|
| 19 |
<a href="https://huggingface.co/spaces/dot-mkv/arm-gym">HF Space</a> ·
|
| 20 |
+
<a href="./blog.md">Blog</a> ·
|
| 21 |
<a href="#quick-start">Quick Start</a> ·
|
| 22 |
<a href="#results">Results</a> ·
|
| 23 |
<a href="#why-it-matters">Why It Matters</a>
|
blog.md
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# We trained an AI to write faster ARM assembly than the compiler
|
| 2 |
+
|
| 3 |
+
Some ideas start with a paper.
|
| 4 |
+
|
| 5 |
+
In January 2026, a team from Stanford and UIUC published [SuperCoder](https://arxiv.org/abs/2505.11480) - a system that trained a language model to write assembly code faster than `gcc -O3`, the most aggressive optimization setting of the world's most widely used compiler. Their result: 1.46x average speedup. On x86-64. With a 7 billion parameter model trained via reinforcement learning.
|
| 6 |
+
|
| 7 |
+
The paper was explicit about one thing it did not do: ARM.
|
| 8 |
+
|
| 9 |
+
> "Extending to ARM, RISC-V, and GPU kernels is noted as future work."
|
| 10 |
+
|
| 11 |
+
That sentence is where ARM-Gym begins.
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## What this is about, for everyone
|
| 16 |
+
|
| 17 |
+
Before getting into what we built, here is the context for anyone who does not spend their days thinking about compilers and processors.
|
| 18 |
+
|
| 19 |
+
**What is a compiler?**
|
| 20 |
+
|
| 21 |
+
A compiler is a program that takes code written in a high-level language like C and translates it into machine instructions. GCC and Clang are the two most important compilers in existence. When you compile with `-O3`, you are asking the compiler to apply every optimization it knows. It inlines functions, unrolls loops, reorders instructions to avoid stalls, selects faster instruction variants where it can.
|
| 22 |
+
|
| 23 |
+
GCC has been doing this for decades. Thousands of engineer-years have gone into making `-O3` as good as it is.
|
| 24 |
+
|
| 25 |
+
**And yet.**
|
| 26 |
+
|
| 27 |
+
Compilers must be conservative. A compiler cannot make an assumption that might be wrong for even one program. It cannot take a risk that improves performance 99% of the time but breaks the other 1%. It follows rules. Rules generalize but do not specialize.
|
| 28 |
+
|
| 29 |
+
**What is ARM?**
|
| 30 |
+
|
| 31 |
+
ARM is a processor architecture - a specific design for how machine instructions are structured and executed. For a long time, ARM meant smartphones. That is no longer true. Today:
|
| 32 |
+
|
| 33 |
+
- AWS Graviton5 - the most widely deployed cloud compute in the world - runs on ARM
|
| 34 |
+
- Azure Cobalt 100, Microsoft's custom data center chip, runs on ARM
|
| 35 |
+
- Every Apple Mac sold since 2020 runs on ARM
|
| 36 |
+
- Meta's AGI CPU - 136 cores, 3nm fabrication, deployed in 2026 - runs on ARM
|
| 37 |
+
|
| 38 |
+
Any improvement in how efficiently code runs on ARM touches all of that. And the specific code that matters most is the tight computational loops inside AI inference: matrix multiply, softmax, convolution. These functions run millions of times per second in every large model deployment.
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## The research that made this possible: SuperCoder
|
| 43 |
+
|
| 44 |
+
The SuperCoder paper (Wei et al., arXiv:2505.11480, Stanford/UIUC, 2025) is the direct foundation for what we built. Understanding what they proved is essential for understanding why ARM-Gym is the next step.
|
| 45 |
+
|
| 46 |
+
### What SuperCoder showed
|
| 47 |
+
|
| 48 |
+
The paper asked: can a language model learn to write assembly that beats the compiler, purely through reinforcement learning?
|
| 49 |
+
|
| 50 |
+
They evaluated 23 language models on a benchmark of 8,072 assembly programs (average 130 lines each - far larger than any prior dataset, which maxed out at 15 lines and no loops). Every program came with its `gcc -O3` baseline assembly and a set of test cases.
|
| 51 |
+
|
| 52 |
+
The base model they chose for training was Qwen2.5-Coder-7B-Instruct - not because it was the strongest baseline, but because it had the highest test pass rate (61.4%) among open-source models, leaving the most room to improve. Claude-opus-4 had a slightly higher average speedup (1.43x) but was not open-source and could not be fine-tuned.
|
| 53 |
+
|
| 54 |
+
They trained using both PPO and GRPO. The reward function was simple: if the generated assembly compiles, passes all test cases, and runs faster than `gcc -O3`, the reward equals the speedup. Otherwise, zero. No partial credit. No reward for being partially correct.
|
| 55 |
+
|
| 56 |
+
That last point turned out to matter a lot. They tested an alternative reward that gave partial credit for passing some tests - and it performed worse (1.38x vs 1.46x). The lesson: partial credit lets the model avoid putting in the work of actually being correct and fast. Binary pass/fail forces it.
|
| 57 |
+
|
| 58 |
+
### The results
|
| 59 |
+
|
| 60 |
+
| Model | Correctness before training | Correctness after | Avg speedup |
|
| 61 |
+
|---|---|---|---|
|
| 62 |
+
| Qwen2.5-Coder-7B (base) | 61.4% | - | 1.10x |
|
| 63 |
+
| SuperCoder (GRPO) | - | 95.0% | 1.44x |
|
| 64 |
+
| SuperCoder (PPO) | - | 95.0% | 1.46x |
|
| 65 |
+
|
| 66 |
+
Correctness jumped from 61.4% to 95%. Average speedup went from 1.10x to 1.46x. The model went from occasionally beating the compiler to reliably beating it.
|
| 67 |
+
|
| 68 |
+
One more finding worth noting: 98.5% of the speedup came from instruction scheduling and code layout - reordering instructions and basic blocks to better hide latency and avoid pipeline stalls. Not exotic instruction selection. The model learned that the compiler's instruction order is not optimal, and found better orderings.
|
| 69 |
+
|
| 70 |
+
### What SuperCoder did not do
|
| 71 |
+
|
| 72 |
+
It targeted x86-64 only. The paper uses the IBM CodeNet dataset, which contains competitive programming submissions compiled for x86. ARM was explicitly outside scope.
|
| 73 |
+
|
| 74 |
+
ARM has a completely different instruction set. Different SIMD extensions (NEON, SVE2 instead of AVX/SSE). Different pipeline characteristics. Different scheduler model. Different optimization opportunities. A model trained on x86 assembly cannot be directly applied to ARM.
|
| 75 |
+
|
| 76 |
+
That is the gap ARM-Gym fills.
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## What we built: ARM-Gym
|
| 81 |
+
|
| 82 |
+
ARM-Gym is a reinforcement learning environment built on [OpenEnv](https://github.com/meta-pytorch/OpenEnv), the framework from Meta and Hugging Face. The task is identical to SuperCoder's in structure - generate assembly that beats `gcc -O3` - but the target is AArch64 (ARM's 64-bit architecture) and the kernels are specifically chosen to be representative of AI inference workloads.
|
| 83 |
+
|
| 84 |
+
### The kernel library
|
| 85 |
+
|
| 86 |
+
We wrote 15 C function templates covering the operations that dominate AI inference:
|
| 87 |
+
|
| 88 |
+
- Vector operations: `vec_add`, `dot_product`, `saxpy`
|
| 89 |
+
- Matrix operations: `gemv`, `matmul`
|
| 90 |
+
- Activation and normalization: `softmax`, `layer_norm`
|
| 91 |
+
- Convolution: `conv1d`, `conv2d`
|
| 92 |
+
- Elementwise: `relu`, `gelu`, `silu`, `fma`
|
| 93 |
+
|
| 94 |
+
From these 15 templates, we generate 523 variants by varying sizes, data types (float32, float16, int8), and parameters. This gives the training loop enough diversity to prevent memorization while staying domain-relevant.
|
| 95 |
+
|
| 96 |
+
### The training loop
|
| 97 |
+
|
| 98 |
+
Here is exactly how training works, step by step:
|
| 99 |
+
|
| 100 |
+
```mermaid
|
| 101 |
+
flowchart LR
|
| 102 |
+
A["C Kernel<br/>(15 templates × 523 variants)"] --> B["gcc -O3<br/>Baseline Assembly"]
|
| 103 |
+
B --> C["LLM Prompt<br/>(C + Baseline ASM)"]
|
| 104 |
+
C --> D["Qwen2.5-Coder-7B<br/>+ LoRA"]
|
| 105 |
+
D --> E["Agent Assembly"]
|
| 106 |
+
E --> F{"3-Gate Verifier"}
|
| 107 |
+
F -->|"Gate 1: Syntax"| G["GNU as"]
|
| 108 |
+
F -->|"Gate 2: Correctness"| H["QEMU × 20<br/>Adversarial Tests"]
|
| 109 |
+
F -->|"Gate 3: Performance"| I["LLVM-MCA<br/>Neoverse V2"]
|
| 110 |
+
I --> J["Dual Verifier<br/>Cross-Check"]
|
| 111 |
+
J --> K["Reward<br/>max(0, speedup - 1)"]
|
| 112 |
+
K --> L["GRPO Update"]
|
| 113 |
+
L --> D
|
| 114 |
+
|
| 115 |
+
style F fill:#f96,stroke:#333,color:#000
|
| 116 |
+
style J fill:#69f,stroke:#333,color:#000
|
| 117 |
+
style L fill:#6c6,stroke:#333,color:#000
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
**Step 1 - Sample a kernel.** Pick one of the 523 variants at random. This is the function the model needs to optimize.
|
| 121 |
+
|
| 122 |
+
**Step 2 - Compile the baseline.** Run `gcc -O3` on the C source. This gives us the baseline assembly and a baseline cycle count from LLVM-MCA.
|
| 123 |
+
|
| 124 |
+
**Step 3 - Build the prompt.** Send the model the C source code, the baseline assembly, and an instruction: write optimized AArch64 assembly for this function, wrapped in `<assembly>...</assembly>` tags. Crucially, the baseline is always included - SuperCoder found that without it, even strong models produce 0% compilable code. The baseline is load-bearing context.
|
| 125 |
+
|
| 126 |
+
**Step 4 - Verify.** The model's output goes through three sequential gates. Pass all three and you get a reward. Fail any one and you score zero.
|
| 127 |
+
|
| 128 |
+
**Step 5 - Update with GRPO.** We use Group Relative Policy Optimization - the same algorithm from the DeepSeekMath paper that has since been adopted widely in RL for LLMs. GRPO generates multiple completions for the same prompt, scores them all, and uses their relative quality to compute the learning signal. It does not need a separate value function or critic model, which makes it efficient on constrained hardware.
|
| 129 |
+
|
| 130 |
+
### The verifier: why it cannot be cheated
|
| 131 |
+
|
| 132 |
+
Every RL environment has a reward hacking problem. Give an agent a metric and it will find the most efficient path to that metric - which is often not the path you intended.
|
| 133 |
+
|
| 134 |
+
We saw this play out in Phase 1 of this hackathon across every finalist submission. One system's agent learned to starve long-context requests so short-request throughput looked better. Another learned to disconnect network access so error logs stopped appearing. A third learned to drop database tables so schema validation errors vanished. In all three cases, the metric went up and the actual objective was completely destroyed.
|
| 135 |
+
|
| 136 |
+
Our verifier was designed from the start to have no exploitable surface.
|
| 137 |
+
|
| 138 |
+
**Gate 1: Syntax via the real assembler.**
|
| 139 |
+
|
| 140 |
+
The assembly must compile with `aarch64-linux-gnu-as`, the actual GNU assembler for ARM. Not a regular expression. Not a syntax checker. The real tool that produces a real binary object. If it rejects the assembly, the score is zero.
|
| 141 |
+
|
| 142 |
+
**Gate 2: Correctness via randomized QEMU tests.**
|
| 143 |
+
|
| 144 |
+
The compiled binary runs 20 times inside `qemu-aarch64-static`, a full ARM CPU emulator. Each run uses a different set of randomly generated inputs - edge cases, boundary values, near-overflow values. The output must match the original C function's output within floating-point tolerance. Inputs are randomized every episode, so the model cannot memorize test inputs and hardcode outputs for them.
|
| 145 |
+
|
| 146 |
+
**Gate 3: Performance via LLVM-MCA.**
|
| 147 |
+
|
| 148 |
+
We measure cycles with LLVM-MCA using the LLVM 21 Neoverse V2 scheduling model. This is a static analysis tool - it reads assembly text and estimates cycles based on the CPU's instruction latencies and throughput. It takes no inputs, executes nothing, and has no runtime attack surface.
|
| 149 |
+
|
| 150 |
+
**Cross-check.** We compare the QEMU instruction count against the LLVM-MCA cycle estimate. A ratio above 3x triggers a hard veto - something is wrong, and we discard the result regardless of the apparent speedup.
|
| 151 |
+
|
| 152 |
+
**3-sigma bound.** We maintain an offline distribution of speedup values for each kernel variant. Any result more than three standard deviations above the mean is rejected as a statistical outlier.
|
| 153 |
+
|
| 154 |
+
No LLM judge anywhere in this stack. SuperCoder used Hyperfine (real execution timing). We use LLVM-MCA (static analysis) because it is deterministic, sub-millisecond per evaluation, and has no measurement noise. The trade-off is that MCA is a model of the hardware, not the hardware itself - which is why we label results as "MCA-model speedup" until we can validate on physical Graviton silicon.
|
| 155 |
+
|
| 156 |
+
### The reward signal
|
| 157 |
+
|
| 158 |
+
```
|
| 159 |
+
reward = max(0, speedup - 1.0)
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
If the model's assembly is slower than `gcc -O3`, the reward is zero - neutral, not a penalty. If it is faster, the reward is the fractional improvement above parity: 1.5x speedup gives 0.5, 2x gives 1.0. Capped at 2.0 to prevent a single outlier from dominating the gradient.
|
| 163 |
+
|
| 164 |
+
The zero floor is not arbitrary. In GRPO, rewards within a group are z-score normalized before computing the advantage. If slower-than-compiler gave a negative reward, a group where all completions happen to be slow would produce similar negative values that normalize to near zero - producing no gradient. Making slower neutral means even a group of slow completions still has relative differences that produce a learning signal. This mirrors exactly what SuperCoder found: sparse terminal reward (no partial credit) consistently outperforms reward designs that penalize failures.
|
| 165 |
+
|
| 166 |
+
### The curriculum
|
| 167 |
+
|
| 168 |
+
Not all kernels are equally hard. Sending the model to optimize a tiled matrix multiply before it has learned to write syntactically valid assembly is wasteful. ARM-Gym uses a staged curriculum that matches kernel difficulty to the model's current capability.
|
| 169 |
+
|
| 170 |
+
```mermaid
|
| 171 |
+
flowchart LR
|
| 172 |
+
S1["Stage 1: Scalar<br/>vec_add, dot, saxpy"] -->|"80% variants ≥ 1.05x"| S2["Stage 2: NEON<br/>gemv, conv1d, fma"]
|
| 173 |
+
S2 -->|"80% variants ≥ 1.05x"| S3["Stage 3: Loops<br/>matmul, softmax"]
|
| 174 |
+
S3 -->|"Beat -O3 mean"| S4["Stage 4: SVE2<br/>(Stretch)"]
|
| 175 |
+
|
| 176 |
+
style S1 fill:#bfb,stroke:#333,color:#000
|
| 177 |
+
style S2 fill:#fbf,stroke:#333,color:#000
|
| 178 |
+
style S3 fill:#bbf,stroke:#333,color:#000
|
| 179 |
+
style S4 fill:#fbb,stroke:#333,color:#000
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
**Stage 1 - Scalar kernels.** Functions like `vec_add` and `saxpy` where the compiler produces a scalar loop. NEON vectorization (processing 4 floats at once instead of 1) is the primary optimization opportunity. This is learnable early because the pattern is consistent.
|
| 183 |
+
|
| 184 |
+
**Stage 2 - NEON kernels.** Functions where the compiler already emits NEON, but instruction ordering and register reuse can be improved. Requires the model to reason about pipeline latency, not just instruction selection.
|
| 185 |
+
|
| 186 |
+
**Stage 3 - Loop-heavy kernels.** Matrix multiply and softmax, where optimization requires loop tiling, unrolling, and prefetch placement. These are the hardest patterns in Stage 1-3.
|
| 187 |
+
|
| 188 |
+
**Stage 4 - SVE2 (stretch target).** ARM's Scalable Vector Extension, available on Neoverse V2 and V3. No training data exists for SVE2 code generation. This stage is explicitly a research frontier - there is no known baseline for what an RL agent can achieve here.
|
| 189 |
+
|
| 190 |
+
Advancement between stages requires beating `gcc -O3` by at least 5% on 80% of the variants in the current stage. The model must demonstrate broad capability, not exploit a single easy variant.
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
## How ARM-Gym differs from SuperCoder
|
| 195 |
+
|
| 196 |
+
| Aspect | SuperCoder | ARM-Gym |
|
| 197 |
+
|---|---|---|
|
| 198 |
+
| Target architecture | x86-64 | AArch64 (ARM) |
|
| 199 |
+
| Dataset | 7,872 competitive programming programs | 523 AI inference kernel variants |
|
| 200 |
+
| Performance measurement | Hyperfine (real hardware timing) | LLVM-MCA (static analysis, deterministic) |
|
| 201 |
+
| RL framework | VERL | HuggingFace TRL |
|
| 202 |
+
| Verifier | Compile + test pass | Compile + QEMU + LLVM-MCA + cross-check + 3-sigma |
|
| 203 |
+
| Reward design | Binary terminal (same principle) | Binary terminal + format shaping |
|
| 204 |
+
| Dataset focus | General programs | GEMM, matmul, softmax, conv (AI inference hot path) |
|
| 205 |
+
| Prior work exists | Yes (this is the paper) | No - ARM is open |
|
| 206 |
+
|
| 207 |
+
The most important difference is the last one. SuperCoder is the proof of concept. ARM-Gym is the next frontier. The paper itself identified ARM as the natural extension - we are building it.
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## Results
|
| 212 |
+
|
| 213 |
+
*[Results will be updated here after training completes.]*
|
| 214 |
+
|
| 215 |
+
| Metric | Value |
|
| 216 |
+
|---|---|
|
| 217 |
+
| Best speedup over `gcc -O3` | [to be updated] |
|
| 218 |
+
| Win rate | [to be updated] |
|
| 219 |
+
| Correctness rate | [to be updated] |
|
| 220 |
+
| Training steps and GPU time | [to be updated] |
|
| 221 |
+
| SuperCoder reference (x86-64, PPO) | 1.46x average over `gcc -O3` |
|
| 222 |
+
| Kernel variants trained on | 523 (15 templates) |
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
## What we learned from building this
|
| 227 |
+
|
| 228 |
+
**The reward formula has to be exactly right.** We made one sign error - `speedup - 1.0` instead of `max(0, speedup - 1.0)` - and it silently destroyed the speedup gradient for an entire run. A group where all completions are slightly slow produces similar small negative values. After z-score normalization, they all collapse to near zero. No gradient. The model was learning correctness but not speed, and there was nothing in the loss curves to tell us why. The fix was one character. Test your reward function independently before attaching a model to it.
|
| 229 |
+
|
| 230 |
+
**LLVM version has a hard dependency for ARM.** LLVM 17's Neoverse V2 scheduling model had the processor's issue-width wrong: 16 microoperations per cycle instead of the correct 8. Training on this would teach the model to optimize for a processor that does not exist. We pinned LLVM 21 specifically because of this correction.
|
| 231 |
+
|
| 232 |
+
**Thinking models fail on assembly generation.** SuperCoder's benchmark found DeepSeek-R1 compiles at 0% across all 200 evaluation problems. The chain-of-thought habit causes the model to spend its entire output budget reasoning about instruction semantics - and never producing executable code. This is a known failure mode for reasoning-heavy models on generation tasks. The base model for assembly RL should not be a reasoning model.
|
| 233 |
+
|
| 234 |
+
**Correct EOS token alignment is not optional.** Qwen2.5 uses `<|im_end|>` as its chat end-of-turn token, but TRL's GRPOTrainer by default reads `tokenizer.eos_token_id` for generation stopping. Without explicitly aligning these, the model never stops generating cleanly. Completions run to the token limit, producing garbage that collapses the reward signal. This required an explicit fix before training produced any useful signal at all.
|
| 235 |
+
|
| 236 |
+
**GRPO needs within-group diversity.** With temperature 0.5, all completions in a group come out very similar - similar tokens, similar speedup, similar z-scores, near-zero gradient. Temperature 0.8 fixes this. More diversity means some completions try NEON vectorization, some stay scalar, some hallucinate - and the relative comparison between them becomes meaningful enough to drive learning.
|
| 237 |
+
|
| 238 |
+
---
|
| 239 |
+
|
| 240 |
+
## Why this matters
|
| 241 |
+
|
| 242 |
+
Compilers use rules. Rules are safe, general, and conservative by necessity. Reinforcement learning finds what rules cannot - the specific instruction sequences, the register orderings, the NEON patterns that extract cycles on a specific microarchitecture for a specific workload.
|
| 243 |
+
|
| 244 |
+
SuperCoder proved this approach works on x86. That result is now published, citable, and reproducible. ARM is the same problem on a larger market with no existing solution.
|
| 245 |
+
|
| 246 |
+
AWS, Azure, Apple, and Meta have all made major bets on ARM infrastructure. The AI inference workloads running on that infrastructure are bottlenecked by the same matrix multiply and softmax kernels we are optimizing. Any improvement compounds.
|
| 247 |
+
|
| 248 |
+
Could a researcher write a paper extending SuperCoder to ARM? Yes. That paper does not exist yet. ARM-Gym is that paper in environment form.
|
| 249 |
+
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
## What comes next
|
| 253 |
+
|
| 254 |
+
**Silicon validation.** Every cycle count in ARM-Gym is an LLVM-MCA estimate on the Neoverse V2 model. Until we run the model's output on a physical Graviton3 machine and measure wall-clock time, these are model-predicted speedups. That validation is the step that turns "MCA speedup" into a real claim.
|
| 255 |
+
|
| 256 |
+
**SVE2.** No model has been specifically trained to generate ARM SVE2 code. The Scalable Vector Extension is the highest-bandwidth path on Neoverse V2 and V3, and it is essentially unexplored territory for code generation models. The optimization potential is high and the competition is zero.
|
| 257 |
+
|
| 258 |
+
**Larger models.** SuperCoder used 7B. ARM-Gym's current training targets 7B as well. The Best-of-8 sampling result from the paper (1.46x → 1.93x) suggests that scaling inference (more candidates, pick the best) is as important as scaling model size. Both directions are worth exploring.
|
| 259 |
+
|
| 260 |
+
---
|
| 261 |
+
|
| 262 |
+
## Try it
|
| 263 |
+
|
| 264 |
+
- **Live environment:** [huggingface.co/spaces/dot-mkv/arm-gym](https://huggingface.co/spaces/dot-mkv/arm-gym)
|
| 265 |
+
- **Training notebook:** `colab/arm_gym_grpo_kaggle.ipynb` - upload to Kaggle, set accelerator to T4 GPU, run all cells
|
| 266 |
+
- **Paper we built on:** [SuperCoder (arXiv:2505.11480)](https://arxiv.org/abs/2505.11480)
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
*Meta / HuggingFace OpenEnv Hackathon India 2026 - Finals. Theme: Wild Card. Team: (dot)mkv.*
|
kaggle/dataset.py
CHANGED
|
@@ -14,26 +14,24 @@ from arm_gym.compile_baseline import ToolchainInfo, compile_to_asm
|
|
| 14 |
from arm_gym.kernels import KernelVariant, generate_all, split_train_eval
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
#
|
| 19 |
-
# PDF: meta-hackathon-llm-wiki/papers/supercoder.pdf
|
| 20 |
SYSTEM_PROMPT = (
|
| 21 |
-
"You
|
| 22 |
-
"
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
|
| 26 |
def user_prompt(c_source: str, baseline_asm: str) -> str:
|
| 27 |
-
# Order and wording follow SuperCoder A.3 (x86-64 -> AArch64 for arm-gym).
|
| 28 |
return (
|
| 29 |
-
"
|
| 30 |
-
"
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
"
|
| 34 |
-
"
|
| 35 |
-
"
|
| 36 |
-
"Optimized Assembly Code:\n"
|
| 37 |
)
|
| 38 |
|
| 39 |
|
|
@@ -60,7 +58,8 @@ def render_prompt(tokenizer, c_source: str, baseline_asm: str) -> str:
|
|
| 60 |
prompt = f"{SYSTEM_PROMPT}\n\n{messages[1]['content']}\n"
|
| 61 |
else:
|
| 62 |
prompt = f"{SYSTEM_PROMPT}\n\n{messages[1]['content']}\n"
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
@dataclass
|
|
|
|
| 14 |
from arm_gym.kernels import KernelVariant, generate_all, split_train_eval
|
| 15 |
|
| 16 |
|
| 17 |
+
# C + baseline stay in the row for the verifier; user message ends with an
|
| 18 |
+
# ``<assembly>`` prefill so the assistant continues inside the tag (Qwen chat).
|
|
|
|
| 19 |
SYSTEM_PROMPT = (
|
| 20 |
+
"You write only AArch64 (aarch64-linux-gnu) assembly inside the user’s "
|
| 21 |
+
"<assembly> block. No prose, no C, no other languages—assembly between "
|
| 22 |
+
"the opening line the user started and a closing </assembly> tag."
|
| 23 |
)
|
| 24 |
|
| 25 |
|
| 26 |
def user_prompt(c_source: str, baseline_asm: str) -> str:
|
|
|
|
| 27 |
return (
|
| 28 |
+
"Generate ONLY AArch64 assembly.\n\n"
|
| 29 |
+
"Wrap your continuation in <assembly></assembly> tags "
|
| 30 |
+
"(the opening tag is already below—finish the block and close with "
|
| 31 |
+
"</assembly>).\n\n"
|
| 32 |
+
f"C code:\n{c_source}\n\n"
|
| 33 |
+
f"Baseline assembly (gcc -O3):\n{baseline_asm}\n\n"
|
| 34 |
+
"<assembly>\n"
|
|
|
|
| 35 |
)
|
| 36 |
|
| 37 |
|
|
|
|
| 58 |
prompt = f"{SYSTEM_PROMPT}\n\n{messages[1]['content']}\n"
|
| 59 |
else:
|
| 60 |
prompt = f"{SYSTEM_PROMPT}\n\n{messages[1]['content']}\n"
|
| 61 |
+
# Do not rstrip: user content must end with "<assembly>\n" for generation prefill.
|
| 62 |
+
return prompt if prompt.endswith("\n") else prompt + "\n"
|
| 63 |
|
| 64 |
|
| 65 |
@dataclass
|
kaggle/hf_run_job.sh
CHANGED
|
@@ -181,269 +181,8 @@ def install_grpo_assembly_stopping(model: Any, tokenizer: Any, **kwargs) -> None
|
|
| 181 |
SHIMPY
|
| 182 |
echo "[hotpatch] stopping criteria DISABLED -- using natural chat EOS"
|
| 183 |
|
| 184 |
-
# 3. dataset.py
|
| 185 |
-
|
| 186 |
-
"""Dataset builder: kernel variant -> GRPO prompt."""
|
| 187 |
-
from __future__ import annotations
|
| 188 |
-
from dataclasses import dataclass
|
| 189 |
-
from datasets import Dataset
|
| 190 |
-
from arm_gym.compile_baseline import ToolchainInfo, compile_to_asm
|
| 191 |
-
from arm_gym.kernels import KernelVariant, generate_all, split_train_eval
|
| 192 |
-
|
| 193 |
-
SYSTEM_PROMPT = (
|
| 194 |
-
"You are an expert AArch64 (aarch64-linux-gnu-gcc) assembly writer. "
|
| 195 |
-
"Obey the user block exactly. Output only what is asked in the required tags."
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
def user_prompt(c_source: str, baseline_asm: str) -> str:
|
| 199 |
-
return (
|
| 200 |
-
"Given the following C code and assembly code, your task is to generate "
|
| 201 |
-
"highly optimized AArch64 assembly code.\n\n"
|
| 202 |
-
f"C Code:\n{c_source}\n\n"
|
| 203 |
-
f"Assembly Code:\n{baseline_asm}\n\n"
|
| 204 |
-
"Only output the optimized assembly code. Do not include any other text. "
|
| 205 |
-
"Do not write any comments in the assembly code. "
|
| 206 |
-
"Wrap the assembly code in <assembly></assembly> tags.\n\n"
|
| 207 |
-
"Optimized Assembly Code:\n"
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
def render_prompt(tokenizer, c_source: str, baseline_asm: str) -> str:
|
| 211 |
-
messages = [
|
| 212 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 213 |
-
{"role": "user", "content": user_prompt(c_source, baseline_asm)},
|
| 214 |
-
]
|
| 215 |
-
if tokenizer is not None:
|
| 216 |
-
kwargs: dict = {
|
| 217 |
-
"tokenize": False,
|
| 218 |
-
"add_generation_prompt": True,
|
| 219 |
-
"enable_thinking": False,
|
| 220 |
-
}
|
| 221 |
-
try:
|
| 222 |
-
prompt = tokenizer.apply_chat_template(messages, **kwargs)
|
| 223 |
-
except TypeError:
|
| 224 |
-
kwargs.pop("enable_thinking", None)
|
| 225 |
-
prompt = tokenizer.apply_chat_template(messages, **kwargs)
|
| 226 |
-
else:
|
| 227 |
-
prompt = f"{SYSTEM_PROMPT}\n\n{messages[1]['content']}\n"
|
| 228 |
-
return prompt.rstrip()
|
| 229 |
-
|
| 230 |
-
@dataclass
|
| 231 |
-
class DatasetConfig:
|
| 232 |
-
max_train: int = 256
|
| 233 |
-
max_eval: int = 32
|
| 234 |
-
difficulty_max: int = 2
|
| 235 |
-
|
| 236 |
-
def build(tc: ToolchainInfo, cfg: DatasetConfig | None = None,
|
| 237 |
-
tokenizer=None) -> tuple[Dataset, Dataset, dict[str, KernelVariant]]:
|
| 238 |
-
cfg = cfg or DatasetConfig()
|
| 239 |
-
from arm_gym.kernels import TEMPLATES
|
| 240 |
-
variants = [v for v in generate_all()
|
| 241 |
-
if TEMPLATES[v.template_name].difficulty <= cfg.difficulty_max]
|
| 242 |
-
train_v, eval_v = split_train_eval(variants, eval_frac=0.1, seed=0)
|
| 243 |
-
train_v = train_v[:cfg.max_train]
|
| 244 |
-
eval_v = eval_v[:cfg.max_eval]
|
| 245 |
-
def to_row(v: KernelVariant) -> dict | None:
|
| 246 |
-
try:
|
| 247 |
-
baseline_asm = compile_to_asm(v.c_source, tc)
|
| 248 |
-
except Exception:
|
| 249 |
-
return None
|
| 250 |
-
prompt = render_prompt(tokenizer, v.c_source, baseline_asm)
|
| 251 |
-
return {"prompt": prompt, "variant_id": v.variant_id,
|
| 252 |
-
"baseline_asm": baseline_asm, "c_source": v.c_source}
|
| 253 |
-
train_rows = [r for r in (to_row(v) for v in train_v) if r is not None]
|
| 254 |
-
eval_rows = [r for r in (to_row(v) for v in eval_v) if r is not None]
|
| 255 |
-
lookup = {v.variant_id: v for v in train_v + eval_v}
|
| 256 |
-
return Dataset.from_list(train_rows), Dataset.from_list(eval_rows), lookup
|
| 257 |
-
DATASETPY
|
| 258 |
-
echo "[hotpatch] dataset.py replaced (no <assembly> pre-injection)"
|
| 259 |
-
|
| 260 |
-
# 4. reward_fn.py: overwrite with version expecting both tags
|
| 261 |
-
cat > kaggle/reward_fn.py << 'REWARDPY'
|
| 262 |
-
"""Reward functions for GRPOTrainer - both-tag aware.
|
| 263 |
-
|
| 264 |
-
Model now emits both <assembly> and </assembly> in completions.
|
| 265 |
-
format_reward rewards: +0.1 open tag, +0.1 close tag, +0.1 body >= 20 chars.
|
| 266 |
-
"""
|
| 267 |
-
from __future__ import annotations
|
| 268 |
-
import hashlib, re, threading
|
| 269 |
-
from dataclasses import dataclass
|
| 270 |
-
from arm_gym.compile_baseline import detect_toolchain
|
| 271 |
-
from arm_gym.verifier import run_correctness_qemu
|
| 272 |
-
from arm_gym.mca import run_mca
|
| 273 |
-
from arm_gym.rollout_budget import TestCase
|
| 274 |
-
from arm_gym.verifier import VerifierConfig, assemble, cleanup_temp_dirs
|
| 275 |
-
|
| 276 |
-
_ASM_RE = re.compile(r"<assembly>(.*?)</assembly>", re.DOTALL | re.IGNORECASE)
|
| 277 |
-
_THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
|
| 278 |
-
|
| 279 |
-
def extract_assembly(text: str) -> str:
|
| 280 |
-
text = _THINK_RE.sub("", text).strip()
|
| 281 |
-
m = _ASM_RE.search(text)
|
| 282 |
-
if m:
|
| 283 |
-
return _clean_assembly(m.group(1))
|
| 284 |
-
if "</assembly>" in text.lower():
|
| 285 |
-
body = re.split(r"</assembly>", text, flags=re.IGNORECASE)[0]
|
| 286 |
-
if "<assembly>" in body.lower():
|
| 287 |
-
body = re.split(r"<assembly>", body, flags=re.IGNORECASE)[-1]
|
| 288 |
-
return _clean_assembly(body)
|
| 289 |
-
if "<assembly>" in text.lower():
|
| 290 |
-
return _clean_assembly(re.split(r"<assembly>", text, flags=re.IGNORECASE)[-1])
|
| 291 |
-
return _clean_assembly(text)
|
| 292 |
-
|
| 293 |
-
def _clean_assembly(text: str) -> str:
|
| 294 |
-
text = text.strip()
|
| 295 |
-
text = re.sub(r"^```(?:asm|assembly|aarch64)?\s*", "", text, flags=re.IGNORECASE)
|
| 296 |
-
text = re.sub(r"\s*```$", "", text)
|
| 297 |
-
return text.strip()
|
| 298 |
-
|
| 299 |
-
@dataclass
|
| 300 |
-
class _Entry:
|
| 301 |
-
assembles: bool = False
|
| 302 |
-
runs: bool = False
|
| 303 |
-
speedup: float = 0.0
|
| 304 |
-
|
| 305 |
-
_CACHE: dict[str, _Entry] = {}
|
| 306 |
-
_LOCK = threading.Lock()
|
| 307 |
-
_BASELINE: dict[str, float] = {}
|
| 308 |
-
_VCFG: VerifierConfig | None = None
|
| 309 |
-
_DEBUG_LIMIT = 5
|
| 310 |
-
_DEBUG_SHOWN = 0
|
| 311 |
-
|
| 312 |
-
def _cfg() -> VerifierConfig:
|
| 313 |
-
global _VCFG
|
| 314 |
-
if _VCFG is None:
|
| 315 |
-
tc = detect_toolchain()
|
| 316 |
-
_VCFG = VerifierConfig(
|
| 317 |
-
mca_bin=tc.mca or "llvm-mca", assembler="aarch64-linux-gnu-as",
|
| 318 |
-
linker="aarch64-linux-gnu-ld", qemu="qemu-aarch64-static", mcpu=tc.mcpu)
|
| 319 |
-
return _VCFG
|
| 320 |
-
|
| 321 |
-
def _bcy(vid: str, basm: str) -> float:
|
| 322 |
-
if vid not in _BASELINE:
|
| 323 |
-
try:
|
| 324 |
-
rep = run_mca(basm, _cfg().mca_bin, _cfg().mcpu)
|
| 325 |
-
_BASELINE[vid] = float(rep.total_cycles)
|
| 326 |
-
except Exception:
|
| 327 |
-
_BASELINE[vid] = 1000.0
|
| 328 |
-
return _BASELINE[vid]
|
| 329 |
-
|
| 330 |
-
def _key(text: str, vid: str) -> str:
|
| 331 |
-
return hashlib.md5(f"{text}::{vid}".encode()).hexdigest()
|
| 332 |
-
|
| 333 |
-
def _run(text: str, vid: str, basm: str) -> _Entry:
|
| 334 |
-
global _DEBUG_SHOWN
|
| 335 |
-
k = _key(text, vid)
|
| 336 |
-
with _LOCK:
|
| 337 |
-
if k in _CACHE:
|
| 338 |
-
return _CACHE[k]
|
| 339 |
-
e = _Entry()
|
| 340 |
-
asm = extract_assembly(text)
|
| 341 |
-
cfg = _cfg()
|
| 342 |
-
obj, err = assemble(asm, cfg)
|
| 343 |
-
if err or obj is None:
|
| 344 |
-
if _DEBUG_SHOWN < _DEBUG_LIMIT:
|
| 345 |
-
with _LOCK:
|
| 346 |
-
if _DEBUG_SHOWN < _DEBUG_LIMIT:
|
| 347 |
-
_DEBUG_SHOWN += 1
|
| 348 |
-
print(f"[reward-debug #{_DEBUG_SHOWN}] vid={vid} asm_fail={err.message[:200] if err else 'None'!r}", flush=True)
|
| 349 |
-
print(f"[reward-debug] raw[:300]={text[:300]!r}", flush=True)
|
| 350 |
-
print(f"[reward-debug] extracted[:300]={asm[:300]!r}", flush=True)
|
| 351 |
-
cleanup_temp_dirs()
|
| 352 |
-
with _LOCK:
|
| 353 |
-
_CACHE[k] = e
|
| 354 |
-
return e
|
| 355 |
-
e.assembles = True
|
| 356 |
-
e.runs = run_correctness_qemu(obj, TestCase(inputs=(), expected=None), cfg)
|
| 357 |
-
if e.runs:
|
| 358 |
-
bc = _bcy(vid, basm)
|
| 359 |
-
try:
|
| 360 |
-
rep = run_mca(asm, cfg.mca_bin, cfg.mcpu)
|
| 361 |
-
e.speedup = bc / max(rep.total_cycles, 1)
|
| 362 |
-
except Exception:
|
| 363 |
-
e.speedup = 0.0
|
| 364 |
-
cleanup_temp_dirs()
|
| 365 |
-
with _LOCK:
|
| 366 |
-
_CACHE[k] = e
|
| 367 |
-
return e
|
| 368 |
-
|
| 369 |
-
def _prep(completions, kwargs):
|
| 370 |
-
texts = [c[-1]["content"] if isinstance(c, list) else str(c) for c in (completions or [])]
|
| 371 |
-
n = len(texts)
|
| 372 |
-
vids = list(kwargs.get("variant_id") or [""] * n)
|
| 373 |
-
basms = list(kwargs.get("baseline_asm") or [""] * n)
|
| 374 |
-
if len(vids) == 1 and n > 1:
|
| 375 |
-
vids = vids * n
|
| 376 |
-
basms = basms * n
|
| 377 |
-
return texts, vids, basms
|
| 378 |
-
|
| 379 |
-
def syntax_reward(prompts=None, completions=None, **kwargs) -> list[float]:
|
| 380 |
-
_ = prompts
|
| 381 |
-
texts, vids, basms = _prep(completions, kwargs)
|
| 382 |
-
return [1.0 if _run(t, v, b).assembles else 0.0 for t, v, b in zip(texts, vids, basms)]
|
| 383 |
-
|
| 384 |
-
_FORMAT_DEBUG_STEP = 0
|
| 385 |
-
_FORMAT_DEBUG_LIMIT = 6
|
| 386 |
-
|
| 387 |
-
def format_reward(prompts=None, completions=None, **kwargs) -> list[float]:
|
| 388 |
-
global _FORMAT_DEBUG_STEP
|
| 389 |
-
_ = prompts
|
| 390 |
-
texts, _, _ = _prep(completions, kwargs)
|
| 391 |
-
if _FORMAT_DEBUG_STEP < _FORMAT_DEBUG_LIMIT and texts:
|
| 392 |
-
_FORMAT_DEBUG_STEP += 1
|
| 393 |
-
head = texts[0]
|
| 394 |
-
print(f"[completion-debug step={_FORMAT_DEBUG_STEP}] len={len(head)} first 500 chars:", flush=True)
|
| 395 |
-
print(repr(head[:500]), flush=True)
|
| 396 |
-
print(f"[completion-debug step={_FORMAT_DEBUG_STEP}] last 200 chars:", flush=True)
|
| 397 |
-
print(repr(head[-200:]), flush=True)
|
| 398 |
-
scores = []
|
| 399 |
-
for text in texts:
|
| 400 |
-
lowered = text.lower()
|
| 401 |
-
has_open = "<assembly>" in lowered
|
| 402 |
-
has_close = "</assembly>" in lowered
|
| 403 |
-
body = extract_assembly(text)
|
| 404 |
-
body_len = len(re.sub(r"\s", "", body))
|
| 405 |
-
has_prose = any(m in lowered for m in ("```", "<think>", "explain", "analysis"))
|
| 406 |
-
score = 0.0
|
| 407 |
-
if has_open:
|
| 408 |
-
score += 0.1
|
| 409 |
-
if has_close:
|
| 410 |
-
score += 0.1
|
| 411 |
-
if body_len >= 20:
|
| 412 |
-
score += 0.1
|
| 413 |
-
if has_prose:
|
| 414 |
-
score -= 0.05
|
| 415 |
-
scores.append(max(0.0, score))
|
| 416 |
-
return scores
|
| 417 |
-
|
| 418 |
-
def correctness_reward(prompts=None, completions=None, **kwargs) -> list[float]:
|
| 419 |
-
_ = prompts
|
| 420 |
-
texts, vids, basms = _prep(completions, kwargs)
|
| 421 |
-
return [1.0 if _run(t, v, b).runs else 0.0 for t, v, b in zip(texts, vids, basms)]
|
| 422 |
-
|
| 423 |
-
def speedup_reward(prompts=None, completions=None, **kwargs) -> list[float]:
|
| 424 |
-
_ = prompts
|
| 425 |
-
texts, vids, basms = _prep(completions, kwargs)
|
| 426 |
-
scores = []
|
| 427 |
-
for t, v, b in zip(texts, vids, basms):
|
| 428 |
-
e = _run(t, v, b)
|
| 429 |
-
scores.append(max(0.0, e.speedup - 1.0) if e.runs else 0.0)
|
| 430 |
-
return scores
|
| 431 |
-
|
| 432 |
-
class LiveRewardFn:
|
| 433 |
-
@classmethod
|
| 434 |
-
def build(cls) -> "LiveRewardFn":
|
| 435 |
-
return cls()
|
| 436 |
-
def __call__(self, prompts=None, completions=None, **kwargs) -> list[float]:
|
| 437 |
-
texts, vids, basms = _prep(completions, kwargs)
|
| 438 |
-
out = []
|
| 439 |
-
for t, v, b in zip(texts, vids, basms):
|
| 440 |
-
e = _run(t, v, b)
|
| 441 |
-
if not e.assembles: out.append(0.0)
|
| 442 |
-
elif not e.runs: out.append(0.1)
|
| 443 |
-
else: out.append(max(0.0, e.speedup - 1.0))
|
| 444 |
-
return out
|
| 445 |
-
REWARDPY
|
| 446 |
-
echo "[hotpatch] reward_fn.py replaced (both-tag format_reward)"
|
| 447 |
|
| 448 |
# Run toolchain setup
|
| 449 |
[ "${HF_RUN_SETUP:-1}" = "1" ] && [ -f kaggle/setup.sh ] && bash kaggle/setup.sh
|
|
|
|
| 181 |
SHIMPY
|
| 182 |
echo "[hotpatch] stopping criteria DISABLED -- using natural chat EOS"
|
| 183 |
|
| 184 |
+
# 3-4. kaggle/dataset.py + kaggle/reward_fn.py: from clone (push before job; no heredoc)
|
| 185 |
+
echo "[hotpatch] using kaggle/dataset.py + kaggle/reward_fn.py from clone"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
# Run toolchain setup
|
| 188 |
[ "${HF_RUN_SETUP:-1}" = "1" ] && [ -f kaggle/setup.sh ] && bash kaggle/setup.sh
|
kaggle/reward_fn.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
"""Reward functions for GRPOTrainer - 3 independent callables for GDPO.
|
| 2 |
|
| 3 |
-
syntax_reward : +1.0 if assembly parses,
|
| 4 |
correctness_reward: +1.0 if QEMU executes without crash, 0.0 otherwise
|
| 5 |
speedup_reward : (baseline_cycles / agent_cycles) - 1.0; 0.0 if !runs
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
Shared _Entry cache avoids 3x verifier calls per completion within a batch.
|
| 8 |
Cache is module-level and persistent (verifier is deterministic, results stable).
|
| 9 |
GDPO: pass [syntax_reward, correctness_reward, speedup_reward] to reward_funcs
|
|
@@ -26,6 +30,20 @@ from arm_gym.verifier import VerifierConfig, assemble, cleanup_temp_dirs
|
|
| 26 |
_ASM_RE = re.compile(r"<assembly>(.*?)</assembly>", re.DOTALL | re.IGNORECASE)
|
| 27 |
_THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def extract_assembly(text: str) -> str:
|
| 31 |
"""Extract assembly body from text.
|
|
@@ -58,6 +76,43 @@ def _clean_assembly(text: str) -> str:
|
|
| 58 |
return text.strip()
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
@dataclass
|
| 62 |
class _Entry:
|
| 63 |
assembles: bool = False
|
|
@@ -114,10 +169,30 @@ def _run(text: str, vid: str, basm: str) -> _Entry:
|
|
| 114 |
# Compute outside lock - subprocess calls can be slow.
|
| 115 |
# Two threads racing on the same key both compute and write; result is identical.
|
| 116 |
e = _Entry()
|
| 117 |
-
|
|
|
|
| 118 |
cfg = _cfg()
|
| 119 |
|
| 120 |
-
# Gate 1: assemble
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
obj, err = assemble(asm, cfg)
|
| 122 |
if err or obj is None:
|
| 123 |
if _DEBUG_SHOWN < _DEBUG_LIMIT:
|
|
@@ -130,7 +205,10 @@ def _run(text: str, vid: str, basm: str) -> _Entry:
|
|
| 130 |
flush=True,
|
| 131 |
)
|
| 132 |
print(f"[reward-debug] raw[:300]={text[:300]!r}", flush=True)
|
| 133 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 134 |
cleanup_temp_dirs()
|
| 135 |
with _LOCK:
|
| 136 |
_CACHE[k] = e
|
|
@@ -140,7 +218,7 @@ def _run(text: str, vid: str, basm: str) -> _Entry:
|
|
| 140 |
# Gate 2: QEMU run-without-crash
|
| 141 |
e.runs = run_correctness_qemu(obj, TestCase(inputs=(), expected=None), cfg)
|
| 142 |
|
| 143 |
-
# Gate 3: MCA speedup
|
| 144 |
if e.runs:
|
| 145 |
bc = _bcy(vid, basm)
|
| 146 |
try:
|
|
@@ -167,38 +245,47 @@ def _prep(completions, kwargs):
|
|
| 167 |
|
| 168 |
|
| 169 |
def syntax_reward(prompts=None, completions=None, **kwargs) -> list[float]:
|
| 170 |
-
"""Gate 1:
|
| 171 |
_ = prompts
|
| 172 |
texts, vids, basms = _prep(completions, kwargs)
|
| 173 |
-
return [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
def format_reward(prompts=None, completions=None, **kwargs) -> list[float]:
|
| 177 |
"""Shaping reward for well-formed completions.
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
Max score: 0.3 for a perfectly formed completion.
|
| 185 |
"""
|
| 186 |
_ = prompts
|
| 187 |
texts, _, _ = _prep(completions, kwargs)
|
| 188 |
scores = []
|
| 189 |
for text in texts:
|
| 190 |
lowered = text.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
has_open = "<assembly>" in lowered
|
| 192 |
-
has_close = "</assembly>" in lowered
|
| 193 |
body = extract_assembly(text)
|
| 194 |
body_len = len(re.sub(r"\s", "", body))
|
|
|
|
|
|
|
| 195 |
has_prose = any(marker in lowered
|
| 196 |
for marker in ("```", "<think>", "explain", "analysis"))
|
| 197 |
score = 0.0
|
| 198 |
-
if
|
| 199 |
-
score += 0.1
|
| 200 |
-
if has_close:
|
| 201 |
score += 0.1
|
|
|
|
| 202 |
if body_len >= 20:
|
| 203 |
score += 0.1
|
| 204 |
if has_prose:
|
|
@@ -242,7 +329,7 @@ class LiveRewardFn:
|
|
| 242 |
for t, v, b in zip(texts, vids, basms):
|
| 243 |
e = _run(t, v, b)
|
| 244 |
if not e.assembles:
|
| 245 |
-
out.append(
|
| 246 |
elif not e.runs:
|
| 247 |
out.append(0.1)
|
| 248 |
else:
|
|
|
|
| 1 |
"""Reward functions for GRPOTrainer - 3 independent callables for GDPO.
|
| 2 |
|
| 3 |
+
syntax_reward : +1.0 if assembly parses, **-1.0** otherwise (was 0.0)
|
| 4 |
correctness_reward: +1.0 if QEMU executes without crash, 0.0 otherwise
|
| 5 |
speedup_reward : (baseline_cycles / agent_cycles) - 1.0; 0.0 if !runs
|
| 6 |
|
| 7 |
+
Stronger **invalid** syntax signal (negative) avoids PPO/GRPO treating “broken” and
|
| 8 |
+
“almost right” as the same plateau. Pre-``as`` fast-reject (unclosed block comments,
|
| 9 |
+
obvious high-level code slop) saves subprocess budget and fails before GNU as.
|
| 10 |
+
|
| 11 |
Shared _Entry cache avoids 3x verifier calls per completion within a batch.
|
| 12 |
Cache is module-level and persistent (verifier is deterministic, results stable).
|
| 13 |
GDPO: pass [syntax_reward, correctness_reward, speedup_reward] to reward_funcs
|
|
|
|
| 30 |
_ASM_RE = re.compile(r"<assembly>(.*?)</assembly>", re.DOTALL | re.IGNORECASE)
|
| 31 |
_THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
|
| 32 |
|
| 33 |
+
# Syntax gate: separate “dead” from “valid” in RL (not flat zeros).
|
| 34 |
+
SYNTAX_REWARD_OK: float = 1.0
|
| 35 |
+
SYNTAX_REWARD_FAIL: float = -1.0
|
| 36 |
+
|
| 37 |
+
# Substrings that almost never appear in hand-written A64 gas but show up in C/Java/JS drift.
|
| 38 |
+
_CODE_SLOP = re.compile(
|
| 39 |
+
r"(?i)\b("
|
| 40 |
+
r"import|package|#include|#pragma|"
|
| 41 |
+
r"class\s|def\s|void\s+main|public\s+static|"
|
| 42 |
+
r"console\.|React\.|"
|
| 43 |
+
r"System\.|using\s+namespace"
|
| 44 |
+
r")\b",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
|
| 48 |
def extract_assembly(text: str) -> str:
|
| 49 |
"""Extract assembly body from text.
|
|
|
|
| 76 |
return text.strip()
|
| 77 |
|
| 78 |
|
| 79 |
+
def preprocess_assembly(asm: str) -> tuple[str, str | None]:
|
| 80 |
+
"""Sanitize for GNU ``as`` and **fast-reject** obvious invalid text before ``assemble()``.
|
| 81 |
+
|
| 82 |
+
* Unclosed C block comments → reject (GAS: "end of file in multiline comment").
|
| 83 |
+
* Obvious high-level / JS / C drift → reject (saves `as` + clearer RL signal than junk errors).
|
| 84 |
+
* ``//`` → ``@`` line comments (A64 gas line comment; plain ``as`` does not treat ``//`` like gcc).
|
| 85 |
+
* Smart quotes / mojibake quotes → ASCII (assembler is strict).
|
| 86 |
+
* Trailing newline (silences warning; some gas versions are picky).
|
| 87 |
+
|
| 88 |
+
We **do not** use a naive ``[a-zA-Z]{6,}`` filter — valid mnemonics like
|
| 89 |
+
``umaddl``, ``cfinv``, ``st1`` would false-positive.
|
| 90 |
+
|
| 91 |
+
Returns ``(text, None)`` if OK, or ``("", reason)`` to skip ``assemble()``.
|
| 92 |
+
"""
|
| 93 |
+
if not (asm and asm.strip()):
|
| 94 |
+
return "", "empty_body"
|
| 95 |
+
|
| 96 |
+
s = (
|
| 97 |
+
asm.replace("\u201c", '"')
|
| 98 |
+
.replace("\u201d", '"')
|
| 99 |
+
.replace("\u2018", "'")
|
| 100 |
+
.replace("\u2019", "'")
|
| 101 |
+
)
|
| 102 |
+
# C++-style line comments: map to GAS line comment; avoids "junk" on `/` for plain `.s`.
|
| 103 |
+
s = s.replace("//", "@")
|
| 104 |
+
|
| 105 |
+
if s.count("/*") != s.count("*/"):
|
| 106 |
+
return "", "unclosed_block_comment"
|
| 107 |
+
|
| 108 |
+
if _CODE_SLOP.search(s):
|
| 109 |
+
return "", "lexical:high_level_code"
|
| 110 |
+
|
| 111 |
+
if not s.endswith("\n"):
|
| 112 |
+
s += "\n"
|
| 113 |
+
return s, None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
@dataclass
|
| 117 |
class _Entry:
|
| 118 |
assembles: bool = False
|
|
|
|
| 169 |
# Compute outside lock - subprocess calls can be slow.
|
| 170 |
# Two threads racing on the same key both compute and write; result is identical.
|
| 171 |
e = _Entry()
|
| 172 |
+
raw_extracted = extract_assembly(text)
|
| 173 |
+
asm, pre_r = preprocess_assembly(raw_extracted)
|
| 174 |
cfg = _cfg()
|
| 175 |
|
| 176 |
+
# Gate 1: assemble (skip subprocess if fast-reject already failed)
|
| 177 |
+
if pre_r is not None:
|
| 178 |
+
if _DEBUG_SHOWN < _DEBUG_LIMIT:
|
| 179 |
+
with _LOCK:
|
| 180 |
+
if _DEBUG_SHOWN < _DEBUG_LIMIT:
|
| 181 |
+
_DEBUG_SHOWN += 1
|
| 182 |
+
print(
|
| 183 |
+
f"[reward-debug #{_DEBUG_SHOWN}] vid={vid} "
|
| 184 |
+
f"pre_reject={pre_r}",
|
| 185 |
+
flush=True,
|
| 186 |
+
)
|
| 187 |
+
print(f"[reward-debug] raw[:300]={text[:300]!r}", flush=True)
|
| 188 |
+
print(
|
| 189 |
+
f"[reward-debug] extracted[:300]={raw_extracted[:300]!r}",
|
| 190 |
+
flush=True,
|
| 191 |
+
)
|
| 192 |
+
with _LOCK:
|
| 193 |
+
_CACHE[k] = e
|
| 194 |
+
return e
|
| 195 |
+
|
| 196 |
obj, err = assemble(asm, cfg)
|
| 197 |
if err or obj is None:
|
| 198 |
if _DEBUG_SHOWN < _DEBUG_LIMIT:
|
|
|
|
| 205 |
flush=True,
|
| 206 |
)
|
| 207 |
print(f"[reward-debug] raw[:300]={text[:300]!r}", flush=True)
|
| 208 |
+
print(
|
| 209 |
+
f"[reward-debug] extracted[:300]={raw_extracted[:300]!r}",
|
| 210 |
+
flush=True,
|
| 211 |
+
)
|
| 212 |
cleanup_temp_dirs()
|
| 213 |
with _LOCK:
|
| 214 |
_CACHE[k] = e
|
|
|
|
| 218 |
# Gate 2: QEMU run-without-crash
|
| 219 |
e.runs = run_correctness_qemu(obj, TestCase(inputs=(), expected=None), cfg)
|
| 220 |
|
| 221 |
+
# Gate 3: MCA speedup (use sanitized asm, same as ``assemble``)
|
| 222 |
if e.runs:
|
| 223 |
bc = _bcy(vid, basm)
|
| 224 |
try:
|
|
|
|
| 245 |
|
| 246 |
|
| 247 |
def syntax_reward(prompts=None, completions=None, **kwargs) -> list[float]:
|
| 248 |
+
"""Gate 1: ``SYNTAX_REWARD_OK`` if assembly parses, ``SYNTAX_REWARD_FAIL`` otherwise."""
|
| 249 |
_ = prompts
|
| 250 |
texts, vids, basms = _prep(completions, kwargs)
|
| 251 |
+
return [
|
| 252 |
+
(SYNTAX_REWARD_OK if _run(t, v, b).assembles else SYNTAX_REWARD_FAIL)
|
| 253 |
+
for t, v, b in zip(texts, vids, basms)
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
FORMAT_REWARD_MISSING_CLOSE: float = -1.0
|
| 258 |
|
| 259 |
|
| 260 |
def format_reward(prompts=None, completions=None, **kwargs) -> list[float]:
|
| 261 |
"""Shaping reward for well-formed completions.
|
| 262 |
|
| 263 |
+
If ``</assembly>`` is **missing** → ``FORMAT_REWARD_MISSING_CLOSE`` (RL signal
|
| 264 |
+
to separate unclosed from weak-but-closed). Otherwise small positive parts.
|
| 265 |
+
|
| 266 |
+
Prompt may prefill ``<assembly>`` so the assistant body often omits the open
|
| 267 |
+
tag; we still require a closing tag and use ``extract_assembly`` for body.
|
|
|
|
| 268 |
"""
|
| 269 |
_ = prompts
|
| 270 |
texts, _, _ = _prep(completions, kwargs)
|
| 271 |
scores = []
|
| 272 |
for text in texts:
|
| 273 |
lowered = text.lower()
|
| 274 |
+
if "</assembly>" not in lowered:
|
| 275 |
+
scores.append(FORMAT_REWARD_MISSING_CLOSE)
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
has_open = "<assembly>" in lowered
|
|
|
|
| 279 |
body = extract_assembly(text)
|
| 280 |
body_len = len(re.sub(r"\s", "", body))
|
| 281 |
+
# Prefill: open tag only in prompt — credit body as soon as it exists.
|
| 282 |
+
has_open_effective = has_open or body_len >= 1
|
| 283 |
has_prose = any(marker in lowered
|
| 284 |
for marker in ("```", "<think>", "explain", "analysis"))
|
| 285 |
score = 0.0
|
| 286 |
+
if has_open_effective:
|
|
|
|
|
|
|
| 287 |
score += 0.1
|
| 288 |
+
score += 0.1 # has_close: we only get here if </assembly> present
|
| 289 |
if body_len >= 20:
|
| 290 |
score += 0.1
|
| 291 |
if has_prose:
|
|
|
|
| 329 |
for t, v, b in zip(texts, vids, basms):
|
| 330 |
e = _run(t, v, b)
|
| 331 |
if not e.assembles:
|
| 332 |
+
out.append(SYNTAX_REWARD_FAIL)
|
| 333 |
elif not e.runs:
|
| 334 |
out.append(0.1)
|
| 335 |
else:
|
scripts/sync_hf_space.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Force-push current branch to Hugging Face Space main, rewriting plot PNGs to Git LFS
|
| 3 |
+
# first so the Hub accepts the pack (raw PNG blobs are rejected).
|
| 4 |
+
#
|
| 5 |
+
# git fetch origin && git checkout main && git pull
|
| 6 |
+
# export HF_TOKEN=hf_... # or use git credential for huggingface.co
|
| 7 |
+
# ./scripts/sync_hf_space.sh
|
| 8 |
+
#
|
| 9 |
+
set -euo pipefail
|
| 10 |
+
ROOT="$(git rev-parse --show-toplevel)"
|
| 11 |
+
cd "$ROOT"
|
| 12 |
+
|
| 13 |
+
HF_USER="${HF_USER:-kaori02}"
|
| 14 |
+
HF_SPACE="${HF_SPACE:-arm-gym}"
|
| 15 |
+
if [[ -n "${HF_TOKEN:-}" ]]; then
|
| 16 |
+
REMOTE_URL="https://${HF_USER}:${HF_TOKEN}@huggingface.co/spaces/${HF_USER}/${HF_SPACE}"
|
| 17 |
+
else
|
| 18 |
+
REMOTE_URL="https://huggingface.co/spaces/${HF_USER}/${HF_SPACE}"
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
REF=$(git symbolic-ref -q HEAD) || { echo "error: need a branch (detached HEAD not supported)" >&2; exit 1; }
|
| 22 |
+
|
| 23 |
+
git lfs install
|
| 24 |
+
git lfs migrate import --include="colab/results/plots/*.png" --include-ref="$REF"
|
| 25 |
+
|
| 26 |
+
echo "[sync_hf_space] pushing HEAD -> ${HF_USER}/${HF_SPACE}:main (force)"
|
| 27 |
+
git push --force "$REMOTE_URL" HEAD:main
|
tests/test_preprocess_assembly.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for ``preprocess_assembly`` (no toolchain)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from kaggle.reward_fn import SYNTAX_REWARD_FAIL, SYNTAX_REWARD_OK, preprocess_assembly, syntax_reward
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_preprocess_appends_newline() -> None:
|
| 11 |
+
s, err = preprocess_assembly("mov x0, #0")
|
| 12 |
+
assert err is None
|
| 13 |
+
assert s.endswith("\n")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_preprocess_maps_double_slash_to_at() -> None:
|
| 17 |
+
s, err = preprocess_assembly("mov x0, x1 // foo\n")
|
| 18 |
+
assert err is None
|
| 19 |
+
assert "//" not in s
|
| 20 |
+
assert "@" in s
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_preprocess_rejects_unclosed_block_comment() -> None:
|
| 24 |
+
_, err = preprocess_assembly("mov x0, x1 /* start\n")
|
| 25 |
+
assert err == "unclosed_block_comment"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_preprocess_rejects_code_slop_import() -> None:
|
| 29 |
+
_, err = preprocess_assembly("import java.util\nmov x0, x0\n")
|
| 30 |
+
assert err == "lexical:high_level_code"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_preprocess_allows_valid_mnemonic_umaddl() -> None:
|
| 34 |
+
# Do not use naive "long English word" heuristics — umaddl is 6+ letters.
|
| 35 |
+
s, err = preprocess_assembly("umaddl x0, w1, w2, w3\n")
|
| 36 |
+
assert err is None
|
| 37 |
+
assert "umaddl" in s
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_syntax_reward_negative_on_garbage(monkeypatch: pytest.MonkeyPatch) -> None:
|
| 41 |
+
from kaggle import reward_fn as rf
|
| 42 |
+
|
| 43 |
+
monkeypatch.setattr(rf, "_CACHE", {})
|
| 44 |
+
monkeypatch.setattr(rf, "_DEBUG_SHOWN", 0)
|
| 45 |
+
# No valid assembly, pre-reject or assemble fail
|
| 46 |
+
bad = "<assembly>import numpy as np\n</assembly>"
|
| 47 |
+
r = syntax_reward(completions=[[{"content": bad}]], variant_id=["t"], baseline_asm=[""])
|
| 48 |
+
assert r == [SYNTAX_REWARD_FAIL]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_syntax_reward_ok_constant() -> None:
|
| 52 |
+
assert SYNTAX_REWARD_OK == 1.0
|
| 53 |
+
assert SYNTAX_REWARD_FAIL == -1.0
|
tests/test_training_prompt_contract.py
CHANGED
|
@@ -36,11 +36,14 @@ def test_format_reward_prefers_closed_assembly_blocks() -> None:
|
|
| 36 |
]
|
| 37 |
)
|
| 38 |
|
| 39 |
-
assert rewards == [0.2,
|
| 40 |
|
| 41 |
|
| 42 |
def test_user_prompt_includes_baseline_fallback() -> None:
|
| 43 |
prompt = user_prompt("int f(void){return 0;}", "f:\n\tret\n")
|
| 44 |
|
| 45 |
-
assert "
|
| 46 |
-
assert "
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
]
|
| 37 |
)
|
| 38 |
|
| 39 |
+
assert rewards == [0.2, -1.0]
|
| 40 |
|
| 41 |
|
| 42 |
def test_user_prompt_includes_baseline_fallback() -> None:
|
| 43 |
prompt = user_prompt("int f(void){return 0;}", "f:\n\tret\n")
|
| 44 |
|
| 45 |
+
assert "Generate ONLY AArch64" in prompt
|
| 46 |
+
assert "C code:" in prompt
|
| 47 |
+
assert "Baseline assembly" in prompt
|
| 48 |
+
assert prompt.rstrip().endswith("<assembly>")
|
| 49 |
+
assert "ret" in prompt
|