|
|
| import os
|
| import sys
|
|
|
|
|
| current_dir = os.path.dirname(os.path.abspath(__file__))
|
| parent_dir = os.path.dirname(current_dir)
|
| sys.path.insert(0, parent_dir)
|
|
|
| from perf_takehome import KernelBuilder, do_kernel_test, Tree, Input, build_mem_image, N_CORES, Machine, reference_kernel2
|
|
|
| def objective(active_threshold, mask_skip):
|
| try:
|
| forest_height = 10
|
| rounds = 16
|
| batch_size = 256
|
|
|
| forest = Tree.generate(forest_height)
|
| inp = Input.generate(forest, batch_size, rounds)
|
| mem = build_mem_image(forest, inp)
|
|
|
| kb = KernelBuilder()
|
| kb.build_kernel(
|
| forest.height,
|
| len(forest.values),
|
| len(inp.indices),
|
| rounds,
|
| active_threshold=active_threshold,
|
| mask_skip=mask_skip
|
| )
|
|
|
| value_trace = {}
|
| machine = Machine(
|
| mem,
|
| kb.instrs,
|
| kb.debug_info(),
|
| n_cores=N_CORES,
|
| value_trace=value_trace,
|
| trace=False,
|
| )
|
| machine.prints = False
|
|
|
| while machine.cores[0].state.value != 3:
|
| machine.run()
|
| if machine.cores[0].state.value == 2:
|
| machine.cores[0].state = machine.cores[0].state.__class__(1)
|
| continue
|
| break
|
|
|
| machine.enable_pause = False
|
| for ref_mem in reference_kernel2(mem, value_trace):
|
| pass
|
|
|
| inp_values_p = ref_mem[6]
|
| if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
|
| return 999999
|
|
|
| return machine.cycle
|
|
|
| except Exception as e:
|
| print(f"Error: {e}")
|
| return 999999
|
|
|
| if __name__ == "__main__":
|
| thresholds = [4]
|
| mask_skip = True
|
| scalar_offloads = [0, 2, 4, 6, 8, 10]
|
|
|
| best_cycles = float('inf')
|
| best_config = None
|
|
|
| for ms in [True]:
|
| for th in thresholds:
|
| for so in scalar_offloads:
|
| print(f"Testing active_threshold={th}, mask_skip={ms}, scalar_offload={so}...")
|
|
|
| try:
|
| forest_height = 10
|
| rounds = 16
|
| batch_size = 256
|
|
|
| forest = Tree.generate(forest_height)
|
| inp = Input.generate(forest, batch_size, rounds)
|
| mem = build_mem_image(forest, inp)
|
|
|
| kb = KernelBuilder()
|
| kb.build_kernel(
|
| forest.height,
|
| len(forest.values),
|
| len(inp.indices),
|
| rounds,
|
| active_threshold=th,
|
| mask_skip=ms,
|
| scalar_offload=so
|
| )
|
|
|
| value_trace = {}
|
| machine = Machine(
|
| mem,
|
| kb.instrs,
|
| kb.debug_info(),
|
| n_cores=N_CORES,
|
| value_trace=value_trace,
|
| trace=False,
|
| )
|
| machine.prints = False
|
|
|
| while machine.cores[0].state.value != 3:
|
| machine.run()
|
| if machine.cores[0].state.value == 2:
|
| machine.cores[0].state = machine.cores[0].state.__class__(1)
|
| continue
|
| break
|
|
|
| machine.enable_pause = False
|
| for ref_mem in reference_kernel2(mem, value_trace):
|
| pass
|
|
|
| inp_values_p = ref_mem[6]
|
| cycles = 0
|
| if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
|
| cycles = 999999
|
| else:
|
| cycles = machine.cycle
|
|
|
| print(f" -> Cycles: {cycles}")
|
| if cycles < best_cycles:
|
| best_cycles = cycles
|
| best_config = (th, ms, so)
|
|
|
| except Exception as e:
|
| print(f"Error: {e}")
|
|
|
| print(f"Best Config: th={best_config[0]}, mask={best_config[1]}, offload={best_config[2]}")
|
| print(f"Best Cycles: {best_cycles}")
|
|
|