mahir-m01 commited on
Commit Β·
93e35c4
1
Parent(s): a2ffabc
feat(hf): v6 GRPO with harness correctness, MVP 50 steps, and run scripts
Browse files- Add hf/v6_train.py: qemu harness, ARMGYM_PROFILE/PROFILE (mvp|long) for hub + steps
- Add hf/run_hf_mvp.sh and run_hf_long.sh for Hugging Face Jobs reference
- Add hf/smoke_harness.py for local syntax checks
- Harden arm_gym/compile_baseline mcpu fallbacks (clang+gcc+MCA)
- Add instructions.md for LoRA export and 50 vs 200 flows
- Ignore .ai-workflow/ and .mcp.json in git
Made-with: Cursor
- .gitignore +2 -0
- arm_gym/compile_baseline.py +54 -17
- hf/PROFILE +1 -0
- hf/run_hf_long.sh +6 -0
- hf/run_hf_mvp.sh +32 -0
- hf/run_long_train.sh +4 -0
- hf/smoke_harness.py +208 -0
- hf/v6_train.py +751 -0
.gitignore
CHANGED
|
@@ -37,6 +37,8 @@ adapters/
|
|
| 37 |
*.gguf
|
| 38 |
.next/
|
| 39 |
excalidraw.log
|
|
|
|
|
|
|
| 40 |
|
| 41 |
.omc/
|
| 42 |
.omg/
|
|
|
|
| 37 |
*.gguf
|
| 38 |
.next/
|
| 39 |
excalidraw.log
|
| 40 |
+
.ai-workflow/
|
| 41 |
+
.mcp.json
|
| 42 |
|
| 43 |
.omc/
|
| 44 |
.omg/
|
arm_gym/compile_baseline.py
CHANGED
|
@@ -3,6 +3,11 @@
|
|
| 3 |
Cut 1 fix: prefer LLVM 21 with -mcpu=neoverse-v2 as the Neoverse V3 proxy.
|
| 4 |
neoverse-v3 is not yet in any LLVM release (Olympus is LLVM 22, NVIDIA-specific).
|
| 5 |
Fallback is disclosed via mcpu_disclosed field, not silently swapped.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
@@ -13,9 +18,11 @@ import tempfile
|
|
| 13 |
from dataclasses import dataclass
|
| 14 |
from pathlib import Path
|
| 15 |
|
| 16 |
-
CLANG_CANDIDATES = ["clang-21", "clang-20", "clang"]
|
| 17 |
GCC_AARCH64 = "aarch64-linux-gnu-gcc"
|
| 18 |
-
MCA_CANDIDATES = ["llvm-mca-21", "llvm-mca-20", "llvm-mca"]
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def find_tool(candidates: list[str]) -> str | None:
|
|
@@ -31,7 +38,7 @@ class ToolchainInfo:
|
|
| 31 |
gcc_aarch64: str | None
|
| 32 |
mca: str | None
|
| 33 |
mcpu: str # actual -mcpu used
|
| 34 |
-
mcpu_disclosed: str | None # e.g. "
|
| 35 |
|
| 36 |
def ready(self) -> bool:
|
| 37 |
return bool(self.clang or self.gcc_aarch64)
|
|
@@ -41,23 +48,53 @@ def detect_toolchain(preferred_cpu: str = "neoverse-v3") -> ToolchainInfo:
|
|
| 41 |
clang = find_tool(CLANG_CANDIDATES)
|
| 42 |
gcc = shutil.which(GCC_AARCH64)
|
| 43 |
mca = find_tool(MCA_CANDIDATES)
|
| 44 |
-
mcpu, disclosed = _pick_cpu(clang, preferred_cpu)
|
| 45 |
return ToolchainInfo(clang=clang, gcc_aarch64=gcc, mca=mca, mcpu=mcpu, mcpu_disclosed=disclosed)
|
| 46 |
|
| 47 |
|
| 48 |
-
def
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
def compile_to_asm(c_source: str, tc: ToolchainInfo, opt: str = "-O3") -> str:
|
|
|
|
| 3 |
Cut 1 fix: prefer LLVM 21 with -mcpu=neoverse-v2 as the Neoverse V3 proxy.
|
| 4 |
neoverse-v3 is not yet in any LLVM release (Olympus is LLVM 22, NVIDIA-specific).
|
| 5 |
Fallback is disclosed via mcpu_disclosed field, not silently swapped.
|
| 6 |
+
|
| 7 |
+
mcpu selection probes both clang and gcc to find the best CPU both compilers
|
| 8 |
+
actually support. Fallback chain: v3 β v2 β v1 β n2 β n1 β generic.
|
| 9 |
+
GCC 12 (Debian Bookworm default) supports up to neoverse-v1.
|
| 10 |
+
LLVM-MCA-15 supports neoverse-v1, so reward signal stays consistent.
|
| 11 |
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
|
|
|
| 18 |
from dataclasses import dataclass
|
| 19 |
from pathlib import Path
|
| 20 |
|
| 21 |
+
CLANG_CANDIDATES = ["clang-21", "clang-20", "clang-17", "clang-16", "clang-15", "clang"]
|
| 22 |
GCC_AARCH64 = "aarch64-linux-gnu-gcc"
|
| 23 |
+
MCA_CANDIDATES = ["llvm-mca-21", "llvm-mca-20", "llvm-mca-17", "llvm-mca-16", "llvm-mca-15", "llvm-mca"]
|
| 24 |
+
|
| 25 |
+
_MCPU_CHAIN = ["neoverse-v3", "neoverse-v2", "neoverse-v1", "neoverse-n2", "neoverse-n1", "generic"]
|
| 26 |
|
| 27 |
|
| 28 |
def find_tool(candidates: list[str]) -> str | None:
|
|
|
|
| 38 |
gcc_aarch64: str | None
|
| 39 |
mca: str | None
|
| 40 |
mcpu: str # actual -mcpu used
|
| 41 |
+
mcpu_disclosed: str | None # e.g. "V1 proxy for V3" when fallback taken
|
| 42 |
|
| 43 |
def ready(self) -> bool:
|
| 44 |
return bool(self.clang or self.gcc_aarch64)
|
|
|
|
| 48 |
clang = find_tool(CLANG_CANDIDATES)
|
| 49 |
gcc = shutil.which(GCC_AARCH64)
|
| 50 |
mca = find_tool(MCA_CANDIDATES)
|
| 51 |
+
mcpu, disclosed = _pick_cpu(clang, gcc, preferred_cpu)
|
| 52 |
return ToolchainInfo(clang=clang, gcc_aarch64=gcc, mca=mca, mcpu=mcpu, mcpu_disclosed=disclosed)
|
| 53 |
|
| 54 |
|
| 55 |
+
def _gcc_probe_mcpu(gcc: str, preferred: str) -> tuple[str, str | None]:
|
| 56 |
+
"""Find best mcpu the installed gcc actually accepts via test-compile."""
|
| 57 |
+
chain = [preferred] + [c for c in _MCPU_CHAIN if c != preferred]
|
| 58 |
+
for cpu in chain:
|
| 59 |
+
try:
|
| 60 |
+
r = subprocess.run(
|
| 61 |
+
[gcc, f"-mcpu={cpu}", "-S", "-x", "c", "-", "-o", "/dev/null"],
|
| 62 |
+
input="int f(void){return 0;}",
|
| 63 |
+
capture_output=True, text=True, timeout=5,
|
| 64 |
+
)
|
| 65 |
+
if r.returncode == 0:
|
| 66 |
+
disclosed = None if cpu == preferred else f"{cpu} proxy for {preferred} (gcc limit)"
|
| 67 |
+
return cpu, disclosed
|
| 68 |
+
except Exception:
|
| 69 |
+
continue
|
| 70 |
+
return "generic", f"generic fallback for {preferred}"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _pick_cpu(clang: str | None, gcc: str | None, preferred: str) -> tuple[str, str | None]:
|
| 74 |
+
"""Probe available compilers to find best mcpu both support."""
|
| 75 |
+
if clang:
|
| 76 |
+
try:
|
| 77 |
+
# Must pass --target=aarch64-linux-gnu β without it, clang lists
|
| 78 |
+
# host (x86) CPUs and neoverse-* names are not present.
|
| 79 |
+
out = subprocess.run(
|
| 80 |
+
[clang, "--target=aarch64-linux-gnu", "--print-supported-cpus"],
|
| 81 |
+
capture_output=True, text=True, timeout=10,
|
| 82 |
+
)
|
| 83 |
+
supported = (out.stdout + out.stderr).lower()
|
| 84 |
+
chain = [preferred] + [c for c in _MCPU_CHAIN if c != preferred]
|
| 85 |
+
for cpu in chain:
|
| 86 |
+
if cpu.lower() in supported:
|
| 87 |
+
disclosed = None if cpu == preferred else f"{cpu} proxy for {preferred} (clang limit)"
|
| 88 |
+
return cpu, disclosed
|
| 89 |
+
except Exception:
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
# No clang or clang probe failed β probe GCC directly by test-compiling.
|
| 93 |
+
if gcc:
|
| 94 |
+
return _gcc_probe_mcpu(gcc, preferred)
|
| 95 |
+
|
| 96 |
+
# No compiler to probe yet; return preferred and let compile_to_asm fail loudly.
|
| 97 |
+
return preferred, None
|
| 98 |
|
| 99 |
|
| 100 |
def compile_to_asm(c_source: str, tc: ToolchainInfo, opt: str = "-O3") -> str:
|
hf/PROFILE
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
mvp
|
hf/run_hf_long.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# 200-step Hugging Face Job launcher (same deps as run_hf_mvp.sh).
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
export ARMGYM_PROFILE=long
|
| 5 |
+
DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
exec bash "$DIR/run_hf_mvp.sh"
|
hf/run_hf_mvp.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Reference launcher for Hugging Face Jobs: MVP (50-step) training on a10g-large (or similar).
|
| 3 |
+
# Requires: HF_TOKEN in env for adapter upload. Do not commit tokens.
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
: "${HF_TOKEN:?set HF_TOKEN}"
|
| 6 |
+
echo "=== System deps ==="
|
| 7 |
+
apt-get update -qq
|
| 8 |
+
DEBIAN_FRONTEND=noninteractive apt-get install -y -qq \
|
| 9 |
+
qemu-user-static binutils-aarch64-linux-gnu gcc-aarch64-linux-gnu \
|
| 10 |
+
libc6-dev-arm64-cross ca-certificates curl gnupg lsb-release >/dev/null
|
| 11 |
+
CODENAME="${CODENAME:-$(lsb_release -cs 2>/dev/null || echo bookworm)}"
|
| 12 |
+
curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | gpg --dearmor -o /usr/share/keyrings/llvm.gpg
|
| 13 |
+
echo "deb [signed-by=/usr/share/keyrings/llvm.gpg] http://apt.llvm.org/${CODENAME}/ llvm-toolchain-${CODENAME}-21 main" \
|
| 14 |
+
> /etc/apt/sources.list.d/llvm21.list
|
| 15 |
+
apt-get update -qq
|
| 16 |
+
DEBIAN_FRONTEND=noninteractive apt-get install -y -qq clang-21 llvm-21 llvm-21-tools >/dev/null \
|
| 17 |
+
|| apt-get install -y -qq clang llvm llvm-tools
|
| 18 |
+
export PATH="/usr/lib/llvm-21/bin:${PATH}"
|
| 19 |
+
echo "=== Python deps ==="
|
| 20 |
+
python3 -m pip install -q --upgrade pip
|
| 21 |
+
python3 -m pip install -q --no-input \
|
| 22 |
+
'trl==0.20.0' 'transformers>=4.55,<4.58' 'accelerate>=1.0' 'peft>=0.13' 'datasets>=3.0' \
|
| 23 |
+
'bitsandbytes>=0.45' 'torch>=2.3' 'numpy>=1.26' 'pydantic>=2.7' sentencepiece protobuf \
|
| 24 |
+
huggingface_hub
|
| 25 |
+
V6_URL="https://huggingface.co/datasets/ZDC-M01/arm-gym-pkg/resolve/main/v6_train.py"
|
| 26 |
+
WHEEL_URL="https://huggingface.co/datasets/ZDC-M01/arm-gym-pkg/resolve/main/arm_gym-0.1.0-py3-none-any.whl"
|
| 27 |
+
python3 -m pip install -q --force-reinstall "arm_gym @ ${WHEEL_URL}"
|
| 28 |
+
python3 -m pip uninstall -y torchao vllm unsloth unsloth_zoo xformers 2>/dev/null || true
|
| 29 |
+
echo "=== Run GRPO (MVP profile = 50 steps via hf/PROFILE) ==="
|
| 30 |
+
curl -sSL "$V6_URL" -o /tmp/v6_train.py
|
| 31 |
+
cd /tmp
|
| 32 |
+
python3 v6_train.py
|
hf/run_long_train.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# 200-step GRPO + LoRA. Overrides hf/PROFILE when sourced before python3 hf/v6_train.py
|
| 3 |
+
export ARMGYM_PROFILE=long
|
| 4 |
+
echo "ARMGYM_PROFILE=long β 200 steps, out runs/v6-200, hub ZDC-M01/arm-gym-train-200"
|
hf/smoke_harness.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Smoke test for generate_test_harness: check generated C is syntactically valid."""
|
| 3 |
+
import os, subprocess, sys, tempfile
|
| 4 |
+
|
| 5 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 6 |
+
|
| 7 |
+
# Pull the harness helpers from v5e_train without importing the whole module
|
| 8 |
+
# (which requires torch, etc.)
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
_SIG_RE = re.compile(r"([\w\s*]+?\bkernel\s*\([^)]*\))", re.DOTALL)
|
| 12 |
+
_ARR_SZ = 8192
|
| 13 |
+
_PRINT_N = 64
|
| 14 |
+
|
| 15 |
+
def _parse_kernel_sig(c_source):
|
| 16 |
+
m = _SIG_RE.search(c_source)
|
| 17 |
+
if not m:
|
| 18 |
+
return "void", [], None
|
| 19 |
+
full_proto = " ".join(m.group(1).split())
|
| 20 |
+
paren = full_proto.index("(")
|
| 21 |
+
ret_and_name = full_proto[:paren].strip()
|
| 22 |
+
ret_type = ret_and_name.rsplit("kernel", 1)[0].strip() or "void"
|
| 23 |
+
raw_params_str = full_proto[paren + 1:].rstrip(")")
|
| 24 |
+
params = []
|
| 25 |
+
for p in raw_params_str.split(","):
|
| 26 |
+
p = p.strip()
|
| 27 |
+
if not p:
|
| 28 |
+
continue
|
| 29 |
+
is_const = "const " in p
|
| 30 |
+
is_ptr = "*" in p
|
| 31 |
+
clean = p.replace("const", "").replace("__restrict__", "").replace("*", "").strip()
|
| 32 |
+
parts = clean.split()
|
| 33 |
+
dtype = parts[0] if parts else "int"
|
| 34 |
+
name = parts[-1] if len(parts) > 1 else f"p{len(params)}"
|
| 35 |
+
params.append({"dtype": dtype, "name": name, "is_ptr": is_ptr,
|
| 36 |
+
"is_const": is_const, "raw": p})
|
| 37 |
+
return ret_type, params, full_proto
|
| 38 |
+
|
| 39 |
+
def generate_test_harness(c_source, with_kernel_def):
|
| 40 |
+
ret_type, params, full_proto = _parse_kernel_sig(c_source)
|
| 41 |
+
lines = ['#include <stdio.h>', '#include <stdlib.h>', '#include <string.h>',
|
| 42 |
+
'#include <stddef.h>', '#include <stdint.h>', '#include <math.h>', '']
|
| 43 |
+
if with_kernel_def:
|
| 44 |
+
lines.append(c_source)
|
| 45 |
+
else:
|
| 46 |
+
if full_proto:
|
| 47 |
+
lines.append(f"extern {full_proto};")
|
| 48 |
+
lines += ['', 'int main(void) {']
|
| 49 |
+
call_args = []
|
| 50 |
+
output_arrays = []
|
| 51 |
+
for p in params:
|
| 52 |
+
dt, nm = p["dtype"], p["name"]
|
| 53 |
+
if p["is_ptr"]:
|
| 54 |
+
lines.append(f" static {dt} {nm}[{_ARR_SZ}];")
|
| 55 |
+
if p["is_const"]:
|
| 56 |
+
if dt in ("float", "double"):
|
| 57 |
+
lines.append(
|
| 58 |
+
f" for (int i = 0; i < {_ARR_SZ}; i++)"
|
| 59 |
+
f" {nm}[i] = ({dt})((i % 17) + 1) * ({dt})0.25;")
|
| 60 |
+
elif "uint" in dt:
|
| 61 |
+
lines.append(
|
| 62 |
+
f" for (int i = 0; i < {_ARR_SZ}; i++)"
|
| 63 |
+
f" {nm}[i] = ({dt})((i % 31) + 1);")
|
| 64 |
+
else:
|
| 65 |
+
lines.append(
|
| 66 |
+
f" for (int i = 0; i < {_ARR_SZ}; i++)"
|
| 67 |
+
f" {nm}[i] = ({dt})((i % 17) + 1);")
|
| 68 |
+
else:
|
| 69 |
+
lines.append(f" memset({nm}, 0, sizeof({nm}));")
|
| 70 |
+
output_arrays.append((nm, dt))
|
| 71 |
+
call_args.append(nm)
|
| 72 |
+
else:
|
| 73 |
+
if dt in ("float", "double"):
|
| 74 |
+
lines.append(f" {dt} {nm} = ({dt})1.5;")
|
| 75 |
+
else:
|
| 76 |
+
lines.append(f" {dt} {nm} = ({dt})3;")
|
| 77 |
+
call_args.append(nm)
|
| 78 |
+
call_expr = f'kernel({", ".join(call_args)})'
|
| 79 |
+
is_void = ret_type.strip() in ("void", "")
|
| 80 |
+
if not is_void:
|
| 81 |
+
lines.append(f" {ret_type.strip()} _result = {call_expr};")
|
| 82 |
+
if ret_type.strip() in ("float", "double"):
|
| 83 |
+
lines.append(' printf("%.10g\\n", (double)_result);')
|
| 84 |
+
elif "unsigned" in ret_type or "uint" in ret_type:
|
| 85 |
+
lines.append(' printf("%u\\n", _result);')
|
| 86 |
+
else:
|
| 87 |
+
lines.append(' printf("%ld\\n", (long)_result);')
|
| 88 |
+
else:
|
| 89 |
+
lines.append(f" {call_expr};")
|
| 90 |
+
for nm, dt in output_arrays:
|
| 91 |
+
if dt in ("float", "double"):
|
| 92 |
+
lines.append(
|
| 93 |
+
f' for (int i = 0; i < {_PRINT_N}; i++)'
|
| 94 |
+
f' printf("%.10g\\n", (double){nm}[i]);')
|
| 95 |
+
elif "unsigned" in dt or "uint" in dt:
|
| 96 |
+
lines.append(
|
| 97 |
+
f' for (int i = 0; i < {_PRINT_N}; i++)'
|
| 98 |
+
f' printf("%u\\n", {nm}[i]);')
|
| 99 |
+
else:
|
| 100 |
+
lines.append(
|
| 101 |
+
f' for (int i = 0; i < {_PRINT_N}; i++)'
|
| 102 |
+
f' printf("%ld\\n", (long){nm}[i]);')
|
| 103 |
+
lines += [" return 0;", "}", ""]
|
| 104 |
+
return "\n".join(lines)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
KERNELS = {
|
| 108 |
+
"vec_add": (
|
| 109 |
+
'#include <stddef.h>\n'
|
| 110 |
+
'void kernel(float * __restrict__ a, const float * __restrict__ b,\n'
|
| 111 |
+
' const float * __restrict__ c) {\n'
|
| 112 |
+
' for (size_t i = 0; i < 16; ++i) a[i] = b[i] + c[i];\n'
|
| 113 |
+
'}\n'
|
| 114 |
+
),
|
| 115 |
+
"dot": (
|
| 116 |
+
'#include <stddef.h>\n'
|
| 117 |
+
'float kernel(const float * __restrict__ a, const float * __restrict__ b) {\n'
|
| 118 |
+
' float s = 0;\n'
|
| 119 |
+
' for (size_t i = 0; i < 16; ++i) s += a[i] * b[i];\n'
|
| 120 |
+
' return s;\n'
|
| 121 |
+
'}\n'
|
| 122 |
+
),
|
| 123 |
+
"saxpy": (
|
| 124 |
+
'#include <stddef.h>\n'
|
| 125 |
+
'void kernel(float alpha, const float * __restrict__ x,\n'
|
| 126 |
+
' float * __restrict__ y) {\n'
|
| 127 |
+
' for (size_t i = 0; i < 32; ++i) y[i] = alpha * x[i] + y[i];\n'
|
| 128 |
+
'}\n'
|
| 129 |
+
),
|
| 130 |
+
"popcount": (
|
| 131 |
+
'#include <stddef.h>\n'
|
| 132 |
+
'#include <stdint.h>\n'
|
| 133 |
+
'unsigned kernel(const uint32_t * __restrict__ x) {\n'
|
| 134 |
+
' unsigned c = 0;\n'
|
| 135 |
+
' for (size_t i = 0; i < 16; i++) {\n'
|
| 136 |
+
' uint32_t v = x[i]; while (v) { c += v & 1; v >>= 1; }\n'
|
| 137 |
+
' }\n'
|
| 138 |
+
' return c;\n'
|
| 139 |
+
'}\n'
|
| 140 |
+
),
|
| 141 |
+
"clip": (
|
| 142 |
+
'#include <stddef.h>\n'
|
| 143 |
+
'void kernel(float * __restrict__ y, const float * __restrict__ x,\n'
|
| 144 |
+
' float lo, float hi) {\n'
|
| 145 |
+
' for (size_t i = 0; i < 16; ++i) {\n'
|
| 146 |
+
' float v = x[i]; y[i] = v < lo ? lo : (v > hi ? hi : v);\n'
|
| 147 |
+
' }\n'
|
| 148 |
+
'}\n'
|
| 149 |
+
),
|
| 150 |
+
"layernorm": (
|
| 151 |
+
'#include <stddef.h>\n'
|
| 152 |
+
'#include <math.h>\n'
|
| 153 |
+
'void kernel(float * __restrict__ y, const float * __restrict__ x,\n'
|
| 154 |
+
' const float * __restrict__ gamma, const float * __restrict__ beta) {\n'
|
| 155 |
+
' float mean = 0; for (size_t i = 0; i < 16; ++i) mean += x[i];\n'
|
| 156 |
+
' mean /= (float)16;\n'
|
| 157 |
+
' float var = 0; for (size_t i = 0; i < 16; ++i) { float d = x[i] - mean; var += d*d; }\n'
|
| 158 |
+
' var /= (float)16;\n'
|
| 159 |
+
' float inv = (float)1.0 / sqrtf(var + (float)1e-5);\n'
|
| 160 |
+
' for (size_t i = 0; i < 16; ++i) y[i] = (x[i] - mean) * inv * gamma[i] + beta[i];\n'
|
| 161 |
+
'}\n'
|
| 162 |
+
),
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
ok, fail = 0, 0
|
| 166 |
+
for name, src in KERNELS.items():
|
| 167 |
+
for mode_name, with_def in [("ref", True), ("extern", False)]:
|
| 168 |
+
code = generate_test_harness(src, with_kernel_def=with_def)
|
| 169 |
+
with tempfile.NamedTemporaryFile(suffix=".c", mode="w", delete=False) as f:
|
| 170 |
+
f.write(code)
|
| 171 |
+
path = f.name
|
| 172 |
+
r = subprocess.run(["cc", "-fsyntax-only", "-Wno-everything", path],
|
| 173 |
+
capture_output=True, text=True)
|
| 174 |
+
os.unlink(path)
|
| 175 |
+
tag = f"{name}/{mode_name}"
|
| 176 |
+
if r.returncode == 0:
|
| 177 |
+
print(f" OK {tag}")
|
| 178 |
+
ok += 1
|
| 179 |
+
else:
|
| 180 |
+
print(f" FAIL {tag}: {r.stderr[:200]}")
|
| 181 |
+
fail += 1
|
| 182 |
+
|
| 183 |
+
print(f"\n{ok} passed, {fail} failed")
|
| 184 |
+
if fail:
|
| 185 |
+
sys.exit(1)
|
| 186 |
+
|
| 187 |
+
print("\n--- Host compile+run test ---")
|
| 188 |
+
for name, src in KERNELS.items():
|
| 189 |
+
code = generate_test_harness(src, with_kernel_def=True)
|
| 190 |
+
with tempfile.NamedTemporaryFile(suffix=".c", mode="w", delete=False) as f:
|
| 191 |
+
f.write(code)
|
| 192 |
+
cpath = f.name
|
| 193 |
+
elf = cpath.replace(".c", "")
|
| 194 |
+
r = subprocess.run(["cc", "-O3", "-o", elf, cpath, "-lm"],
|
| 195 |
+
capture_output=True, text=True)
|
| 196 |
+
os.unlink(cpath)
|
| 197 |
+
if r.returncode != 0:
|
| 198 |
+
print(f" COMPILE FAIL {name}: {r.stderr[:200]}")
|
| 199 |
+
continue
|
| 200 |
+
r = subprocess.run([elf], capture_output=True, text=True, timeout=5)
|
| 201 |
+
os.unlink(elf)
|
| 202 |
+
if r.returncode != 0:
|
| 203 |
+
print(f" RUN FAIL {name}: exit={r.returncode}")
|
| 204 |
+
continue
|
| 205 |
+
lines_out = r.stdout.strip().split("\n")
|
| 206 |
+
print(f" OK {name}: {len(lines_out)} output lines, first={lines_out[0]}")
|
| 207 |
+
|
| 208 |
+
print("\nAll smoke tests passed.")
|
hf/v6_train.py
ADDED
|
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""ARM-Gym V6 β Qwen2.5-Coder-7B bf16+LoRA GRPO with harness-based correctness.
|
| 3 |
+
|
| 4 |
+
V6 changes over V5e:
|
| 5 |
+
- CRITICAL FIX: correctness_reward now actually works. Generates a C test
|
| 6 |
+
harness with main(), compiles reference + candidate as complete ELFs,
|
| 7 |
+
runs both under QEMU, and compares stdout (float-tolerant).
|
| 8 |
+
- Previously: bare kernel .o had no entry point β crashed under QEMU β 0.0.
|
| 9 |
+
|
| 10 |
+
Expects: torch, transformers, trl, peft, datasets, arm_gym pre-installed.
|
| 11 |
+
System: clang-21, llvm-mca-21, aarch64-linux-gnu-{as,gcc}, qemu-aarch64-static.
|
| 12 |
+
"""
|
| 13 |
+
import csv, hashlib, json, logging, os, re, subprocess, sys, tempfile, threading, time
|
| 14 |
+
from dataclasses import asdict, dataclass
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
level=logging.INFO,
|
| 19 |
+
format="[%(asctime)s] %(levelname)s %(message)s",
|
| 20 |
+
datefmt="%H:%M:%S",
|
| 21 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 22 |
+
)
|
| 23 |
+
log = logging.getLogger("v6")
|
| 24 |
+
log.info("==== ARM-Gym V6 β Qwen2.5 bf16+LoRA GRPO (harness correctness) ====")
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
log.info("torch %s CUDA %s GPU %s",
|
| 28 |
+
torch.__version__,
|
| 29 |
+
torch.version.cuda if torch.cuda.is_available() else "N/A",
|
| 30 |
+
torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")
|
| 31 |
+
|
| 32 |
+
import transformers, trl, peft
|
| 33 |
+
log.info("transformers %s trl %s peft %s",
|
| 34 |
+
transformers.__version__, trl.__version__, peft.__version__)
|
| 35 |
+
|
| 36 |
+
from arm_gym.compile_baseline import detect_toolchain, compile_to_asm
|
| 37 |
+
from arm_gym.kernels import TEMPLATES, generate_all, split_train_eval
|
| 38 |
+
tc = detect_toolchain()
|
| 39 |
+
|
| 40 |
+
# Image gcc is too old for neoverse-v3. Force clang-21 (verified working).
|
| 41 |
+
if tc.clang and tc.gcc_aarch64:
|
| 42 |
+
log.info("Disabling gcc (%s) β neoverse-v3 unsupported; using clang-21 instead",
|
| 43 |
+
tc.gcc_aarch64)
|
| 44 |
+
tc.gcc_aarch64 = None
|
| 45 |
+
log.info("ARM toolchain: clang=%s gcc=%s mca=%s mcpu=%s disclosed=%s",
|
| 46 |
+
tc.clang, tc.gcc_aarch64, tc.mca, tc.mcpu, tc.mcpu_disclosed)
|
| 47 |
+
|
| 48 |
+
SMOKE_C = '#include <stddef.h>\nvoid f(int *a, int *b) { for(size_t i=0;i<4;i++) a[i]+=b[i]; }\n'
|
| 49 |
+
try:
|
| 50 |
+
smoke_asm = compile_to_asm(SMOKE_C, tc)
|
| 51 |
+
log.info("Smoke compile OK (%d bytes asm)", len(smoke_asm))
|
| 52 |
+
except Exception as e:
|
| 53 |
+
log.error("SMOKE COMPILE FAILED: %s", e)
|
| 54 |
+
log.error("This means ALL dataset rows will be empty. Aborting early.")
|
| 55 |
+
sys.exit(1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ββ CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
# ARMGYM_PROFILE=mvp (default on main/dev): 50 steps β hub ZDC-M01/arm-gym-mvp-50
|
| 60 |
+
# ARMGYM_PROFILE=long (200-step run): 200 steps β hub ZDC-M01/arm-gym-train-200
|
| 61 |
+
# On the `training/hf-200` branch, use `source hf/run_long_train.sh` before python.
|
| 62 |
+
@dataclass
|
| 63 |
+
class Cfg:
|
| 64 |
+
model_id: str = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
| 65 |
+
hub_model_id: str = "ZDC-M01/arm-gym-mvp-50"
|
| 66 |
+
steps: int = 50
|
| 67 |
+
num_generations: int = 6
|
| 68 |
+
gradient_accumulation_steps: int = 8
|
| 69 |
+
per_device_train_batch_size: int = 1
|
| 70 |
+
lora_rank: int = 24
|
| 71 |
+
lora_alpha: int = 48
|
| 72 |
+
learning_rate: float = 5e-6
|
| 73 |
+
max_prompt_length: int = 2048
|
| 74 |
+
max_completion_length: int = 640
|
| 75 |
+
temperature: float = 0.7
|
| 76 |
+
difficulty_max: int = 1
|
| 77 |
+
max_train: int = 64
|
| 78 |
+
max_eval: int = 16
|
| 79 |
+
warmup_steps: int = 10
|
| 80 |
+
out_dir: str = "runs/v6-mvp"
|
| 81 |
+
save_steps: int = 25
|
| 82 |
+
|
| 83 |
+
def _default_profile() -> str:
|
| 84 |
+
if p := os.environ.get("ARMGYM_PROFILE"):
|
| 85 |
+
return p.lower().strip()
|
| 86 |
+
pfile = Path(__file__).resolve().with_name("PROFILE")
|
| 87 |
+
if pfile.is_file():
|
| 88 |
+
t = pfile.read_text().strip().lower()
|
| 89 |
+
if t in ("mvp", "long"):
|
| 90 |
+
return t
|
| 91 |
+
return "mvp"
|
| 92 |
+
|
| 93 |
+
def _apply_profile(cfg: Cfg) -> Cfg:
|
| 94 |
+
p = _default_profile()
|
| 95 |
+
if p == "long":
|
| 96 |
+
cfg.hub_model_id = "ZDC-M01/arm-gym-train-200"
|
| 97 |
+
cfg.steps = 200
|
| 98 |
+
cfg.out_dir = "runs/v6-200"
|
| 99 |
+
cfg.save_steps = 50
|
| 100 |
+
return cfg
|
| 101 |
+
|
| 102 |
+
cfg = _apply_profile(Cfg())
|
| 103 |
+
log.info("Profile: %s hub=%s out=%s steps=%d",
|
| 104 |
+
_default_profile(), cfg.hub_model_id, cfg.out_dir, cfg.steps)
|
| 105 |
+
log.info("Config: model=%s steps=%d G=%d temp=%.1f",
|
| 106 |
+
cfg.model_id, cfg.steps, cfg.num_generations, cfg.temperature)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ββ DATASET β SuperCoder A.3 prompt ββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
+
SYSTEM_PROMPT = (
|
| 111 |
+
"You are an expert AArch64 (aarch64-linux-gnu-gcc) assembly writer. "
|
| 112 |
+
"Obey the user block exactly. Output only what is asked in the required tags."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def user_prompt(c_source: str, baseline_asm: str) -> str:
|
| 116 |
+
return (
|
| 117 |
+
"Given the following C code and assembly code, your task is to generate "
|
| 118 |
+
"highly optimized AArch64 assembly code.\n\n"
|
| 119 |
+
f"C Code:\n{c_source}\n\n"
|
| 120 |
+
f"Assembly Code:\n{baseline_asm}\n\n"
|
| 121 |
+
"Only output the optimized assembly code. Do not include any other text. "
|
| 122 |
+
"Do not write any comments in the assembly code. "
|
| 123 |
+
"Wrap the assembly code in <assembly></assembly> tags.\n\n"
|
| 124 |
+
"Optimized Assembly Code:\n"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def build_dataset(cfg, tok):
|
| 128 |
+
from datasets import Dataset
|
| 129 |
+
vs = [v for v in generate_all()
|
| 130 |
+
if TEMPLATES[v.template_name].difficulty <= cfg.difficulty_max]
|
| 131 |
+
log.info("Kernel variants (difficulty<=%d): %d", cfg.difficulty_max, len(vs))
|
| 132 |
+
tv, ev = split_train_eval(vs, eval_frac=0.1, seed=0)
|
| 133 |
+
tv, ev = tv[:cfg.max_train], ev[:cfg.max_eval]
|
| 134 |
+
|
| 135 |
+
_compile_fails = [0]
|
| 136 |
+
def row(v):
|
| 137 |
+
try:
|
| 138 |
+
basm = compile_to_asm(v.c_source, tc)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
_compile_fails[0] += 1
|
| 141 |
+
if _compile_fails[0] <= 3:
|
| 142 |
+
log.warning("compile_to_asm failed [%d]: %s", _compile_fails[0], e)
|
| 143 |
+
return None
|
| 144 |
+
msgs = [
|
| 145 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 146 |
+
{"role": "user", "content": user_prompt(v.c_source, basm)},
|
| 147 |
+
]
|
| 148 |
+
for kw in (
|
| 149 |
+
{"tokenize": False, "add_generation_prompt": True,
|
| 150 |
+
"enable_thinking": False},
|
| 151 |
+
{"tokenize": False, "add_generation_prompt": True},
|
| 152 |
+
):
|
| 153 |
+
try:
|
| 154 |
+
prompt = tok.apply_chat_template(msgs, **kw)
|
| 155 |
+
break
|
| 156 |
+
except TypeError:
|
| 157 |
+
continue
|
| 158 |
+
else:
|
| 159 |
+
prompt = f"{SYSTEM_PROMPT}\n\n{msgs[1]['content']}\n"
|
| 160 |
+
return {"prompt": prompt, "variant_id": v.variant_id,
|
| 161 |
+
"baseline_asm": basm, "c_source": v.c_source}
|
| 162 |
+
|
| 163 |
+
tr = [r for v in tv if (r := row(v)) is not None]
|
| 164 |
+
er = [r for v in ev if (r := row(v)) is not None]
|
| 165 |
+
log.info("Dataset: train=%d eval=%d", len(tr), len(er))
|
| 166 |
+
if tr:
|
| 167 |
+
toks = tok(tr[0]["prompt"], return_tensors="pt")
|
| 168 |
+
log.info("Sample prompt tokens: %d", toks["input_ids"].shape[1])
|
| 169 |
+
log.info("Prompt tail: %r", tr[0]["prompt"][-200:])
|
| 170 |
+
return Dataset.from_list(tr), Dataset.from_list(er)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ββ REWARD FUNCTIONS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 174 |
+
_ASM_RE = re.compile(r"<assembly>(.*?)</assembly>", re.DOTALL | re.IGNORECASE)
|
| 175 |
+
_THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
|
| 176 |
+
_CACHE: dict = {}
|
| 177 |
+
_BASELINE: dict = {}
|
| 178 |
+
_LOCK = threading.Lock()
|
| 179 |
+
_VCFG = None
|
| 180 |
+
|
| 181 |
+
def _clean_asm_directives(asm: str) -> str:
|
| 182 |
+
"""Strip directives that GNU as rejects but clang emits."""
|
| 183 |
+
lines = []
|
| 184 |
+
for line in asm.splitlines():
|
| 185 |
+
stripped = line.strip()
|
| 186 |
+
if stripped.startswith(".addrsig") or stripped.startswith(".ident"):
|
| 187 |
+
continue
|
| 188 |
+
lines.append(line)
|
| 189 |
+
asm = "\n".join(lines)
|
| 190 |
+
if ".arch " not in asm:
|
| 191 |
+
asm = "\t.arch armv9-a+sve2+crc\n" + asm
|
| 192 |
+
return asm
|
| 193 |
+
|
| 194 |
+
def _extract_asm(text):
|
| 195 |
+
text = _THINK_RE.sub("", text).strip()
|
| 196 |
+
m = _ASM_RE.search(text)
|
| 197 |
+
if m:
|
| 198 |
+
return _clean_asm_directives(m.group(1).strip())
|
| 199 |
+
if "</assembly>" in text.lower():
|
| 200 |
+
body = re.split(r"</assembly>", text, flags=re.IGNORECASE)[0]
|
| 201 |
+
if "<assembly>" in body.lower():
|
| 202 |
+
body = re.split(r"<assembly>", body, flags=re.IGNORECASE)[-1]
|
| 203 |
+
return _clean_asm_directives(body.strip())
|
| 204 |
+
if "<assembly>" in text.lower():
|
| 205 |
+
return _clean_asm_directives(
|
| 206 |
+
re.split(r"<assembly>", text, flags=re.IGNORECASE)[-1].strip())
|
| 207 |
+
return _clean_asm_directives(text.strip())
|
| 208 |
+
|
| 209 |
+
def _vcfg():
|
| 210 |
+
global _VCFG
|
| 211 |
+
if _VCFG is None:
|
| 212 |
+
from arm_gym.verifier import VerifierConfig
|
| 213 |
+
_VCFG = VerifierConfig(
|
| 214 |
+
mca_bin=tc.mca or "llvm-mca", assembler="aarch64-linux-gnu-as",
|
| 215 |
+
linker="aarch64-linux-gnu-ld", qemu="qemu-aarch64-static",
|
| 216 |
+
mcpu=tc.mcpu)
|
| 217 |
+
return _VCFG
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ββ CORRECTNESS HARNESS ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 221 |
+
_SIG_RE = re.compile(r"([\w\s*]+?\bkernel\s*\([^)]*\))", re.DOTALL)
|
| 222 |
+
_HARNESS_DIR: Path | None = None
|
| 223 |
+
_REF_ELF_CACHE: dict[str, Path | None] = {}
|
| 224 |
+
_HARNESS_OBJ_CACHE: dict[str, Path | None] = {}
|
| 225 |
+
_CORRECTNESS_GCC = "aarch64-linux-gnu-gcc"
|
| 226 |
+
_CORRECTNESS_QEMU = "qemu-aarch64-static"
|
| 227 |
+
_ARR_SZ = 8192
|
| 228 |
+
_PRINT_N = 64
|
| 229 |
+
|
| 230 |
+
def _harness_dir() -> Path:
|
| 231 |
+
global _HARNESS_DIR
|
| 232 |
+
if _HARNESS_DIR is None:
|
| 233 |
+
_HARNESS_DIR = Path(tempfile.mkdtemp(prefix="armgym_harness_"))
|
| 234 |
+
return _HARNESS_DIR
|
| 235 |
+
|
| 236 |
+
def _parse_kernel_sig(c_source: str):
|
| 237 |
+
m = _SIG_RE.search(c_source)
|
| 238 |
+
if not m:
|
| 239 |
+
return "void", [], None
|
| 240 |
+
full_proto = " ".join(m.group(1).split())
|
| 241 |
+
paren = full_proto.index("(")
|
| 242 |
+
ret_and_name = full_proto[:paren].strip()
|
| 243 |
+
ret_type = ret_and_name.rsplit("kernel", 1)[0].strip() or "void"
|
| 244 |
+
raw_params_str = full_proto[paren + 1:].rstrip(")")
|
| 245 |
+
params = []
|
| 246 |
+
for p in raw_params_str.split(","):
|
| 247 |
+
p = p.strip()
|
| 248 |
+
if not p:
|
| 249 |
+
continue
|
| 250 |
+
is_const = "const " in p
|
| 251 |
+
is_ptr = "*" in p
|
| 252 |
+
clean = p.replace("const", "").replace("__restrict__", "").replace("*", "").strip()
|
| 253 |
+
parts = clean.split()
|
| 254 |
+
dtype = parts[0] if parts else "int"
|
| 255 |
+
name = parts[-1] if len(parts) > 1 else f"p{len(params)}"
|
| 256 |
+
params.append({"dtype": dtype, "name": name, "is_ptr": is_ptr,
|
| 257 |
+
"is_const": is_const, "raw": p})
|
| 258 |
+
return ret_type, params, full_proto
|
| 259 |
+
|
| 260 |
+
def generate_test_harness(c_source: str, with_kernel_def: bool) -> str:
|
| 261 |
+
ret_type, params, full_proto = _parse_kernel_sig(c_source)
|
| 262 |
+
lines = ["#include <stdio.h>", "#include <stdlib.h>", "#include <string.h>",
|
| 263 |
+
"#include <stddef.h>", "#include <stdint.h>", "#include <math.h>", ""]
|
| 264 |
+
if with_kernel_def:
|
| 265 |
+
lines.append(c_source)
|
| 266 |
+
else:
|
| 267 |
+
if full_proto:
|
| 268 |
+
lines.append(f"extern {full_proto};")
|
| 269 |
+
lines += ["", "int main(void) {"]
|
| 270 |
+
call_args = []
|
| 271 |
+
output_arrays = []
|
| 272 |
+
for p in params:
|
| 273 |
+
dt, nm = p["dtype"], p["name"]
|
| 274 |
+
if p["is_ptr"]:
|
| 275 |
+
lines.append(f" static {dt} {nm}[{_ARR_SZ}];")
|
| 276 |
+
if p["is_const"]:
|
| 277 |
+
if dt in ("float", "double"):
|
| 278 |
+
lines.append(
|
| 279 |
+
f" for (int i = 0; i < {_ARR_SZ}; i++)"
|
| 280 |
+
f" {nm}[i] = ({dt})((i % 17) + 1) * ({dt})0.25;")
|
| 281 |
+
elif "uint" in dt:
|
| 282 |
+
lines.append(
|
| 283 |
+
f" for (int i = 0; i < {_ARR_SZ}; i++)"
|
| 284 |
+
f" {nm}[i] = ({dt})((i % 31) + 1);")
|
| 285 |
+
else:
|
| 286 |
+
lines.append(
|
| 287 |
+
f" for (int i = 0; i < {_ARR_SZ}; i++)"
|
| 288 |
+
f" {nm}[i] = ({dt})((i % 17) + 1);")
|
| 289 |
+
else:
|
| 290 |
+
lines.append(f" memset({nm}, 0, sizeof({nm}));")
|
| 291 |
+
output_arrays.append((nm, dt))
|
| 292 |
+
call_args.append(nm)
|
| 293 |
+
else:
|
| 294 |
+
if dt in ("float", "double"):
|
| 295 |
+
lines.append(f" {dt} {nm} = ({dt})1.5;")
|
| 296 |
+
else:
|
| 297 |
+
lines.append(f" {dt} {nm} = ({dt})3;")
|
| 298 |
+
call_args.append(nm)
|
| 299 |
+
call_expr = f'kernel({", ".join(call_args)})'
|
| 300 |
+
is_void = ret_type.strip() in ("void", "")
|
| 301 |
+
if not is_void:
|
| 302 |
+
lines.append(f" {ret_type.strip()} _result = {call_expr};")
|
| 303 |
+
if ret_type.strip() in ("float", "double"):
|
| 304 |
+
lines.append(' printf("%.10g\\n", (double)_result);')
|
| 305 |
+
elif "unsigned" in ret_type or "uint" in ret_type:
|
| 306 |
+
lines.append(' printf("%u\\n", _result);')
|
| 307 |
+
else:
|
| 308 |
+
lines.append(' printf("%ld\\n", (long)_result);')
|
| 309 |
+
else:
|
| 310 |
+
lines.append(f" {call_expr};")
|
| 311 |
+
for nm, dt in output_arrays:
|
| 312 |
+
if dt in ("float", "double"):
|
| 313 |
+
lines.append(
|
| 314 |
+
f' for (int i = 0; i < {_PRINT_N}; i++)'
|
| 315 |
+
f' printf("%.10g\\n", (double){nm}[i]);')
|
| 316 |
+
elif "unsigned" in dt or "uint" in dt:
|
| 317 |
+
lines.append(
|
| 318 |
+
f' for (int i = 0; i < {_PRINT_N}; i++)'
|
| 319 |
+
f' printf("%u\\n", {nm}[i]);')
|
| 320 |
+
else:
|
| 321 |
+
lines.append(
|
| 322 |
+
f' for (int i = 0; i < {_PRINT_N}; i++)'
|
| 323 |
+
f' printf("%ld\\n", (long){nm}[i]);')
|
| 324 |
+
lines += [" return 0;", "}", ""]
|
| 325 |
+
return "\n".join(lines)
|
| 326 |
+
|
| 327 |
+
def _get_reference_elf(vid: str, c_source: str) -> Path | None:
|
| 328 |
+
if vid in _REF_ELF_CACHE:
|
| 329 |
+
p = _REF_ELF_CACHE[vid]
|
| 330 |
+
if p and p.exists():
|
| 331 |
+
return p
|
| 332 |
+
if p is None:
|
| 333 |
+
return None
|
| 334 |
+
d = _harness_dir() / f"ref_{vid}"
|
| 335 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 336 |
+
src = d / "combined.c"
|
| 337 |
+
src.write_text(generate_test_harness(c_source, with_kernel_def=True))
|
| 338 |
+
elf = d / "ref.elf"
|
| 339 |
+
r = subprocess.run(
|
| 340 |
+
[_CORRECTNESS_GCC, "-O3", "-static", "-o", str(elf), str(src), "-lm"],
|
| 341 |
+
capture_output=True, text=True, timeout=30)
|
| 342 |
+
if r.returncode != 0:
|
| 343 |
+
log.warning("[harness] ref compile failed vid=%s: %s", vid, r.stderr[:300])
|
| 344 |
+
_REF_ELF_CACHE[vid] = None
|
| 345 |
+
return None
|
| 346 |
+
_REF_ELF_CACHE[vid] = elf
|
| 347 |
+
return elf
|
| 348 |
+
|
| 349 |
+
def _get_harness_obj(vid: str, c_source: str) -> Path | None:
|
| 350 |
+
if vid in _HARNESS_OBJ_CACHE:
|
| 351 |
+
p = _HARNESS_OBJ_CACHE[vid]
|
| 352 |
+
if p and p.exists():
|
| 353 |
+
return p
|
| 354 |
+
if p is None:
|
| 355 |
+
return None
|
| 356 |
+
d = _harness_dir() / f"harn_{vid}"
|
| 357 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 358 |
+
src = d / "harness.c"
|
| 359 |
+
src.write_text(generate_test_harness(c_source, with_kernel_def=False))
|
| 360 |
+
obj = d / "harness.o"
|
| 361 |
+
r = subprocess.run(
|
| 362 |
+
[_CORRECTNESS_GCC, "-c", "-o", str(obj), str(src)],
|
| 363 |
+
capture_output=True, text=True, timeout=30)
|
| 364 |
+
if r.returncode != 0:
|
| 365 |
+
log.warning("[harness] harness compile failed vid=%s: %s", vid, r.stderr[:300])
|
| 366 |
+
_HARNESS_OBJ_CACHE[vid] = None
|
| 367 |
+
return None
|
| 368 |
+
_HARNESS_OBJ_CACHE[vid] = obj
|
| 369 |
+
return obj
|
| 370 |
+
|
| 371 |
+
def _link_candidate_elf(cand_obj: Path, vid: str, c_source: str) -> Path | None:
|
| 372 |
+
harness_obj = _get_harness_obj(vid, c_source)
|
| 373 |
+
if harness_obj is None:
|
| 374 |
+
return None
|
| 375 |
+
tag = hashlib.md5(str(cand_obj).encode()).hexdigest()[:8]
|
| 376 |
+
d = _harness_dir() / f"cand_{vid}_{tag}"
|
| 377 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 378 |
+
elf = d / "cand.elf"
|
| 379 |
+
r = subprocess.run(
|
| 380 |
+
[_CORRECTNESS_GCC, "-static", "-o", str(elf),
|
| 381 |
+
str(harness_obj), str(cand_obj), "-lm", "-lc"],
|
| 382 |
+
capture_output=True, text=True, timeout=30)
|
| 383 |
+
if r.returncode != 0:
|
| 384 |
+
return None
|
| 385 |
+
return elf
|
| 386 |
+
|
| 387 |
+
def _qemu_stdout(elf: Path) -> str | None:
|
| 388 |
+
try:
|
| 389 |
+
r = subprocess.run(
|
| 390 |
+
[_CORRECTNESS_QEMU, str(elf)],
|
| 391 |
+
capture_output=True, text=True, timeout=10)
|
| 392 |
+
return r.stdout if r.returncode == 0 else None
|
| 393 |
+
except (subprocess.TimeoutExpired, OSError):
|
| 394 |
+
return None
|
| 395 |
+
|
| 396 |
+
def _outputs_match(out_a: str, out_b: str, rtol: float = 1e-4,
|
| 397 |
+
atol: float = 1e-6) -> bool:
|
| 398 |
+
ta = out_a.strip().split()
|
| 399 |
+
tb = out_b.strip().split()
|
| 400 |
+
if len(ta) != len(tb):
|
| 401 |
+
return False
|
| 402 |
+
for a, b in zip(ta, tb):
|
| 403 |
+
if a == b:
|
| 404 |
+
continue
|
| 405 |
+
try:
|
| 406 |
+
fa, fb = float(a), float(b)
|
| 407 |
+
except ValueError:
|
| 408 |
+
return False
|
| 409 |
+
if fa != fa and fb != fb:
|
| 410 |
+
continue
|
| 411 |
+
diff = abs(fa - fb)
|
| 412 |
+
tol = atol + rtol * max(abs(fa), abs(fb))
|
| 413 |
+
if diff > tol:
|
| 414 |
+
return False
|
| 415 |
+
return True
|
| 416 |
+
|
| 417 |
+
def _bcy(vid, basm):
|
| 418 |
+
if vid not in _BASELINE:
|
| 419 |
+
try:
|
| 420 |
+
from arm_gym.mca import run_mca
|
| 421 |
+
_BASELINE[vid] = float(run_mca(basm, _vcfg().mca_bin,
|
| 422 |
+
_vcfg().mcpu).total_cycles)
|
| 423 |
+
except Exception:
|
| 424 |
+
_BASELINE[vid] = 1000.0
|
| 425 |
+
return _BASELINE[vid]
|
| 426 |
+
|
| 427 |
+
@dataclass
|
| 428 |
+
class _E:
|
| 429 |
+
assembles: bool = False
|
| 430 |
+
runs: bool = False
|
| 431 |
+
speedup: float = 0.0
|
| 432 |
+
|
| 433 |
+
_DBG_N = 0
|
| 434 |
+
_DBG_MAX = 5
|
| 435 |
+
|
| 436 |
+
def _eval(text, vid, basm, c_src=""):
|
| 437 |
+
global _DBG_N
|
| 438 |
+
k = hashlib.md5(f"{text}::{vid}".encode()).hexdigest()
|
| 439 |
+
with _LOCK:
|
| 440 |
+
if k in _CACHE:
|
| 441 |
+
return _CACHE[k]
|
| 442 |
+
e = _E()
|
| 443 |
+
asm = _extract_asm(text)
|
| 444 |
+
from arm_gym.verifier import assemble, cleanup_temp_dirs
|
| 445 |
+
try:
|
| 446 |
+
obj, err = assemble(asm, _vcfg())
|
| 447 |
+
if err or obj is None:
|
| 448 |
+
if _DBG_N < _DBG_MAX:
|
| 449 |
+
with _LOCK:
|
| 450 |
+
if _DBG_N < _DBG_MAX:
|
| 451 |
+
_DBG_N += 1
|
| 452 |
+
log.info("[reward-dbg #%d] vid=%s err=%r",
|
| 453 |
+
_DBG_N, vid,
|
| 454 |
+
err.message[:200] if err else "None")
|
| 455 |
+
log.info("[reward-dbg] raw[:300]=%r", text[:300])
|
| 456 |
+
log.info("[reward-dbg] asm[:300]=%r", asm[:300])
|
| 457 |
+
cleanup_temp_dirs()
|
| 458 |
+
with _LOCK:
|
| 459 |
+
_CACHE[k] = e
|
| 460 |
+
return e
|
| 461 |
+
e.assembles = True
|
| 462 |
+
if c_src:
|
| 463 |
+
ref_elf = _get_reference_elf(vid, c_src)
|
| 464 |
+
cand_elf = _link_candidate_elf(obj, vid, c_src)
|
| 465 |
+
if ref_elf and cand_elf:
|
| 466 |
+
ref_out = _qemu_stdout(ref_elf)
|
| 467 |
+
cand_out = _qemu_stdout(cand_elf)
|
| 468 |
+
if ref_out is not None and cand_out is not None:
|
| 469 |
+
e.runs = _outputs_match(cand_out, ref_out)
|
| 470 |
+
if e.runs:
|
| 471 |
+
log.info("[correctness] PASS vid=%s", vid)
|
| 472 |
+
elif _DBG_N < _DBG_MAX:
|
| 473 |
+
with _LOCK:
|
| 474 |
+
if _DBG_N < _DBG_MAX:
|
| 475 |
+
_DBG_N += 1
|
| 476 |
+
log.info("[correctness] FAIL vid=%s "
|
| 477 |
+
"ref=%r cand=%r",
|
| 478 |
+
vid, ref_out[:200], cand_out[:200])
|
| 479 |
+
if e.runs:
|
| 480 |
+
from arm_gym.mca import run_mca
|
| 481 |
+
bc = _bcy(vid, basm)
|
| 482 |
+
rep = run_mca(asm, _vcfg().mca_bin, _vcfg().mcpu)
|
| 483 |
+
e.speedup = bc / max(rep.total_cycles, 1)
|
| 484 |
+
cleanup_temp_dirs()
|
| 485 |
+
except Exception as ex:
|
| 486 |
+
log.debug("eval err vid=%s: %s", vid, ex)
|
| 487 |
+
try:
|
| 488 |
+
cleanup_temp_dirs()
|
| 489 |
+
except Exception:
|
| 490 |
+
pass
|
| 491 |
+
with _LOCK:
|
| 492 |
+
_CACHE[k] = e
|
| 493 |
+
return e
|
| 494 |
+
|
| 495 |
+
def _prep(completions, kw):
|
| 496 |
+
texts = [c[-1]["content"] if isinstance(c, list) else str(c)
|
| 497 |
+
for c in (completions or [])]
|
| 498 |
+
n = len(texts)
|
| 499 |
+
vids = list(kw.get("variant_id") or [""] * n)
|
| 500 |
+
bs = list(kw.get("baseline_asm") or [""] * n)
|
| 501 |
+
cs = list(kw.get("c_source") or [""] * n)
|
| 502 |
+
if len(vids) == 1 and n > 1:
|
| 503 |
+
vids *= n
|
| 504 |
+
bs *= n
|
| 505 |
+
cs *= n
|
| 506 |
+
return texts, vids, bs, cs
|
| 507 |
+
|
| 508 |
+
_FMT_DBG_N = 0
|
| 509 |
+
_FMT_DBG_MAX = 8
|
| 510 |
+
|
| 511 |
+
def format_reward(prompts=None, completions=None, **kw):
|
| 512 |
+
global _FMT_DBG_N
|
| 513 |
+
texts, _, _, _ = _prep(completions, kw)
|
| 514 |
+
if _FMT_DBG_N < _FMT_DBG_MAX and texts:
|
| 515 |
+
_FMT_DBG_N += 1
|
| 516 |
+
h = texts[0]
|
| 517 |
+
log.info("[completion #%d] len=%d first500=%r", _FMT_DBG_N, len(h), h[:500])
|
| 518 |
+
log.info("[completion #%d] last200=%r", _FMT_DBG_N, h[-200:])
|
| 519 |
+
scores = []
|
| 520 |
+
for text in texts:
|
| 521 |
+
low = text.lower()
|
| 522 |
+
s = 0.0
|
| 523 |
+
if "<assembly>" in low:
|
| 524 |
+
s += 0.3
|
| 525 |
+
if "</assembly>" in low:
|
| 526 |
+
s += 0.3
|
| 527 |
+
body = _extract_asm(text)
|
| 528 |
+
if len(re.sub(r"\s", "", body)) >= 20:
|
| 529 |
+
s += 0.4
|
| 530 |
+
if any(m in low for m in ("```", "<think>", "explain", "analysis")):
|
| 531 |
+
s -= 0.5
|
| 532 |
+
scores.append(max(-0.5, s))
|
| 533 |
+
return scores
|
| 534 |
+
|
| 535 |
+
def syntax_reward(prompts=None, completions=None, **kw):
|
| 536 |
+
t, v, b, c = _prep(completions, kw)
|
| 537 |
+
return [3.0 if _eval(x, vi, bi, ci).assembles else 0.0
|
| 538 |
+
for x, vi, bi, ci in zip(t, v, b, c)]
|
| 539 |
+
|
| 540 |
+
def correctness_reward(prompts=None, completions=None, **kw):
|
| 541 |
+
t, v, b, c = _prep(completions, kw)
|
| 542 |
+
return [5.0 if _eval(x, vi, bi, ci).runs else 0.0
|
| 543 |
+
for x, vi, bi, ci in zip(t, v, b, c)]
|
| 544 |
+
|
| 545 |
+
def speedup_reward(prompts=None, completions=None, **kw):
|
| 546 |
+
t, v, b, c = _prep(completions, kw)
|
| 547 |
+
return [max(0.0, _eval(x, vi, bi, ci).speedup - 1.0)
|
| 548 |
+
if _eval(x, vi, bi, ci).runs else 0.0
|
| 549 |
+
for x, vi, bi, ci in zip(t, v, b, c)]
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# ββ MODEL LOADING (bf16 + LoRA β no Unsloth) βββββββββββββββββββββββββββββββββ
|
| 553 |
+
QWEN25_EOS_IDS = (151645, 151643)
|
| 554 |
+
|
| 555 |
+
def load_model(cfg):
|
| 556 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 557 |
+
|
| 558 |
+
tok = AutoTokenizer.from_pretrained(cfg.model_id)
|
| 559 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 560 |
+
cfg.model_id,
|
| 561 |
+
dtype=torch.bfloat16,
|
| 562 |
+
device_map={"": 0},
|
| 563 |
+
attn_implementation="eager",
|
| 564 |
+
)
|
| 565 |
+
log.info("Model dtype: %s device: %s",
|
| 566 |
+
next(model.parameters()).dtype,
|
| 567 |
+
next(model.parameters()).device)
|
| 568 |
+
|
| 569 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 570 |
+
model.enable_input_require_grads()
|
| 571 |
+
|
| 572 |
+
from peft import LoraConfig, get_peft_model
|
| 573 |
+
lora = LoraConfig(
|
| 574 |
+
r=cfg.lora_rank, lora_alpha=cfg.lora_alpha, lora_dropout=0.0,
|
| 575 |
+
bias="none", task_type="CAUSAL_LM",
|
| 576 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 577 |
+
"gate_proj", "up_proj", "down_proj"])
|
| 578 |
+
model = get_peft_model(model, lora)
|
| 579 |
+
log.info("bf16+LoRA model loaded")
|
| 580 |
+
|
| 581 |
+
if tok.pad_token is None:
|
| 582 |
+
tok.pad_token = tok.eos_token
|
| 583 |
+
tok.pad_token_id = tok.eos_token_id
|
| 584 |
+
tok.truncation_side = "left"
|
| 585 |
+
|
| 586 |
+
eos_ids = []
|
| 587 |
+
for tok_str in ("<|im_end|>", "<|endoftext|>"):
|
| 588 |
+
tid = tok.convert_tokens_to_ids(tok_str)
|
| 589 |
+
if isinstance(tid, int) and tid >= 0 and tid != tok.unk_token_id:
|
| 590 |
+
if tid not in eos_ids:
|
| 591 |
+
eos_ids.append(tid)
|
| 592 |
+
if not eos_ids:
|
| 593 |
+
eos_ids = list(QWEN25_EOS_IDS)
|
| 594 |
+
tok.eos_token_id = eos_ids[0]
|
| 595 |
+
model.config.eos_token_id = eos_ids[0]
|
| 596 |
+
gc = getattr(model, "generation_config", None)
|
| 597 |
+
if gc is not None:
|
| 598 |
+
gc.eos_token_id = eos_ids[0] if len(eos_ids) == 1 else eos_ids
|
| 599 |
+
log.info("EOS ids=%s pad=%d trunc_side=%s",
|
| 600 |
+
eos_ids, tok.pad_token_id, tok.truncation_side)
|
| 601 |
+
|
| 602 |
+
model.print_trainable_parameters()
|
| 603 |
+
|
| 604 |
+
import types
|
| 605 |
+
_base_generate = type(model).generate
|
| 606 |
+
def _gc_safe_generate(self, *args, **kwargs):
|
| 607 |
+
self.gradient_checkpointing_disable()
|
| 608 |
+
try:
|
| 609 |
+
return _base_generate(self, *args, **kwargs)
|
| 610 |
+
finally:
|
| 611 |
+
self.gradient_checkpointing_enable(
|
| 612 |
+
gradient_checkpointing_kwargs={"use_reentrant": False})
|
| 613 |
+
model.generate = types.MethodType(_gc_safe_generate, model)
|
| 614 |
+
log.info("Patched model.generate to toggle gradient_checkpointing off/on")
|
| 615 |
+
|
| 616 |
+
return model, tok
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
# ββ GRPO CONFIG βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 620 |
+
def build_grpo_config(cfg):
|
| 621 |
+
from trl import GRPOConfig
|
| 622 |
+
gen_kwargs = {
|
| 623 |
+
"eos_token_id": list(QWEN25_EOS_IDS),
|
| 624 |
+
"top_p": 0.9,
|
| 625 |
+
"top_k": 40,
|
| 626 |
+
}
|
| 627 |
+
p = dict(
|
| 628 |
+
output_dir=cfg.out_dir,
|
| 629 |
+
max_steps=cfg.steps,
|
| 630 |
+
learning_rate=cfg.learning_rate,
|
| 631 |
+
warmup_steps=cfg.warmup_steps,
|
| 632 |
+
lr_scheduler_type="constant_with_warmup",
|
| 633 |
+
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 634 |
+
per_device_train_batch_size=cfg.per_device_train_batch_size,
|
| 635 |
+
num_generations=cfg.num_generations,
|
| 636 |
+
generation_batch_size=cfg.num_generations,
|
| 637 |
+
max_prompt_length=cfg.max_prompt_length,
|
| 638 |
+
max_completion_length=cfg.max_completion_length,
|
| 639 |
+
mask_truncated_completions=True,
|
| 640 |
+
gradient_checkpointing=True,
|
| 641 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 642 |
+
bf16=True,
|
| 643 |
+
max_grad_norm=1.0,
|
| 644 |
+
temperature=cfg.temperature,
|
| 645 |
+
generation_kwargs=gen_kwargs,
|
| 646 |
+
loss_type="grpo",
|
| 647 |
+
beta=0.04,
|
| 648 |
+
epsilon=0.2,
|
| 649 |
+
remove_unused_columns=False,
|
| 650 |
+
logging_steps=1,
|
| 651 |
+
save_steps=cfg.save_steps,
|
| 652 |
+
save_total_limit=1,
|
| 653 |
+
report_to="none",
|
| 654 |
+
)
|
| 655 |
+
while True:
|
| 656 |
+
try:
|
| 657 |
+
return GRPOConfig(**p)
|
| 658 |
+
except TypeError as e:
|
| 659 |
+
m = re.search(r"unexpected keyword argument '(\w+)'", str(e))
|
| 660 |
+
if not m:
|
| 661 |
+
raise
|
| 662 |
+
log.warning("Dropping unsupported GRPOConfig param: %r", m.group(1))
|
| 663 |
+
p.pop(m.group(1), None)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
# ββ TRAIN βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 667 |
+
out = Path(cfg.out_dir)
|
| 668 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 669 |
+
(out / "config.json").write_text(json.dumps(asdict(cfg), indent=2))
|
| 670 |
+
|
| 671 |
+
model, tok = load_model(cfg)
|
| 672 |
+
train_ds, eval_ds = build_dataset(cfg, tok)
|
| 673 |
+
|
| 674 |
+
from trl import GRPOTrainer
|
| 675 |
+
trainer = GRPOTrainer(
|
| 676 |
+
model=model,
|
| 677 |
+
reward_funcs=[format_reward, syntax_reward, correctness_reward,
|
| 678 |
+
speedup_reward],
|
| 679 |
+
args=build_grpo_config(cfg),
|
| 680 |
+
train_dataset=train_ds,
|
| 681 |
+
eval_dataset=eval_ds,
|
| 682 |
+
processing_class=tok,
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
# ββ PRE-TRAINING SANITY CHECK βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 686 |
+
log.info("Running pre-training generation sanity check...")
|
| 687 |
+
with torch.no_grad():
|
| 688 |
+
sample = train_ds[0]["prompt"]
|
| 689 |
+
inputs = tok(sample, return_tensors="pt", truncation=True,
|
| 690 |
+
max_length=cfg.max_prompt_length).to(model.device)
|
| 691 |
+
log.info("Sanity input: %d tokens, device=%s, dtype=%s",
|
| 692 |
+
inputs["input_ids"].shape[1], inputs["input_ids"].device,
|
| 693 |
+
next(model.parameters()).dtype)
|
| 694 |
+
gen_ids = model.generate(
|
| 695 |
+
**inputs, max_new_tokens=128, temperature=0.5, top_p=0.9, top_k=40,
|
| 696 |
+
do_sample=True, eos_token_id=list(QWEN25_EOS_IDS),
|
| 697 |
+
)
|
| 698 |
+
new_ids = gen_ids[0][inputs["input_ids"].shape[1]:]
|
| 699 |
+
gen_text = tok.decode(new_ids, skip_special_tokens=True)
|
| 700 |
+
log.info("SANITY OUTPUT (%d tokens): %r", len(new_ids), gen_text[:500])
|
| 701 |
+
if "stringodzi" in gen_text or len(set(gen_text.split())) < 5:
|
| 702 |
+
log.error("BASE MODEL IS GENERATING GIBBERISH β dtype or device issue!")
|
| 703 |
+
else:
|
| 704 |
+
log.info("Base model generates coherent text. GRPO should learn.")
|
| 705 |
+
|
| 706 |
+
log.info("=" * 60)
|
| 707 |
+
log.info("TRAINING START: %d steps | G=%d | %s",
|
| 708 |
+
cfg.steps, cfg.num_generations, cfg.model_id)
|
| 709 |
+
log.info("=" * 60)
|
| 710 |
+
t0 = time.time()
|
| 711 |
+
try:
|
| 712 |
+
trainer.train()
|
| 713 |
+
finally:
|
| 714 |
+
rows = getattr(trainer.state, "log_history", [])
|
| 715 |
+
all_keys = list(dict.fromkeys(k for r in rows for k in r.keys()))
|
| 716 |
+
with open(out / "log.csv", "w", newline="") as f:
|
| 717 |
+
if all_keys:
|
| 718 |
+
w = csv.DictWriter(f, fieldnames=all_keys, extrasaction="ignore")
|
| 719 |
+
w.writeheader()
|
| 720 |
+
w.writerows(rows)
|
| 721 |
+
log.info("Logged %d rows to log.csv", len(rows))
|
| 722 |
+
|
| 723 |
+
elapsed = time.time() - t0
|
| 724 |
+
log.info("Training done in %.0fs (%.1f min)", elapsed, elapsed / 60)
|
| 725 |
+
|
| 726 |
+
try:
|
| 727 |
+
trainer.save_model(str(out / "lora-adapter"))
|
| 728 |
+
tok.save_pretrained(str(out / "lora-adapter"))
|
| 729 |
+
log.info("LoRA adapter saved to %s", out / "lora-adapter")
|
| 730 |
+
except Exception as e:
|
| 731 |
+
log.error("Failed to save LoRA adapter: %s", e)
|
| 732 |
+
|
| 733 |
+
if hf_tok := os.environ.get("HF_TOKEN"):
|
| 734 |
+
try:
|
| 735 |
+
from huggingface_hub import HfApi
|
| 736 |
+
api = HfApi(token=hf_tok)
|
| 737 |
+
api.create_repo(cfg.hub_model_id, repo_type="model",
|
| 738 |
+
exist_ok=True, private=False)
|
| 739 |
+
_msg = (
|
| 740 |
+
f"v6 GRPO LoRA: profile={os.environ.get('ARMGYM_PROFILE', 'mvp')} "
|
| 741 |
+
f"steps={cfg.steps} (harness correctness)"
|
| 742 |
+
)
|
| 743 |
+
api.upload_folder(
|
| 744 |
+
folder_path=str(out), repo_id=cfg.hub_model_id,
|
| 745 |
+
repo_type="model", commit_message=_msg[:200],
|
| 746 |
+
)
|
| 747 |
+
log.info("Uploaded to https://huggingface.co/%s", cfg.hub_model_id)
|
| 748 |
+
except Exception as e:
|
| 749 |
+
log.error("HF Hub upload failed: %s", e)
|
| 750 |
+
|
| 751 |
+
log.info("==== V6 COMPLETE ====")
|