import torch from ltx_core.loader.module_ops import ModuleOps from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps from ltx_core.model.transformer.model import LTXModel BLOCK_SIZE = 1024 def _fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor: # Lazy import triton - only available on CUDA platforms import triton # noqa: PLC0415 from ltx_core.loader.kernels import fused_add_round_kernel # noqa: PLC0415 if original_weight.dtype == torch.float8_e4m3fn: exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7 elif original_weight.dtype == torch.float8_e5m2: exponent_bits, mantissa_bits, exponent_bias = 5, 2, 15 # noqa: F841 else: raise ValueError("Unsupported dtype") if target_weight.dtype != torch.bfloat16: raise ValueError("target_weight dtype must be bfloat16") # Calculate grid and block sizes n_elements = original_weight.numel() grid = (triton.cdiv(n_elements, BLOCK_SIZE),) # Launch kernel fused_add_round_kernel[grid]( original_weight, target_weight, seed, n_elements, exponent_bias, mantissa_bits, BLOCK_SIZE, ) return target_weight def _naive_weight_or_bias_downcast(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: """ Downcast the weight or bias to the float8_e4m3fn dtype. """ return [KeyValueOperationResult(key, value.to(dtype=torch.float8_e4m3fn))] def _upcast_and_round( weight: torch.Tensor, dtype: torch.dtype, with_stochastic_rounding: bool = False, seed: int = 0 ) -> torch.Tensor: """ Upcast the weight to the given dtype and optionally apply stochastic rounding. Input weight needs to have float8_e4m3fn or float8_e5m2 dtype. """ if not with_stochastic_rounding: return weight.to(dtype) return _fused_add_round_launch(torch.zeros_like(weight, dtype=dtype), weight, seed) class Fp8CastLinear(torch.nn.Linear): """nn.Linear storing weights in fp8, upcasting to input dtype during forward. Used via __class__ reassignment (not subclassing) so existing weight tensors are preserved in-place. Class-level forward is required for torch.compile compatibility — instance-level closure monkey-patches cause graph breaks. """ _with_stochastic_rounding: bool _seed: int def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002, type: ignore[override] w_up = _upcast_and_round(self.weight, input.dtype, self._with_stochastic_rounding, self._seed) b_up = ( _upcast_and_round(self.bias, input.dtype, self._with_stochastic_rounding, self._seed) if self.bias is not None else None ) return torch.nn.functional.linear(input, w_up, b_up) def _replace_fwd_with_upcast(layer: torch.nn.Linear, with_stochastic_rounding: bool = False, seed: int = 0) -> None: """ Intended to be applied via __class__ reassignment to existing nn.Linear instances so that their parameter and buffer tensors are preserved in-place, avoiding re-instantiation. Forward remains defined at the class level, which is required for torch.compile compatibility — instance-level closure monkey-patches cause graph breaks. """ layer.__class__ = Fp8CastLinear layer._with_stochastic_rounding = with_stochastic_rounding layer._seed = seed def _amend_forward_with_upcast( model: torch.nn.Module, with_stochastic_rounding: bool = False, seed: int = 0 ) -> torch.nn.Module: """ Replace the forward method of the model's Linear layers to forward with upcast and optional stochastic rounding. """ for m in model.modules(): if isinstance(m, (torch.nn.Linear)): _replace_fwd_with_upcast(m, with_stochastic_rounding, seed) return model TRANSFORMER_LINEAR_DOWNCAST_MAP = ( SDOps("TRANSFORMER_LINEAR_DOWNCAST_MAP") .with_kv_operation( key_prefix="transformer_blocks.", key_suffix=".to_q.weight", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix=".to_q.bias", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix=".to_k.weight", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix=".to_k.bias", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix=".to_v.weight", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix=".to_v.bias", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix=".to_out.0.weight", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix=".to_out.0.bias", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix="ff.net.0.proj.weight", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix="ff.net.0.proj.bias", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix="ff.net.2.weight", operation=_naive_weight_or_bias_downcast ) .with_kv_operation( key_prefix="transformer_blocks.", key_suffix="ff.net.2.bias", operation=_naive_weight_or_bias_downcast ) ) UPCAST_DURING_INFERENCE = ModuleOps( name="upcast_fp8_during_linear_forward", matcher=lambda model: isinstance(model, LTXModel), mutator=lambda model: _amend_forward_with_upcast(model, False), ) class UpcastWithStochasticRounding(ModuleOps): """ ModuleOps for upcasting the model's float8_e4m3fn weights and biases to the bfloat16 dtype and applying stochastic rounding during linear forward. """ def __new__(cls, seed: int = 0): return super().__new__( cls, name="upcast_fp8_during_linear_forward_with_stochastic_rounding", matcher=lambda model: isinstance(model, LTXModel), mutator=lambda model: _amend_forward_with_upcast(model, True, seed), )