| from triton.backends.compiler import BaseBackend, GPUTarget |
| from triton._C.libtriton import ir, passes, llvm, amd |
| from dataclasses import dataclass |
| from typing import Any, Tuple |
| import hashlib |
| import tempfile |
| import os |
| import re |
| import subprocess |
| import functools |
| from pathlib import Path |
|
|
|
|
| @dataclass(frozen=True) |
| class HIPOptions: |
| num_warps: int = 4 |
| waves_per_eu: int = 1 |
| num_stages: int = 0 |
| num_ctas: int = 1 |
| extern_libs: dict = None |
| cluster_dims: tuple = (1, 1, 1) |
| debug: bool = False |
| arch: str = None |
| allow_fp8e4nv: bool = False |
| allow_fp8e4b15: bool = False |
| default_dot_input_precision: str = "ieee" |
| allowed_dot_input_precisions: Tuple[str] = ("ieee", ) |
| enable_fp_fusion: bool = True |
| matrix_instr_nonkdim: int = 0 |
| kpack: int = 1 |
| allow_flush_denorm: bool = False |
| max_num_imprecise_acc_default: int = 0 |
| backend_name: str = 'hip' |
|
|
| def __post_init__(self): |
| default_libdir = Path(__file__).parent / 'lib' |
| extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) |
| |
| warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch else 64 |
| object.__setattr__(self, 'warp_size', warp_size) |
| libs = ["ocml", "ockl"] |
| for lib in libs: |
| extern_libs[lib] = str(default_libdir / f'{lib}.bc') |
| object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) |
| assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ |
| "num_warps must be a power of 2" |
|
|
| def hash(self): |
| key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) |
| return hashlib.sha256(key.encode("utf-8")).hexdigest() |
|
|
|
|
| class HIPBackend(BaseBackend): |
|
|
| @staticmethod |
| def supports_target(target: GPUTarget): |
| return target.backend == 'hip' |
|
|
| def __init__(self, target: GPUTarget) -> None: |
| super().__init__(target) |
| assert isinstance(target.arch, str) |
| self.binary_ext = "hsaco" |
|
|
| def parse_options(self, opts) -> Any: |
| args = {'arch': self.target.arch} |
| args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts}) |
| return HIPOptions(**args) |
|
|
| def pack_metadata(self, metadata): |
| return ( |
| metadata.num_warps, |
| metadata.num_ctas, |
| metadata.shared, |
| metadata.cluster_dims[0], |
| metadata.cluster_dims[1], |
| metadata.cluster_dims[2], |
| ) |
|
|
| def get_codegen_implementation(self): |
| codegen_fns = dict() |
| return codegen_fns |
|
|
| def load_dialects(self, ctx): |
| amd.load_dialects(ctx) |
|
|
| @staticmethod |
| def path_to_rocm_lld(): |
| |
| lld_env_path = os.getenv("TRITON_HIP_LLD_PATH") |
| if lld_env_path is not None: |
| lld = Path(lld_env_path) |
| if lld.is_file(): |
| return lld |
| |
| lld = Path(__file__).parent / "llvm/bin/ld.lld" |
| if lld.is_file(): |
| return lld |
| lld = Path("/opt/rocm/llvm/bin/ld.lld") |
| if lld.is_file(): |
| return lld |
| lld = Path("/usr/bin/ld.lld") |
| if lld.is_file(): |
| return lld |
| raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found") |
|
|
| @staticmethod |
| def make_ttir(mod, metadata, options): |
| pm = ir.pass_manager(mod.context) |
| pm.enable_debug() |
| passes.common.add_inliner(pm) |
| passes.ttir.add_rewrite_tensor_pointer(pm) |
| passes.ttir.add_combine(pm) |
| passes.common.add_canonicalizer(pm) |
| passes.ttir.add_reorder_broadcast(pm) |
| passes.common.add_cse(pm) |
| passes.common.add_licm(pm) |
| passes.common.add_symbol_dce(pm) |
| pm.run(mod) |
| return mod |
|
|
| @staticmethod |
| def make_ttgir(mod, metadata, options): |
| pm = ir.pass_manager(mod.context) |
| pm.enable_debug() |
| passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size, |
| options.num_ctas) |
| pm.run(mod) |
| pm = ir.pass_manager(mod.context) |
| pm.enable_debug() |
| passes.ttgpuir.add_coalesce(pm) |
| passes.ttgpuir.add_remove_layout_conversions(pm) |
| passes.ttgpuir.add_optimize_thread_locality(pm) |
| amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack) |
| passes.ttgpuir.add_remove_layout_conversions(pm) |
| amd.passes.ttgpuir.add_optimize_epilogue(pm) |
| passes.ttgpuir.add_optimize_dot_operands(pm, True) |
| if options.num_stages == 0 and amd.has_matrix_core_feature(options.arch): |
| amd.passes.ttgpuir.add_stream_pipeline(pm) |
| passes.common.add_canonicalizer(pm) |
| passes.ttgpuir.add_optimize_dot_operands(pm, True) |
| passes.ttgpuir.add_remove_layout_conversions(pm) |
| passes.ttgpuir.add_reduce_data_duplication(pm) |
| if options.num_stages != 0: |
| amd.passes.ttgpuir.add_reorder_instructions(pm) |
| passes.common.add_cse(pm) |
| passes.common.add_symbol_dce(pm) |
| pm.run(mod) |
| return mod |
|
|
| @staticmethod |
| def make_llir(src, metadata, options): |
| mod = src |
| |
| pm = ir.pass_manager(mod.context) |
| pm.enable_debug() |
| amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch) |
| passes.convert.add_scf_to_cf(pm) |
| passes.convert.add_index_to_llvmir(pm) |
|
|
| passes.ttgpuir.add_allocate_shared_memory(pm) |
| |
| |
| |
| |
| |
| |
| |
| __HIP_FTZ = True |
| amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) |
| passes.common.add_canonicalizer(pm) |
| passes.common.add_cse(pm) |
|
|
| passes.convert.add_cf_to_llvmir(pm) |
| passes.convert.add_arith_to_llvmir(pm) |
| passes.common.add_canonicalizer(pm) |
| passes.common.add_cse(pm) |
| passes.common.add_symbol_dce(pm) |
| if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": |
| passes.llvmir.add_di_scope(pm) |
| |
| |
| |
| |
| |
| amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm) |
| pm.run(mod) |
|
|
| |
| llvm.init_targets() |
| context = llvm.context() |
| llvm_mod = llvm.to_module(mod, context) |
|
|
| |
| |
| amd.set_isa_version(llvm_mod, options.arch) |
| amd.set_abi_version(llvm_mod, 400) |
| amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False) |
| amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True) |
| amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False) |
| amd.set_bool_control_constant(llvm_mod, "__oclc_wavefrontsize64", options.warp_size == 64) |
|
|
| |
| fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()] |
| |
| fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL) |
| fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}") |
| fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}") |
| denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee" |
| fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode) |
|
|
| if options.extern_libs: |
| paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)] |
| llvm.link_extern_libs(llvm_mod, paths) |
|
|
| llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, amd.TARGET_TRIPLE) |
|
|
| |
| metadata["shared"] = src.get_int_attr("triton_gpu.shared") |
|
|
| amd.cleanup_bitcode_metadata(llvm_mod) |
| return str(llvm_mod) |
|
|
| @staticmethod |
| def make_amdgcn(src, metadata, options): |
| |
| |
| |
| names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src) |
| assert len(names) == 1 |
| metadata["name"] = names[0] |
| |
| amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False) |
| if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1": |
| print("// -----// AMDGCN Dump //----- //") |
| print(amdgcn) |
| return amdgcn |
|
|
| @staticmethod |
| def make_hsaco(src, metadata, options): |
| hsaco = amd.assemble_amdgcn(src, options.arch, '') |
|
|
| rocm_path = HIPBackend.path_to_rocm_lld() |
| with tempfile.NamedTemporaryFile() as tmp_out: |
| with tempfile.NamedTemporaryFile() as tmp_in: |
| with open(tmp_in.name, 'wb') as fd_in: |
| fd_in.write(hsaco) |
| subprocess.check_call([rocm_path, '-flavor', 'gnu', '-shared', tmp_in.name, '-o', tmp_out.name]) |
| with open(tmp_out.name, 'rb') as fd_out: |
| ret = fd_out.read() |
| return ret |
|
|
| def add_stages(self, stages, options): |
| stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) |
| stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options) |
| stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) |
| stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options) |
| stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options) |
|
|
| @functools.lru_cache() |
| def hash(self): |
| version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding='utf-8') |
| return f'{version}-{self.target}' |
|
|