Spaces:
Running on Zero
Running on Zero
| # ruff: noqa: ANN001, ANN201, ERA001, N803, N806 | |
| import triton | |
| import triton.language as tl | |
| def fused_add_round_kernel( | |
| x_ptr, | |
| output_ptr, # contents will be added to the output | |
| seed, | |
| n_elements, | |
| EXPONENT_BIAS, | |
| MANTISSA_BITS, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| """ | |
| A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding | |
| and add them to bfloat16 output weights. Might be used to upcast original model weights | |
| and to further add them to precalculated deltas coming from LoRAs. | |
| """ | |
| # Get program ID and compute offsets | |
| pid = tl.program_id(axis=0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| # Load data | |
| x = tl.load(x_ptr + offsets, mask=mask) | |
| rand_vals = tl.rand(seed, offsets) - 0.5 | |
| x = tl.cast(x, tl.float16) | |
| delta = tl.load(output_ptr + offsets, mask=mask) | |
| delta = tl.cast(delta, tl.float16) | |
| x = x + delta | |
| x_bits = tl.cast(x, tl.int16, bitcast=True) | |
| # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for | |
| # normal numbers and -14 for subnormals. | |
| fp16_exponent_bits = (x_bits & 0x7C00) >> 10 | |
| fp16_normals = fp16_exponent_bits > 0 | |
| fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14) | |
| # Add the target dtype's exponent bias and clamp to the target dtype's exponent range. | |
| exponent = fp16_exponent + EXPONENT_BIAS | |
| MAX_EXPONENT = 2 * EXPONENT_BIAS + 1 | |
| exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent) | |
| exponent = tl.where(exponent < 0, 0, exponent) | |
| # Normal ULP exponent, expressed as an fp16 exponent field: | |
| # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15 | |
| # Simplifies to: fp16_exponent - MANTISSA_BITS + 15 | |
| # See https://en.wikipedia.org/wiki/Unit_in_the_last_place | |
| eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15)) | |
| # Calculate epsilon in the target dtype | |
| eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True) | |
| # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) -> | |
| # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 = | |
| # 16 - EXPONENT_BIAS - MANTISSA_BITS | |
| eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True) | |
| eps = tl.where(exponent > 0, eps_normal, eps_subnormal) | |
| # Apply zero mask to epsilon | |
| eps = tl.where(x == 0, 0.0, eps) | |
| # Apply stochastic rounding | |
| output = tl.cast(x + rand_vals * eps, tl.bfloat16) | |
| # Store the result | |
| tl.store(output_ptr + offsets, output, mask=mask) | |