mahir-m01 commited on
Commit
93e35c4
Β·
1 Parent(s): a2ffabc

feat(hf): v6 GRPO with harness correctness, MVP 50 steps, and run scripts

Browse files

- Add hf/v6_train.py: qemu harness, ARMGYM_PROFILE/PROFILE (mvp|long) for hub + steps
- Add hf/run_hf_mvp.sh and run_hf_long.sh for Hugging Face Jobs reference
- Add hf/smoke_harness.py for local syntax checks
- Harden arm_gym/compile_baseline mcpu fallbacks (clang+gcc+MCA)
- Add instructions.md for LoRA export and 50 vs 200 flows
- Ignore .ai-workflow/ and .mcp.json in git

Made-with: Cursor

.gitignore CHANGED
@@ -37,6 +37,8 @@ adapters/
37
  *.gguf
38
  .next/
39
  excalidraw.log
 
 
40
 
41
  .omc/
42
  .omg/
 
37
  *.gguf
38
  .next/
39
  excalidraw.log
40
+ .ai-workflow/
41
+ .mcp.json
42
 
43
  .omc/
44
  .omg/
arm_gym/compile_baseline.py CHANGED
@@ -3,6 +3,11 @@
3
  Cut 1 fix: prefer LLVM 21 with -mcpu=neoverse-v2 as the Neoverse V3 proxy.
4
  neoverse-v3 is not yet in any LLVM release (Olympus is LLVM 22, NVIDIA-specific).
5
  Fallback is disclosed via mcpu_disclosed field, not silently swapped.
 
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
@@ -13,9 +18,11 @@ import tempfile
13
  from dataclasses import dataclass
14
  from pathlib import Path
15
 
16
- CLANG_CANDIDATES = ["clang-21", "clang-20", "clang"]
17
  GCC_AARCH64 = "aarch64-linux-gnu-gcc"
18
- MCA_CANDIDATES = ["llvm-mca-21", "llvm-mca-20", "llvm-mca"]
 
 
19
 
20
 
21
  def find_tool(candidates: list[str]) -> str | None:
@@ -31,7 +38,7 @@ class ToolchainInfo:
31
  gcc_aarch64: str | None
32
  mca: str | None
33
  mcpu: str # actual -mcpu used
34
- mcpu_disclosed: str | None # e.g. "V2 proxy for V3" when fallback taken
35
 
36
  def ready(self) -> bool:
37
  return bool(self.clang or self.gcc_aarch64)
@@ -41,23 +48,53 @@ def detect_toolchain(preferred_cpu: str = "neoverse-v3") -> ToolchainInfo:
41
  clang = find_tool(CLANG_CANDIDATES)
42
  gcc = shutil.which(GCC_AARCH64)
43
  mca = find_tool(MCA_CANDIDATES)
44
- mcpu, disclosed = _pick_cpu(clang, preferred_cpu)
45
  return ToolchainInfo(clang=clang, gcc_aarch64=gcc, mca=mca, mcpu=mcpu, mcpu_disclosed=disclosed)
46
 
47
 
48
- def _pick_cpu(clang: str | None, preferred: str) -> tuple[str, str | None]:
49
- if not clang:
50
- return preferred, None
51
- try:
52
- out = subprocess.run([clang, "--print-supported-cpus"],
53
- capture_output=True, text=True, timeout=10)
54
- supported = (out.stdout + out.stderr).lower()
55
- except Exception:
56
- return "neoverse-v2", f"V2 proxy for {preferred} (clang probe failed)"
57
- if preferred.lower() in supported:
58
- return preferred, None
59
- # Cut 1 disclosure: document the downgrade rather than silently proxy.
60
- return "neoverse-v2", f"V2 proxy for {preferred} (not in clang --print-supported-cpus)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def compile_to_asm(c_source: str, tc: ToolchainInfo, opt: str = "-O3") -> str:
 
3
  Cut 1 fix: prefer LLVM 21 with -mcpu=neoverse-v2 as the Neoverse V3 proxy.
4
  neoverse-v3 is not yet in any LLVM release (Olympus is LLVM 22, NVIDIA-specific).
5
  Fallback is disclosed via mcpu_disclosed field, not silently swapped.
6
+
7
+ mcpu selection probes both clang and gcc to find the best CPU both compilers
8
+ actually support. Fallback chain: v3 β†’ v2 β†’ v1 β†’ n2 β†’ n1 β†’ generic.
9
+ GCC 12 (Debian Bookworm default) supports up to neoverse-v1.
10
+ LLVM-MCA-15 supports neoverse-v1, so reward signal stays consistent.
11
  """
12
 
13
  from __future__ import annotations
 
18
  from dataclasses import dataclass
19
  from pathlib import Path
20
 
21
+ CLANG_CANDIDATES = ["clang-21", "clang-20", "clang-17", "clang-16", "clang-15", "clang"]
22
  GCC_AARCH64 = "aarch64-linux-gnu-gcc"
23
+ MCA_CANDIDATES = ["llvm-mca-21", "llvm-mca-20", "llvm-mca-17", "llvm-mca-16", "llvm-mca-15", "llvm-mca"]
24
+
25
+ _MCPU_CHAIN = ["neoverse-v3", "neoverse-v2", "neoverse-v1", "neoverse-n2", "neoverse-n1", "generic"]
26
 
27
 
28
  def find_tool(candidates: list[str]) -> str | None:
 
38
  gcc_aarch64: str | None
39
  mca: str | None
40
  mcpu: str # actual -mcpu used
41
+ mcpu_disclosed: str | None # e.g. "V1 proxy for V3" when fallback taken
42
 
43
  def ready(self) -> bool:
44
  return bool(self.clang or self.gcc_aarch64)
 
48
  clang = find_tool(CLANG_CANDIDATES)
49
  gcc = shutil.which(GCC_AARCH64)
50
  mca = find_tool(MCA_CANDIDATES)
51
+ mcpu, disclosed = _pick_cpu(clang, gcc, preferred_cpu)
52
  return ToolchainInfo(clang=clang, gcc_aarch64=gcc, mca=mca, mcpu=mcpu, mcpu_disclosed=disclosed)
53
 
54
 
55
+ def _gcc_probe_mcpu(gcc: str, preferred: str) -> tuple[str, str | None]:
56
+ """Find best mcpu the installed gcc actually accepts via test-compile."""
57
+ chain = [preferred] + [c for c in _MCPU_CHAIN if c != preferred]
58
+ for cpu in chain:
59
+ try:
60
+ r = subprocess.run(
61
+ [gcc, f"-mcpu={cpu}", "-S", "-x", "c", "-", "-o", "/dev/null"],
62
+ input="int f(void){return 0;}",
63
+ capture_output=True, text=True, timeout=5,
64
+ )
65
+ if r.returncode == 0:
66
+ disclosed = None if cpu == preferred else f"{cpu} proxy for {preferred} (gcc limit)"
67
+ return cpu, disclosed
68
+ except Exception:
69
+ continue
70
+ return "generic", f"generic fallback for {preferred}"
71
+
72
+
73
+ def _pick_cpu(clang: str | None, gcc: str | None, preferred: str) -> tuple[str, str | None]:
74
+ """Probe available compilers to find best mcpu both support."""
75
+ if clang:
76
+ try:
77
+ # Must pass --target=aarch64-linux-gnu β€” without it, clang lists
78
+ # host (x86) CPUs and neoverse-* names are not present.
79
+ out = subprocess.run(
80
+ [clang, "--target=aarch64-linux-gnu", "--print-supported-cpus"],
81
+ capture_output=True, text=True, timeout=10,
82
+ )
83
+ supported = (out.stdout + out.stderr).lower()
84
+ chain = [preferred] + [c for c in _MCPU_CHAIN if c != preferred]
85
+ for cpu in chain:
86
+ if cpu.lower() in supported:
87
+ disclosed = None if cpu == preferred else f"{cpu} proxy for {preferred} (clang limit)"
88
+ return cpu, disclosed
89
+ except Exception:
90
+ pass
91
+
92
+ # No clang or clang probe failed β€” probe GCC directly by test-compiling.
93
+ if gcc:
94
+ return _gcc_probe_mcpu(gcc, preferred)
95
+
96
+ # No compiler to probe yet; return preferred and let compile_to_asm fail loudly.
97
+ return preferred, None
98
 
99
 
100
  def compile_to_asm(c_source: str, tc: ToolchainInfo, opt: str = "-O3") -> str:
hf/PROFILE ADDED
@@ -0,0 +1 @@
 
 
1
+ mvp
hf/run_hf_long.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # 200-step Hugging Face Job launcher (same deps as run_hf_mvp.sh).
3
+ set -euo pipefail
4
+ export ARMGYM_PROFILE=long
5
+ DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ exec bash "$DIR/run_hf_mvp.sh"
hf/run_hf_mvp.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Reference launcher for Hugging Face Jobs: MVP (50-step) training on a10g-large (or similar).
3
+ # Requires: HF_TOKEN in env for adapter upload. Do not commit tokens.
4
+ set -euo pipefail
5
+ : "${HF_TOKEN:?set HF_TOKEN}"
6
+ echo "=== System deps ==="
7
+ apt-get update -qq
8
+ DEBIAN_FRONTEND=noninteractive apt-get install -y -qq \
9
+ qemu-user-static binutils-aarch64-linux-gnu gcc-aarch64-linux-gnu \
10
+ libc6-dev-arm64-cross ca-certificates curl gnupg lsb-release >/dev/null
11
+ CODENAME="${CODENAME:-$(lsb_release -cs 2>/dev/null || echo bookworm)}"
12
+ curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | gpg --dearmor -o /usr/share/keyrings/llvm.gpg
13
+ echo "deb [signed-by=/usr/share/keyrings/llvm.gpg] http://apt.llvm.org/${CODENAME}/ llvm-toolchain-${CODENAME}-21 main" \
14
+ > /etc/apt/sources.list.d/llvm21.list
15
+ apt-get update -qq
16
+ DEBIAN_FRONTEND=noninteractive apt-get install -y -qq clang-21 llvm-21 llvm-21-tools >/dev/null \
17
+ || apt-get install -y -qq clang llvm llvm-tools
18
+ export PATH="/usr/lib/llvm-21/bin:${PATH}"
19
+ echo "=== Python deps ==="
20
+ python3 -m pip install -q --upgrade pip
21
+ python3 -m pip install -q --no-input \
22
+ 'trl==0.20.0' 'transformers>=4.55,<4.58' 'accelerate>=1.0' 'peft>=0.13' 'datasets>=3.0' \
23
+ 'bitsandbytes>=0.45' 'torch>=2.3' 'numpy>=1.26' 'pydantic>=2.7' sentencepiece protobuf \
24
+ huggingface_hub
25
+ V6_URL="https://huggingface.co/datasets/ZDC-M01/arm-gym-pkg/resolve/main/v6_train.py"
26
+ WHEEL_URL="https://huggingface.co/datasets/ZDC-M01/arm-gym-pkg/resolve/main/arm_gym-0.1.0-py3-none-any.whl"
27
+ python3 -m pip install -q --force-reinstall "arm_gym @ ${WHEEL_URL}"
28
+ python3 -m pip uninstall -y torchao vllm unsloth unsloth_zoo xformers 2>/dev/null || true
29
+ echo "=== Run GRPO (MVP profile = 50 steps via hf/PROFILE) ==="
30
+ curl -sSL "$V6_URL" -o /tmp/v6_train.py
31
+ cd /tmp
32
+ python3 v6_train.py
hf/run_long_train.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # 200-step GRPO + LoRA. Overrides hf/PROFILE when sourced before python3 hf/v6_train.py
3
+ export ARMGYM_PROFILE=long
4
+ echo "ARMGYM_PROFILE=long β†’ 200 steps, out runs/v6-200, hub ZDC-M01/arm-gym-train-200"
hf/smoke_harness.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Smoke test for generate_test_harness: check generated C is syntactically valid."""
3
+ import os, subprocess, sys, tempfile
4
+
5
+ sys.path.insert(0, os.path.dirname(__file__))
6
+
7
+ # Pull the harness helpers from v5e_train without importing the whole module
8
+ # (which requires torch, etc.)
9
+ import re
10
+
11
+ _SIG_RE = re.compile(r"([\w\s*]+?\bkernel\s*\([^)]*\))", re.DOTALL)
12
+ _ARR_SZ = 8192
13
+ _PRINT_N = 64
14
+
15
+ def _parse_kernel_sig(c_source):
16
+ m = _SIG_RE.search(c_source)
17
+ if not m:
18
+ return "void", [], None
19
+ full_proto = " ".join(m.group(1).split())
20
+ paren = full_proto.index("(")
21
+ ret_and_name = full_proto[:paren].strip()
22
+ ret_type = ret_and_name.rsplit("kernel", 1)[0].strip() or "void"
23
+ raw_params_str = full_proto[paren + 1:].rstrip(")")
24
+ params = []
25
+ for p in raw_params_str.split(","):
26
+ p = p.strip()
27
+ if not p:
28
+ continue
29
+ is_const = "const " in p
30
+ is_ptr = "*" in p
31
+ clean = p.replace("const", "").replace("__restrict__", "").replace("*", "").strip()
32
+ parts = clean.split()
33
+ dtype = parts[0] if parts else "int"
34
+ name = parts[-1] if len(parts) > 1 else f"p{len(params)}"
35
+ params.append({"dtype": dtype, "name": name, "is_ptr": is_ptr,
36
+ "is_const": is_const, "raw": p})
37
+ return ret_type, params, full_proto
38
+
39
+ def generate_test_harness(c_source, with_kernel_def):
40
+ ret_type, params, full_proto = _parse_kernel_sig(c_source)
41
+ lines = ['#include <stdio.h>', '#include <stdlib.h>', '#include <string.h>',
42
+ '#include <stddef.h>', '#include <stdint.h>', '#include <math.h>', '']
43
+ if with_kernel_def:
44
+ lines.append(c_source)
45
+ else:
46
+ if full_proto:
47
+ lines.append(f"extern {full_proto};")
48
+ lines += ['', 'int main(void) {']
49
+ call_args = []
50
+ output_arrays = []
51
+ for p in params:
52
+ dt, nm = p["dtype"], p["name"]
53
+ if p["is_ptr"]:
54
+ lines.append(f" static {dt} {nm}[{_ARR_SZ}];")
55
+ if p["is_const"]:
56
+ if dt in ("float", "double"):
57
+ lines.append(
58
+ f" for (int i = 0; i < {_ARR_SZ}; i++)"
59
+ f" {nm}[i] = ({dt})((i % 17) + 1) * ({dt})0.25;")
60
+ elif "uint" in dt:
61
+ lines.append(
62
+ f" for (int i = 0; i < {_ARR_SZ}; i++)"
63
+ f" {nm}[i] = ({dt})((i % 31) + 1);")
64
+ else:
65
+ lines.append(
66
+ f" for (int i = 0; i < {_ARR_SZ}; i++)"
67
+ f" {nm}[i] = ({dt})((i % 17) + 1);")
68
+ else:
69
+ lines.append(f" memset({nm}, 0, sizeof({nm}));")
70
+ output_arrays.append((nm, dt))
71
+ call_args.append(nm)
72
+ else:
73
+ if dt in ("float", "double"):
74
+ lines.append(f" {dt} {nm} = ({dt})1.5;")
75
+ else:
76
+ lines.append(f" {dt} {nm} = ({dt})3;")
77
+ call_args.append(nm)
78
+ call_expr = f'kernel({", ".join(call_args)})'
79
+ is_void = ret_type.strip() in ("void", "")
80
+ if not is_void:
81
+ lines.append(f" {ret_type.strip()} _result = {call_expr};")
82
+ if ret_type.strip() in ("float", "double"):
83
+ lines.append(' printf("%.10g\\n", (double)_result);')
84
+ elif "unsigned" in ret_type or "uint" in ret_type:
85
+ lines.append(' printf("%u\\n", _result);')
86
+ else:
87
+ lines.append(' printf("%ld\\n", (long)_result);')
88
+ else:
89
+ lines.append(f" {call_expr};")
90
+ for nm, dt in output_arrays:
91
+ if dt in ("float", "double"):
92
+ lines.append(
93
+ f' for (int i = 0; i < {_PRINT_N}; i++)'
94
+ f' printf("%.10g\\n", (double){nm}[i]);')
95
+ elif "unsigned" in dt or "uint" in dt:
96
+ lines.append(
97
+ f' for (int i = 0; i < {_PRINT_N}; i++)'
98
+ f' printf("%u\\n", {nm}[i]);')
99
+ else:
100
+ lines.append(
101
+ f' for (int i = 0; i < {_PRINT_N}; i++)'
102
+ f' printf("%ld\\n", (long){nm}[i]);')
103
+ lines += [" return 0;", "}", ""]
104
+ return "\n".join(lines)
105
+
106
+
107
+ KERNELS = {
108
+ "vec_add": (
109
+ '#include <stddef.h>\n'
110
+ 'void kernel(float * __restrict__ a, const float * __restrict__ b,\n'
111
+ ' const float * __restrict__ c) {\n'
112
+ ' for (size_t i = 0; i < 16; ++i) a[i] = b[i] + c[i];\n'
113
+ '}\n'
114
+ ),
115
+ "dot": (
116
+ '#include <stddef.h>\n'
117
+ 'float kernel(const float * __restrict__ a, const float * __restrict__ b) {\n'
118
+ ' float s = 0;\n'
119
+ ' for (size_t i = 0; i < 16; ++i) s += a[i] * b[i];\n'
120
+ ' return s;\n'
121
+ '}\n'
122
+ ),
123
+ "saxpy": (
124
+ '#include <stddef.h>\n'
125
+ 'void kernel(float alpha, const float * __restrict__ x,\n'
126
+ ' float * __restrict__ y) {\n'
127
+ ' for (size_t i = 0; i < 32; ++i) y[i] = alpha * x[i] + y[i];\n'
128
+ '}\n'
129
+ ),
130
+ "popcount": (
131
+ '#include <stddef.h>\n'
132
+ '#include <stdint.h>\n'
133
+ 'unsigned kernel(const uint32_t * __restrict__ x) {\n'
134
+ ' unsigned c = 0;\n'
135
+ ' for (size_t i = 0; i < 16; i++) {\n'
136
+ ' uint32_t v = x[i]; while (v) { c += v & 1; v >>= 1; }\n'
137
+ ' }\n'
138
+ ' return c;\n'
139
+ '}\n'
140
+ ),
141
+ "clip": (
142
+ '#include <stddef.h>\n'
143
+ 'void kernel(float * __restrict__ y, const float * __restrict__ x,\n'
144
+ ' float lo, float hi) {\n'
145
+ ' for (size_t i = 0; i < 16; ++i) {\n'
146
+ ' float v = x[i]; y[i] = v < lo ? lo : (v > hi ? hi : v);\n'
147
+ ' }\n'
148
+ '}\n'
149
+ ),
150
+ "layernorm": (
151
+ '#include <stddef.h>\n'
152
+ '#include <math.h>\n'
153
+ 'void kernel(float * __restrict__ y, const float * __restrict__ x,\n'
154
+ ' const float * __restrict__ gamma, const float * __restrict__ beta) {\n'
155
+ ' float mean = 0; for (size_t i = 0; i < 16; ++i) mean += x[i];\n'
156
+ ' mean /= (float)16;\n'
157
+ ' float var = 0; for (size_t i = 0; i < 16; ++i) { float d = x[i] - mean; var += d*d; }\n'
158
+ ' var /= (float)16;\n'
159
+ ' float inv = (float)1.0 / sqrtf(var + (float)1e-5);\n'
160
+ ' for (size_t i = 0; i < 16; ++i) y[i] = (x[i] - mean) * inv * gamma[i] + beta[i];\n'
161
+ '}\n'
162
+ ),
163
+ }
164
+
165
+ ok, fail = 0, 0
166
+ for name, src in KERNELS.items():
167
+ for mode_name, with_def in [("ref", True), ("extern", False)]:
168
+ code = generate_test_harness(src, with_kernel_def=with_def)
169
+ with tempfile.NamedTemporaryFile(suffix=".c", mode="w", delete=False) as f:
170
+ f.write(code)
171
+ path = f.name
172
+ r = subprocess.run(["cc", "-fsyntax-only", "-Wno-everything", path],
173
+ capture_output=True, text=True)
174
+ os.unlink(path)
175
+ tag = f"{name}/{mode_name}"
176
+ if r.returncode == 0:
177
+ print(f" OK {tag}")
178
+ ok += 1
179
+ else:
180
+ print(f" FAIL {tag}: {r.stderr[:200]}")
181
+ fail += 1
182
+
183
+ print(f"\n{ok} passed, {fail} failed")
184
+ if fail:
185
+ sys.exit(1)
186
+
187
+ print("\n--- Host compile+run test ---")
188
+ for name, src in KERNELS.items():
189
+ code = generate_test_harness(src, with_kernel_def=True)
190
+ with tempfile.NamedTemporaryFile(suffix=".c", mode="w", delete=False) as f:
191
+ f.write(code)
192
+ cpath = f.name
193
+ elf = cpath.replace(".c", "")
194
+ r = subprocess.run(["cc", "-O3", "-o", elf, cpath, "-lm"],
195
+ capture_output=True, text=True)
196
+ os.unlink(cpath)
197
+ if r.returncode != 0:
198
+ print(f" COMPILE FAIL {name}: {r.stderr[:200]}")
199
+ continue
200
+ r = subprocess.run([elf], capture_output=True, text=True, timeout=5)
201
+ os.unlink(elf)
202
+ if r.returncode != 0:
203
+ print(f" RUN FAIL {name}: exit={r.returncode}")
204
+ continue
205
+ lines_out = r.stdout.strip().split("\n")
206
+ print(f" OK {name}: {len(lines_out)} output lines, first={lines_out[0]}")
207
+
208
+ print("\nAll smoke tests passed.")
hf/v6_train.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """ARM-Gym V6 β€” Qwen2.5-Coder-7B bf16+LoRA GRPO with harness-based correctness.
3
+
4
+ V6 changes over V5e:
5
+ - CRITICAL FIX: correctness_reward now actually works. Generates a C test
6
+ harness with main(), compiles reference + candidate as complete ELFs,
7
+ runs both under QEMU, and compares stdout (float-tolerant).
8
+ - Previously: bare kernel .o had no entry point β†’ crashed under QEMU β†’ 0.0.
9
+
10
+ Expects: torch, transformers, trl, peft, datasets, arm_gym pre-installed.
11
+ System: clang-21, llvm-mca-21, aarch64-linux-gnu-{as,gcc}, qemu-aarch64-static.
12
+ """
13
+ import csv, hashlib, json, logging, os, re, subprocess, sys, tempfile, threading, time
14
+ from dataclasses import asdict, dataclass
15
+ from pathlib import Path
16
+
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format="[%(asctime)s] %(levelname)s %(message)s",
20
+ datefmt="%H:%M:%S",
21
+ handlers=[logging.StreamHandler(sys.stdout)],
22
+ )
23
+ log = logging.getLogger("v6")
24
+ log.info("==== ARM-Gym V6 β€” Qwen2.5 bf16+LoRA GRPO (harness correctness) ====")
25
+
26
+ import torch
27
+ log.info("torch %s CUDA %s GPU %s",
28
+ torch.__version__,
29
+ torch.version.cuda if torch.cuda.is_available() else "N/A",
30
+ torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")
31
+
32
+ import transformers, trl, peft
33
+ log.info("transformers %s trl %s peft %s",
34
+ transformers.__version__, trl.__version__, peft.__version__)
35
+
36
+ from arm_gym.compile_baseline import detect_toolchain, compile_to_asm
37
+ from arm_gym.kernels import TEMPLATES, generate_all, split_train_eval
38
+ tc = detect_toolchain()
39
+
40
+ # Image gcc is too old for neoverse-v3. Force clang-21 (verified working).
41
+ if tc.clang and tc.gcc_aarch64:
42
+ log.info("Disabling gcc (%s) β€” neoverse-v3 unsupported; using clang-21 instead",
43
+ tc.gcc_aarch64)
44
+ tc.gcc_aarch64 = None
45
+ log.info("ARM toolchain: clang=%s gcc=%s mca=%s mcpu=%s disclosed=%s",
46
+ tc.clang, tc.gcc_aarch64, tc.mca, tc.mcpu, tc.mcpu_disclosed)
47
+
48
+ SMOKE_C = '#include <stddef.h>\nvoid f(int *a, int *b) { for(size_t i=0;i<4;i++) a[i]+=b[i]; }\n'
49
+ try:
50
+ smoke_asm = compile_to_asm(SMOKE_C, tc)
51
+ log.info("Smoke compile OK (%d bytes asm)", len(smoke_asm))
52
+ except Exception as e:
53
+ log.error("SMOKE COMPILE FAILED: %s", e)
54
+ log.error("This means ALL dataset rows will be empty. Aborting early.")
55
+ sys.exit(1)
56
+
57
+
58
+ # ── CONFIG ────────────────────────────────────────────────────────────────────
59
+ # ARMGYM_PROFILE=mvp (default on main/dev): 50 steps β†’ hub ZDC-M01/arm-gym-mvp-50
60
+ # ARMGYM_PROFILE=long (200-step run): 200 steps β†’ hub ZDC-M01/arm-gym-train-200
61
+ # On the `training/hf-200` branch, use `source hf/run_long_train.sh` before python.
62
+ @dataclass
63
+ class Cfg:
64
+ model_id: str = "Qwen/Qwen2.5-Coder-7B-Instruct"
65
+ hub_model_id: str = "ZDC-M01/arm-gym-mvp-50"
66
+ steps: int = 50
67
+ num_generations: int = 6
68
+ gradient_accumulation_steps: int = 8
69
+ per_device_train_batch_size: int = 1
70
+ lora_rank: int = 24
71
+ lora_alpha: int = 48
72
+ learning_rate: float = 5e-6
73
+ max_prompt_length: int = 2048
74
+ max_completion_length: int = 640
75
+ temperature: float = 0.7
76
+ difficulty_max: int = 1
77
+ max_train: int = 64
78
+ max_eval: int = 16
79
+ warmup_steps: int = 10
80
+ out_dir: str = "runs/v6-mvp"
81
+ save_steps: int = 25
82
+
83
+ def _default_profile() -> str:
84
+ if p := os.environ.get("ARMGYM_PROFILE"):
85
+ return p.lower().strip()
86
+ pfile = Path(__file__).resolve().with_name("PROFILE")
87
+ if pfile.is_file():
88
+ t = pfile.read_text().strip().lower()
89
+ if t in ("mvp", "long"):
90
+ return t
91
+ return "mvp"
92
+
93
+ def _apply_profile(cfg: Cfg) -> Cfg:
94
+ p = _default_profile()
95
+ if p == "long":
96
+ cfg.hub_model_id = "ZDC-M01/arm-gym-train-200"
97
+ cfg.steps = 200
98
+ cfg.out_dir = "runs/v6-200"
99
+ cfg.save_steps = 50
100
+ return cfg
101
+
102
+ cfg = _apply_profile(Cfg())
103
+ log.info("Profile: %s hub=%s out=%s steps=%d",
104
+ _default_profile(), cfg.hub_model_id, cfg.out_dir, cfg.steps)
105
+ log.info("Config: model=%s steps=%d G=%d temp=%.1f",
106
+ cfg.model_id, cfg.steps, cfg.num_generations, cfg.temperature)
107
+
108
+
109
+ # ── DATASET β€” SuperCoder A.3 prompt ──────────────────────────────────────────
110
+ SYSTEM_PROMPT = (
111
+ "You are an expert AArch64 (aarch64-linux-gnu-gcc) assembly writer. "
112
+ "Obey the user block exactly. Output only what is asked in the required tags."
113
+ )
114
+
115
+ def user_prompt(c_source: str, baseline_asm: str) -> str:
116
+ return (
117
+ "Given the following C code and assembly code, your task is to generate "
118
+ "highly optimized AArch64 assembly code.\n\n"
119
+ f"C Code:\n{c_source}\n\n"
120
+ f"Assembly Code:\n{baseline_asm}\n\n"
121
+ "Only output the optimized assembly code. Do not include any other text. "
122
+ "Do not write any comments in the assembly code. "
123
+ "Wrap the assembly code in <assembly></assembly> tags.\n\n"
124
+ "Optimized Assembly Code:\n"
125
+ )
126
+
127
+ def build_dataset(cfg, tok):
128
+ from datasets import Dataset
129
+ vs = [v for v in generate_all()
130
+ if TEMPLATES[v.template_name].difficulty <= cfg.difficulty_max]
131
+ log.info("Kernel variants (difficulty<=%d): %d", cfg.difficulty_max, len(vs))
132
+ tv, ev = split_train_eval(vs, eval_frac=0.1, seed=0)
133
+ tv, ev = tv[:cfg.max_train], ev[:cfg.max_eval]
134
+
135
+ _compile_fails = [0]
136
+ def row(v):
137
+ try:
138
+ basm = compile_to_asm(v.c_source, tc)
139
+ except Exception as e:
140
+ _compile_fails[0] += 1
141
+ if _compile_fails[0] <= 3:
142
+ log.warning("compile_to_asm failed [%d]: %s", _compile_fails[0], e)
143
+ return None
144
+ msgs = [
145
+ {"role": "system", "content": SYSTEM_PROMPT},
146
+ {"role": "user", "content": user_prompt(v.c_source, basm)},
147
+ ]
148
+ for kw in (
149
+ {"tokenize": False, "add_generation_prompt": True,
150
+ "enable_thinking": False},
151
+ {"tokenize": False, "add_generation_prompt": True},
152
+ ):
153
+ try:
154
+ prompt = tok.apply_chat_template(msgs, **kw)
155
+ break
156
+ except TypeError:
157
+ continue
158
+ else:
159
+ prompt = f"{SYSTEM_PROMPT}\n\n{msgs[1]['content']}\n"
160
+ return {"prompt": prompt, "variant_id": v.variant_id,
161
+ "baseline_asm": basm, "c_source": v.c_source}
162
+
163
+ tr = [r for v in tv if (r := row(v)) is not None]
164
+ er = [r for v in ev if (r := row(v)) is not None]
165
+ log.info("Dataset: train=%d eval=%d", len(tr), len(er))
166
+ if tr:
167
+ toks = tok(tr[0]["prompt"], return_tensors="pt")
168
+ log.info("Sample prompt tokens: %d", toks["input_ids"].shape[1])
169
+ log.info("Prompt tail: %r", tr[0]["prompt"][-200:])
170
+ return Dataset.from_list(tr), Dataset.from_list(er)
171
+
172
+
173
+ # ── REWARD FUNCTIONS ──────────────────────────────────────────────────────────
174
+ _ASM_RE = re.compile(r"<assembly>(.*?)</assembly>", re.DOTALL | re.IGNORECASE)
175
+ _THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
176
+ _CACHE: dict = {}
177
+ _BASELINE: dict = {}
178
+ _LOCK = threading.Lock()
179
+ _VCFG = None
180
+
181
+ def _clean_asm_directives(asm: str) -> str:
182
+ """Strip directives that GNU as rejects but clang emits."""
183
+ lines = []
184
+ for line in asm.splitlines():
185
+ stripped = line.strip()
186
+ if stripped.startswith(".addrsig") or stripped.startswith(".ident"):
187
+ continue
188
+ lines.append(line)
189
+ asm = "\n".join(lines)
190
+ if ".arch " not in asm:
191
+ asm = "\t.arch armv9-a+sve2+crc\n" + asm
192
+ return asm
193
+
194
+ def _extract_asm(text):
195
+ text = _THINK_RE.sub("", text).strip()
196
+ m = _ASM_RE.search(text)
197
+ if m:
198
+ return _clean_asm_directives(m.group(1).strip())
199
+ if "</assembly>" in text.lower():
200
+ body = re.split(r"</assembly>", text, flags=re.IGNORECASE)[0]
201
+ if "<assembly>" in body.lower():
202
+ body = re.split(r"<assembly>", body, flags=re.IGNORECASE)[-1]
203
+ return _clean_asm_directives(body.strip())
204
+ if "<assembly>" in text.lower():
205
+ return _clean_asm_directives(
206
+ re.split(r"<assembly>", text, flags=re.IGNORECASE)[-1].strip())
207
+ return _clean_asm_directives(text.strip())
208
+
209
+ def _vcfg():
210
+ global _VCFG
211
+ if _VCFG is None:
212
+ from arm_gym.verifier import VerifierConfig
213
+ _VCFG = VerifierConfig(
214
+ mca_bin=tc.mca or "llvm-mca", assembler="aarch64-linux-gnu-as",
215
+ linker="aarch64-linux-gnu-ld", qemu="qemu-aarch64-static",
216
+ mcpu=tc.mcpu)
217
+ return _VCFG
218
+
219
+
220
+ # ── CORRECTNESS HARNESS ──────────────────────────────────────────────────────
221
+ _SIG_RE = re.compile(r"([\w\s*]+?\bkernel\s*\([^)]*\))", re.DOTALL)
222
+ _HARNESS_DIR: Path | None = None
223
+ _REF_ELF_CACHE: dict[str, Path | None] = {}
224
+ _HARNESS_OBJ_CACHE: dict[str, Path | None] = {}
225
+ _CORRECTNESS_GCC = "aarch64-linux-gnu-gcc"
226
+ _CORRECTNESS_QEMU = "qemu-aarch64-static"
227
+ _ARR_SZ = 8192
228
+ _PRINT_N = 64
229
+
230
+ def _harness_dir() -> Path:
231
+ global _HARNESS_DIR
232
+ if _HARNESS_DIR is None:
233
+ _HARNESS_DIR = Path(tempfile.mkdtemp(prefix="armgym_harness_"))
234
+ return _HARNESS_DIR
235
+
236
+ def _parse_kernel_sig(c_source: str):
237
+ m = _SIG_RE.search(c_source)
238
+ if not m:
239
+ return "void", [], None
240
+ full_proto = " ".join(m.group(1).split())
241
+ paren = full_proto.index("(")
242
+ ret_and_name = full_proto[:paren].strip()
243
+ ret_type = ret_and_name.rsplit("kernel", 1)[0].strip() or "void"
244
+ raw_params_str = full_proto[paren + 1:].rstrip(")")
245
+ params = []
246
+ for p in raw_params_str.split(","):
247
+ p = p.strip()
248
+ if not p:
249
+ continue
250
+ is_const = "const " in p
251
+ is_ptr = "*" in p
252
+ clean = p.replace("const", "").replace("__restrict__", "").replace("*", "").strip()
253
+ parts = clean.split()
254
+ dtype = parts[0] if parts else "int"
255
+ name = parts[-1] if len(parts) > 1 else f"p{len(params)}"
256
+ params.append({"dtype": dtype, "name": name, "is_ptr": is_ptr,
257
+ "is_const": is_const, "raw": p})
258
+ return ret_type, params, full_proto
259
+
260
+ def generate_test_harness(c_source: str, with_kernel_def: bool) -> str:
261
+ ret_type, params, full_proto = _parse_kernel_sig(c_source)
262
+ lines = ["#include <stdio.h>", "#include <stdlib.h>", "#include <string.h>",
263
+ "#include <stddef.h>", "#include <stdint.h>", "#include <math.h>", ""]
264
+ if with_kernel_def:
265
+ lines.append(c_source)
266
+ else:
267
+ if full_proto:
268
+ lines.append(f"extern {full_proto};")
269
+ lines += ["", "int main(void) {"]
270
+ call_args = []
271
+ output_arrays = []
272
+ for p in params:
273
+ dt, nm = p["dtype"], p["name"]
274
+ if p["is_ptr"]:
275
+ lines.append(f" static {dt} {nm}[{_ARR_SZ}];")
276
+ if p["is_const"]:
277
+ if dt in ("float", "double"):
278
+ lines.append(
279
+ f" for (int i = 0; i < {_ARR_SZ}; i++)"
280
+ f" {nm}[i] = ({dt})((i % 17) + 1) * ({dt})0.25;")
281
+ elif "uint" in dt:
282
+ lines.append(
283
+ f" for (int i = 0; i < {_ARR_SZ}; i++)"
284
+ f" {nm}[i] = ({dt})((i % 31) + 1);")
285
+ else:
286
+ lines.append(
287
+ f" for (int i = 0; i < {_ARR_SZ}; i++)"
288
+ f" {nm}[i] = ({dt})((i % 17) + 1);")
289
+ else:
290
+ lines.append(f" memset({nm}, 0, sizeof({nm}));")
291
+ output_arrays.append((nm, dt))
292
+ call_args.append(nm)
293
+ else:
294
+ if dt in ("float", "double"):
295
+ lines.append(f" {dt} {nm} = ({dt})1.5;")
296
+ else:
297
+ lines.append(f" {dt} {nm} = ({dt})3;")
298
+ call_args.append(nm)
299
+ call_expr = f'kernel({", ".join(call_args)})'
300
+ is_void = ret_type.strip() in ("void", "")
301
+ if not is_void:
302
+ lines.append(f" {ret_type.strip()} _result = {call_expr};")
303
+ if ret_type.strip() in ("float", "double"):
304
+ lines.append(' printf("%.10g\\n", (double)_result);')
305
+ elif "unsigned" in ret_type or "uint" in ret_type:
306
+ lines.append(' printf("%u\\n", _result);')
307
+ else:
308
+ lines.append(' printf("%ld\\n", (long)_result);')
309
+ else:
310
+ lines.append(f" {call_expr};")
311
+ for nm, dt in output_arrays:
312
+ if dt in ("float", "double"):
313
+ lines.append(
314
+ f' for (int i = 0; i < {_PRINT_N}; i++)'
315
+ f' printf("%.10g\\n", (double){nm}[i]);')
316
+ elif "unsigned" in dt or "uint" in dt:
317
+ lines.append(
318
+ f' for (int i = 0; i < {_PRINT_N}; i++)'
319
+ f' printf("%u\\n", {nm}[i]);')
320
+ else:
321
+ lines.append(
322
+ f' for (int i = 0; i < {_PRINT_N}; i++)'
323
+ f' printf("%ld\\n", (long){nm}[i]);')
324
+ lines += [" return 0;", "}", ""]
325
+ return "\n".join(lines)
326
+
327
+ def _get_reference_elf(vid: str, c_source: str) -> Path | None:
328
+ if vid in _REF_ELF_CACHE:
329
+ p = _REF_ELF_CACHE[vid]
330
+ if p and p.exists():
331
+ return p
332
+ if p is None:
333
+ return None
334
+ d = _harness_dir() / f"ref_{vid}"
335
+ d.mkdir(parents=True, exist_ok=True)
336
+ src = d / "combined.c"
337
+ src.write_text(generate_test_harness(c_source, with_kernel_def=True))
338
+ elf = d / "ref.elf"
339
+ r = subprocess.run(
340
+ [_CORRECTNESS_GCC, "-O3", "-static", "-o", str(elf), str(src), "-lm"],
341
+ capture_output=True, text=True, timeout=30)
342
+ if r.returncode != 0:
343
+ log.warning("[harness] ref compile failed vid=%s: %s", vid, r.stderr[:300])
344
+ _REF_ELF_CACHE[vid] = None
345
+ return None
346
+ _REF_ELF_CACHE[vid] = elf
347
+ return elf
348
+
349
+ def _get_harness_obj(vid: str, c_source: str) -> Path | None:
350
+ if vid in _HARNESS_OBJ_CACHE:
351
+ p = _HARNESS_OBJ_CACHE[vid]
352
+ if p and p.exists():
353
+ return p
354
+ if p is None:
355
+ return None
356
+ d = _harness_dir() / f"harn_{vid}"
357
+ d.mkdir(parents=True, exist_ok=True)
358
+ src = d / "harness.c"
359
+ src.write_text(generate_test_harness(c_source, with_kernel_def=False))
360
+ obj = d / "harness.o"
361
+ r = subprocess.run(
362
+ [_CORRECTNESS_GCC, "-c", "-o", str(obj), str(src)],
363
+ capture_output=True, text=True, timeout=30)
364
+ if r.returncode != 0:
365
+ log.warning("[harness] harness compile failed vid=%s: %s", vid, r.stderr[:300])
366
+ _HARNESS_OBJ_CACHE[vid] = None
367
+ return None
368
+ _HARNESS_OBJ_CACHE[vid] = obj
369
+ return obj
370
+
371
+ def _link_candidate_elf(cand_obj: Path, vid: str, c_source: str) -> Path | None:
372
+ harness_obj = _get_harness_obj(vid, c_source)
373
+ if harness_obj is None:
374
+ return None
375
+ tag = hashlib.md5(str(cand_obj).encode()).hexdigest()[:8]
376
+ d = _harness_dir() / f"cand_{vid}_{tag}"
377
+ d.mkdir(parents=True, exist_ok=True)
378
+ elf = d / "cand.elf"
379
+ r = subprocess.run(
380
+ [_CORRECTNESS_GCC, "-static", "-o", str(elf),
381
+ str(harness_obj), str(cand_obj), "-lm", "-lc"],
382
+ capture_output=True, text=True, timeout=30)
383
+ if r.returncode != 0:
384
+ return None
385
+ return elf
386
+
387
+ def _qemu_stdout(elf: Path) -> str | None:
388
+ try:
389
+ r = subprocess.run(
390
+ [_CORRECTNESS_QEMU, str(elf)],
391
+ capture_output=True, text=True, timeout=10)
392
+ return r.stdout if r.returncode == 0 else None
393
+ except (subprocess.TimeoutExpired, OSError):
394
+ return None
395
+
396
+ def _outputs_match(out_a: str, out_b: str, rtol: float = 1e-4,
397
+ atol: float = 1e-6) -> bool:
398
+ ta = out_a.strip().split()
399
+ tb = out_b.strip().split()
400
+ if len(ta) != len(tb):
401
+ return False
402
+ for a, b in zip(ta, tb):
403
+ if a == b:
404
+ continue
405
+ try:
406
+ fa, fb = float(a), float(b)
407
+ except ValueError:
408
+ return False
409
+ if fa != fa and fb != fb:
410
+ continue
411
+ diff = abs(fa - fb)
412
+ tol = atol + rtol * max(abs(fa), abs(fb))
413
+ if diff > tol:
414
+ return False
415
+ return True
416
+
417
+ def _bcy(vid, basm):
418
+ if vid not in _BASELINE:
419
+ try:
420
+ from arm_gym.mca import run_mca
421
+ _BASELINE[vid] = float(run_mca(basm, _vcfg().mca_bin,
422
+ _vcfg().mcpu).total_cycles)
423
+ except Exception:
424
+ _BASELINE[vid] = 1000.0
425
+ return _BASELINE[vid]
426
+
427
+ @dataclass
428
+ class _E:
429
+ assembles: bool = False
430
+ runs: bool = False
431
+ speedup: float = 0.0
432
+
433
+ _DBG_N = 0
434
+ _DBG_MAX = 5
435
+
436
+ def _eval(text, vid, basm, c_src=""):
437
+ global _DBG_N
438
+ k = hashlib.md5(f"{text}::{vid}".encode()).hexdigest()
439
+ with _LOCK:
440
+ if k in _CACHE:
441
+ return _CACHE[k]
442
+ e = _E()
443
+ asm = _extract_asm(text)
444
+ from arm_gym.verifier import assemble, cleanup_temp_dirs
445
+ try:
446
+ obj, err = assemble(asm, _vcfg())
447
+ if err or obj is None:
448
+ if _DBG_N < _DBG_MAX:
449
+ with _LOCK:
450
+ if _DBG_N < _DBG_MAX:
451
+ _DBG_N += 1
452
+ log.info("[reward-dbg #%d] vid=%s err=%r",
453
+ _DBG_N, vid,
454
+ err.message[:200] if err else "None")
455
+ log.info("[reward-dbg] raw[:300]=%r", text[:300])
456
+ log.info("[reward-dbg] asm[:300]=%r", asm[:300])
457
+ cleanup_temp_dirs()
458
+ with _LOCK:
459
+ _CACHE[k] = e
460
+ return e
461
+ e.assembles = True
462
+ if c_src:
463
+ ref_elf = _get_reference_elf(vid, c_src)
464
+ cand_elf = _link_candidate_elf(obj, vid, c_src)
465
+ if ref_elf and cand_elf:
466
+ ref_out = _qemu_stdout(ref_elf)
467
+ cand_out = _qemu_stdout(cand_elf)
468
+ if ref_out is not None and cand_out is not None:
469
+ e.runs = _outputs_match(cand_out, ref_out)
470
+ if e.runs:
471
+ log.info("[correctness] PASS vid=%s", vid)
472
+ elif _DBG_N < _DBG_MAX:
473
+ with _LOCK:
474
+ if _DBG_N < _DBG_MAX:
475
+ _DBG_N += 1
476
+ log.info("[correctness] FAIL vid=%s "
477
+ "ref=%r cand=%r",
478
+ vid, ref_out[:200], cand_out[:200])
479
+ if e.runs:
480
+ from arm_gym.mca import run_mca
481
+ bc = _bcy(vid, basm)
482
+ rep = run_mca(asm, _vcfg().mca_bin, _vcfg().mcpu)
483
+ e.speedup = bc / max(rep.total_cycles, 1)
484
+ cleanup_temp_dirs()
485
+ except Exception as ex:
486
+ log.debug("eval err vid=%s: %s", vid, ex)
487
+ try:
488
+ cleanup_temp_dirs()
489
+ except Exception:
490
+ pass
491
+ with _LOCK:
492
+ _CACHE[k] = e
493
+ return e
494
+
495
+ def _prep(completions, kw):
496
+ texts = [c[-1]["content"] if isinstance(c, list) else str(c)
497
+ for c in (completions or [])]
498
+ n = len(texts)
499
+ vids = list(kw.get("variant_id") or [""] * n)
500
+ bs = list(kw.get("baseline_asm") or [""] * n)
501
+ cs = list(kw.get("c_source") or [""] * n)
502
+ if len(vids) == 1 and n > 1:
503
+ vids *= n
504
+ bs *= n
505
+ cs *= n
506
+ return texts, vids, bs, cs
507
+
508
+ _FMT_DBG_N = 0
509
+ _FMT_DBG_MAX = 8
510
+
511
+ def format_reward(prompts=None, completions=None, **kw):
512
+ global _FMT_DBG_N
513
+ texts, _, _, _ = _prep(completions, kw)
514
+ if _FMT_DBG_N < _FMT_DBG_MAX and texts:
515
+ _FMT_DBG_N += 1
516
+ h = texts[0]
517
+ log.info("[completion #%d] len=%d first500=%r", _FMT_DBG_N, len(h), h[:500])
518
+ log.info("[completion #%d] last200=%r", _FMT_DBG_N, h[-200:])
519
+ scores = []
520
+ for text in texts:
521
+ low = text.lower()
522
+ s = 0.0
523
+ if "<assembly>" in low:
524
+ s += 0.3
525
+ if "</assembly>" in low:
526
+ s += 0.3
527
+ body = _extract_asm(text)
528
+ if len(re.sub(r"\s", "", body)) >= 20:
529
+ s += 0.4
530
+ if any(m in low for m in ("```", "<think>", "explain", "analysis")):
531
+ s -= 0.5
532
+ scores.append(max(-0.5, s))
533
+ return scores
534
+
535
+ def syntax_reward(prompts=None, completions=None, **kw):
536
+ t, v, b, c = _prep(completions, kw)
537
+ return [3.0 if _eval(x, vi, bi, ci).assembles else 0.0
538
+ for x, vi, bi, ci in zip(t, v, b, c)]
539
+
540
+ def correctness_reward(prompts=None, completions=None, **kw):
541
+ t, v, b, c = _prep(completions, kw)
542
+ return [5.0 if _eval(x, vi, bi, ci).runs else 0.0
543
+ for x, vi, bi, ci in zip(t, v, b, c)]
544
+
545
+ def speedup_reward(prompts=None, completions=None, **kw):
546
+ t, v, b, c = _prep(completions, kw)
547
+ return [max(0.0, _eval(x, vi, bi, ci).speedup - 1.0)
548
+ if _eval(x, vi, bi, ci).runs else 0.0
549
+ for x, vi, bi, ci in zip(t, v, b, c)]
550
+
551
+
552
+ # ── MODEL LOADING (bf16 + LoRA β€” no Unsloth) ─────────────────────────────────
553
+ QWEN25_EOS_IDS = (151645, 151643)
554
+
555
+ def load_model(cfg):
556
+ from transformers import AutoModelForCausalLM, AutoTokenizer
557
+
558
+ tok = AutoTokenizer.from_pretrained(cfg.model_id)
559
+ model = AutoModelForCausalLM.from_pretrained(
560
+ cfg.model_id,
561
+ dtype=torch.bfloat16,
562
+ device_map={"": 0},
563
+ attn_implementation="eager",
564
+ )
565
+ log.info("Model dtype: %s device: %s",
566
+ next(model.parameters()).dtype,
567
+ next(model.parameters()).device)
568
+
569
+ if hasattr(model, "enable_input_require_grads"):
570
+ model.enable_input_require_grads()
571
+
572
+ from peft import LoraConfig, get_peft_model
573
+ lora = LoraConfig(
574
+ r=cfg.lora_rank, lora_alpha=cfg.lora_alpha, lora_dropout=0.0,
575
+ bias="none", task_type="CAUSAL_LM",
576
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
577
+ "gate_proj", "up_proj", "down_proj"])
578
+ model = get_peft_model(model, lora)
579
+ log.info("bf16+LoRA model loaded")
580
+
581
+ if tok.pad_token is None:
582
+ tok.pad_token = tok.eos_token
583
+ tok.pad_token_id = tok.eos_token_id
584
+ tok.truncation_side = "left"
585
+
586
+ eos_ids = []
587
+ for tok_str in ("<|im_end|>", "<|endoftext|>"):
588
+ tid = tok.convert_tokens_to_ids(tok_str)
589
+ if isinstance(tid, int) and tid >= 0 and tid != tok.unk_token_id:
590
+ if tid not in eos_ids:
591
+ eos_ids.append(tid)
592
+ if not eos_ids:
593
+ eos_ids = list(QWEN25_EOS_IDS)
594
+ tok.eos_token_id = eos_ids[0]
595
+ model.config.eos_token_id = eos_ids[0]
596
+ gc = getattr(model, "generation_config", None)
597
+ if gc is not None:
598
+ gc.eos_token_id = eos_ids[0] if len(eos_ids) == 1 else eos_ids
599
+ log.info("EOS ids=%s pad=%d trunc_side=%s",
600
+ eos_ids, tok.pad_token_id, tok.truncation_side)
601
+
602
+ model.print_trainable_parameters()
603
+
604
+ import types
605
+ _base_generate = type(model).generate
606
+ def _gc_safe_generate(self, *args, **kwargs):
607
+ self.gradient_checkpointing_disable()
608
+ try:
609
+ return _base_generate(self, *args, **kwargs)
610
+ finally:
611
+ self.gradient_checkpointing_enable(
612
+ gradient_checkpointing_kwargs={"use_reentrant": False})
613
+ model.generate = types.MethodType(_gc_safe_generate, model)
614
+ log.info("Patched model.generate to toggle gradient_checkpointing off/on")
615
+
616
+ return model, tok
617
+
618
+
619
+ # ── GRPO CONFIG ───────────────────────────────────────────────────────────────
620
+ def build_grpo_config(cfg):
621
+ from trl import GRPOConfig
622
+ gen_kwargs = {
623
+ "eos_token_id": list(QWEN25_EOS_IDS),
624
+ "top_p": 0.9,
625
+ "top_k": 40,
626
+ }
627
+ p = dict(
628
+ output_dir=cfg.out_dir,
629
+ max_steps=cfg.steps,
630
+ learning_rate=cfg.learning_rate,
631
+ warmup_steps=cfg.warmup_steps,
632
+ lr_scheduler_type="constant_with_warmup",
633
+ gradient_accumulation_steps=cfg.gradient_accumulation_steps,
634
+ per_device_train_batch_size=cfg.per_device_train_batch_size,
635
+ num_generations=cfg.num_generations,
636
+ generation_batch_size=cfg.num_generations,
637
+ max_prompt_length=cfg.max_prompt_length,
638
+ max_completion_length=cfg.max_completion_length,
639
+ mask_truncated_completions=True,
640
+ gradient_checkpointing=True,
641
+ gradient_checkpointing_kwargs={"use_reentrant": False},
642
+ bf16=True,
643
+ max_grad_norm=1.0,
644
+ temperature=cfg.temperature,
645
+ generation_kwargs=gen_kwargs,
646
+ loss_type="grpo",
647
+ beta=0.04,
648
+ epsilon=0.2,
649
+ remove_unused_columns=False,
650
+ logging_steps=1,
651
+ save_steps=cfg.save_steps,
652
+ save_total_limit=1,
653
+ report_to="none",
654
+ )
655
+ while True:
656
+ try:
657
+ return GRPOConfig(**p)
658
+ except TypeError as e:
659
+ m = re.search(r"unexpected keyword argument '(\w+)'", str(e))
660
+ if not m:
661
+ raise
662
+ log.warning("Dropping unsupported GRPOConfig param: %r", m.group(1))
663
+ p.pop(m.group(1), None)
664
+
665
+
666
+ # ── TRAIN ─────────────────────────────────────────────────────────────────────
667
+ out = Path(cfg.out_dir)
668
+ out.mkdir(parents=True, exist_ok=True)
669
+ (out / "config.json").write_text(json.dumps(asdict(cfg), indent=2))
670
+
671
+ model, tok = load_model(cfg)
672
+ train_ds, eval_ds = build_dataset(cfg, tok)
673
+
674
+ from trl import GRPOTrainer
675
+ trainer = GRPOTrainer(
676
+ model=model,
677
+ reward_funcs=[format_reward, syntax_reward, correctness_reward,
678
+ speedup_reward],
679
+ args=build_grpo_config(cfg),
680
+ train_dataset=train_ds,
681
+ eval_dataset=eval_ds,
682
+ processing_class=tok,
683
+ )
684
+
685
+ # ── PRE-TRAINING SANITY CHECK ─────────────────────────────────────────────────
686
+ log.info("Running pre-training generation sanity check...")
687
+ with torch.no_grad():
688
+ sample = train_ds[0]["prompt"]
689
+ inputs = tok(sample, return_tensors="pt", truncation=True,
690
+ max_length=cfg.max_prompt_length).to(model.device)
691
+ log.info("Sanity input: %d tokens, device=%s, dtype=%s",
692
+ inputs["input_ids"].shape[1], inputs["input_ids"].device,
693
+ next(model.parameters()).dtype)
694
+ gen_ids = model.generate(
695
+ **inputs, max_new_tokens=128, temperature=0.5, top_p=0.9, top_k=40,
696
+ do_sample=True, eos_token_id=list(QWEN25_EOS_IDS),
697
+ )
698
+ new_ids = gen_ids[0][inputs["input_ids"].shape[1]:]
699
+ gen_text = tok.decode(new_ids, skip_special_tokens=True)
700
+ log.info("SANITY OUTPUT (%d tokens): %r", len(new_ids), gen_text[:500])
701
+ if "stringodzi" in gen_text or len(set(gen_text.split())) < 5:
702
+ log.error("BASE MODEL IS GENERATING GIBBERISH β€” dtype or device issue!")
703
+ else:
704
+ log.info("Base model generates coherent text. GRPO should learn.")
705
+
706
+ log.info("=" * 60)
707
+ log.info("TRAINING START: %d steps | G=%d | %s",
708
+ cfg.steps, cfg.num_generations, cfg.model_id)
709
+ log.info("=" * 60)
710
+ t0 = time.time()
711
+ try:
712
+ trainer.train()
713
+ finally:
714
+ rows = getattr(trainer.state, "log_history", [])
715
+ all_keys = list(dict.fromkeys(k for r in rows for k in r.keys()))
716
+ with open(out / "log.csv", "w", newline="") as f:
717
+ if all_keys:
718
+ w = csv.DictWriter(f, fieldnames=all_keys, extrasaction="ignore")
719
+ w.writeheader()
720
+ w.writerows(rows)
721
+ log.info("Logged %d rows to log.csv", len(rows))
722
+
723
+ elapsed = time.time() - t0
724
+ log.info("Training done in %.0fs (%.1f min)", elapsed, elapsed / 60)
725
+
726
+ try:
727
+ trainer.save_model(str(out / "lora-adapter"))
728
+ tok.save_pretrained(str(out / "lora-adapter"))
729
+ log.info("LoRA adapter saved to %s", out / "lora-adapter")
730
+ except Exception as e:
731
+ log.error("Failed to save LoRA adapter: %s", e)
732
+
733
+ if hf_tok := os.environ.get("HF_TOKEN"):
734
+ try:
735
+ from huggingface_hub import HfApi
736
+ api = HfApi(token=hf_tok)
737
+ api.create_repo(cfg.hub_model_id, repo_type="model",
738
+ exist_ok=True, private=False)
739
+ _msg = (
740
+ f"v6 GRPO LoRA: profile={os.environ.get('ARMGYM_PROFILE', 'mvp')} "
741
+ f"steps={cfg.steps} (harness correctness)"
742
+ )
743
+ api.upload_folder(
744
+ folder_path=str(out), repo_id=cfg.hub_model_id,
745
+ repo_type="model", commit_message=_msg[:200],
746
+ )
747
+ log.info("Uploaded to https://huggingface.co/%s", cfg.hub_model_id)
748
+ except Exception as e:
749
+ log.error("HF Hub upload failed: %s", e)
750
+
751
+ log.info("==== V6 COMPLETE ====")