"""Evaluate the trained LoRA adapter against gcc -O3 baseline.
Usage:
python scripts/evaluate.py \
--checkpoint colab/results/runs/grpo/checkpoint-200 \
--samples 8
Outputs:
- Side-by-side assembly for each kernel
- MCA cycles: baseline vs trained
- Win rate (how often the model beats gcc -O3)
"""
from __future__ import annotations
import argparse
import os
import re
import sys
from pathlib import Path
PROJECT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT))
os.environ["PATH"] = "/opt/homebrew/opt/llvm/bin:" + os.environ.get("PATH", "")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from arm_gym.compile_baseline import detect_toolchain, compile_to_asm
from arm_gym.mca import run_mca
from arm_gym.kernels import generate_all, split_train_eval, TEMPLATES
from kaggle.dataset import SYSTEM_PROMPT, user_prompt
def parse_assembly(text: str) -> str | None:
m = re.search(r"(.*?)", text, re.DOTALL)
if m:
return m.group(1).strip()
# fallback: look for .text section
if ".text" in text or ".global" in text:
return text.strip()
return None
def load_model(checkpoint: Path, device: str):
cfg_path = checkpoint / "adapter_config.json"
import json
base_model_id = json.loads(cfg_path.read_text())["base_model_name_or_path"]
print(f"Base model: {base_model_id}")
print(f"Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(checkpoint)
print(f"Loading base model (float16)...")
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map={"": device},
low_cpu_mem_usage=True,
)
print(f"Loading LoRA adapter from {checkpoint}...")
model = PeftModel.from_pretrained(model, str(checkpoint))
model.eval()
print("Model ready.\n")
return model, tok
def generate(model, tok, c_source: str, baseline_asm: str, device: str, max_new_tokens: int = 512) -> str:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt(c_source, baseline_asm)},
]
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.5,
pad_token_id=tok.eos_token_id,
)
new_tokens = out[0][inputs["input_ids"].shape[1]:]
return tok.decode(new_tokens, skip_special_tokens=True)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", type=Path,
default=PROJECT / "colab/results/runs/grpo/checkpoint-200")
ap.add_argument("--samples", type=int, default=8)
ap.add_argument("--device", default="mps" if torch.backends.mps.is_available() else "cpu")
ap.add_argument("--difficulty", type=int, default=1)
args = ap.parse_args()
print(f"Device: {args.device}")
print(f"Checkpoint: {args.checkpoint}\n")
tc = detect_toolchain()
print(f"Toolchain: clang={tc.clang} mca={tc.mca} mcpu={tc.mcpu}\n")
assert tc.ready() and tc.mca, "Run: brew install llvm"
model, tok = load_model(args.checkpoint, args.device)
variants = [v for v in generate_all()
if TEMPLATES[v.template_name].difficulty <= args.difficulty]
_, eval_variants = split_train_eval(variants, eval_frac=0.1, seed=0)
eval_variants = eval_variants[:args.samples]
results = []
sep = "-" * 72
for i, v in enumerate(eval_variants):
print(f"\n{'='*72}")
print(f"[{i+1}/{len(eval_variants)}] {v.variant_id}")
print(f"{'='*72}")
# Baseline
try:
baseline_asm = compile_to_asm(v.c_source, tc)
base_rep = run_mca(baseline_asm, tc.mca, tc.mcpu)
base_cycles = base_rep.total_cycles
except Exception as e:
print(f"Baseline compile failed: {e}")
continue
# C source
print("\nC Source:")
print(v.c_source.strip())
# Generate
print(f"\nGenerating assembly (device={args.device})...")
raw = generate(model, tok, v.c_source, baseline_asm, args.device)
trained_asm = parse_assembly(raw)
if trained_asm is None:
print("Model output (no tags found):")
print(raw[:500])
results.append({"variant": v.variant_id, "base": base_cycles,
"trained": None, "speedup": None, "status": "parse_fail"})
continue
# MCA on trained assembly
try:
tr_rep = run_mca(trained_asm, tc.mca, tc.mcpu)
trained_cycles = tr_rep.total_cycles
speedup = base_cycles / trained_cycles if trained_cycles > 0 else 0
beat = speedup > 1.0
status = "WIN" if beat else "LOSS"
except Exception as e:
print(f"MCA on trained asm failed: {e}")
print("Trained asm:")
print(trained_asm[:300])
results.append({"variant": v.variant_id, "base": base_cycles,
"trained": None, "speedup": None, "status": "mca_fail"})
continue
print(f"\nBaseline assembly (gcc -O3, {base_cycles} cycles):")
print(sep)
for line in baseline_asm.strip().split("\n")[:20]:
print(line)
if baseline_asm.count("\n") > 20:
print(" ... (truncated)")
print(f"\nTrained assembly ({trained_cycles} cycles):")
print(sep)
for line in trained_asm.strip().split("\n")[:20]:
print(line)
if trained_asm.count("\n") > 20:
print(" ... (truncated)")
marker = "✓ BEATS gcc -O3" if beat else "✗ slower than gcc -O3"
print(f"\n{sep}")
print(f"Result: {status} | baseline={base_cycles} trained={trained_cycles} speedup={speedup:.2f}x {marker}")
results.append({"variant": v.variant_id, "base": base_cycles,
"trained": trained_cycles, "speedup": speedup, "status": status})
# Summary
print(f"\n{'='*72}")
print("SUMMARY")
print(f"{'='*72}")
valid = [r for r in results if r["speedup"] is not None]
wins = [r for r in valid if r["status"] == "WIN"]
parse_fails = sum(1 for r in results if r["status"] == "parse_fail")
mca_fails = sum(1 for r in results if r["status"] == "mca_fail")
print(f"{'Kernel':<35} {'Base':>6} {'Trained':>8} {'Speedup':>8} {'Result':>6}")
print("-" * 72)
for r in results:
sp = f"{r['speedup']:.2f}x" if r['speedup'] else "N/A"
tr = str(r['trained']) if r['trained'] else "N/A"
print(f"{r['variant'][:35]:<35} {r['base']:>6} {tr:>8} {sp:>8} {r['status']:>6}")
print(f"\nWin rate: {len(wins)}/{len(valid)} kernels beat gcc -O3 ({100*len(wins)/len(valid):.0f}%)" if valid else "No valid results")
if wins:
best = max(wins, key=lambda r: r["speedup"])
print(f"Best: {best['speedup']:.2f}x speedup on {best['variant']}")
if valid:
avg = sum(r["speedup"] for r in valid) / len(valid)
print(f"Avg speedup: {avg:.2f}x across all valid kernels")
if parse_fails:
print(f"Parse fails: {parse_fails} (model didn't use tags)")
if mca_fails:
print(f"MCA fails: {mca_fails} (invalid assembly syntax)")
print(f"\nHow to improve:")
print(f" - More steps (200 is early; 1000+ shows consistent speedup)")
print(f" - Larger model (7B on Colab with bitsandbytes fixed)")
print(f" - Difficulty curriculum (unlock stage 2+ kernels after 40% win rate)")
if __name__ == "__main__":
main()