Spaces:
Running on Zero
Running on Zero
| from dataclasses import dataclass | |
| from ltx_core.loader.module_ops import ModuleOps | |
| from ltx_core.loader.sd_ops import SDOps | |
| from ltx_core.quantization.fp8_cast import TRANSFORMER_LINEAR_DOWNCAST_MAP, UPCAST_DURING_INFERENCE | |
| from ltx_core.quantization.fp8_scaled_mm import FP8_PREPARE_MODULE_OPS, FP8_TRANSPOSE_SD_OPS | |
| class QuantizationPolicy: | |
| """Configuration for model quantization during loading. | |
| Attributes: | |
| sd_ops: State dict operations for weight transformation. | |
| module_ops: Post-load module transformations. | |
| """ | |
| sd_ops: SDOps | None = None | |
| module_ops: tuple[ModuleOps, ...] = () | |
| def fp8_cast(cls) -> "QuantizationPolicy": | |
| """Create policy using FP8 casting with upcasting during inference.""" | |
| return cls( | |
| sd_ops=TRANSFORMER_LINEAR_DOWNCAST_MAP, | |
| module_ops=(UPCAST_DURING_INFERENCE,), | |
| ) | |
| def fp8_scaled_mm(cls) -> "QuantizationPolicy": | |
| """Create policy using FP8 scaled matrix multiplication.""" | |
| try: | |
| import tensorrt_llm # noqa: F401, PLC0415 | |
| except ImportError as e: | |
| raise ImportError("tensorrt_llm is not installed, skipping FP8 scaled MM quantization") from e | |
| return cls( | |
| sd_ops=FP8_TRANSPOSE_SD_OPS, | |
| module_ops=(FP8_PREPARE_MODULE_OPS,), | |
| ) | |