| """Evaluate the trained LoRA adapter against gcc -O3 baseline. |
| |
| Usage: |
| python scripts/evaluate.py \ |
| --checkpoint colab/results/runs/grpo/checkpoint-200 \ |
| --samples 8 |
| |
| Outputs: |
| - Side-by-side assembly for each kernel |
| - MCA cycles: baseline vs trained |
| - Win rate (how often the model beats gcc -O3) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import re |
| import sys |
| from pathlib import Path |
|
|
| PROJECT = Path(__file__).parent.parent |
| sys.path.insert(0, str(PROJECT)) |
|
|
| os.environ["PATH"] = "/opt/homebrew/opt/llvm/bin:" + os.environ.get("PATH", "") |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
|
|
| from arm_gym.compile_baseline import detect_toolchain, compile_to_asm |
| from arm_gym.mca import run_mca |
| from arm_gym.kernels import generate_all, split_train_eval, TEMPLATES |
| from kaggle.dataset import SYSTEM_PROMPT, user_prompt |
|
|
|
|
| def parse_assembly(text: str) -> str | None: |
| m = re.search(r"<assembly>(.*?)</assembly>", text, re.DOTALL) |
| if m: |
| return m.group(1).strip() |
| |
| if ".text" in text or ".global" in text: |
| return text.strip() |
| return None |
|
|
|
|
| def load_model(checkpoint: Path, device: str): |
| cfg_path = checkpoint / "adapter_config.json" |
| import json |
| base_model_id = json.loads(cfg_path.read_text())["base_model_name_or_path"] |
| print(f"Base model: {base_model_id}") |
| print(f"Loading tokenizer...") |
| tok = AutoTokenizer.from_pretrained(checkpoint) |
|
|
| print(f"Loading base model (float16)...") |
| model = AutoModelForCausalLM.from_pretrained( |
| base_model_id, |
| torch_dtype=torch.float16, |
| device_map={"": device}, |
| low_cpu_mem_usage=True, |
| ) |
| print(f"Loading LoRA adapter from {checkpoint}...") |
| model = PeftModel.from_pretrained(model, str(checkpoint)) |
| model.eval() |
| print("Model ready.\n") |
| return model, tok |
|
|
|
|
| def generate(model, tok, c_source: str, baseline_asm: str, device: str, max_new_tokens: int = 512) -> str: |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_prompt(c_source, baseline_asm)}, |
| ] |
| prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = tok(prompt, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=0.5, |
| pad_token_id=tok.eos_token_id, |
| ) |
| new_tokens = out[0][inputs["input_ids"].shape[1]:] |
| return tok.decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--checkpoint", type=Path, |
| default=PROJECT / "colab/results/runs/grpo/checkpoint-200") |
| ap.add_argument("--samples", type=int, default=8) |
| ap.add_argument("--device", default="mps" if torch.backends.mps.is_available() else "cpu") |
| ap.add_argument("--difficulty", type=int, default=1) |
| args = ap.parse_args() |
|
|
| print(f"Device: {args.device}") |
| print(f"Checkpoint: {args.checkpoint}\n") |
|
|
| tc = detect_toolchain() |
| print(f"Toolchain: clang={tc.clang} mca={tc.mca} mcpu={tc.mcpu}\n") |
| assert tc.ready() and tc.mca, "Run: brew install llvm" |
|
|
| model, tok = load_model(args.checkpoint, args.device) |
|
|
| variants = [v for v in generate_all() |
| if TEMPLATES[v.template_name].difficulty <= args.difficulty] |
| _, eval_variants = split_train_eval(variants, eval_frac=0.1, seed=0) |
| eval_variants = eval_variants[:args.samples] |
|
|
| results = [] |
| sep = "-" * 72 |
|
|
| for i, v in enumerate(eval_variants): |
| print(f"\n{'='*72}") |
| print(f"[{i+1}/{len(eval_variants)}] {v.variant_id}") |
| print(f"{'='*72}") |
|
|
| |
| try: |
| baseline_asm = compile_to_asm(v.c_source, tc) |
| base_rep = run_mca(baseline_asm, tc.mca, tc.mcpu) |
| base_cycles = base_rep.total_cycles |
| except Exception as e: |
| print(f"Baseline compile failed: {e}") |
| continue |
|
|
| |
| print("\nC Source:") |
| print(v.c_source.strip()) |
|
|
| |
| print(f"\nGenerating assembly (device={args.device})...") |
| raw = generate(model, tok, v.c_source, baseline_asm, args.device) |
| trained_asm = parse_assembly(raw) |
|
|
| if trained_asm is None: |
| print("Model output (no <assembly> tags found):") |
| print(raw[:500]) |
| results.append({"variant": v.variant_id, "base": base_cycles, |
| "trained": None, "speedup": None, "status": "parse_fail"}) |
| continue |
|
|
| |
| try: |
| tr_rep = run_mca(trained_asm, tc.mca, tc.mcpu) |
| trained_cycles = tr_rep.total_cycles |
| speedup = base_cycles / trained_cycles if trained_cycles > 0 else 0 |
| beat = speedup > 1.0 |
| status = "WIN" if beat else "LOSS" |
| except Exception as e: |
| print(f"MCA on trained asm failed: {e}") |
| print("Trained asm:") |
| print(trained_asm[:300]) |
| results.append({"variant": v.variant_id, "base": base_cycles, |
| "trained": None, "speedup": None, "status": "mca_fail"}) |
| continue |
|
|
| print(f"\nBaseline assembly (gcc -O3, {base_cycles} cycles):") |
| print(sep) |
| for line in baseline_asm.strip().split("\n")[:20]: |
| print(line) |
| if baseline_asm.count("\n") > 20: |
| print(" ... (truncated)") |
|
|
| print(f"\nTrained assembly ({trained_cycles} cycles):") |
| print(sep) |
| for line in trained_asm.strip().split("\n")[:20]: |
| print(line) |
| if trained_asm.count("\n") > 20: |
| print(" ... (truncated)") |
|
|
| marker = "✓ BEATS gcc -O3" if beat else "✗ slower than gcc -O3" |
| print(f"\n{sep}") |
| print(f"Result: {status} | baseline={base_cycles} trained={trained_cycles} speedup={speedup:.2f}x {marker}") |
|
|
| results.append({"variant": v.variant_id, "base": base_cycles, |
| "trained": trained_cycles, "speedup": speedup, "status": status}) |
|
|
| |
| print(f"\n{'='*72}") |
| print("SUMMARY") |
| print(f"{'='*72}") |
| valid = [r for r in results if r["speedup"] is not None] |
| wins = [r for r in valid if r["status"] == "WIN"] |
| parse_fails = sum(1 for r in results if r["status"] == "parse_fail") |
| mca_fails = sum(1 for r in results if r["status"] == "mca_fail") |
|
|
| print(f"{'Kernel':<35} {'Base':>6} {'Trained':>8} {'Speedup':>8} {'Result':>6}") |
| print("-" * 72) |
| for r in results: |
| sp = f"{r['speedup']:.2f}x" if r['speedup'] else "N/A" |
| tr = str(r['trained']) if r['trained'] else "N/A" |
| print(f"{r['variant'][:35]:<35} {r['base']:>6} {tr:>8} {sp:>8} {r['status']:>6}") |
|
|
| print(f"\nWin rate: {len(wins)}/{len(valid)} kernels beat gcc -O3 ({100*len(wins)/len(valid):.0f}%)" if valid else "No valid results") |
| if wins: |
| best = max(wins, key=lambda r: r["speedup"]) |
| print(f"Best: {best['speedup']:.2f}x speedup on {best['variant']}") |
| if valid: |
| avg = sum(r["speedup"] for r in valid) / len(valid) |
| print(f"Avg speedup: {avg:.2f}x across all valid kernels") |
| if parse_fails: |
| print(f"Parse fails: {parse_fails} (model didn't use <assembly> tags)") |
| if mca_fails: |
| print(f"MCA fails: {mca_fails} (invalid assembly syntax)") |
|
|
| print(f"\nHow to improve:") |
| print(f" - More steps (200 is early; 1000+ shows consistent speedup)") |
| print(f" - Larger model (7B on Colab with bitsandbytes fixed)") |
| print(f" - Difficulty curriculum (unlock stage 2+ kernels after 40% win rate)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|