anthropic-kernel / atempt_1 /perf_takehome.py
algorembrant's picture
Upload 39 files
f3ce0b0 verified
"""
# Anthropic's Original Performance Engineering Take-home (Release version)
Copyright Anthropic PBC 2026. Permission is granted to modify and use, but not
to publish or redistribute your solutions so it's hard to find spoilers.
# Task
- Optimize the kernel (in KernelBuilder.build_kernel) as much as possible in the
available time, as measured by test_kernel_cycles on a frozen separate copy
of the simulator.
Validate your results using `python tests/submission_tests.py` without modifying
anything in the tests/ folder.
We recommend you look through problem.py next.
"""
from collections import defaultdict
import random
import unittest
from problem import (
Engine,
DebugInfo,
SLOT_LIMITS,
VLEN,
N_CORES,
SCRATCH_SIZE,
Machine,
Tree,
Input,
HASH_STAGES,
reference_kernel,
build_mem_image,
reference_kernel2,
)
class KernelBuilder:
def __init__(self):
self.instrs = []
self.scratch = {}
self.scratch_debug = {}
self.scratch_ptr = 0
self.const_map = {}
def debug_info(self):
return DebugInfo(scratch_map=self.scratch_debug)
def build(self, slots: list[tuple[Engine, tuple]], vliw: bool = False):
# We need a proper packer now to put multiple operations in one instruction
# Simple greedy packer
instrs = []
current_instr = defaultdict(list)
# Sort slots by priority/constraints if needed, but FIFO is okay for now
# We need to respect SLOT_LIMITS
for engine, args in slots:
# check if current_instr has space
if len(current_instr[engine]) < SLOT_LIMITS[engine]:
current_instr[engine].append(args)
else:
# Flush current instruction
instrs.append(dict(current_instr))
current_instr = defaultdict(list)
current_instr[engine].append(args)
if current_instr:
instrs.append(dict(current_instr))
return instrs
def add_instr(self, instr_dict):
self.instrs.append(instr_dict)
def alloc_scratch(self, name=None, length=1):
addr = self.scratch_ptr
if name is not None:
self.scratch[name] = addr
self.scratch_debug[addr] = (name, length)
self.scratch_ptr += length
assert self.scratch_ptr <= SCRATCH_SIZE, f"Out of scratch space: {self.scratch_ptr}"
return addr
def scratch_const(self, val, name=None):
if val not in self.const_map:
addr = self.alloc_scratch(name)
# We can only load constants using 'load' engine or 'flow' add_imm
# But the simplest is using the 'const' op in 'load' engine
self.instrs.append({"load": [("const", addr, val)]})
self.const_map[val] = addr
return self.const_map[val]
def scratch_vec_const(self, val, name=None):
# Create a vector constant (broadcasted)
key = (val, "vec")
if key not in self.const_map:
addr = self.alloc_scratch(name if name else f"vconst_{val}", VLEN)
scalar_addr = self.scratch_const(val)
self.add_instr({"valu": [("vbroadcast", addr, scalar_addr)]})
self.const_map[key] = addr
return self.const_map[key]
def build_hash_opt(self, val_vec, tmp1_vec, tmp2_vec):
"""
Generates slots for the strength-reduced hash function.
Returns LIST OF LISTS of ops. Each inner list is a stage that must be completed before next.
"""
stages = []
# Stage 0: MAD
c1 = self.scratch_vec_const(0x7ED55D16, "h0_c")
m1 = self.scratch_vec_const(1 + (1<<12), "h0_m")
stages.append([("valu", ("multiply_add", val_vec, val_vec, m1, c1))])
# Stage 1: Xor, Shift, Xor
c2 = self.scratch_vec_const(0xC761C23C, "h1_c")
s2 = self.scratch_vec_const(19, "h1_s")
# These 3 ops have dependencies: tmp1(val), tmp2(val), val(tmp1,tmp2).
# We can split into 2 sub-stages:
# 1a: tmp1 = ..., tmp2 = ...
# 1b: val = ...
stages.append([
("valu", ("^", tmp1_vec, val_vec, c2)),
("valu", (">>", tmp2_vec, val_vec, s2))
])
stages.append([("valu", ("^", val_vec, tmp1_vec, tmp2_vec))])
# Stage 2: MAD
c3 = self.scratch_vec_const(0x165667B1, "h2_c")
m3 = self.scratch_vec_const(1 + (1<<5), "h2_m")
stages.append([("valu", ("multiply_add", val_vec, val_vec, m3, c3))])
# Stage 3: Add, Shift, Xor
c4 = self.scratch_vec_const(0xD3A2646C, "h3_c")
s4 = self.scratch_vec_const(9, "h3_s")
stages.append([
("valu", ("+", tmp1_vec, val_vec, c4)),
("valu", ("<<", tmp2_vec, val_vec, s4))
])
stages.append([("valu", ("^", val_vec, tmp1_vec, tmp2_vec))])
# Stage 4: MAD
c5 = self.scratch_vec_const(0xFD7046C5, "h4_c")
m5 = self.scratch_vec_const(1 + (1<<3), "h4_m")
stages.append([("valu", ("multiply_add", val_vec, val_vec, m5, c5))])
# Stage 5: Xor, Shift, Xor
c6 = self.scratch_vec_const(0xB55A4F09, "h5_c")
s6 = self.scratch_vec_const(16, "h5_s")
stages.append([
("valu", ("^", tmp1_vec, val_vec, c6)),
("valu", (">>", tmp2_vec, val_vec, s6))
])
stages.append([("valu", ("^", val_vec, tmp1_vec, tmp2_vec))])
return stages
def build_kernel(
self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
):
"""
Vectorized Wavefront implementation.
"""
# --- Memory Pointers ---
init_vars = [
"rounds", "n_nodes", "batch_size", "forest_height",
"forest_values_p", "inp_indices_p", "inp_values_p"
]
ptr_map = {}
tmp_load = self.alloc_scratch("tmp_load")
for i, v in enumerate(init_vars):
addr = self.alloc_scratch(v)
ptr_map[v] = addr
self.add_instr({"load": [("const", tmp_load, i)]})
self.add_instr({"load": [("load", addr, tmp_load)]})
indices_base = self.alloc_scratch("indices_cache", batch_size)
values_base = self.alloc_scratch("values_cache", batch_size)
# Memory Optimization: Reuse Scratch
# We need 2 Blocks for Temps:
# Block X: tmp_addrs -> node_vals -> vtmp1
# Block Y: vtmp2
block_x = self.alloc_scratch("block_x", batch_size)
block_y = self.alloc_scratch("block_y", batch_size)
num_vecs = batch_size // VLEN
tmp_addrs_base = block_x
node_vals_base = block_x # Alias safe (load dest same as addr source)
vtmp1_base = block_x # Alias safe (node_vals dead after Mix)
vtmp2_base = block_y
# Constants
const_0_vec = self.scratch_vec_const(0)
const_1_vec = self.scratch_vec_const(1)
global_n_nodes_vec = self.alloc_scratch("n_nodes_vec", VLEN)
self.add_instr({"valu": [("vbroadcast", global_n_nodes_vec, ptr_map["n_nodes"])]})
# --- 1. Load Input Data (Wavefront) ---
# Address Calc
ops = []
for i in range(0, batch_size, VLEN):
i_const = self.scratch_const(i)
# Indices Addr
ops.append(("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const)))
self.instrs.extend(self.build(ops)) # This reuses tmp_load rapidly?
# WAIT! tmp_load is reused. Danger.
# alu writes tmp_load. Next alu overwrites.
# We need unique tmp_load per op? Or serialize.
# Serializing Init Load is fine (it runs once).
# Let's keep Init Load simple/sequential.
for i in range(0, batch_size, VLEN):
i_const = self.scratch_const(i)
self.add_instr({"alu": [("+", tmp_load, ptr_map["inp_indices_p"], i_const)]})
self.add_instr({"load": [("vload", indices_base + i, tmp_load)]})
self.add_instr({"alu": [("+", tmp_load, ptr_map["inp_values_p"], i_const)]})
self.add_instr({"load": [("vload", values_base + i, tmp_load)]})
# --- 2. Main Loop ---
self.add_instr({"flow": [("pause",)]})
self.add_instr({"debug": [("comment", "Starting Computed Loop")]})
# Unrolled Loop for 'rounds'
for r in range(rounds):
self.add_instr({"debug": [("comment", f"Round {r}")]})
# --- Wavefront Body ---
# Collect register pointers for all vectors
vecs = []
for vec_i in range(num_vecs):
offset = vec_i * VLEN
vecs.append({
'idx': indices_base + offset,
'val': values_base + offset,
'node': node_vals_base + offset,
'tmp1': vtmp1_base + offset,
'tmp2': vtmp2_base + offset,
'addr': tmp_addrs_base + offset
})
if r == 0:
# Round 0: 1 Node (0)
scalar_node = self.alloc_scratch("scalar_node_r0")
self.add_instr({"load": [("load", scalar_node, ptr_map["forest_values_p"])]})
ops = []
for vec in vecs:
ops.append(("valu", ("vbroadcast", vec['node'], scalar_node)))
self.instrs.extend(self.build(ops))
else:
# Genetic Wavefront Load
# Wave A: Address Calc (All Vecs)
ops = []
for vec in vecs:
for lane in range(VLEN):
ops.append(("alu", ("+", vec['addr'] + lane, ptr_map["forest_values_p"], vec['idx'] + lane)))
self.instrs.extend(self.build(ops))
# Wave B: Load Node Vals (All Vecs)
ops = []
for vec in vecs:
for lane in range(VLEN):
ops.append(("load", ("load", vec['node'] + lane, vec['addr'] + lane)))
self.instrs.extend(self.build(ops))
# Wave C: Hash Ops (All Vecs)
# Mix
ops = []
for vec in vecs:
ops.append(("valu", ("^", vec['val'], vec['val'], vec['node'])))
self.instrs.extend(self.build(ops))
# Hash Stages
all_stages = [] # list of 32 stage-lists
for vec in vecs:
all_stages.append(self.build_hash_opt(vec['val'], vec['tmp1'], vec['tmp2']))
num_stages = len(all_stages[0])
for s in range(num_stages):
wave_ops = []
for v_stages in all_stages:
for op in v_stages[s]:
wave_ops.append(op)
self.instrs.extend(self.build(wave_ops))
# Wave D: Update Index
# Step 1: &
ops = []
for vec in vecs:
ops.append(("valu", ("&", vec['tmp1'], vec['val'], const_1_vec)))
self.instrs.extend(self.build(ops))
# Step 2: + step
ops = []
for vec in vecs:
ops.append(("valu", ("+", vec['tmp1'], vec['tmp1'], const_1_vec)))
self.instrs.extend(self.build(ops))
# Step 3: idx * 2
ops = []
for vec in vecs:
ops.append(("valu", ("+", vec['idx'], vec['idx'], vec['idx'])))
self.instrs.extend(self.build(ops))
# Step 4: idx + step
ops = []
for vec in vecs:
ops.append(("valu", ("+", vec['idx'], vec['idx'], vec['tmp1'])))
self.instrs.extend(self.build(ops))
# Wave E: Wrap Index
# Mask
ops = []
for vec in vecs:
ops.append(("valu", ("<", vec['tmp1'], vec['idx'], global_n_nodes_vec)))
self.instrs.extend(self.build(ops))
# Select
ops = []
for vec in vecs:
ops.append(("flow", ("vselect", vec['idx'], vec['tmp1'], vec['idx'], const_0_vec)))
self.instrs.extend(self.build(ops))
# End Unrolled Loop
# --- 3. Final Store ---
for i in range(0, batch_size, VLEN):
i_const = self.scratch_const(i)
self.add_instr({"alu": [("+", tmp_load, ptr_map["inp_indices_p"], i_const)]})
self.add_instr({"store": [("vstore", tmp_load, indices_base + i)]})
self.add_instr({"alu": [("+", tmp_load, ptr_map["inp_values_p"], i_const)]})
self.add_instr({"store": [("vstore", tmp_load, values_base + i)]})
self.add_instr({"flow": [("pause",)]})
BASELINE = 147734
def do_kernel_test(
forest_height: int,
rounds: int,
batch_size: int,
seed: int = 123,
trace: bool = False,
prints: bool = False,
):
print(f"{forest_height=}, {rounds=}, {batch_size=}")
random.seed(seed)
forest = Tree.generate(forest_height)
inp = Input.generate(forest, batch_size, rounds)
mem = build_mem_image(forest, inp)
kb = KernelBuilder()
kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds)
# print(kb.instrs)
value_trace = {}
machine = Machine(
mem,
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
value_trace=value_trace,
trace=trace,
)
machine.prints = prints
for i, ref_mem in enumerate(reference_kernel2(mem, value_trace)):
machine.run()
inp_values_p = ref_mem[6]
if prints:
print(machine.mem[inp_values_p : inp_values_p + len(inp.values)])
print(ref_mem[inp_values_p : inp_values_p + len(inp.values)])
assert (
machine.mem[inp_values_p : inp_values_p + len(inp.values)]
== ref_mem[inp_values_p : inp_values_p + len(inp.values)]
), f"Incorrect result on round {i}"
inp_indices_p = ref_mem[5]
if prints:
print(machine.mem[inp_indices_p : inp_indices_p + len(inp.indices)])
print(ref_mem[inp_indices_p : inp_indices_p + len(inp.indices)])
# Updating these in memory isn't required, but you can enable this check for debugging
# assert machine.mem[inp_indices_p:inp_indices_p+len(inp.indices)] == ref_mem[inp_indices_p:inp_indices_p+len(inp.indices)]
print("CYCLES: ", machine.cycle)
print("Speedup over baseline: ", BASELINE / machine.cycle)
return machine.cycle
class Tests(unittest.TestCase):
def test_ref_kernels(self):
"""
Test the reference kernels against each other
"""
random.seed(123)
for i in range(10):
f = Tree.generate(4)
inp = Input.generate(f, 10, 6)
mem = build_mem_image(f, inp)
reference_kernel(f, inp)
for _ in reference_kernel2(mem, {}):
pass
assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)]
assert inp.values == mem[mem[6] : mem[6] + len(inp.values)]
def test_kernel_trace(self):
# Full-scale example for performance testing
do_kernel_test(10, 16, 256, trace=True, prints=False)
# Passing this test is not required for submission, see submission_tests.py for the actual correctness test
# You can uncomment this if you think it might help you debug
# def test_kernel_correctness(self):
# for batch in range(1, 3):
# for forest_height in range(3):
# do_kernel_test(
# forest_height + 2, forest_height + 4, batch * 16 * VLEN * N_CORES
# )
def test_kernel_cycles(self):
do_kernel_test(10, 16, 256)
# To run all the tests:
# python perf_takehome.py
# To run a specific test:
# python perf_takehome.py Tests.test_kernel_cycles
# To view a hot-reloading trace of all the instructions: **Recommended debug loop**
# NOTE: The trace hot-reloading only works in Chrome. In the worst case if things aren't working, drag trace.json onto https://ui.perfetto.dev/
# python perf_takehome.py Tests.test_kernel_trace
# Then run `python watch_trace.py` in another tab, it'll open a browser tab, then click "Open Perfetto"
# You can then keep that open and re-run the test to see a new trace.
# To run the proper checks to see which thresholds you pass:
# python tests/submission_tests.py
if __name__ == "__main__":
unittest.main()