|
|
| import os
|
| import sys
|
| import ray
|
| from ray import tune
|
| from ray.tune.search.optuna import OptunaSearch
|
|
|
|
|
| current_dir = os.path.dirname(os.path.abspath(__file__))
|
| parent_dir = os.path.dirname(current_dir)
|
| sys.path.insert(0, parent_dir)
|
|
|
| ray_path = os.path.join(parent_dir, "ray", "python")
|
| sys.path.insert(0, ray_path)
|
|
|
| import ray
|
| from ray import tune
|
|
|
| def objective(config):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| try:
|
| forest_height = 10
|
| rounds = 16
|
| batch_size = 256
|
|
|
|
|
| forest = Tree.generate(forest_height)
|
| inp = Input.generate(forest, batch_size, rounds)
|
| mem = build_mem_image(forest, inp)
|
|
|
| kb = KernelBuilder()
|
|
|
| kb.build_kernel(
|
| forest.height,
|
| len(forest.values),
|
| len(inp.indices),
|
| rounds,
|
| active_threshold=config["active_threshold"],
|
| mask_skip=config["mask_skip"]
|
| )
|
|
|
| value_trace = {}
|
| machine = Machine(
|
| mem,
|
| kb.instrs,
|
| kb.debug_info(),
|
| n_cores=N_CORES,
|
| value_trace=value_trace,
|
| trace=False,
|
| )
|
| machine.prints = False
|
|
|
|
|
| while machine.cores[0].state.value != 3:
|
| machine.run()
|
| if machine.cores[0].state.value == 2:
|
| machine.cores[0].state = machine.cores[0].state.__class__(1)
|
| continue
|
| break
|
|
|
| machine.enable_pause = False
|
|
|
| for ref_mem in reference_kernel2(mem, value_trace):
|
| pass
|
|
|
|
|
| inp_values_p = ref_mem[6]
|
| if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
|
| return {"cycles": 999999, "correct": False}
|
|
|
| return {"cycles": machine.cycle, "correct": True}
|
|
|
| except Exception as e:
|
| print(f"Error: {e}")
|
| return {"cycles": 999999, "correct": False}
|
|
|
| if __name__ == "__main__":
|
| ray.init()
|
|
|
| analysis = tune.run(
|
| objective,
|
| config={
|
| "active_threshold": tune.grid_search([4, 8, 16]),
|
|
|
| "mask_skip": True
|
| },
|
| mode="min",
|
| metric="cycles",
|
| num_samples=1,
|
| )
|
|
|
| print("Best config: ", analysis.get_best_config(metric="cycles", mode="min"))
|
| print("Best cycles: ", analysis.best_result["cycles"])
|
|
|