Spaces:
Running on Zero
Running on Zero
File size: 1,394 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | from dataclasses import dataclass
from ltx_core.loader.module_ops import ModuleOps
from ltx_core.loader.sd_ops import SDOps
from ltx_core.quantization.fp8_cast import TRANSFORMER_LINEAR_DOWNCAST_MAP, UPCAST_DURING_INFERENCE
from ltx_core.quantization.fp8_scaled_mm import FP8_PREPARE_MODULE_OPS, FP8_TRANSPOSE_SD_OPS
@dataclass(frozen=True)
class QuantizationPolicy:
"""Configuration for model quantization during loading.
Attributes:
sd_ops: State dict operations for weight transformation.
module_ops: Post-load module transformations.
"""
sd_ops: SDOps | None = None
module_ops: tuple[ModuleOps, ...] = ()
@classmethod
def fp8_cast(cls) -> "QuantizationPolicy":
"""Create policy using FP8 casting with upcasting during inference."""
return cls(
sd_ops=TRANSFORMER_LINEAR_DOWNCAST_MAP,
module_ops=(UPCAST_DURING_INFERENCE,),
)
@classmethod
def fp8_scaled_mm(cls) -> "QuantizationPolicy":
"""Create policy using FP8 scaled matrix multiplication."""
try:
import tensorrt_llm # noqa: F401, PLC0415
except ImportError as e:
raise ImportError("tensorrt_llm is not installed, skipping FP8 scaled MM quantization") from e
return cls(
sd_ops=FP8_TRANSPOSE_SD_OPS,
module_ops=(FP8_PREPARE_MODULE_OPS,),
)
|