args
Browse files
wfx.py
CHANGED
|
@@ -22,6 +22,8 @@ def parse_args():
|
|
| 22 |
args.add_argument('--model', type=str, required=True)
|
| 23 |
args.add_argument('--custom-pipeline', type=str, default=None)
|
| 24 |
args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
|
|
|
|
|
|
|
| 25 |
return args.parse_args()
|
| 26 |
|
| 27 |
def quantize_unet(m):
|
|
@@ -68,7 +70,10 @@ class WFX():
|
|
| 68 |
except ImportError:
|
| 69 |
logger.warning('triton not found, disabling triton')
|
| 70 |
|
| 71 |
-
self.compiler_config.enable_cuda_graph =
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
for key in self.compiler_config.__dict__:
|
| 74 |
logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')
|
|
|
|
| 22 |
args.add_argument('--model', type=str, required=True)
|
| 23 |
args.add_argument('--custom-pipeline', type=str, default=None)
|
| 24 |
args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
|
| 25 |
+
args.add_argument('--enable-cuda-graph', action='store_true', default=False)
|
| 26 |
+
args.add_argument('--disable-prefer-lowp-gemm', action='store_true', default=False)
|
| 27 |
return args.parse_args()
|
| 28 |
|
| 29 |
def quantize_unet(m):
|
|
|
|
| 70 |
except ImportError:
|
| 71 |
logger.warning('triton not found, disabling triton')
|
| 72 |
|
| 73 |
+
self.compiler_config.enable_cuda_graph = args.enable_cuda_graph
|
| 74 |
+
|
| 75 |
+
if args.disable_prefer_lowp_gemm:
|
| 76 |
+
self.compiler_config.prefer_lowp_gemm = False
|
| 77 |
|
| 78 |
for key in self.compiler_config.__dict__:
|
| 79 |
logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')
|