[general] name = "megablocks" universal = false [torch] src = [ "torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h" ] [kernel.megablocks] backend = "rocm" rocm-archs = [ "gfx942", "gfx1030", "gfx1100", "gfx1101", ] depends = ["torch"] src = [ "csrc/new_cumsum.h", "csrc/new_cumsum.cu", "csrc/new_histogram.h", "csrc/new_histogram.cu", "csrc/new_indices.h", "csrc/new_indices.cu", "csrc/new_replicate.cu", "csrc/new_replicate.h", "csrc/new_sort.h", "csrc/new_sort.cu", # vendored grouped gemm #"csrc/grouped_gemm/fill_arguments.cuh", #"csrc/grouped_gemm/grouped_gemm.cu", #"csrc/grouped_gemm/grouped_gemm.h", ]