Dramabox / ltx2 /ltx_core /loader /kernels.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
# ruff: noqa: ANN001, ANN201, ERA001, N803, N806
import triton
import triton.language as tl
@triton.jit
def fused_add_round_kernel(
x_ptr,
output_ptr, # contents will be added to the output
seed,
n_elements,
EXPONENT_BIAS,
MANTISSA_BITS,
BLOCK_SIZE: tl.constexpr,
):
"""
A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
and add them to bfloat16 output weights. Might be used to upcast original model weights
and to further add them to precalculated deltas coming from LoRAs.
"""
# Get program ID and compute offsets
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
rand_vals = tl.rand(seed, offsets) - 0.5
x = tl.cast(x, tl.float16)
delta = tl.load(output_ptr + offsets, mask=mask)
delta = tl.cast(delta, tl.float16)
x = x + delta
x_bits = tl.cast(x, tl.int16, bitcast=True)
# Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
# normal numbers and -14 for subnormals.
fp16_exponent_bits = (x_bits & 0x7C00) >> 10
fp16_normals = fp16_exponent_bits > 0
fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
# Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
exponent = fp16_exponent + EXPONENT_BIAS
MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
exponent = tl.where(exponent < 0, 0, exponent)
# Normal ULP exponent, expressed as an fp16 exponent field:
# (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
# Simplifies to: fp16_exponent - MANTISSA_BITS + 15
# See https://en.wikipedia.org/wiki/Unit_in_the_last_place
eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
# Calculate epsilon in the target dtype
eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
# Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
# fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
# 16 - EXPONENT_BIAS - MANTISSA_BITS
eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
# Apply zero mask to epsilon
eps = tl.where(x == 0, 0.0, eps)
# Apply stochastic rounding
output = tl.cast(x + rand_vals * eps, tl.bfloat16)
# Store the result
tl.store(output_ptr + offsets, output, mask=mask)