"""Local CUDA build for activation kernels. Usage: pip install -e . # editable install python setup.py build_ext --inplace # build only The built extension is named '_activation' and can be loaded via: import _activation torch.ops._activation.rms_norm(...) """ import os from pathlib import Path import torch from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension ROOT = Path(__file__).parent CUDA_SOURCES = [ "activation/poly_norm.cu", "activation/fused_mul_poly_norm.cu", "activation/rms_norm.cu", "activation/fused_add_rms_norm.cu", "activation/grouped_poly_norm.cu", ] CPP_SOURCES = [ "torch-ext/torch_binding.cpp", ] # Include dirs: project root (for registration.h, activation/*.h) # and torch-ext/ (for torch_binding.h) INCLUDE_DIRS = [ str(ROOT), str(ROOT / "activation"), str(ROOT / "torch-ext"), ] # CUDA flags matching the existing kernel style NVCC_FLAGS = [ "-O3", "--use_fast_math", "-std=c++17", # Generate code for common architectures "-gencode=arch=compute_80,code=sm_80", # A100 "-gencode=arch=compute_89,code=sm_89", # L40/4090 "-gencode=arch=compute_90,code=sm_90", # H100 ] # Check for B200 support (sm_100, requires CUDA 12.8+) cuda_version = tuple(int(x) for x in torch.version.cuda.split(".")[:2]) if cuda_version >= (12, 8): NVCC_FLAGS.append("-gencode=arch=compute_100,code=sm_100") CXX_FLAGS = ["-O3", "-std=c++17"] ext_modules = [ CUDAExtension( name="_activation", sources=[str(ROOT / s) for s in CPP_SOURCES + CUDA_SOURCES], include_dirs=INCLUDE_DIRS, extra_compile_args={ "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, ), ] setup( name="activation", version="0.1.0", description="Custom CUDA normalization kernels for LLM training", ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, packages=["activation"], package_dir={"activation": "torch-ext/activation"}, python_requires=">=3.10", install_requires=["torch>=2.7"], )