| import os |
| import sys |
| from pathlib import Path |
| from setuptools import setup, find_packages |
|
|
|
|
| common_setup_kwargs = { |
| "version": "0.4.1", |
| "name": "auto_gptq", |
| "author": "PanQiWei", |
| "description": "An easy-to-use LLMs quantization package with user-friendly apis, based on GPTQ algorithm.", |
| "long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"), |
| "long_description_content_type": "text/markdown", |
| "url": "https://github.com/PanQiWei/AutoGPTQ", |
| "keywords": ["gptq", "quantization", "large-language-models", "transformers"], |
| "platforms": ["windows", "linux"], |
| "classifiers": [ |
| "Environment :: GPU :: NVIDIA CUDA :: 11.7", |
| "Environment :: GPU :: NVIDIA CUDA :: 11.8", |
| "Environment :: GPU :: NVIDIA CUDA :: 12.0", |
| "License :: OSI Approved :: MIT License", |
| "Natural Language :: Chinese (Simplified)", |
| "Natural Language :: English", |
| "Programming Language :: Python :: 3.8", |
| "Programming Language :: Python :: 3.9", |
| "Programming Language :: Python :: 3.10", |
| "Programming Language :: Python :: 3.11", |
| "Programming Language :: C++", |
| ] |
| } |
|
|
|
|
| BUILD_CUDA_EXT = int(os.environ.get('BUILD_CUDA_EXT', '1')) == 1 |
| if BUILD_CUDA_EXT: |
| try: |
| import torch |
| except: |
| print("Building cuda extension requires PyTorch(>=1.13.0) been installed, please install PyTorch first!") |
| sys.exit(-1) |
|
|
| CUDA_VERSION = None |
| ROCM_VERSION = os.environ.get('ROCM_VERSION', None) |
| if ROCM_VERSION and not torch.version.hip: |
| print( |
| f"Trying to compile auto-gptq for RoCm, but PyTorch {torch.__version__} " |
| "is installed without RoCm support." |
| ) |
| sys.exit(-1) |
|
|
| if not ROCM_VERSION: |
| default_cuda_version = torch.version.cuda |
| CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", default_cuda_version).split(".")) |
|
|
| if ROCM_VERSION: |
| common_setup_kwargs['version'] += f"+rocm{ROCM_VERSION}" |
| else: |
| if not CUDA_VERSION: |
| print( |
| f"Trying to compile auto-gptq for CUDA, byt Pytorch {torch.__version__} " |
| "is installed without CUDA support." |
| ) |
| sys.exit(-1) |
| common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}" |
|
|
|
|
| requirements = [ |
| "accelerate>=0.19.0", |
| "datasets", |
| "numpy", |
| "rouge", |
| "torch>=1.13.0", |
| "safetensors", |
| "transformers>=4.31.0", |
| "peft" |
| ] |
|
|
| extras_require = { |
| "triton": ["triton==2.0.0"], |
| "test": ["parameterized"] |
| } |
|
|
| include_dirs = ["autogptq_cuda"] |
|
|
| additional_setup_kwargs = dict() |
| if BUILD_CUDA_EXT: |
| from torch.utils import cpp_extension |
|
|
| if not ROCM_VERSION: |
| from distutils.sysconfig import get_python_lib |
| conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") |
|
|
| print("conda_cuda_include_dir", conda_cuda_include_dir) |
| if os.path.isdir(conda_cuda_include_dir): |
| include_dirs.append(conda_cuda_include_dir) |
| print(f"appending conda cuda include dir {conda_cuda_include_dir}") |
| extensions = [ |
| cpp_extension.CUDAExtension( |
| "autogptq_cuda_64", |
| [ |
| "autogptq_cuda/autogptq_cuda_64.cpp", |
| "autogptq_cuda/autogptq_cuda_kernel_64.cu" |
| ] |
| ), |
| cpp_extension.CUDAExtension( |
| "autogptq_cuda_256", |
| [ |
| "autogptq_cuda/autogptq_cuda_256.cpp", |
| "autogptq_cuda/autogptq_cuda_kernel_256.cu" |
| ] |
| ) |
| ] |
|
|
| if os.environ.get("INCLUDE_EXLLAMA_KERNELS", "1") == "1": |
| extensions.append( |
| cpp_extension.CUDAExtension( |
| "exllama_kernels", |
| [ |
| "autogptq_cuda/exllama/exllama_ext.cpp", |
| "autogptq_cuda/exllama/cuda_buffers.cu", |
| "autogptq_cuda/exllama/cuda_func/column_remap.cu", |
| "autogptq_cuda/exllama/cuda_func/q4_matmul.cu", |
| "autogptq_cuda/exllama/cuda_func/q4_matrix.cu" |
| ] |
| ) |
| ) |
|
|
| additional_setup_kwargs = { |
| "ext_modules": extensions, |
| "cmdclass": {'build_ext': cpp_extension.BuildExtension} |
| } |
| common_setup_kwargs.update(additional_setup_kwargs) |
| setup( |
| packages=find_packages(), |
| install_requires=requirements, |
| extras_require=extras_require, |
| include_dirs=include_dirs, |
| python_requires=">=3.8.0", |
| **common_setup_kwargs |
| ) |
|
|