| """
|
| # Anthropic's Original Performance Engineering Take-home (Release version)
|
|
|
| Copyright Anthropic PBC 2026. Permission is granted to modify and use, but not
|
| to publish or redistribute your solutions so it's hard to find spoilers.
|
|
|
| # Task
|
|
|
| - Optimize the kernel (in KernelBuilder.build_kernel) as much as possible in the
|
| available time, as measured by test_kernel_cycles on a frozen separate copy
|
| of the simulator.
|
|
|
| Validate your results using `python tests/submission_tests.py` without modifying
|
| anything in the tests/ folder.
|
|
|
| We recommend you look through problem.py next.
|
| """
|
|
|
| from collections import defaultdict
|
| import random
|
| import unittest
|
|
|
| from problem import (
|
| Engine,
|
| DebugInfo,
|
| SLOT_LIMITS,
|
| VLEN,
|
| N_CORES,
|
| SCRATCH_SIZE,
|
| Machine,
|
| Tree,
|
| Input,
|
| HASH_STAGES,
|
| reference_kernel,
|
| build_mem_image,
|
| reference_kernel2,
|
| )
|
|
|
|
|
| class KernelBuilder:
|
| def __init__(self):
|
| self.instrs = []
|
| self.scratch = {}
|
| self.scratch_debug = {}
|
| self.scratch_ptr = 0
|
| self.const_map = {}
|
|
|
| def debug_info(self):
|
| return DebugInfo(scratch_map=self.scratch_debug)
|
|
|
| def build(self, slots: list[tuple[Engine, tuple]], vliw: bool = False):
|
|
|
|
|
| instrs = []
|
| current_instr = defaultdict(list)
|
|
|
|
|
|
|
|
|
| for engine, args in slots:
|
|
|
| if len(current_instr[engine]) < SLOT_LIMITS[engine]:
|
| current_instr[engine].append(args)
|
| else:
|
|
|
| instrs.append(dict(current_instr))
|
| current_instr = defaultdict(list)
|
| current_instr[engine].append(args)
|
|
|
| if current_instr:
|
| instrs.append(dict(current_instr))
|
|
|
| return instrs
|
|
|
| def add_instr(self, instr_dict):
|
| self.instrs.append(instr_dict)
|
|
|
| def alloc_scratch(self, name=None, length=1):
|
| addr = self.scratch_ptr
|
| if name is not None:
|
| self.scratch[name] = addr
|
| self.scratch_debug[addr] = (name, length)
|
| self.scratch_ptr += length
|
| assert self.scratch_ptr <= SCRATCH_SIZE, f"Out of scratch space: {self.scratch_ptr}"
|
| return addr
|
|
|
| def scratch_const(self, val, name=None):
|
| if val not in self.const_map:
|
| addr = self.alloc_scratch(name)
|
|
|
|
|
| self.instrs.append({"load": [("const", addr, val)]})
|
| self.const_map[val] = addr
|
| return self.const_map[val]
|
| def scratch_vec_const(self, val, name=None):
|
|
|
| key = (val, "vec")
|
| if key not in self.const_map:
|
| addr = self.alloc_scratch(name if name else f"vconst_{val}", VLEN)
|
| scalar_addr = self.scratch_const(val)
|
| self.add_instr({"valu": [("vbroadcast", addr, scalar_addr)]})
|
| self.const_map[key] = addr
|
| return self.const_map[key]
|
|
|
| def build_hash_opt(self, val_vec, tmp1_vec, tmp2_vec):
|
| """
|
| Generates slots for the strength-reduced hash function.
|
| Returns LIST OF LISTS of ops. Each inner list is a stage that must be completed before next.
|
| """
|
| stages = []
|
|
|
|
|
| c1 = self.scratch_vec_const(0x7ED55D16, "h0_c")
|
| m1 = self.scratch_vec_const(1 + (1<<12), "h0_m")
|
| stages.append([("valu", ("multiply_add", val_vec, val_vec, m1, c1))])
|
|
|
|
|
| c2 = self.scratch_vec_const(0xC761C23C, "h1_c")
|
| s2 = self.scratch_vec_const(19, "h1_s")
|
|
|
|
|
|
|
|
|
| stages.append([
|
| ("valu", ("^", tmp1_vec, val_vec, c2)),
|
| ("valu", (">>", tmp2_vec, val_vec, s2))
|
| ])
|
| stages.append([("valu", ("^", val_vec, tmp1_vec, tmp2_vec))])
|
|
|
|
|
| c3 = self.scratch_vec_const(0x165667B1, "h2_c")
|
| m3 = self.scratch_vec_const(1 + (1<<5), "h2_m")
|
| stages.append([("valu", ("multiply_add", val_vec, val_vec, m3, c3))])
|
|
|
|
|
| c4 = self.scratch_vec_const(0xD3A2646C, "h3_c")
|
| s4 = self.scratch_vec_const(9, "h3_s")
|
| stages.append([
|
| ("valu", ("+", tmp1_vec, val_vec, c4)),
|
| ("valu", ("<<", tmp2_vec, val_vec, s4))
|
| ])
|
| stages.append([("valu", ("^", val_vec, tmp1_vec, tmp2_vec))])
|
|
|
|
|
| c5 = self.scratch_vec_const(0xFD7046C5, "h4_c")
|
| m5 = self.scratch_vec_const(1 + (1<<3), "h4_m")
|
| stages.append([("valu", ("multiply_add", val_vec, val_vec, m5, c5))])
|
|
|
|
|
| c6 = self.scratch_vec_const(0xB55A4F09, "h5_c")
|
| s6 = self.scratch_vec_const(16, "h5_s")
|
| stages.append([
|
| ("valu", ("^", tmp1_vec, val_vec, c6)),
|
| ("valu", (">>", tmp2_vec, val_vec, s6))
|
| ])
|
| stages.append([("valu", ("^", val_vec, tmp1_vec, tmp2_vec))])
|
|
|
| return stages
|
|
|
| def build_kernel(
|
| self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
|
| ):
|
| """
|
| Vectorized Wavefront implementation.
|
| """
|
|
|
| init_vars = [
|
| "rounds", "n_nodes", "batch_size", "forest_height",
|
| "forest_values_p", "inp_indices_p", "inp_values_p"
|
| ]
|
| ptr_map = {}
|
| tmp_load = self.alloc_scratch("tmp_load")
|
|
|
| for i, v in enumerate(init_vars):
|
| addr = self.alloc_scratch(v)
|
| ptr_map[v] = addr
|
| self.add_instr({"load": [("const", tmp_load, i)]})
|
| self.add_instr({"load": [("load", addr, tmp_load)]})
|
|
|
| indices_base = self.alloc_scratch("indices_cache", batch_size)
|
| values_base = self.alloc_scratch("values_cache", batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
| block_x = self.alloc_scratch("block_x", batch_size)
|
| block_y = self.alloc_scratch("block_y", batch_size)
|
|
|
| num_vecs = batch_size // VLEN
|
|
|
| tmp_addrs_base = block_x
|
| node_vals_base = block_x
|
| vtmp1_base = block_x
|
| vtmp2_base = block_y
|
|
|
|
|
| const_0_vec = self.scratch_vec_const(0)
|
| const_1_vec = self.scratch_vec_const(1)
|
| global_n_nodes_vec = self.alloc_scratch("n_nodes_vec", VLEN)
|
| self.add_instr({"valu": [("vbroadcast", global_n_nodes_vec, ptr_map["n_nodes"])]})
|
|
|
|
|
|
|
| ops = []
|
| for i in range(0, batch_size, VLEN):
|
| i_const = self.scratch_const(i)
|
|
|
| ops.append(("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const)))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
|
|
|
|
|
|
|
|
| for i in range(0, batch_size, VLEN):
|
| i_const = self.scratch_const(i)
|
| self.add_instr({"alu": [("+", tmp_load, ptr_map["inp_indices_p"], i_const)]})
|
| self.add_instr({"load": [("vload", indices_base + i, tmp_load)]})
|
| self.add_instr({"alu": [("+", tmp_load, ptr_map["inp_values_p"], i_const)]})
|
| self.add_instr({"load": [("vload", values_base + i, tmp_load)]})
|
|
|
|
|
| self.add_instr({"flow": [("pause",)]})
|
| self.add_instr({"debug": [("comment", "Starting Computed Loop")]})
|
|
|
|
|
| for r in range(rounds):
|
| self.add_instr({"debug": [("comment", f"Round {r}")]})
|
|
|
|
|
|
|
|
|
| vecs = []
|
| for vec_i in range(num_vecs):
|
| offset = vec_i * VLEN
|
| vecs.append({
|
| 'idx': indices_base + offset,
|
| 'val': values_base + offset,
|
| 'node': node_vals_base + offset,
|
| 'tmp1': vtmp1_base + offset,
|
| 'tmp2': vtmp2_base + offset,
|
| 'addr': tmp_addrs_base + offset
|
| })
|
|
|
| if r == 0:
|
|
|
| scalar_node = self.alloc_scratch("scalar_node_r0")
|
| self.add_instr({"load": [("load", scalar_node, ptr_map["forest_values_p"])]})
|
| ops = []
|
| for vec in vecs:
|
| ops.append(("valu", ("vbroadcast", vec['node'], scalar_node)))
|
| self.instrs.extend(self.build(ops))
|
|
|
| else:
|
|
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| for lane in range(VLEN):
|
| ops.append(("alu", ("+", vec['addr'] + lane, ptr_map["forest_values_p"], vec['idx'] + lane)))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| for lane in range(VLEN):
|
| ops.append(("load", ("load", vec['node'] + lane, vec['addr'] + lane)))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| ops.append(("valu", ("^", vec['val'], vec['val'], vec['node'])))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
| all_stages = []
|
| for vec in vecs:
|
| all_stages.append(self.build_hash_opt(vec['val'], vec['tmp1'], vec['tmp2']))
|
|
|
| num_stages = len(all_stages[0])
|
| for s in range(num_stages):
|
| wave_ops = []
|
| for v_stages in all_stages:
|
| for op in v_stages[s]:
|
| wave_ops.append(op)
|
| self.instrs.extend(self.build(wave_ops))
|
|
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| ops.append(("valu", ("&", vec['tmp1'], vec['val'], const_1_vec)))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| ops.append(("valu", ("+", vec['tmp1'], vec['tmp1'], const_1_vec)))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| ops.append(("valu", ("+", vec['idx'], vec['idx'], vec['idx'])))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| ops.append(("valu", ("+", vec['idx'], vec['idx'], vec['tmp1'])))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| ops.append(("valu", ("<", vec['tmp1'], vec['idx'], global_n_nodes_vec)))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
| ops = []
|
| for vec in vecs:
|
| ops.append(("flow", ("vselect", vec['idx'], vec['tmp1'], vec['idx'], const_0_vec)))
|
| self.instrs.extend(self.build(ops))
|
|
|
|
|
|
|
|
|
| for i in range(0, batch_size, VLEN):
|
| i_const = self.scratch_const(i)
|
| self.add_instr({"alu": [("+", tmp_load, ptr_map["inp_indices_p"], i_const)]})
|
| self.add_instr({"store": [("vstore", tmp_load, indices_base + i)]})
|
| self.add_instr({"alu": [("+", tmp_load, ptr_map["inp_values_p"], i_const)]})
|
| self.add_instr({"store": [("vstore", tmp_load, values_base + i)]})
|
|
|
| self.add_instr({"flow": [("pause",)]})
|
|
|
| BASELINE = 147734
|
|
|
| def do_kernel_test(
|
| forest_height: int,
|
| rounds: int,
|
| batch_size: int,
|
| seed: int = 123,
|
| trace: bool = False,
|
| prints: bool = False,
|
| ):
|
| print(f"{forest_height=}, {rounds=}, {batch_size=}")
|
| random.seed(seed)
|
| forest = Tree.generate(forest_height)
|
| inp = Input.generate(forest, batch_size, rounds)
|
| mem = build_mem_image(forest, inp)
|
|
|
| kb = KernelBuilder()
|
| kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds)
|
|
|
|
|
| value_trace = {}
|
| machine = Machine(
|
| mem,
|
| kb.instrs,
|
| kb.debug_info(),
|
| n_cores=N_CORES,
|
| value_trace=value_trace,
|
| trace=trace,
|
| )
|
| machine.prints = prints
|
| for i, ref_mem in enumerate(reference_kernel2(mem, value_trace)):
|
| machine.run()
|
| inp_values_p = ref_mem[6]
|
| if prints:
|
| print(machine.mem[inp_values_p : inp_values_p + len(inp.values)])
|
| print(ref_mem[inp_values_p : inp_values_p + len(inp.values)])
|
| assert (
|
| machine.mem[inp_values_p : inp_values_p + len(inp.values)]
|
| == ref_mem[inp_values_p : inp_values_p + len(inp.values)]
|
| ), f"Incorrect result on round {i}"
|
| inp_indices_p = ref_mem[5]
|
| if prints:
|
| print(machine.mem[inp_indices_p : inp_indices_p + len(inp.indices)])
|
| print(ref_mem[inp_indices_p : inp_indices_p + len(inp.indices)])
|
|
|
|
|
|
|
| print("CYCLES: ", machine.cycle)
|
| print("Speedup over baseline: ", BASELINE / machine.cycle)
|
| return machine.cycle
|
|
|
|
|
| class Tests(unittest.TestCase):
|
| def test_ref_kernels(self):
|
| """
|
| Test the reference kernels against each other
|
| """
|
| random.seed(123)
|
| for i in range(10):
|
| f = Tree.generate(4)
|
| inp = Input.generate(f, 10, 6)
|
| mem = build_mem_image(f, inp)
|
| reference_kernel(f, inp)
|
| for _ in reference_kernel2(mem, {}):
|
| pass
|
| assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)]
|
| assert inp.values == mem[mem[6] : mem[6] + len(inp.values)]
|
|
|
| def test_kernel_trace(self):
|
|
|
| do_kernel_test(10, 16, 256, trace=True, prints=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def test_kernel_cycles(self):
|
| do_kernel_test(10, 16, 256)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| unittest.main()
|
|
|